In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import pandas as pd
from components.positional import add_timing_signal_nd

In [25]:
def embedding_initializer(shape, dtype):
    E = tf.random.uniform(shape, minval=-1.0, maxval=1.0, dtype=dtype)
    E = tf.nn.l2_normalize(E, -1)
    return E

In [22]:
class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)

        hidden_with_time_axis = tf.expand_dims(hidden, 1)

        # score shape == (batch_size, 64, hidden_size)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))

        attention_weights = tf.nn.softmax(self.V(score), axis=1)

        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

In [32]:
class Decoder(tf.keras.Model):
    def __init__(self, embedding_dim, units, vocab_size, id_end):
        super(Decoder, self). __init__()
        self.units = units
        # comes from Vocab
        self._id_end = id_end

        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.lstm = tf.keras.layers.LSTMCell(self.units)

        self.fc1 = tf.keras.layers.Dense(self.units)
        self.fc2 = tf.keras.layers.Dense(vocab_size)

        self.attention = BahdanauAttention(self.units)

        self.start_token = tf.Variable(initial_value=embedding_initializer([embedding_dim], dtype=tf.float32), 
                                       dtype=tf.float32, shape=[embedding_dim],)

        self.decoder = tfa.seq2seq.BeamSearchDecoder(
            self.lstm, beam_width=3, output_layer=self.fc1, )

    def call(self, x, features, hidden):
        # defining attention as seperate model
        context_vector, attention_weights = self.attention(features, hidden)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(x)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passsing concatenated vector to LSTM
        output, state = self.decoder(x, start_tokens=self.start_token)

        # shape == (batch_size, max_length, hidden_size)
        # x = self.fc1(output)

        # x shape == (batch_size, max_length, hidden_size)
        x = tf.reshape(output, (-1, x.shape[2]))

        # output shape == (batch_size, * max_length, vocab)
        x = self.fc2(x)

        return x, state, attention_weights

    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.units))

In [5]:
class CNN_Encoder(tf.keras.Model):
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(padding='same'),
            tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(padding='same'),
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(pool_size=(2, 1), strides=(2, 1), padding='same'),
            tf.keras.layers.Conv2D(512, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(pool_size=(1, 2), strides=(1, 2), padding='same'),
            tf.keras.layers.Conv2D(512, 3, activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Lambda(add_timing_signal_nd),
            tf.keras.layers.Dense(embedding_dim),
        ])
    def call(self, x):
        x = self.encoder(x)
        x = tf.nn.relu(x)

        return x

In [6]:
from model.utils.data_generator import DataGenerator
from model.utils.general import Config
from model.utils.text import Vocab
from model.utils.image import greyscale

In [7]:
data = "configs/data_small.json"
vocab = "configs/vocab_small.json"
training = "configs/training_small.json"
model = "configs/model.json"
output = "results/small/"

In [8]:
config = Config([data, vocab, training, model])

In [9]:
vocab = Vocab(config)

In [10]:
train_set = DataGenerator(path_formulas=config.path_formulas_train, dir_images=config.dir_images_train,
                         img_prepro=greyscale, max_iter=config.max_iter, bucket=config.bucket_train,
                         path_matching=config.path_matching_train, max_len=config.max_length_formula,
                         form_prepro=vocab.form_prepro)

Loaded 10 formulas from data/small.formulas.norm.txt
Bucketing the dataset...
- done.


In [11]:
val_set = DataGenerator(path_formulas=config.path_formulas_val,
                       dir_images=config.dir_images_val, img_prepro=greyscale,
                       max_iter=config.max_iter, bucket=config.bucket_val,
                       path_matching=config.path_matching_val, max_len=config.max_length_formula,
                       form_prepro=vocab.form_prepro)

Loaded 10 formulas from data/small.formulas.norm.txt
Bucketing the dataset...
- done.


In [12]:
len(train_set)

10

In [13]:
train_set._formulas

{0: '\\alpha + \\beta',
 1: '\\frac { 1 } { 2 }',
 2: '\\frac { \\alpha } { \\beta }',
 3: '1 + 2',
 4: '\\alpha + \\beta',
 5: '\\frac { 1 } { 2 }',
 6: '\\frac { \\alpha } { \\beta }',
 7: '1 + 2',
 8: '\\alpha + \\beta',
 9: '\\frac { 1 } { 2 }'}

In [14]:
encoder = CNN_Encoder(config.attn_cell_config['dim_embeddings'])

In [33]:
decoder = Decoder(config.attn_cell_config['dim_embeddings'], config.attn_cell_config['num_units'], len(vocab.id_to_tok), vocab.id_end)


In [34]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)
    
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    
    return tf.reduce_mean(loss_)