In [3]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Conv2D, Input, MaxPool2D, BatchNormalization, LSTM, concatenate, Softmax, RNN, ReLU, Dense
from keras.layers import Lambda
import numpy as np
import random
import os
import cv2
from datetime import datetime

print(tf.__version__) # 2.4.1
print(tf.keras.__version__) # 2.4.0
print(np.__version__) # 1.19.5

tf.executing_eagerly() # True

In [213]:
H = 512
W = 512
C = 1
vocab_size = 504
embedding_dim = 80
ENC_DIM = 256 # Hidden state dimension of encoder RNN
DEC_DIM = 512 # Hidden state dimension of decoder RNN

# **Define all layers in the model**

In [214]:
layers = {}

layers['conv1'] = Conv2D(filters=64, kernel_size=[3, 3], padding='same', activation='relu')
layers['maxpool1'] = MaxPool2D(pool_size=[2, 2], strides=[2, 2])
layers['conv2'] = Conv2D(filters=128, kernel_size=[3, 3], padding='same', activation='relu')
layers['maxpool2'] = MaxPool2D(pool_size=[2, 2], strides=[2, 2])
layers['conv3'] = Conv2D(filters=256, kernel_size=[3, 3], padding='same', activation='relu')
layers['bn1'] = BatchNormalization()
layers['conv4'] = Conv2D(filters=256, kernel_size=[3, 3], padding='same', activation='relu')
layers['maxpool3'] = MaxPool2D(pool_size=[1, 2], strides=[1, 2])
layers['conv5'] = Conv2D(filters=512, kernel_size=[3, 3], padding='same', activation='relu')
layers['bn2'] = BatchNormalization()
layers['maxpool4'] = MaxPool2D(pool_size=[2, 1], strides=[2, 1])
layers['conv6'] = Conv2D(filters=512, kernel_size=[3, 3], padding='same', activation='relu')
layers['bn3'] = BatchNormalization()


class EncoderCell(keras.layers.Layer):
    '''
    Splits the convolution output vertically along height (dim == 1) and
    runs RNN on each vertical cross section of conv output
    '''
    def __init__(self, encoder, state_size, output_size, **kwargs):
        self.encoder = encoder
        self.state_size = state_size
        self.output_size = output_size
        super(EncoderCell, self).__init__(**kwargs)
 
    def build(self, input_shape):
        self.built = True
 
    def call(self, inputs, states):
        output = self.encoder(inputs)
        return output, states


encoder_fw_cell = EncoderCell(LSTM(ENC_DIM, return_sequences=True), state_size=tf.TensorShape([1]), output_size=tf.TensorShape([None, ENC_DIM]))
encoder_bw_cell = EncoderCell(LSTM(ENC_DIM, return_sequences=True, go_backwards=True), state_size=tf.TensorShape([1]), output_size=tf.TensorShape([None, ENC_DIM]))

layers['encoder_fw'] = RNN(encoder_fw_cell, return_sequences=True)
layers['encoder_bw'] = RNN(encoder_bw_cell, return_sequences=True)

