In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model.cnn import CNN
from model.encoder import Encoder
from model.decoder import Decoder
from model.endtoend import HME2LaTeX
from data_processing.loadData import HMEDataset
import pandas as pd

In [None]:
labels = '..\\data\\CROHME2016_data\\labels.csv'
images = '..\\data\\CROHME2016_data\\formula_png'
dataset = HMEDataset(labels, images, problem_type='formula')
BATCH_SIZE = 10

In [None]:
REPLACEMENTS = [
    ('(', '( '),
    ('{', '{ '),
    ('[', '[ '),
    (')', ' )'),
    ('}', ' }'),
    (']', ' ]'),
    ('=', ' = '),
    ('+', ' + '),
    ('-', ' - '),
    ('^', ' ^ '),
    ('*', ' * '),
    ('$', ' $ '),
    (',', ' , ')
]

def normalize(string, replacements):
    for replacement in replacements:
        string = string.replace(replacement[0], replacement[1])
    return string


In [None]:
dataset.img_labels.iloc[:,1] = dataset.img_labels.iloc[:,1].apply(lambda x: normalize(x, REPLACEMENTS))

In [None]:
train_dataloader = DataLoader(dataset, BATCH_SIZE, shuffle=True)

In [None]:
#Below code is from https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]


def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    seq = [SOS_token]
    seq.extend(indexes)
    return torch.tensor(seq, dtype=torch.long).view(-1, 1)


In [None]:
latex = Lang('latex')
for labels in dataset.img_labels.iloc[:,1]:
    latex.addSentence(labels)

In [None]:
words = torch.nn.utils.rnn.pad_sequence([tensorFromSentence(latex,dataset.img_labels.iloc[i,1]) for i in range(len(dataset.img_labels))], padding_value=-1)

In [None]:
cnn = CNN()
encoder = Encoder(512, 256, 27*24, BATCH_SIZE)
decoder = Decoder(1, 512, latex.n_words,27*24,BATCH_SIZE)
model  = HME2LaTeX(cnn, encoder, decoder,words.shape[0],BATCH_SIZE, latex.n_words, 1, 0, words.shape[0])

In [None]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
for epoch in range(1):
    for img, labels, indices in train_dataloader:
        optimizer.zero_grad()
        pred = model(img.to(torch.float32), words[:,indices,:].to(torch.float32))
        target = words[:,indices,:]
        total_loss = torch.zeros(1)

        for sample in range(BATCH_SIZE):
            sentence = target[:,sample,:]
            prob = pred[:,sample,:]
            unpacked_sentence = sentence[sentence!=-1]
            unpacked_prob = prob[:unpacked_sentence.shape[0],:]
            l = loss(unpacked_prob[:-1,:], unpacked_sentence[1:].type(torch.long))
            total_loss += l
        total_loss.backward()
        optimizer.step()
        break

In [None]:
torch.save({
    'model_state_dict' : model.state_dict(),
    'optimizer_state_dict' : optimizer.state_dict(),
    'loss': total_loss,
    'data_loader': train_dataloader
}, './trainedmodel.tar')