In [1]:
import os

import cv2
import tensorflow as tf
import numpy as np
import pandas as pd

from model.destr_model import ObjDetSplitTransformer, train_one_step, validate
from utils.data_loader import load_data_tfrecord

2024-07-05 21:27:34.186582: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-05 21:27:34.189152: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-05 21:27:34.197051: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-05 21:27:34.211574: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-05 21:27:34.211604: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-05 21:27:34.222208: I tensorflow/core/platform/cpu_feature_guard.cc:

In [2]:
num_cls = 8
EPOCH_NUMS = 20
BATCH_SIZE = 8
checkpoint_dir = '/workspace/models/checkpoints'

train_size = 5000
valid_size = 500

load_from_ckpt = False

In [3]:
destr_block = ObjDetSplitTransformer(input_shape=(224, 224, 3), num_cls=num_cls)

img = tf.keras.Input(shape=(224, 224, 3), dtype=tf.float32)
cls_output, reg_output, total_proposals = destr_block(img)

model = tf.keras.Model(inputs=img, outputs=[cls_output, reg_output, total_proposals])



In [4]:
# Load the weights from latest checkpoint

if load_from_ckpt:
    checkpoint = tf.train.Checkpoint(model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [5]:
local_dir = '/media/daniel/DatasetIMDB/imdb_chunks'
container_dir = '/workspace/data/tfrecords'

full_dataset = load_data_tfrecord(path_to_tfrecord=local_dir)

In [6]:
loss_history = {'train_loss': [], 'valid_loss': []}
optimizers = {'mini_det': tf.keras.optimizers.Adam(learning_rate=0.00001), 'destr': tf.keras.optimizers.Adam(learning_rate=0.00001)}
optimizer = tf.keras.optimizers.Adam(learning_rate=0.000001)

for epoch_idx in range(10, EPOCH_NUMS+10):
    dataset = full_dataset.shuffle(buffer_size=5000).batch(batch_size=BATCH_SIZE, drop_remainder=True)

    epoch_dataset = dataset.take(count=train_size+valid_size)
    train_dataset = epoch_dataset.take(count=train_size)
    valid_dataset = epoch_dataset.skip(count=train_size)

    total_md_loss, total_loss, step = 0, 0, 0
    for batch in train_dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = train_one_step(
                model, optimizer, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), (-1, 224, 224, 3)), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()

        step += 1
        if step == train_size:
            break
    loss_history['train_loss'].append((total_md_loss / train_size, total_loss / train_size))

    total_md_loss, total_loss, step = 0, 0, 0
    for batch in valid_dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = validate(
                model, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), (-1, 224, 224, 3)), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()

        step += 1
        if step == valid_size:
            break
    loss_history['valid_loss'].append((total_md_loss / valid_size, total_loss / valid_size))
    
    # Save parameters after each epoch
    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint_prefix = os.path.join(checkpoint_dir, f'ckpt_{epoch_idx}')
    checkpoint.save(file_prefix=checkpoint_prefix)
    
    print(f'''epoch {epoch_idx+1:>2}: \n
          \t train_loss: {loss_history["train_loss"][-1][0]:.4f} {loss_history["train_loss"][-1][1]:.4f},
          \t valid loss: {loss_history["valid_loss"][-1][0]:.4f} {loss_history["valid_loss"][-1][1]:.4f}''')

epoch 11: 

          	 train_loss: 1.1565 1.6737,
          	 valid loss: 1.0775 1.6343
epoch 12: 

          	 train_loss: 1.0712 1.5256,
          	 valid loss: 1.0024 1.4128
epoch 13: 

          	 train_loss: 0.9998 1.3594,
          	 valid loss: 0.9371 1.3440
epoch 14: 

          	 train_loss: 0.9215 1.3175,
          	 valid loss: 0.8864 1.3470
epoch 15: 

          	 train_loss: 0.8883 1.3861,
          	 valid loss: 0.8180 1.3234
epoch 16: 

          	 train_loss: 0.8222 1.3438,
          	 valid loss: 0.8059 1.2797
epoch 17: 

          	 train_loss: 0.7994 1.2916,
          	 valid loss: 0.7466 1.2310
epoch 18: 

          	 train_loss: 0.7840 1.1898,
          	 valid loss: 0.7262 1.2290
epoch 19: 

          	 train_loss: 0.7264 1.1851,
          	 valid loss: 0.7638 1.1876
epoch 20: 

          	 train_loss: 0.7203 1.2228,
          	 valid loss: 0.7063 1.1493


In [7]:
#model.load_weights('/workspace/models/destr_20')
#weights = model.get_weight()
#model.set_weights(weights)