In [1]:
import tensorflow as tf

In [2]:
import os
import glob

### Make sure our data is in order
data_base_dir = "../data"
figs_base_dir = "../figs"

original_data_path = data_base_dir + "/original/formula/"
processed_data_path = data_base_dir + "/processed/formula/"
pickle_data_path = data_base_dir + "/pickle/formula/"

assert os.path.exists(original_data_path), "Original data path does not exist."

In [3]:
training_images = glob.glob(f"{processed_data_path}/images/train/*.png")

In [140]:
import pandas as pd 
import numpy as np

# formulas = pd.read_csv(f"{processed_data_path}formulas.norm.lst", sep='~!!!~~~', header=None)
with open(f"{processed_data_path}labels/train.formulas.norm.txt") as f:
    train_labels = np.array(f.read().splitlines())
    
train_matches = pd.read_csv(f"{processed_data_path}images/train/train.matching.txt", sep=' ', header=None).values

print(f"Found {len(train_labels)} training labels.")
print(f"Found {len(train_matches)} training matches.")

# Get correct labels
train_labels = train_labels[[list(map(lambda f: f[1], train_matches))][0]]

print(f"Kept {len(train_labels)} training labels.")

Found 76322 training labels.
Found 76303 training matches.
Kept 76303 training labels.


In [5]:
print(f"Found {len(training_images)} training images.")

Found 76303 training images.


In [121]:
class Vocab(object):
    def __init__(self, vocab_path):
        self.build_vocab(vocab_path)
        
    def build_vocab(self, vocab_path):
        '''
        Builds the complete vocabulary, including special tokens
        '''
        self.unk   = "<UNK>"
        self.start = "<SOS>"
        self.end   = "<END>"
        self.pad   = "<PAD>"
        
        # First, load our vocab from disk & determine 
        # highest index in mapping.
        vocab = self.load_vocab(vocab_path)
        max_index = max(vocab.values())
        
        # Compile special token mapping
        special_tokens = {
            self.unk : max_index + 1,
            self.start : max_index + 2,
            self.end : max_index + 3,
            self.pad : max_index + 4
        }
        
        # Merge dicts to produce final word index
        self.token_index = {**vocab, **special_tokens}
        self.reverse_index = {v: k for k, v in self.token_index.items()} 
    
    def load_vocab(self, vocab_path):
        '''
        Load vocabulary from file
        '''
        token_index = {}
        with open(vocab_path) as f:
            for idx, token in enumerate(f):
                token = token.strip()
                token_index[token] = idx
        assert len(token_index) > 0, "Could not build word index"
        return token_index
                
    def tokenize_formula(self, formula):
        '''
        Converts a formula into a sequence of tokens using the vocabulary
        '''
        def lookup_token(token):
            return self.token_index[token] if token in self.token_index else self.token_index[self.unk]
        tokens = formula.strip().split(' ')        
        return list(map(lambda f: lookup_token(f), tokens))
        
    def pad_formula(self, formula, max_length):
        '''
        Pads a formula to max_length with pad_token, appending end_token.
        '''
        # Extra space for the end token
        padded_formula = self.token_index[self.pad] * np.ones(max_length + 1)
        padded_formula[len(formula)] = self.token_index[self.end]
        padded_formula[:len(formula)] = formula
        return padded_formula

    @property
    def length(self):
        return len(self.token_index)
    
vocab = Vocab(f"{processed_data_path}/vocab.txt")

In [7]:
# Hyperparameters
buffer_size = 1000
batch_size = 16
embedding_dim = 256
vocab_size = vocab.length
hidden_units = 256
num_datapoints = 16384
num_steps = num_datapoints // batch_size
epochs = 30 
train_new_model = True
max_image_size=(50,200)
max_formula_length = 130

In [141]:
# This hash table is used to perform token lookups in the vocab
table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant(list(vocab.token_index.keys())),
        values=tf.constant(list(vocab.token_index.values())),
    ),
    default_value=tf.constant(vocab.token_index[vocab.unk]),
    name="class_weight"
)