class AttentionCell(keras.layers.Layer):
    '''
    Bahdanau attention cell defined in https://arxiv.org/abs/1609.04938
    '''
    def __init__(self, input_embedding_size, decoder_out_shape, state_size, output_size, **kwargs):
        self.input_embedding_size = input_embedding_size
        self.decoder_out_shape = decoder_out_shape
        self.state_size = state_size
        self.output_size = output_size # vocab_size
        super(AttentionCell, self).__init__(**kwargs)
                
    def build(self, input_shape):
        
        self.gates = self.add_weight(shape=(self.input_embedding_size[0]+512, 4*512),
                                  initializer=tf.keras.initializers.GlorotUniform(),
                                  trainable=True,
                                  name='gates')  # (80+512, 4*512)
        self.gates_bias = self.add_weight(shape=(1, 4*512),
                                  initializer='zeros',
                                  trainable=True,
                                  name='gates_bias')  # (1, 4*512)
        self.Wa = self.add_weight(shape=(self.decoder_out_shape[1], self.decoder_out_shape[1]),
                                  initializer=tf.keras.initializers.GlorotUniform(),
                                  trainable=True,
                                  name='Wa')  # (512, 512)
        self.Wc = self.add_weight(shape=(self.decoder_out_shape[1]*2, self.decoder_out_shape[1]),
                                  initializer=tf.keras.initializers.GlorotUniform(),
                                  trainable=True,
                                  name='Wc')  # (512 + 512, 512) => (ENC_DIM*2 + DEC_DIM, DEC_DIM)
        self.Ws = self.add_weight(shape=(self.decoder_out_shape[1], self.output_size[0]),
                                  initializer=tf.keras.initializers.GlorotUniform(),
                                  trainable=True,
                                  name='Ws')  # (512, vocab_size)
        self.built = True
        
    def call(self, inputs, states):
        
        hs = states[1]
        _, c_tm1, output_tm1 = tf.split(axis=-1, num_or_size_splits=3, value=states[0])
        output_tm1 = tf.squeeze(output_tm1, axis=1)
        xt = concatenate([inputs, output_tm1], axis=-1)
        gates_out = tf.linalg.matmul(xt, self.gates) + self.gates_bias
        i_t, f_t, o_t, g_t = tf.split(axis=-1, num_or_size_splits=4, value=gates_out)
        
        c_t = tf.math.sigmoid(f_t)*c_tm1 + tf.math.sigmoid(i_t)*tf.tanh(g_t)
        h_t = tf.math.sigmoid(o_t)*tf.tanh(c_t)
        
        h_t = tf.expand_dims(h_t, axis=-1)
        Wa_ht = tf.linalg.matmul(self.Wa, h_t)
        score = tf.linalg.matmul(hs, Wa_ht)
        score = tf.squeeze(score, axis=-1)
        
        at = Softmax(axis=-1)(score)
        at = tf.expand_dims(at, axis=-2)
        ct = tf.linalg.matmul(at, hs)

        h_t = tf.squeeze(h_t, axis=-1)
        h_t = tf.expand_dims(h_t, axis=-2)
        ht_bar = tf.math.tanh(tf.linalg.matmul(tf.concat([ct, h_t], axis=-1), self.Wc))
        Ws_ht_bar = tf.linalg.matmul(ht_bar, self.Ws)

        output = tf.squeeze(Ws_ht_bar, axis=-2)
        h_t = tf.squeeze(h_t, axis=-2)
        ht_bar = tf.squeeze(ht_bar, axis=-2)

        return output, [concatenate([h_t, c_t, ht_bar], axis=-1), hs]
    

layers['embedding'] = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim,
                                               embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=1.0/np.sqrt(vocab_size)))

layers['attention_cell'] = AttentionCell(input_embedding_size=tf.TensorShape([embedding_dim]), 
                            decoder_out_shape=tf.TensorShape([None, DEC_DIM]),
                            state_size=[tf.TensorShape([1, DEC_DIM*3]), tf.TensorShape([None, ENC_DIM*2])], #tf.TensorShape([DEC_DIM*3]), 
                            output_size=tf.TensorShape([vocab_size]))

layers['attention_layer'] = RNN(layers['attention_cell'], return_sequences=True)

# Build the model with the layers

In [215]:
def build_model(image, latex_seq, decoder_initial_state, encoder_hid_st_input=None):
    # encoder
    img = image-128
    img = img/128

    x = layers['conv1'](img)
    x = layers['maxpool1'](x)
    # x -> (H/2, W/2, 64)

    x = layers['conv2'](x)
    x = layers['maxpool2'](x)
    # x -> (H/4, W/4, 128)

    x = layers['conv3'](x)
    x = layers['bn1'](x)
    # x -> (H/4, W/4, 256)

    x = layers['conv4'](x)
    x = layers['maxpool3'](x)
    # x -> (H/4, W/8, 256)

    x = layers['conv5'](x)
    x = layers['bn2'](x)
    x = layers['maxpool4'](x)
    # x -> (H/8, W/8, 512)

    x = layers['conv6'](x)
    x = layers['bn3'](x)
    # x -> (H/8, W/8, 512)
    
    encoder_fw_hid_st = layers['encoder_fw'](x)
    encoder_fw_hid_st = tf.reshape(encoder_fw_hid_st,[tf.shape(encoder_fw_hid_st)[0],-1,tf.shape(encoder_fw_hid_st)[-1]])
    
    encoder_bw_hid_st = layers['encoder_bw'](x)
    encoder_bw_hid_st = tf.reshape(encoder_bw_hid_st,[tf.shape(encoder_bw_hid_st)[0],-1,tf.shape(encoder_bw_hid_st)[-1]])

    encoder_hid_st = concatenate([encoder_fw_hid_st, encoder_bw_hid_st], axis=-1)
    
    # decoder
    if encoder_hid_st_input is None:
        latex_emb = layers['embedding'](latex_seq)
        logits = layers['attention_layer'](latex_emb, [decoder_initial_state, encoder_hid_st])
        return keras.Model(inputs=[image, latex_seq, decoder_initial_state], outputs=logits)
    else:
        latex_emb = layers['embedding'](latex_seq)
        decoder_hid_st = layers['decoder'](latex_emb)
        latex_pred = layers['attention_layer'](decoder_hid_st, encoder_hid_st_input)
        return keras.Model(inputs=image, outputs=encoder_hid_st), keras.Model(inputs=[latex_seq, encoder_hid_st_input], outputs=latex_pred)


