In [None]:
from google.colab import drive 
drive.mount('/mntDrive')

In [None]:
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path

from tqdm.notebook import tqdm

In [None]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading data

In [None]:
in_dir = Path('/mntDrive/MyDrive/icdar-dataset-20220207')
#in_dir = Path('icdar-dataset-20220207')

train = pd.read_csv(in_dir/'task2_train.csv', index_col=0)
val = pd.read_csv(in_dir/'task2_val.csv', index_col=0)

train = train.fillna('')
val = val.fillna('')

In [None]:
print('train:', train.shape[0], 'samples')
print('val:', val.shape[0], 'samples')

In [None]:
def add_lens(data: pd.DataFrame) -> pd.DataFrame:
    data['len_ocr'] = data['ocr'].apply(lambda x: len(x))
    data['len_gs'] = data['gs'].apply(lambda x: len(x))

    return data

train = add_lens(train)
val = add_lens(val)

In [None]:
train

In [None]:
from pathlib import Path
import torch
from torch.utils.data import Dataset

class Task2Dataset(Dataset):
    def __init__(self, data, task1_data_dir, batch_size, max_len=11):
        self.ds = data.query(f'len_ocr < {max_len}').query(f'len_gs < {max_len}').copy()
        self.ds = self.ds.reset_index(drop=False)

        self.task1_data_dir = task1_data_dir
        self.batch_size = batch_size

        self.vectors_loaded = {}
        for in_file in tqdm(Path(self.task1_data_dir).iterdir()):
          if in_file.is_file():
            self.vectors_loaded[in_file] = torch.load(in_file)


    def __len__(self):
        return self.ds.shape[0]

    def __getitem__(self, idx):
        sample = self.ds.loc[idx]
        original_idx = sample['index']
        #print('original idx', original_idx)

        file_index = original_idx // self.batch_size
        index_in_file = original_idx % self.batch_size
        in_file = self.task1_data_dir/f'task2_task1_output_{file_index}.pt'
        #print('original idx', original_idx, 'file index', file_index, 'index in file', index_in_file, in_file)
        if in_file not in self.vectors_loaded.keys():
            self.vectors_loaded[in_file] = torch.load(in_file)
        # Copy the task1_ouput slice, so we have a new tensor
        task1_output = self.vectors_loaded[in_file][index_in_file].clone().detach().requires_grad_(True)

        return [char for char in sample.ocr], [char for char in sample.gs], task1_output

In [None]:
from pathlib import Path

#out_dir = Path('icdar-dataset-20220207')
out_dir = Path('/mntDrive/MyDrive/icdar-dataset-20220207')

In [None]:
from torchtext.vocab import build_vocab_from_iterator

vocab_transform = {}

def yield_tokens(data, col):
    for token in data[col].to_list():
        for char in token:
            yield char

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for name in ('ocr', 'gs'):
    
    vocab_transform[name] = build_vocab_from_iterator(yield_tokens(train, name),
                                                      min_freq=1,
                                                      specials=special_symbols,
                                                      special_first=True)
# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for name in ('ocr', 'gs'):
  vocab_transform[name].set_default_index(UNK_IDX)

In [None]:
from typing import Iterable, List

from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for name in ('ocr', 'gs'):
    text_transform[name] = sequential_transforms(vocab_transform[name],  # Numericalization (char -> idx)
                                                 tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch, input_hidden_batch = [], [], []
    for src_sample, tgt_sample, input_hidden in batch:
        src_batch.append(text_transform['ocr'](src_sample))
        tgt_batch.append(text_transform['gs'](tgt_sample))
        input_hidden_batch.append(input_hidden)

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    input_hidden_batch = torch.stack(input_hidden_batch)
    return src_batch.to(torch.int64), tgt_batch.to(torch.int64), torch.unsqueeze(input_hidden_batch, 0)

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

    def forward(self, input, hidden):
        # print('Encoder')
        # print('input size', input.size())
        # print('hidden size', hidden.size())
        embedded = self.embedding(input) 
        # print('embedded size', embedded.size())
        # print(embedded)
        # print('embedded size met view', embedded.view(1, 1, -1).size())
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size, device=device)

