In [2]:
import os
import json
import gc

import cv2
import tensorflow as tf

from model.destr_model import ObjDetSplitTransformer, train_one_step, validate
from utils.data_loader import load_data_tfrecord

2024-07-06 20:47:00.672022: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-06 20:47:00.674367: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-06 20:47:00.681805: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-06 20:47:00.695828: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-06 20:47:00.695858: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-06 20:47:00.705723: I tensorflow/core/platform/cpu_feature_guard.cc:

In [9]:
with open('./config.json', 'r') as fin:
    config = json.load(fin)

In [None]:
destr_block = ObjDetSplitTransformer(input_shape=config['model']['input_shape'], num_cls=config['model']['num_class'])

img = tf.keras.Input(shape=config['model']['input_shape'], dtype=tf.float32)
cls_output, reg_output, total_proposals = destr_block(img)

model = tf.keras.Model(inputs=img, outputs=[cls_output, reg_output, total_proposals])

In [None]:
# Load the weights from latest checkpoint

load_from_ckpt = False
if load_from_ckpt:
    checkpoint = tf.train.Checkpoint(model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(config['paths']['to_checkpoint']))

In [None]:
local_dir = '/media/daniel/DatasetIMDB/imdb_chunks'
container_dir = '/workspace/data/tfrecords'

full_dataset = load_data_tfrecord(path_to_tfrecord=config['paths']['to_dataset_local'])

In [None]:
loss_history = {'train_loss': [], 'valid_loss': []}
optimizer = tf.keras.optimizers.Adam(learning_rate=config['train']['learning_rate'])

for epoch_idx in range(config['train']['num_epochs']):
    train_progress_bar = tf.keras.utils.Progbar(config['train']['num_train_samples'])
    valid_progress_bar = tf.keras.utils.Progbar(config['train']['num_valid_samples'])

    dataset = full_dataset.shuffle(buffer_size=config['train']['shuffle_buffer']).batch(batch_size=config['train']['batch_size'], drop_remainder=True)

    epoch_dataset = dataset.take(count=config['train']['num_train_samples']+config['train']['num_valid_samples'])
    train_dataset = epoch_dataset.take(count=config['train']['num_train_samples'])
    valid_dataset = epoch_dataset.skip(count=config['train']['num_train_samples'])

    train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    valid_dataset = valid_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    total_md_loss, total_loss, step = 0, 0, 0
    for batch in train_dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = train_one_step(
                model, optimizer, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), shape=[-1]+config['model']['input_shape']), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()

        step += 1
        train_progress_bar.update(step)
        if step == config['train']['num_train_samples']:
            break
    loss_history['train_loss'].append(
        (total_md_loss/config['train']['num_train_samples'], 
         total_loss/config['train']['num_train_samples'])
    )

    gc.collect()

    total_md_loss, total_loss, step = 0, 0, 0
    for batch in valid_dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = validate(
                model, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), shape=[-1] + config['model']['input_shape']), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()

        step += 1
        valid_progress_bar.update(step)
        if step == config['train']['num_valid_samples']:
            break
    loss_history['valid_loss'].append(
        (total_md_loss/config['train']['num_valid_samples'], 
         total_loss/config['train']['num_valid_samples'])
    )
    
    # Save parameters after each epoch
    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint_prefix = os.path.join(config['paths']['to_checkpoint'], f'ckpt_{epoch_idx}')
    checkpoint.save(file_prefix=checkpoint_prefix)

    gc.collect()
    tf.keras.backend.clear_session()

    latest_ckpt = tf.train.latest_checkpoint(config['paths']['to_checkpoint'])
    checkpoint.restore(latest_ckpt)
    
    print(f'''epoch {epoch_idx+1:>2}: \n
          \t train_loss: {loss_history["train_loss"][-1][0]:.4f} {loss_history["train_loss"][-1][1]:.4f},
          \t valid loss: {loss_history["valid_loss"][-1][0]:.4f} {loss_history["valid_loss"][-1][1]:.4f}''')
    with open(config['paths']['to_loss_records'], mode='w') as fout:
        print(f'''epoch {epoch_idx+1:>2}: \n
          \t train_loss: {loss_history["train_loss"][-1][0]:.4f} {loss_history["train_loss"][-1][1]:.4f},
          \t valid loss: {loss_history["valid_loss"][-1][0]:.4f} {loss_history["valid_loss"][-1][1]:.4f}''', file=fout)

In [None]:
#model.load_weights('/workspace/models/destr_20')
#weights = model.get_weight()
#model.set_weights(weights)