In [216]:
load_model = input('Load model? (y/n):') == 'y'
stage2_training_model = build_model(Input(shape=(None, None, C)), 
                                    Input(shape=tf.TensorShape([None])),
                                    Input(shape=(None, DEC_DIM*3)),
                                    None)
if load_model:
    stage2_training_model.load_weights('model_checkpoints_1/cp-0001.ckpt')
    print("x----------Model loaded----------x")

lr = 0.1
clipnorm = 5.0
optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=clipnorm)
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
stage2_training_model.compile(optimizer=optimizer, loss=loss_func, metrics=['accuracy', 'crossentropy'])
stage2_training_model.optimizer.learning_rate.assign(lr)

print('lr:', stage2_training_model.optimizer.learning_rate.numpy())
print('clipnorm:', stage2_training_model.optimizer.clipnorm)
print('optimizer:', stage2_training_model.optimizer)

stage2_inference_encoder_model, stage2_inference_decoder_model = build_model(Input(shape=(H, W, C), batch_size=1), 
                                                                             Input(shape=tf.TensorShape([5]), batch_size=1),
                                                                             Input(shape=(None, DEC_DIM*3))
                                                                             Input(shape=(None, 512), batch_size=1))
stage2_inference_encoder_model.compile()
stage2_inference_decoder_model.compile()
print("x----------Model built----------x")

Load model? (y/n):n
lr: 0.1
clipnorm: 5.0
optimizer: <keras.optimizer_v2.adam.Adam object at 0x000002231A988D30>
x----------Model built----------x


In [217]:
# stage2_training_model.summary()

# Load data

In [4]:
train_batch = 1

tfr_description = {
        'image': tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True),
        'latex_seq_in': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
        'latex_seq_out': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
}

def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, tfr_description)

def filter_long_seqs(sample_input, sample_output):
    return tf.shape(sample_output)[-1] < 120

def get_data(tfr_dir, subset, batch_size, filter_func=None):
    files = tf.io.matching_files(os.path.join(tfr_dir, '{}_100K_??of??.tfrecord'.format(subset)))
    dataset = tf.data.TFRecordDataset(files)
    dataset = dataset.map(_parse_image_function)
    dataset = dataset.map(lambda sample: ((tf.image.decode_jpeg(sample['image'][0]), sample['latex_seq_in']), sample['latex_seq_out']))
    dataset = dataset.map(lambda input_, _: ((tf.cast(input_[0], dtype=tf.float32), tf.cast(input_[1], dtype=tf.float32)), _))
    if filter_func is not None:
        dataset = dataset.filter(filter_func)

    dataset = dataset.padded_batch(batch_size, padding_values=((np.array(255, dtype=np.float32), np.array(0, dtype=np.float32)), np.array(0, dtype=np.int64)))
    return dataset

train_dataset = get_data(r'.\100K_tfrecords_v3', 'train', batch_size=train_batch, filter_func=filter_long_seqs)
val_dataset = get_data(r'.\100K_tfrecords_v3', 'val', batch_size=train_batch)
test_dataset = get_data(r'.\100K_tfrecords_v3', 'test', batch_size=train_batch)

# Check loaded data

In [219]:
for batch in val_dataset:#.take(1):
    img = batch[0][0]
    latex_seq_in = batch[0][1]
    latex_seq_out = batch[1]
    cv2.imshow('img', img.numpy()[0].astype(np.uint8))
    key_press = cv2.waitKey(0)
    if key_press & 0xFF == ord('q'):
        cv2.destroyAllWindows()
        break

In [220]:
cv2.destroyAllWindows()

# Train model