In [None]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=11):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)

        # print('embedded size', embedded.size())
        # print(embedded)
        embedded = torch.permute(embedded, (1, 0, 2))
        # print('permuted embedded size', embedded.size())
        # print(embedded)

        # print('hidden size', hidden.size())
        # print(hidden)

        # print('permuted embedded[0] size', embedded[0].size())
        # print('hidden[0] size', hidden[0].size())

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)

        # print('attn_weights', attn_weights.size())
        # print('attn_weights unsqueeze(1)', attn_weights.unsqueeze(1).size())
        # print('encoder outputs', encoder_outputs.size())


        attn_applied = torch.bmm(attn_weights.unsqueeze(1),
                                 encoder_outputs)

        # print('attn_applied', attn_applied.size())
        # print('attn_applied squeeze', attn_applied.squeeze().size())
        output = torch.cat((embedded[0], attn_applied.squeeze()), 1)
        # print('output', output.size())
        output = self.attn_combine(output).unsqueeze(0)
        # print('output', output.size())

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)

        # print(f'output: {output.size()}; hidden: {hidden.size()}; attn_weigts: {attn_weights.size()}')

        return output, hidden, attn_weights

    def initHidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size, device=device)

In [None]:
class ICDARTask2Seq2seq(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout, max_length, teacher_forcing_ratio):
        super(ICDARTask2Seq2seq, self).__init__()

        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.encoder = EncoderRNN(input_size, hidden_size)
        self.decoder = AttnDecoderRNN(hidden_size, output_size, 
                                      dropout_p=dropout, max_length=max_length)
    
    def forward(self, input, encoder_hidden, target, max_length):
        # input is src seq len x batch size
        # input voor de encoder (1 stap) moet zijn input seq len x batch size x 1
        input_tensor = input.unsqueeze(2)
        # print('input tensor size', input_tensor.size())

        input_length = input.size(0)

        batch_size = input.size(1)

        # Encoder part
        encoder_outputs = torch.zeros(batch_size, max_length, self.encoder.hidden_size, 
                                      device=device)
        # print('encoder outputs size', encoder_outputs.size())
    
        for ei in range(input_length):
            # print(f'Index {ei}; input size: {input_tensor[ei].size()}; encoder hidden size: {encoder_hidden.size()}')
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden)
            # print('Index', ei)
            # print('encoder output size', encoder_output.size())
            # print('encoder outputs size', encoder_outputs.size())
            # print('output selection size', encoder_output[:, 0].size())
            # print('ouput to save', encoder_outputs[:,ei].size())
            encoder_outputs[:, ei] = encoder_output[0, 0]
        
        # print('encoder outputs', encoder_outputs)
        # print('encoder hidden', encoder_hidden)

        # Decoder part
        # Target = seq len x batch size
        # Decoder input moet zijn: batch_size x 1 (van het eerste token = BOS)
        target_length = target.size(0)

        decoder_input = torch.tensor([[BOS_IDX] for _ in range(batch_size)], device=device)
        # print('decoder input size', decoder_input.size())

        decoder_outputs = torch.zeros(batch_size, max_length, self.decoder.output_size, 
                                      device=device)

        decoder_hidden = encoder_hidden

        use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
        
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            if use_teacher_forcing:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target[di, :].unsqueeze(1)  # Teacher forcing
                #print('decoder input size:', decoder_input.size())
            else:
                # Without teacher forcing: use its own predictions as the next input
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.detach()  # detach from history as input
                #print('decoder input size:', decoder_input.size())

            # print(f'Index {di}; decoder output size: {decoder_output.size()}; decoder input size: {decoder_input.size()}')
            decoder_outputs[:, di] = decoder_output

        # Zero out probabilities for padded chars
        target_masks = (target != PAD_IDX).float()

        # Compute log probability of generating true target words
        # print('P (decoder_outputs)', decoder_outputs.size())
        # print(target.transpose(0, 1))
        # print('Index', target.size(), target.transpose(0, 1).unsqueeze(-1))
        target_gold_std_log_prob = torch.gather(decoder_outputs, index=target.transpose(0, 1).unsqueeze(-1), dim=-1).squeeze(-1) * target_masks.transpose(0, 1)
        #print(target_gold_std_log_prob)
        scores = target_gold_std_log_prob.sum(dim=1)

        #print(scores)

        return scores, encoder_outputs


