# Notebook to showcase how to use the model

## CLEAR OUTPUT OF THE NOTEBOOK BEFORE COMMITING/PUSHING

The path may not work for you as they are hardcoded for a sample of the data on my machine

In [None]:
import torch
from torchvision.io import read_image
import torchvision.transforms as T
from cnn import CNN
from encoder import Encoder
from decoder import Decoder
import os
import csv
import re

In [None]:
FOLDER_PATH = '../../data/CROHME2016_data/data_png/subset/'

In [None]:
def sort_files(file):
    """Utility function to sort the file names according to their number"""
    match = re.match(r'\D*(?P<num>\d+)\..*', file)
    if match:
        return int(match.group('num'))   

In [None]:
#loading the images in one tensor
batch = torch.zeros((11,1,304,304))
tree = next(os.walk(FOLDER_PATH))
files = [file for file in tree[-1] if file.endswith('png')]
files.sort(key=sort_files)
for i,file in enumerate(files):
    pic = read_image(f"{FOLDER_PATH}/{file}").to(torch.float32)
    batch[i] += pic

In [None]:
#load the labels in a dictionary
labels = dict()
with open(f"{FOLDER_PATH}iso_GT.txt") as f:
    reader = csv.reader(f)
    for row in reader:
        labels[row[0]] = row[1]

## Word embedding

In [None]:
#Below code is from https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]


def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long).view(-1, 1)


In [None]:
#create our embedding of the data
latex = Lang('latex')
for label in labels.values():
    latex.addWord(label)
tensorFromSentence(latex, '\\sum')

## Training the end-to-end system

In [None]:
net = CNN()
encoder = Encoder(512, 256, 27*24, 11)
decoder = Decoder(1,512, latex.n_words, 27*24, 11)

In [None]:
words = torch.zeros((3,11,1))
for i,label in enumerate(labels.values()):
    sentence = torch.cat((torch.tensor([[SOS_token]]),tensorFromSentence(latex, label)))
    words[:,i,:] += sentence
words.requires_grad = True

In [None]:
from endtoend import HME2LaTeX

In [None]:
model = HME2LaTeX(net, encoder, decoder, 3, 11, 10, 1, 0, 3)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
loss = torch.nn.CrossEntropyLoss() #should be cross entropy

In [None]:
for epoch in range(5):
    optimizer.zero_grad()
    probs = model(batch,words)
    l = torch.zeros(1)
    for i in range(3):
        l += loss(probs[i].type(torch.float32), words[i].reshape((11)).type(torch.long))
    l.backward()
    optimizer.step()
    print(l)

In [None]:
pred = model(batch,None)

In [None]:
pred.topk(1)[1].view(3,11).T