def load_and_decode_img(path):
    ''' Load the image and decode from png'''
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image)
    return tf.image.rgb_to_grayscale(image)

@tf.function
def lookup_token(token):
    ''' Lookup the given token in the vocab'''
    table.lookup(token)
    return  table.lookup(token)

def process_label(label):
    ''' Split to tokens, lookup & append <END> token'''
    tokens = tf.strings.split(label, " ")   
    tokens = tf.map_fn(lookup_token, tokens, dtype=tf.int32)
    return tf.concat([tokens, [vocab.token_index[vocab.end]]], 0)

def process_datum(path, label):
    return load_and_decode_img(path), process_label(label)

# Tokenize formulas
dataset = tf.data.Dataset.from_tensor_slices((training_images, train_labels)).map(process_datum)

In [142]:
# This hash table is used to perform token lookups in the vocab
table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant(list(vocab.token_index.keys())),
        values=tf.constant(list(vocab.token_index.values())),
    ),
    default_value=tf.constant(vocab.token_index[vocab.unk]),
    name="class_weight"
)

def load_and_decode_img(path):
    ''' Load the image and decode from png'''
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image)
    return tf.image.rgb_to_grayscale(image)

@tf.function
def lookup_token(token):
    ''' Lookup the given token in the vocab'''
    table.lookup(token)
    return  table.lookup(token)

def process_label(label):
    ''' Split to tokens, lookup & append <END> token'''
    tokens = tf.strings.split(label, " ")   
    return tokens

def process_datum(path, label):
    return load_and_decode_img(path), process_label(label)

# Tokenize formulas
dataset = tf.data.Dataset.from_tensor_slices((training_images, train_labels)).map(process_datum)

In [146]:
# Print some values from the dataset (pre-filter)
for datum in dataset.take(5):
    print(datum[1])
    print("\n")

tf.Tensor(
[480 498 473 507  21 509  35   4  20   9 507 213 507 496 478 493 498 428
 509 507 497 509 509   5 473 507 213 507  21 509 507  20   7 121 473 507
  21 509 509 509 248 480 497 473 507  21 509   7 497 473 507  21 509 480
 428 473 507  21 509   7 497 473 507  21 509 498 485 492 473 507  21 509
 428 480 446 473 507  21 509 355   9 507 213 507 480 499 473 507  21 509
 509 507   4  20   9 507 213 507 496 478 493 498 428 509 507 497 509 509
   5 473 507 213 507  21 509 507  20   7 121 473 507  21 509 509 509 509
 509  74  12 514], shape=(130,), dtype=int32)


tf.Tensor(
[465 215 474 507 290 507 484 493 494 482 509 509 392 415 474 507 492  36
  14 509 465 507  45 509 474 507 492 509 507 213 507   4   9 476   5 473
 507 492 509 509 507  21 473 507  21 492   9  20 509 509 509 514], shape=(52,), dtype=int32)