# batch_size = 2
# hidden_size = 768
# model = ICDARTask2Seq2seq(len(vocab_transform['ocr']), 
#                           hidden_size, 
#                           len(vocab_transform['gs']), 
#                           0.1, 
#                           11, 
#                           teacher_forcing_ratio=1)

# input = torch.tensor([[  7,  50],
#         [  4,   9],
#         [  3, 171],
#         [  1,  70],
#         [  1,  41],
#         [  1,   3]])
# # Er moet een initiele hidden vector gemaakt worden voor elk sample in de batch
# # De size is dus batch_size x encoder_hidden_size
# encoder_hidden = model.encoder.initHidden(batch_size=batch_size)

# target = torch.tensor([[47, 42],
#         [ 4, 15],
#         [ 3, 18],
#         [ 1,  3]])

# model(input, encoder_hidden, target, 11)

## Training the Model

To train we run the input sentence through the encoder, and keep track
of every output and the latest hidden state. Then the decoder is given
the ``<SOS>`` token as its first input, and the last hidden state of the
encoder as its first hidden state.

"Teacher forcing" is the concept of using the real target outputs as
each next input, instead of using the decoder's guess as the next input.
Using teacher forcing causes it to converge faster but `when the trained
network is exploited, it may exhibit
instability <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.378.4095&rep=rep1&type=pdf>`__.

You can observe outputs of teacher-forced networks that read with
coherent grammar but wander far from the correct translation -
intuitively it has learned to represent the output grammar and can "pick
up" the meaning once the teacher tells it the first few words, but it
has not properly learned how to create the sentence from the translation
in the first place.

Because of the freedom PyTorch's autograd gives us, we can randomly
choose to use teacher forcing or not with a simple if statement. Turn
``teacher_forcing_ratio`` up to use more of it.




In [None]:
from torch.utils.data import DataLoader

#import edlib

def calculate_ed(src_strings, tgt_strings):
    dist = []

    for src, tgt in zip(src_strings, tgt_strings):
        #print(repr(src), repr(tgt))
        ed = edlib.align(src, tgt)
        dist.append(ed['editDistance'])

    return np.mean(dist), np.std(dist)


# https://pytorch.org/tutorials/beginner/translation_transformer.html


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MAX_LENGTH = 11
BATCH_SIZE = 128


def indices2string(indices, itos):
    output = []
    for idxs in indices:
        #print(idxs)
        string = []
        for idx in idxs:
            if idx not in (UNK_IDX, PAD_IDX, BOS_IDX):
                if idx == EOS_IDX:
                    break
                else:
                    string.append(itos[idx])
        word = ''.join(string)
        output.append(word)
    return output


def validate_model(model, dataloader, MAX_LENGTH):
    cum_loss = 0
    cum_examples = 0

    was_training = model.training
    model.eval()

    # itos = tgt_vocab.get_itos()
    # output_strings = []

    with torch.no_grad():
        for src, tgt, input_hidden in dataloader:
            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)
            encoder_hidden = input_hidden.to(DEVICE)
            
            batch_size = src.size(1)

            example_losses, decoder_ouputs = model(src, encoder_hidden, tgt, MAX_LENGTH)
            example_losses = -example_losses
            batch_loss = example_losses.sum()

            bl = batch_loss.item()
            cum_loss += bl
            cum_examples += batch_size

            # Generate string outputs
            # output_idxs = decoder_ouputs.argmax(-1)
            # #print(output_idxs.size())
            # #print(output_idxs)

            # strings_batch = indices2string(output_idxs, itos)
            # for s in strings_batch:
            #     output_strings.append(s)

            # m, std = calculate_ed(output_strings, tgt_strings)

    if was_training:
        model.train()

    return cum_loss/cum_examples
            

