In [None]:
import os

import cv2
import tensorflow as tf
import numpy as np
import pandas as pd

from model.destr_model import ObjDetSplitTransformer, train_one_step, validate
from utils.data_loader import load_data_tfrecord

In [None]:
num_cls = 8
EPOCH_NUMS = 20
BATCH_SIZE = 8
checkpoint_dir = '/workspace/models/checkpoints'

train_size = 5000
valid_size = 500

load_from_ckpt = False

In [None]:
destr_block = ObjDetSplitTransformer(input_shape=(224, 224, 3), num_cls=num_cls)

img = tf.keras.Input(shape=(224, 224, 3), dtype=tf.float32)
cls_output, reg_output, total_proposals = destr_block(img)

model = tf.keras.Model(inputs=img, outputs=[cls_output, reg_output, total_proposals])

In [None]:
# Load the weights from latest checkpoint

if load_from_ckpt:
    checkpoint = tf.train.Checkpoint(model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
local_dir = '/media/daniel/DatasetIMDB/imdb_chunks'
container_dir = '/workspace/data/tfrecords'

full_dataset = load_data_tfrecord(path_to_tfrecord=local_dir)

In [None]:
loss_history = {'train_loss': [], 'valid_loss': []}
optimizers = {'mini_det': tf.keras.optimizers.Adam(learning_rate=0.00001), 'destr': tf.keras.optimizers.Adam(learning_rate=0.00001)}
optimizer = tf.keras.optimizers.Adam(learning_rate=0.000001)

for epoch_idx in range(10, EPOCH_NUMS+10):
    dataset = full_dataset.shuffle(buffer_size=5000).batch(batch_size=BATCH_SIZE, drop_remainder=True)

    epoch_dataset = dataset.take(count=train_size+valid_size)
    train_dataset = epoch_dataset.take(count=train_size)
    valid_dataset = epoch_dataset.skip(count=train_size)

    total_md_loss, total_loss, step = 0, 0, 0
    for batch in train_dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = train_one_step(
                model, optimizer, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), (-1, 224, 224, 3)), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()

        step += 1
        if step == train_size:
            break
    loss_history['train_loss'].append((total_md_loss / train_size, total_loss / train_size))

    total_md_loss, total_loss, step = 0, 0, 0
    for batch in valid_dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = validate(
                model, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), (-1, 224, 224, 3)), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()

        step += 1
        if step == valid_size:
            break
    loss_history['valid_loss'].append((total_md_loss / valid_size, total_loss / valid_size))
    
    # Save parameters after each epoch
    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint_prefix = os.path.join(checkpoint_dir, f'ckpt_{epoch_idx}')
    checkpoint.save(file_prefix=checkpoint_prefix)
    
    print(f'''epoch {epoch_idx+1:>2}: \n
          \t train_loss: {loss_history["train_loss"][-1][0]:.4f} {loss_history["train_loss"][-1][1]:.4f},
          \t valid loss: {loss_history["valid_loss"][-1][0]:.4f} {loss_history["valid_loss"][-1][1]:.4f}''')

In [None]:
#model.load_weights('/workspace/models/destr_20')
#weights = model.get_weight()
#model.set_weights(weights)