In [22]:
initial_epoch = 0
checkpoint_path = r"./model_checkpoints/cp-{epoch:04d}.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True, # True
                                                 verbose=0,
                                                 save_freq=100)#'epoch')                    

tbcallback = tf.keras.callbacks.TensorBoard(
    log_dir='./tb_logs', histogram_freq=0, write_graph=False,
    write_images=True, update_freq=100, profile_batch=0,
    embeddings_freq=0, embeddings_metadata=None)


class CustomCallback(keras.callbacks.Callback):
    
    def __init__(self, **kwargs):
        self.train_losses = []
        self.val_losses = []
        self.best_perp = np.iinfo(np.int32).max
        self.step_count = 0
        super(CustomCallback, self).__init__(**kwargs)
        
    def sum_of_gradients(self):
        image = [abs(np.random.normal(size=(512, 512, 1))) for _ in range(1)]
        img = tf.convert_to_tensor(image, dtype=tf.float32)
        label = np.ones([1, 10])
        support_class = tf.convert_to_tensor(label, dtype=tf.int64)
        loss_fn = tf.losses.CategoricalCrossentropy()
        grad_sum = None
        with tf.GradientTape(persistent=True) as tape:
            tape.watch(img)
            softmaxed = self.model((img, label))
            loss = loss_fn(tf.one_hot(support_class, axis=-1, depth=504), softmaxed)
        grads = tape.gradient(loss, img, unconnected_gradients=tf.UnconnectedGradients.NONE)
        grad_sum = tf.reduce_sum(grads, axis=None).numpy()
        
        return grad_sum
    
    def on_epoch_begin(self, epoch, logs=None):
        print(datetime.now().strftime("%d/%m/%Y %H:%M:%S"))
        print('lr =', self.model.optimizer.learning_rate.numpy(), 'optimizer =', self.model.optimizer)
        
    def on_epoch_end(self, epoch, logs=None):
        mean_loss_train = np.mean(self.train_losses)
        mean_perp_train = np.mean(list(map(lambda x: np.power(np.e,x), self.train_losses)))
        print("Mean train loss:", mean_loss_train,",Mean train perplexity:", mean_perp_train)
        mean_loss_val = np.mean(self.val_losses)
        mean_perp_val = np.mean(list(map(lambda x: np.power(np.e,x), self.val_losses)))
        print("Mean val loss:", mean_loss_val,",Mean val perplexity:", mean_perp_val)
        if mean_perp_val < self.best_perp:
            self.best_perp = mean_perp_val
        else:
            self.model.optimizer.learning_rate.assign(self.model.optimizer.learning_rate.numpy() / 2)
        print("Best perplexity:", self.best_perp)
        self.train_losses = []
        self.val_losses = []
        
    def on_train_batch_end(self, batch, logs=None):
        self.train_losses.append(logs['loss'])
        self.step_count += 1
        
    def on_test_batch_end(self, batch, logs=None):
        self.val_losses.append(logs['loss'])

custom_callback = CustomCallback()

In [1]:
epochs = 12
stage2_training_model.fit(train_dataset, steps_per_epoch=None, # steps_per_epoch=None -> till dataset is exhausted
                          epochs=initial_epoch+epochs, initial_epoch=initial_epoch,
                          validation_data=val_dataset, validation_steps=None, # validation_steps=None -> till dataset is exhausted
                          callbacks=[cp_callback, tbcallback, custom_callback])
initial_epoch = initial_epoch+epochs

In [2]:
def bleu_metric(y_true, y_pred):
    y_true = np.argmax(y_true, axis=-1)
    y_true = np.expand_dims(y_true, axis=[1])
    
    y_pred = np.argmax(y_pred, axis=-1)
    
    return corpus_bleu(y_true.tolist(), y_pred.tolist())
    

bleu_list = []
test_itr = iter(test_dataset)

for _ in range(test_size//train_batch):
    batch = next(test_itr)
    prediction = stage2_training_model.predict(batch[0])
    bleu_list.append(bleu_metric(batch[1].numpy(), prediction))

bleu = np.mean(bleu_list)
print(bleu)

# Save the inference models

In [None]:
tf.saved_model.save(stage2_inference_encoder_model, "stage2_inference_encoder_model")

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(stage2_inference_decoder_model)
stage2_inference_decoder_model_tflite = converter.convert()
with open('stage2_inference_decoder_model_tflite.tflite', 'wb') as f:
  f.write(stage2_inference_decoder_model_tflite)