In [2]:
import os
import json
import gc

import tensorflow as tf

from .model.destr_model import ObjDetSplitTransformer, train_one_step, validate
from .utils.data_loader import load_data_tfrecord

ImportError: attempted relative import with no known parent package

In [None]:
with open('./config.json', 'r') as fin:
    config = json.load(fin)

In [None]:
#tf.keras.mixed_precision.set_global_policy("mixed_float16")

In [None]:
destr_block = ObjDetSplitTransformer(input_shape=config['model']['input_shape'], num_cls=config['model']['num_class'])

img = tf.keras.Input(shape=config['model']['input_shape'], dtype=tf.float32)
cls_output, reg_output, total_proposals = destr_block(img)

model = tf.keras.Model(inputs=img, outputs=[cls_output, reg_output, total_proposals])

In [None]:
# Load the weights from latest checkpoint

checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=config['paths']['to_checkpoint'], max_to_keep=1)

load_from_ckpt = False
if load_from_ckpt:
    status = checkpoint.restore(checkpoint_manager.latest_checkpoint)

In [None]:
loss_history = {'train_loss': (0, 0), 'valid_loss': (0, 0)}
optimizer = tf.keras.optimizers.Adam(learning_rate=config['train']['learning_rate'])
full_dataset = load_data_tfrecord(path_to_tfrecord=config['paths']['to_dataset_local'])

train_progress_bar = tf.keras.utils.Progbar(config['train']['num_train_samples'])
valid_progress_bar = tf.keras.utils.Progbar(config['train']['num_valid_samples'])

for epoch_idx in range(config['train']['num_epochs']):
    dataset = full_dataset.shuffle(buffer_size=config['train']['shuffle_buffer']).batch(batch_size=config['train']['batch_size'], drop_remainder=True)

    train_dataset = dataset.take(count=config['train']['num_train_samples']).prefetch(buffer_size=tf.data.AUTOTUNE)
    valid_dataset = dataset.skip(count=config['train']['num_train_samples']).prefetch(buffer_size=tf.data.AUTOTUNE)

    total_md_loss, total_loss, step = 0, 0, 0
    for batch in train_dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = train_one_step(
                model, optimizer, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), shape=[-1]+config['model']['input_shape']), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()

        step += 1
        train_progress_bar.update(step)
        if step == config['train']['num_train_samples']:
            avg_mini_det_loss = total_md_loss / config['train']['num_train_samples']
            avg_model_loss = total_loss / config['train']['num_train_samples']
            break
    loss_history['train_loss'] = (avg_mini_det_loss, avg_model_loss)

    del dataset, train_dataset
    gc.collect()

    total_md_loss, total_loss, step = 0, 0, 0
    for batch in valid_dataset:
        logits, coord, label, oh_label = batch
        
        mini_det_loss, model_loss = validate(
                model, 
                tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), shape=[-1] + config['model']['input_shape']), 
                tf.concat([label[..., tf.newaxis], oh_label, coord], axis=-1)
            )
        total_md_loss += mini_det_loss.numpy()
        total_loss += model_loss.numpy()

        step += 1
        valid_progress_bar.update(step)
        if step == config['train']['num_valid_samples']:
            avg_mini_det_loss = total_md_loss / config['train']['num_valid_samples']
            avg_model_loss = total_loss / config['train']['num_valid_samples']
            break
    loss_history['valid_loss'] = (avg_mini_det_loss, avg_model_loss)
    
    # Save parameters after each epoch
    checkpoint_manager.save()
    
    del valid_dataset
    tf.keras.backend.clear_session()
    gc.collect()

    print(f'''epoch {epoch_idx+1:>2}: \n
          \t train_loss: {loss_history["train_loss"][0]:.4f} {loss_history["train_loss"][1]:.4f},
          \t valid loss: {loss_history["valid_loss"][0]:.4f} {loss_history["valid_loss"][1]:.4f}''')
    with open(config['paths']['to_loss_records'], mode='w') as fout:
        print(f'''epoch {epoch_idx+1:>2}: \n
          \t train_loss: {loss_history["train_loss"][0]:.4f} {loss_history["train_loss"][1]:.4f},
          \t valid loss: {loss_history["valid_loss"][0]:.4f} {loss_history["valid_loss"][1]:.4f}''', file=fout)

In [None]:
#model.load_weights('/workspace/models/destr_20')
#weights = model.get_weight()
#model.set_weights(weights)