In [None]:
import os

import cv2
import tensorflow as tf
import numpy as np
import pandas as pd

from model.destr_model import ObjDetSplitTransformer, train_one_step
from utils.data_loader import load_data_tfrecord

In [None]:
num_cls = 8
EPOCH_NUMS = 10
BATCH_SIZE = 8
checkpoint_dir = '/workspace/models/checkpoints'

load_from_ckpt = False

In [None]:
destr_block = ObjDetSplitTransformer(input_shape=(224, 224, 3), num_cls=num_cls)

img = tf.keras.Input(shape=(224, 224, 3), dtype=tf.float32)
cls_output, reg_output, total_proposals = destr_block(img)

model = tf.keras.Model(inputs=img, outputs=[cls_output, reg_output, total_proposals])

In [None]:
# Load the weights from latest checkpoint

if load_from_ckpt:
    checkpoint = tf.train.Checkpoint(model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
local_dir = '/media/daniel/DatasetIMDB/imdb_chunks'
container_dir = '/workspace/data/tfrecords'

dataset = load_data_tfrecord(path_to_tfrecord=local_dir)

In [None]:
loss_history = []
optimizers = {'mini_det': tf.keras.optimizers.Adam(learning_rate=0.00001), 'destr': tf.keras.optimizers.Adam(learning_rate=0.00001)}
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00001)

for epoch_idx in range(10, EPOCH_NUMS+10):
    total_md_loss, total_loss, cnt = 0, 0, 0
    for batch in dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = train_one_step(
                model, optimizer, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), (-1, 224, 224, 3)), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()
        cnt += 1
        

    loss_history.append((total_md_loss / cnt, total_loss / cnt))
    # Save parameters of model each epoch
    #checkpoint = tf.train.Checkpoint(model=model)
    #checkpoint_prefix = os.path.join(checkpoint_dir, f'ckpt_{epoch_idx}')
    #checkpoint.save(file_prefix=checkpoint_prefix)
    
    print(f'{epoch_idx+1}: {loss_history[-1]}')

In [None]:
#model.load_weights('/workspace/models/destr_20')
#weights = model.get_weight()
#model.set_weights(weights)