tf.Tensor(
[  4 507 162  50 509 474 507 476 509 483   5 474 507 485 487 509  35  14
   8  68  68  68  68   4 507 162  50 509 474 507 476 509  46   5 474 507
 485 487 488 509  35 

In [44]:
def filter_by_size(image, label):
    '''Filter the dataset by the size of the image & length of label'''
    label_length = tf.shape(label)
    image_size = tf.shape(image)
    
    # Does this image meet our size constraint?
    keep_image = tf.math.reduce_all(
        tf.math.greater_equal(max_image_size, image_size[:2])
    )
    # Does this image meet our formula length constraint?
    keep_label = tf.math.reduce_all(
        tf.math.greater_equal(max_formula_length, label_length[0])
    )
    return tf.math.logical_and(keep_image, keep_label)

data = dataset.filter(filter_by_size)

In [45]:
for datum in data.take(5):
    print(datum[1])
    print("\n")

tf.Tensor(
[480 498 473 507  21 509  35   4  20   9 507 213 507 496 478 493 498 428
 509 507 497 509 509   5 473 507 213 507  21 509 507  20   7 121 473 507
  21 509 509 509 248 480 497 473 507  21 509   7 497 473 507  21 509 480
 428 473 507  21 509   7 497 473 507  21 509 498 485 492 473 507  21 509
 428 480 446 473 507  21 509 355   9 507 213 507 480 499 473 507  21 509
 509 507   4  20   9 507 213 507 496 478 493 498 428 509 507 497 509 509
   5 473 507 213 507  21 509 507  20   7 121 473 507  21 509 509 509 509
 509  74 512 514], shape=(130,), dtype=int32)


tf.Tensor(
[483 474 507 485 487 509   4 504   5  35 507 213 507  20 509 507 476 473
 507  21 509 509 509  74 183 474 507 485 487 509   8 510 510 336 473 507
 476 509   4 504   5  35 336 473 507 476 509   8 351   4 476   8 336 473
 507 476 509  69  32 510 290 507 478 493 492 498 499  12 509 512 514], shape=(71,), dtype=int32)


tf.Tensor(
[362 474 507  50 509   4 496   5  35 415 474 507 490  35  20 509 473 507
  50 509  68  55 

In [116]:
# Shuffle and batch (dropping any batches that are < BATCH_SIZE long)
# also pad each batch to the largest image size + formulas to
# max_formula_length.
shapes = (tf.TensorShape([None,None,1]),tf.TensorShape([max_formula_length]))
values = (tf.constant(255, dtype=tf.uint8), tf.constant(vocab.token_index[vocab.pad]))
dataset = data.shuffle(buffer_size).padded_batch(
    batch_size, 
    padded_shapes=shapes,
    padding_values=values,
    drop_remainder=True
)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE).cache()

In [48]:
from __future__ import division
import math
import numpy as np
from six.moves import xrange
import tensorflow as tf


# taken from https://github.com/tensorflow/tensor2tensor/blob/37465a1759e278e8f073cd04cd9b4fe377d3c740/tensor2tensor/layers/common_attention.py
def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4):
    """Adds a bunch of sinusoids of different frequencies to a Tensor.

    Each channel of the input Tensor is incremented by a sinusoid of a difft
    frequency and phase in one of the positional dimensions.

    This allows attention to learn to use absolute and relative positions.
    Timing signals should be added to some precursors of both the query and the
    memory inputs to attention.

    The use of relative position is possible because sin(a+b) and cos(a+b) can
    be expressed in terms of b, sin(a) and cos(a).

    x is a Tensor with n "positional" dimensions, e.g. one dimension for a
    sequence or two dimensions for an image

    We use a geometric sequence of timescales starting with
    min_timescale and ending with max_timescale.  The number of different
    timescales is equal to channels // (n * 2). For each timescale, we
    generate the two sinusoidal signals sin(timestep/timescale) and
    cos(timestep/timescale).  All of these sinusoids are concatenated in
    the channels dimension.

    Args:
        x: a Tensor with shape [batch, d1 ... dn, channels]
        min_timescale: a float
        max_timescale: a float

    Returns:
        a Tensor the same shape as x.

    """
    static_shape = x.get_shape().as_list()
    num_dims = len(static_shape) - 2
    channels = tf.shape(x)[-1]
    num_timescales = channels // (num_dims * 2)
    log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            (tf.cast(num_timescales, tf.float32) - 1))
    inv_timescales = min_timescale * tf.exp(
            tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment)
    for dim in xrange(num_dims):
        length = tf.shape(x)[dim + 1]
        position = tf.cast(tf.range(length), tf.float32)
        scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(
                inv_timescales, 0)
        signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
        prepad = dim * 2 * num_timescales
        postpad = channels - (dim + 1) * 2 * num_timescales
        signal = tf.pad(signal, [[0, 0], [prepad, postpad]])
        for _ in xrange(1 + dim):
            signal = tf.expand_dims(signal, 0)
        for _ in xrange(num_dims - 1 - dim):
            signal = tf.expand_dims(signal, -2)
        x += signal
    return x

In [49]:
from tensorflow.keras import metrics, layers, Model

class BahdanauAttention(layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
        
    def call(self, features, hidden):
        # First, flatten the image
        shape = tf.shape(features)
        if len(shape) == 4:
            batch_size = shape[0]
            img_height = shape[1]
            img_width  = shape[2]
            channels   = shape[3]
            features = tf.reshape(features, shape=(batch_size, img_height*img_width, channels))
        else:
            print(f"Image shape not supported: {shape}.")
            raise NotImplementedError
        
        # features(CNN_encoder_output)
        # shape => (batch_size, flattened_image_size, embedding_size)
        
        # hidden shape == (batch_size, hidden_size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
        hidden_with_time_axis = tf.expand_dims(hidden, 1)

        # score 
        # shape => (batch_size, flattened_image_size, hidden_size)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))

        # attention_weights
        # shape => (batch_size, flattened_image_size, 1)
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        
        # context_vector
        # shape after sum => (batch_size, hidden_size)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)
        
        return context_vector, attention_weights


In [101]:
class CNNEncoder(tf.keras.Model):    
    def __init__(self, embedding_dim):
        super(CNNEncoder, self).__init__()
        
        self.fc = tf.keras.layers.Dense(embedding_dim)
        
    def build(self, input_shape):       
        self.cnn_1 = layers.Conv2D(64, (3, 3), activation='relu', input_shape=(input_shape[1], input_shape[2], 1))
        self.max_pool_1 = layers.MaxPooling2D((2, 2))
        
        self.cnn_2 = layers.Conv2D(256, (3, 3), activation='relu')
        self.max_pool_2 = layers.MaxPooling2D((2, 2))
        
        self.cnn_3 = layers.Conv2D(512, (3, 3), activation='relu')
        
    def call(self, images):
        images = tf.cast(images, tf.float32)
        x = self.cnn_1(images)
        x = self.max_pool_1(x)
        x = self.cnn_2(x)
        x = self.max_pool_2(x)
        x = self.cnn_3(x)
        return add_timing_signal_nd(x)


In [102]:
class RNNDecoder(tf.keras.Model):
    def __init__(self, embedding_dim, units, vocab_size):
        super(RNNDecoder, self).__init__()
        self.units = units
        
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(self.units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')
        
        self.fc1 = tf.keras.layers.Dense(self.units)
        self.fc2 = tf.keras.layers.Dense(vocab_size)

        self.attention = BahdanauAttention(self.units)

        
    def call(self, x, features, hidden):        
        # attend over the image features
        context_vector, attention_weights = self.attention(features, hidden)
        
        # convert our input vector to an embedding
        x = self.embedding(x)
                
        # concat the embedding and the context vector (from attention)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
        
        # pass to GRU
        output, state = self.gru(x)
        
        x = self.fc1(output)
        
        x = tf.reshape(x, (-1, x.shape[2]))
        
        # This produces a distribution over the vocab
        x = self.fc2(x)
        
        return x, state, attention_weights
    
    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.units))


In [103]:
encoder = CNNEncoder(embedding_dim=embedding_dim)
decoder = RNNDecoder(embedding_dim, hidden_units, vocab_size)

In [109]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def loss_function(real, pred):
    # In order to avoid <PAD> tokens contributing to the loss, we mask those tokens.
    # First, we create the mask, and compute the loss.
    mask = tf.math.logical_not(tf.math.equal(real, vocab.token_index[vocab.pad]))
    loss_ = loss_object(real, pred)
    
    # Second, we multiply the computed loss by the mask to zero out contributions from the <PAD> tokens.
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_mean(loss_)

In [110]:
checkpoint_path = "./checkpoints/train"
checkpoint = tf.train.Checkpoint(encoder=encoder,
                                 decoder=decoder,
                                 optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_path, max_to_keep=10)

# Attempt to restore from training checkpoint
start_epoch = 0
save_at_epoch = 5
if train_new_model is False and checkpoint_manager.latest_checkpoint:
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    start_epoch = int(checkpoint_manager.latest_checkpoint.split('-')[-1])*save_at_epoch
    print(f"Restored from checkpoint: {checkpoint_manager.latest_checkpoint}.")
    print(f"Start epoch: {start_epoch}.")
else:
    print("Did not restore from a checkpoint -- training new model!")    

Did not restore from a checkpoint -- training new model!


In [111]:
epoch_losses = []

In [112]:
@tf.function
def train_step(img_tensor, target):
    ''' Function that encapsulates training logic'''
    loss = 0

    # reset the decoder state, since Latex is different for each image    
    hidden = decoder.reset_state(batch_size=target.shape[0])

    # shape => (batch_size, 1)
    dec_input = tf.expand_dims([vocab.token_index[vocab.start]] * batch_size, 1)

    sequence_length = target.shape[1]    
    with tf.GradientTape() as tape:
        features = encoder(img_tensor)

        for i in range(0, sequence_length):
            # passing the features through the decoder
            predictions, hidden, _ = decoder(dec_input, features, hidden)

            ground_truth_token = target[:, i]
            loss += loss_function(ground_truth_token, predictions)
                
            # Teacher forcing: feed the correct word in as the next input to the
            # encoder, to provide the decoder with the proper context to predict
            # the following token in the sequence
            dec_input = tf.expand_dims(ground_truth_token, 1)

    total_loss = (loss / int(sequence_length))
    trainable_variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(gradients, trainable_variables))
    return loss, total_loss

