# Notebook to showcase how to use the model

## CLEAR OUTPUT OF THE NOTEBOOK BEFORE COMMITING/PUSHING

The path may not work for you as they are hardcoded for a sample of the data on my machine

In [1]:
import torch
from torchvision.io import read_image
import torchvision.transforms as T
from cnn import CNN
from encoder import Encoder
from decoder import Decoder
import os
import csv
import re

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
FOLDER_PATH = '../../data/CROHME2016_data/data_png/subset/'

In [4]:
def sort_files(file):
    """Utility function to sort the file names according to their number"""
    match = re.match(r'\D*(?P<num>\d+)\..*', file)
    if match:
        return int(match.group('num'))   

In [5]:
#loading the images in one tensor
batch = torch.zeros((11,1,304,304)).to(device)
tree = next(os.walk(FOLDER_PATH))
files = [file for file in tree[-1] if file.endswith('png')]
files.sort(key=sort_files)
for i,file in enumerate(files):
    pic = read_image(f"{FOLDER_PATH}/{file}").float().to(device)
    batch[i] += pic

In [6]:
#load the labels in a dictionary
labels = dict()
with open(f"{FOLDER_PATH}iso_GT.txt") as f:
    reader = csv.reader(f)
    for row in reader:
        labels[row[0]] = row[1]

## Word embedding

In [7]:
#Below code is from https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]


def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    seq = [SOS_token]
    seq.extend(indexes)
    return torch.tensor(seq, dtype=torch.long).view(-1, 1).to(device)



REPLACEMENTS = [
    ('(', '( '),
    ('{', '{ '),
    ('[', '[ '),
    (')', ' )'),
    ('}', ' }'),
    (']', ' ]'),
    ('=', ' = '),
    ('+', ' + '),
    ('-', ' - '),
    ('^', ' ^ '),
    ('*', ' * '),
    ('$', ' $ '),
    (',', ' , ')
]

def normalize(string, replacements):
    for replacement in replacements:
        string = string.replace(replacement[0], replacement[1])
    return string

In [8]:
#create our embedding of the data
latex = Lang('latex')
for label in labels.values():
    latex.addWord(label)
tensorFromSentence(latex, '\\sum')

tensor([[0],
        [8],
        [1]])

## Training the end-to-end system

In [9]:
net = CNN(device).to(device)
encoder = Encoder(512, 256, 32*31, 11).to(device)
decoder = Decoder(1,512, latex.n_words, 32*31, 11, device).to(device)

In [10]:
words = torch.zeros((3,11,1)).to(device)
for i,label in enumerate(labels.values()):
    sentence = tensorFromSentence(latex, label)
    words[:,i,:] += sentence
words.requires_grad = True

In [11]:
from endtoend import HME2LaTeX

In [12]:
model = HME2LaTeX(net, encoder, decoder, 3, 11, 10, 1, 0, 3, device)

In [13]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
loss = torch.nn.CrossEntropyLoss() 

In [14]:
checkpoint = torch.load('model.tar')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
l = checkpoint['loss']
model.train();

In [None]:
for epoch in range(30):
    optimizer.zero_grad()
    probs = model(batch,words)
    l = torch.zeros(1).to(device)
    for i in range(2):
        l += loss(probs[:2][i].type(torch.float32).to(device), words[1:][i].reshape((11)).type(torch.long).to(device))
    l.backward()
    optimizer.step()
    print(l)

In [None]:
torch.save({
    'model_state_dict' : model.state_dict(),
    'optimizer_state_dict' : optimizer.state_dict(),
    'loss': l
}, './model.tar')

In [15]:
pred = model(batch,words)

In [16]:
pred[:2].topk(1)[1].view(2,11).T

tensor([[2, 1],
        [6, 1],
        [2, 1],
        [5, 1],
        [6, 1],
        [7, 1],
        [6, 1],
        [2, 1],
        [9, 1],
        [6, 1],
        [5, 1]])

In [17]:
#accuracy
torch.count_nonzero(words[1:].view(2,11).T == pred[:2].topk(1)[1].view(2,11).T.cpu()) / 22 * 100

tensor(86.3636)