def train_model(model=None, optimizer=None, num_epochs=5, valid_niter=5000, 
                model_save_path='model.rar', max_num_patience=5, max_num_trial=5, 
                lr_decay=0.5):  
    data_dir = out_dir/'task1_output'/'train'  
    train_iter = Task2Dataset(train, data_dir, batch_size=128)
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, 
                                  collate_fn=collate_fn, shuffle=True)

    data_dir = out_dir/'task1_output'/'val'
    val_dataloader = DataLoader(Task2Dataset(val, data_dir, batch_size=128), 
                                batch_size=4*BATCH_SIZE, collate_fn=collate_fn)

    #_, tgt_strings = get_gold_tgt_words(val_dataloader, vocab_transform['ocr'], vocab_transform['gs'])


    num_iter = 0
    report_loss = 0
    report_examples = 0
    val_loss_hist = []
    num_trial = 0
    patience = 0

    model.train()

    for epoch in range(1, num_epochs+1):
        cum_loss = 0
        cum_examples = 0

        for src, tgt, input_hidden in tqdm(train_dataloader):
            num_iter += 1
            # print('Source')
            # print(src.size())
            # print(src)
            # print(src[0].size())
            # print(src[0])
            # a = src[0].unsqueeze(1)
            # print(a.size())
            # print(a)
            # input_tensor = src.unsqueeze(2)
            # print(input_tensor.size())
            # print(input_tensor)
            # print(input_tensor[0])
            # print('Target')
            # print(tgt.size())
            # print(tgt)
            # print(tgt.dtype)
            if tgt.dtype == torch.float32:
                print(tgt)

            batch_size = src.size(1)

            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)
            encoder_hidden = input_hidden.to(DEVICE)

            # print(input_hidden.size())

            example_losses, _ = model(src, encoder_hidden, tgt, MAX_LENGTH)
            example_losses = -example_losses
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            bl = batch_loss.item()
            report_loss += bl
            report_examples += batch_size

            cum_loss += bl
            cum_examples += batch_size

            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

            optimizer.step()

            if num_iter % valid_niter == 0:
                val_loss = validate_model(model, val_dataloader, MAX_LENGTH)
                print(f'Epoch {epoch}, iter {num_iter}, avg. train loss {report_loss/report_examples}, avg. val loss {val_loss}')

                report_loss = 0
                report_examples = 0

                better_model = len(val_loss_hist) == 0 or val_loss < min(val_loss_hist)
                if better_model:
                    print(f'Saving model and optimizer to {model_save_path}')
                    torch.save({
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      }, model_save_path)
                elif patience < max_num_patience:
                    patience += 1
                    print(f'hit patience {patience}')

                    if patience == max_num_patience:
                        num_trial += 1
                        print(f'hit #{num_trial} trial')
                        if num_trial == max_num_trial:
                            print('early stop!')
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * lr_decay
                        print(f'load previously best model and decay learning rate to {lr}')

                        # load model
                        checkpoint = torch.load(model_save_path)
                        model.load_state_dict(checkpoint['model_state_dict'])
                        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                        
                        model = model.to(device)
                        
                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0
                    

                val_loss_hist.append(val_loss)


def get_gold_tgt_words(dataloader, src_vocab, tgt_vocab):

    src_itos = src_vocab.get_itos()
    src_strings = []

    tgt_itos = tgt_vocab.get_itos()
    tgt_strings = []

    for src, tgt in dataloader:
        src_batch = indices2string(src.transpose(0, 1), src_itos)
        for word in src_batch:
            src_strings.append(word)
        #print(tgt.size())
        tgt_batch = indices2string(tgt.transpose(0, 1), tgt_itos)
        for word in tgt_batch:
            tgt_strings.append(word)
    return src_strings, tgt_strings


hidden_size = 768
model = ICDARTask2Seq2seq(len(vocab_transform['ocr']), 
                          hidden_size, 
                          len(vocab_transform['gs']), 
                          0.1, 
                          MAX_LENGTH, 
                          teacher_forcing_ratio=0.0)
model.to(device)    
optimizer = optim.Adam(model.parameters())

msp = '/mntDrive/MyDrive/model.rar'
#msp = './model.rar'

train_model(model=model, optimizer=optimizer, model_save_path=msp, 
            num_epochs=25, valid_niter=1000, max_num_patience=5, max_num_trial=5, lr_decay=0.5)