In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import os
from datetime import datetime
from im2latex import build_training_model
from data_loader import get_dataset

print(tf.__version__) # 2.8.0
print(tf.keras.__version__) # 2.8.0
print(np.__version__) # 1.21.6
tf.executing_eagerly() # True


2.8.2
2.8.0
1.21.6
2.8.2
2.8.0
1.21.6


True

# Load model

In [None]:
saved_model = None
# saved_model = r'./checkpoints/cp-0015.ckpt'
im2latex_model = build_training_model(saved_model)

# Load data

In [None]:
train_batch = 20
max_seq_len = 150
def filter_func(x, y):
  return x[0][1].shape[-1] <= max_seq_len
train_dataset = get_dataset(npy_path=r'./datasets/train_buckets.npy',
                        image_path=r'./images',
                        batch_size=train_batch,
                        filter_predicate=filter_func)
val_dataset = get_dataset(npy_path=r'./datasets/valid_buckets.npy',
                        image_path=r'./images',
                        batch_size=train_batch,
                        filter_predicate=filter_func)

# Train model

In [None]:
checkpoint_path = r"./train_checkpoints/cp-{epoch:04d}.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True, # True
                                                 verbose=0,
                                                 save_freq='epoch')                    

tbcallback = tf.keras.callbacks.TensorBoard(
    log_dir='./tb_logs', histogram_freq=0, write_graph=False,
    write_images=True, update_freq=100, profile_batch=0,
    embeddings_freq=0, embeddings_metadata=None)


class CustomCallback(keras.callbacks.Callback):
    
    def __init__(self, **kwargs):
        self.learning_rate = 0.1
        self.train_losses = []
        self.val_losses = []
        self.best_perp = np.iinfo(np.int32).max
        self.step_count = 0
        super(CustomCallback, self).__init__(**kwargs)
    
    def on_epoch_begin(self, epoch, logs=None):
        print(datetime.now().strftime("\n%d/%m/%Y %H:%M:%S"))
        self.model.optimizer.learning_rate.assign(self.learning_rate)
        print('lr =', self.model.optimizer.learning_rate.numpy(), 'optimizer =', self.model.optimizer)
        
    def on_epoch_end(self, epoch, logs=None):
        mean_loss_train = np.mean(self.train_losses)
        mean_perp_train = np.mean(list(map(lambda x: np.power(np.e,x), self.train_losses)))
        print("Mean train loss:", mean_loss_train,",Mean train perplexity:", mean_perp_train)
        mean_loss_val = np.mean(self.val_losses)
        mean_perp_val = np.mean(list(map(lambda x: np.power(np.e,x), self.val_losses)))
        mean_perp_val = np.round(mean_perp_val, 2)
        print("Mean val loss:", mean_loss_val,",Mean val perplexity:", mean_perp_val)
        if mean_perp_val < self.best_perp:
            self.best_perp = mean_perp_val
        else:
            self.learning_rate = self.model.optimizer.learning_rate.numpy() / 2
            print("learning rate reduced to", self.learning_rate)
        print("Best perplexity:", self.best_perp)
        self.train_losses = []
        self.val_losses = []
        
    def on_train_batch_end(self, batch, logs=None):
        self.train_losses.append(logs['loss'])
        self.step_count += 1
        
    def on_test_batch_end(self, batch, logs=None):
        self.val_losses.append(logs['loss'])


custom_callback = CustomCallback()

In [None]:
# On restarting
# 1. Check learning rate
# 2. Check initial epoch and epochs
# 3. Update best perplexity
# 4. Check loading recent checkpoint file

initial_epoch = 0

In [None]:
epochs = 15
im2latex_model.fit(train_dataset, steps_per_epoch=None, # steps_per_epoch=None -> till dataset is exhausted
                          epochs=initial_epoch+epochs, initial_epoch=initial_epoch,
                          validation_data=val_dataset, validation_steps=None, # validation_steps=None -> till dataset is exhausted
                          callbacks=[custom_callback, cp_callback, tbcallback])
initial_epoch = initial_epoch+epochs