In [132]:
vocab.token_index["."]
vocab.tokenize_formula(train_labels[0])
# for (batch, (img_tensor, target)) in enumerate(dataset):
#     for label in target:
#         print([vocab.reverse_index[token.numpy()] for token in label])
#     break

[480,
 498,
 473,
 507,
 21,
 509,
 35,
 4,
 20,
 9,
 507,
 213,
 507,
 496,
 478,
 493,
 498,
 428,
 509,
 507,
 497,
 509,
 509,
 5,
 473,
 507,
 213,
 507,
 21,
 509,
 507,
 20,
 7,
 121,
 473,
 507,
 21,
 509,
 509,
 509,
 248,
 480,
 497,
 473,
 507,
 21,
 509,
 7,
 497,
 473,
 507,
 21,
 509,
 480,
 428,
 473,
 507,
 21,
 509,
 7,
 497,
 473,
 507,
 21,
 509,
 498,
 485,
 492,
 473,
 507,
 21,
 509,
 428,
 480,
 446,
 473,
 507,
 21,
 509,
 355,
 9,
 507,
 213,
 507,
 480,
 499,
 473,
 507,
 21,
 509,
 509,
 507,
 4,
 20,
 9,
 507,
 213,
 507,
 496,
 478,
 493,
 498,
 428,
 509,
 507,
 497,
 509,
 509,
 5,
 473,
 507,
 213,
 507,
 21,
 509,
 507,
 20,
 7,
 121,
 473,
 507,
 21,
 509,
 509,
 509,
 509,
 509,
 74,
 12]

In [113]:
import time
import datetime

print(f"[Started training at: {datetime.datetime.now()}. Training for {epochs} epochs.]")
print(f"[Starting epoch: {start_epoch}]")
for epoch in range(start_epoch, epochs):
    start = time.time()
    total_loss = 0

    for (batch, (img_tensor, target)) in enumerate(dataset):
        batch_loss, sequence_loss = train_step(img_tensor, target)
        total_loss += sequence_loss

        if batch % 50 == 0:
            print(f"[Epoch: {epoch + 1} | Batch: {batch} | Loss: {sequence_loss:.4f}]")

    # Save epoch loss
    epoch_losses.append(total_loss / num_steps)

    # Save checkpoint (if required)
    if epoch % save_at_epoch == 0:
        checkpoint_manager.save()

    print(f"[Epoch: {epoch + 1} | Epoch Loss: {total_loss / num_steps}]")
    print(f"[Time elapsed for epoch: {format(time.time() - start)} seconds.] \n")


Tensor("img_tensor:0", shape=(16, 50, 200, 1), dtype=uint8)
[Epoch: 1 | Batch: 0 | Loss: 6.3455]


KeyboardInterrupt: 