In [None]:
!pip install python-dateutil --upgrade
!pip install awscli --upgrade
!pip install greenlet --ignore-installed --upgrade
!pip install allennlp

In [None]:
%load_ext autoreload
%autoreload 2

FOLD = 0

import os
import sys
import random
import glob
import gc
import logging
import requests
import re

from typing import Dict, Tuple, List
from collections import OrderedDict
from overrides import overrides
from time import sleep

import cv2
import numpy as np
import pandas as pd

import mlcrate as mlc

from sklearn.model_selection import KFold

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import optim

import torchvision

import allennlp

from allennlp.common import Registrable, Params
from allennlp.common.util import START_SYMBOL, END_SYMBOL, JsonDict

from allennlp.data import DatasetReader, Instance
from allennlp.data.fields import ArrayField, TextField
from allennlp.data.iterators import BucketIterator, MultiprocessIterator
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers import Token, CharacterTokenizer
from allennlp.data.vocabulary import Vocabulary

from allennlp.models import Model

from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper # MIGHT USE FOR ABSTRACTION

from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.nn.beam_search import BeamSearch

from allennlp.training.metrics import F1Measure, BLEU
from allennlp.training import Trainer

sys.path.insert(0, './math_handwriting_recognition')

logger = logging.getLogger()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

KEY = 'iDCO9Ns00jOTY_Db3KuaVLhjux-HKPp_tEtV8LEtesP'

def notify(value_1, value_2, value_3='', key=KEY):
    report = {}
    report['value1'] = value_1
    report['value2'] = value_2
    report['value3'] = value_3

    requests.post(f'https://maker.ifttt.com/trigger/notification/with/key/{key}', data=report)
    requests.post(f'https://maker.ifttt.com/trigger/email/with/key/{key}', data=report)

In [None]:
!mkdir logs
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip -o ngrok-stable-linux-amd64.zip
LOG_DIR = './logs' # Here you have to put your log directory
get_ipython().system_raw(
    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
    .format(LOG_DIR)
)
get_ipython().system_raw('./ngrok http 6006 &')

!curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"
        
temp = !curl -s http://localhost:4040/api/tunnels | python3 -c "import sys,json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"
print(temp[0])

notify('Tensorboard', f'Tensorboard at: {temp[0]}', temp[0])

!rm ngrok
!rm ngrok-stable-linux-amd64.zip

In [None]:
# #Alternative to ngrok
# !mkdir logs
# get_ipython().system_raw('tensorboard --logdir ./logs --host 0.0.0.0 --port 6006 &')
# !ssh -o "StrictHostKeyChecking no" -R 80:localhost:6006 serveo.net

In [None]:
!tar -xf ../input/crohme-2019-unofficial-processed/train.tgz -C ./
!tar -xf ../input/crohme-2019-unofficial-processed/val.tgz -C ./

!mkdir math_handwriting_recognition
!touch math_handwriting_recognition/__init__.py

In [None]:
train_df = pd.read_csv('./crohme-train/train.csv')
kfold = KFold(n_splits=10, shuffle=True, random_state=1337)
train_idx, val_idx = list(kfold.split(train_df))[0]
train_df, val_df = train_df.iloc[train_idx].reset_index(), train_df.iloc[val_idx].reset_index()
train_df.to_csv('./crohme-train/train.csv')
val_df.to_csv('./crohme-train/val.csv')

In [None]:
%%writefile math_handwriting_recognition/dataset.py
import os
import random
from typing import Dict, Tuple, List
from overrides import overrides

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

import spacy

import allennlp

from allennlp.common.util import START_SYMBOL, END_SYMBOL, get_spacy_model

from allennlp.data import DatasetReader, Instance
from allennlp.data.fields import ArrayField, TextField, MetadataField
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, CharacterTokenizer, WordTokenizer

@Tokenizer.register("latex")
class LatexTokenizer(Tokenizer):
    def __init__(self) -> None:
        super().__init__()

    def _tokenize(self, text):        
        text = text.replace('(', ' ( ')
        text = text.replace(')', ' ) ')
        text = text.replace('{', ' { ')
        text = text.replace('}', ' } ')
#         text = text.replace('$', ' $ ')
        text = text.replace('$', '')
        text = text.replace('_', ' _ ')
        text = text.replace('^', ' ^ ')
        text = text.replace('+', ' + ')
        text = text.replace('-', ' - ')
        text = text.replace('/', ' / ')
        text = text.replace('*', ' * ')
        text = text.replace('=', ' = ')
        text = text.replace('[', ' [ ')
        text = text.replace(']', ' ] ')
        text = text.replace('|', ' | ')
        text = text.replace('!', ' ! ')
        text = text.replace(',', ' , ')
        
        text = text.replace('\\', ' \\')
        
        text = text.replace('0', ' 0 ')
        text = text.replace('1', ' 1 ')
        text = text.replace('2', ' 2 ')
        text = text.replace('3', ' 3 ')
        text = text.replace('4', ' 4 ')
        text = text.replace('5', ' 5 ')
        text = text.replace('6', ' 6 ')
        text = text.replace('7', ' 7 ')
        text = text.replace('8', ' 8 ')
        text = text.replace('9', ' 9 ')
        
        text2 = ''
        for word in text.split():
            if len(word) > 1:
                if word[0] != '\\':
                    for char in word:
                        text2 += f' {char}'
                else:
                    text2 += f' {word}'
            else:
                text2 += f' {word}'

        return [Token(token) for token in text2.split()]

    @overrides
    def tokenize(self, text: str) -> List[Token]:
        tokens = self._tokenize(text)

        return tokens
    
# From https://jdhao.github.io/2017/11/06/resize-image-to-square-with-padding/
def resize(im, desired_size):

    old_size = im.shape[:2] # old_size is in (height, width) format

    ratio = float(desired_size)/max(old_size)
    new_size = tuple([int(x*ratio) for x in old_size])

    # new_size should be in (width, height) format

    im = cv2.resize(im, (new_size[1], new_size[0]))

    delta_w = desired_size - new_size[1]
    delta_h = desired_size - new_size[0]
    top, bottom = delta_h//2, delta_h-(delta_h//2)
    left, right = delta_w//2, delta_w-(delta_w//2)

    color = [0, 0, 0]
    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT,
        value=color)

    return new_im

@DatasetReader.register('CROHME')
class CROHMEDatasetReader(DatasetReader):
    def __init__(self, root_path: str, tokenizer: Tokenizer, height: int = 512, width: int = 512, lazy: bool = True,
                 subset: bool = False) -> None:
        super().__init__(lazy)
        
        self.mean = 0.4023
        self.std = 0.4864
        
        self.root_path = root_path
        self.height = height
        self.width = width
        self.subset = subset
        
        self._tokenizer = tokenizer
        self._token_indexer = {"tokens": SingleIdTokenIndexer()}

    @overrides
    def _read(self, file: str):
        df = pd.read_csv(os.path.join(self.root_path, file))
        if self.subset:
            df = df.loc[:16]

        for _, row in df.iterrows():
            img_id = row['id']
            
            if 'label' in df.columns:
                label = row['label']
                yield self.text_to_instance(file, img_id, label)
            else:
                yield self.text_to_instance(file, img_id)
            
    @overrides
    def text_to_instance(self, file: str, img_id: int, label: str = None) -> Instance:
        sub_path = file.split('/')[0]
        path = os.path.join(self.root_path, sub_path, 'data', f'{img_id}.png')

        img = (1 - plt.imread(path)[:,:,0])
        img = img.reshape(1, img.shape[0], img.shape[1])
        img = np.concatenate((img, img, img))
        img = cv2.resize(img.transpose(1, 2, 0), (self.width, self.height)).transpose(2, 0, 1)
        img = np.rint(img)
    
        fields = {}
        fields['metadata'] = MetadataField({'path': path})
        fields['img'] = ArrayField(img)
        
        if label is not None:
            label = self._tokenizer.tokenize(label)

            label.insert(0, Token(START_SYMBOL))
            label.append(Token(END_SYMBOL))
            
            fields['label'] = TextField(label, self._token_indexer)
        
        return Instance(fields)

In [None]:
%%writefile math_handwriting_recognition/metrics.py
import os
import random
import subprocess
from typing import Dict, Tuple
from overrides import overrides

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import allennlp

from allennlp.common import Registrable, Params
from allennlp.common.util import START_SYMBOL, END_SYMBOL

from allennlp.training.metrics import Metric, F1Measure, BLEU, BooleanAccuracy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# From https://github.com/allenai/allennlp/blob/master/allennlp/training/metrics/boolean_accuracy.py
@Metric.register("exprate")
class Exprate(Metric):
    def __init__(self, end_index: int, vocab) -> None:
        self._correct = 0.0
        self._total = 0.0
        self.vocab = vocab
        
        self._end_index = end_index

    def __call__(self, predictions: torch.Tensor, targets: torch.Tensor):
        predictions, targets = self.unwrap_to_tensors(predictions, targets)
        batch_size = predictions.size(0)

        # Shape: (batch_size, -1)
        predictions = predictions.view(batch_size, -1)
        # Shape: (batch_size, -1)
        targets = targets.view(batch_size, -1)

        # Get index of eos token in targets
        end_indices = (targets == self._end_index).nonzero()[:, 1]
        
        # Check if each prediction in batch is identical to target
        for i in range(batch_size):
            end_index = end_indices[i]
            
            # Shape: (1, -1)
            target = targets[i, :end_index]

            # Shape: (1, -1)
            prediction = predictions[i, :end_index]
            
            if torch.equal(prediction, target):
                self._correct += 1
            self._total += 1

    def get_metric(self, reset: bool = False):
        accuracy = float(self._correct) / float(self._total)
        if reset:
            self.reset()
        return {'exprate': accuracy}

    @overrides
    def reset(self):
        self._correct = 0.0
        self._total = 0.0

In [None]:
%%writefile math_handwriting_recognition/encoder.py
import os
import random
from typing import Dict, Tuple, List
from overrides import overrides
from collections import OrderedDict

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision

import allennlp

from allennlp.common import Registrable, Params

from allennlp.data import Vocabulary

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# from https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65939
class CSE(nn.Module):
    def __init__(self, in_ch, r):
        super(CSE, self).__init__()
        
        self.linear_1 = nn.Linear(in_ch, in_ch//r)
        self.linear_2 = nn.Linear(in_ch//r, in_ch)
    
    def forward(self, x):
        input_x = x

        x = x.view(*(x.shape[:-2]),-1).mean(-1)
        x = F.relu(self.linear_1(x), inplace=True)
        x = self.linear_2(x)
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = torch.sigmoid(x)

        x = torch.mul(input_x, x)
        
        return x

class SSE(nn.Module):
    def __init__(self, in_ch):
        super(SSE, self).__init__()
        
        self.conv = nn.Conv2d(in_ch, 1, kernel_size=1, stride=1)
        
    def forward(self, x):
        input_x = x
        
        x = self.conv(x)
        x = torch.sigmoid(x)
        
        x = torch.mul(input_x, x)
        
        return x

class SCSE(nn.Module):
    def __init__(self, in_ch, r):
        super(SCSE, self).__init__()
        
        self.cSE = CSE(in_ch, r)
        self.sSE = SSE(in_ch)
        
    def forward(self, x):
        cSE = self.cSE(x)
        sSE = self.sSE(x)
        
        x = torch.add(cSE, sSE)
        
        return x

class WAPConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, dropout: bool = False):
        super(WAPConv, self).__init__()
        
        self._dropout = dropout
        
        self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.bn = nn.BatchNorm2d(out_ch)
        
    def forward(self, x: torch.Tensor):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)

        if self._dropout:
            x = F.dropout(x, 0.2)
        
        return x

class WAPBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, dropout: bool = False):
        super(WAPBlock, self).__init__()
        
        self.conv_1 = WAPConv(in_ch, out_ch, dropout=dropout)
        self.conv_2 = WAPConv(out_ch, out_ch, dropout=dropout)
        self.conv_3 = WAPConv(out_ch, out_ch, dropout=dropout)
        self.conv_4 = WAPConv(out_ch, out_ch, dropout=dropout)

        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x: torch.Tensor):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.conv_4(x)
        
        x = self.pool(x)

        return x

# Can't be pretrained; param is only for compatibility
def WAPBackbone(pretrained: bool = False):
    model = nn.Sequential(
        WAPBlock(3, 32),
        WAPBlock(32, 64),
        WAPBlock(64, 64),
        WAPBlock(64, 128, dropout=True)
    )
    
    return model

# From https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
# Can't be pretrained; param is only for compatibility
def densenet(pretrained: bool = False):
    model =  torchvision.models.DenseNet(growth_rate=24, block_config=(32, 32, 32), num_init_features=48)
    
    return model

class Encoder(nn.Module, Registrable):
    def __init__(self, pretrained: bool = False) -> None:
        super().__init__()
        
        self._pretrained = pretrained
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

    def get_output_dim(self) -> int:
        raise NotImplementedError()
        
    def get_feature_map_size(self) -> int:
        raise NotImplementedError()
        
@Encoder.register('backbone')
class BackboneEncoder(Encoder):
    def __init__(self, encoder_type: str = 'renset18', encoder_height: int = 4, encoder_width: int = 16, pretrained: bool = False, custom_in_conv: bool = False) -> None:
        super().__init__(pretrained=pretrained)
        
        self._encoder_type = encoder_type
        
        self._encoder_height = encoder_height
        self._encoder_width = encoder_width
        
        self._custom_in_conv = custom_in_conv
        
        self._backbones = {
            'vgg16': {
                'model': torchvision.models.vgg16,
                'encoder_dim': 512
            },
            'resnet18': { # 4 x 16
                'model': torchvision.models.resnet18,
                'encoder_dim': 512
            },
            'resnet50': { # 4 x 16
                'model': torchvision.models.resnet50,
                'encoder_dim': 2048
            },
            'densenet': { # 8 x 32
                'model': densenet,
                'encoder_dim': 1356
            },
            'WAP': { # 8 x 32
                'model': WAPBackbone,
                'encoder_dim': 128
            },
            'Im2latex': { # 14 x 62
                'model': Im2latexBackbone,
                'encoder_dim': 512
            },
            'smallResnet18': { # 8 x 32
                'model': torchvision.models.resnet18,
                'encoder_dim': 256
            },
        }
        
        self._backbone = self._backbones[self._encoder_type]['model'](pretrained=self._pretrained)
        self._encoder_dim = self._backbones[self._encoder_type]['encoder_dim']
        
        if self._custom_in_conv:
            self._backbone._modules['conv1'] = nn.Conv2d(3, 64, 3, padding=1)
        
        modules = list(self._backbone.children())
        
        if self._encoder_type == 'densenet':
            modules = modules[0][:-1]
        elif self._encoder_type == 'vgg16':
            modules = modules[:-1]
        elif self._encoder_type == 'smallResnet18':
            modules = modules[:-3]
        elif self._encoder_type == 'resnet18' or self._encoder_type == 'resnet50':
            modules = modules[:-2]

            # Add SCSE between resnet blocks
            modules = nn.Sequential(
                *modules[:5],
                SCSE(64, 16),
                *modules[5],
                SCSE(128, 16),
                *modules[6],
                SCSE(256, 16),
                *modules[7],
                SCSE(512, 16)
            )

        self._encoder = nn.Sequential(
            *modules,
            nn.AdaptiveAvgPool2d((self._encoder_height, self._encoder_width))
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encode image
        x = self._encoder(x)

        # Flatten image
        # Shape: (batch_size, height * width, encoder_dim)
        x = x.view(x.shape[0], -1, x.shape[1])

        return x

    def get_output_dim(self) -> int:
        return self._encoder_dim
    
    def get_feature_map_size(self) -> int:
        return self._encoder_height * self._encoder_width
        
@Encoder.register('lstm')
class LstmEncoder(Encoder):
    # Don't set hidden_size manually
    def __init__(self, encoder: Encoder, hidden_size: int = 512, layers: int = 1, bidirectional: bool = False) -> None:
        super().__init__(pretrained=False)

        self._encoder = encoder
        
        self._hidden_size = hidden_size
        self._layers = layers
        self._bidirectional = bidirectional
        
        self._lstm = nn.LSTM(input_size=self._encoder.get_output_dim(), hidden_size=self._hidden_size, num_layers=self._layers, batch_first=True, 
                             bidirectional=self._bidirectional)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encode image
        # Shape: (batch_size, height * width, encoder_dim)
        x = self._encoder(x)
        
        # Encode encoded feature map with (bi)lstm
        # Shape: (batch_size, height * width, num_directions * hidden_size)
        x, _ = self._lstm(x)

        if self._bidirectional:
            # Shape: (batch_size, height * width, num_directions, hidden_size)
            x = x.view(-1, x.shape[1], 2, self._hidden_size)

            # Add directions and reverse bidirectional part
#             x = x[:, :, 0, :] + torch.from_numpy(np.flip(x[:, :, 1, :].detach().cpu().numpy(), axis=-1).copy()).to(device)
            x = x[:, :, 0, :] + x[:, :, 1, :]

        return x
    
    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder':  # type: ignore
        encoder = params.pop("encoder")
        layers = params.pop("layers")
        bidirectional = params.pop("bidirectional")
        
        encoder = Encoder.from_params(vocab=vocab, params=encoder)
        
        return cls(encoder, encoder._encoder_dim, layers, bidirectional)

    @overrides
    def get_output_dim(self) -> int:
        return self._hidden_size
    
    def get_feature_map_size(self) -> int:
        return self._encoder._encoder_height * self._encoder._encoder_width
    
class Im2latexBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, padding: int, bn: bool = True, pool: bool = False, pool_stride: Tuple[int, int] = None):
        super(Im2latexBlock, self).__init__()
        
        self._bn = bn
        self._pool = pool

        self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1 if padding else 0)
        
        if self._bn:
            self.bn = nn.BatchNorm2d(out_ch)
            
        if self._pool:
            self.pool = nn.MaxPool2d(pool_stride, pool_stride)

    def forward(self, x: torch.Tensor):
        x = self.conv(x)

        if self._bn:
            x = self.bn(x)

        x = F.relu(x)

        if self._pool:
            x = self.pool(x)

        return x

# Can't be pretrained; param is only for compatibility
def Im2latexBackbone(pretrained: bool = False):
    model = nn.Sequential(
        Im2latexBlock(3, 64, 1, False, True, (2, 2)),
        Im2latexBlock(64, 128, 1, False, True, (2, 2)),
        Im2latexBlock(128, 256, 1, True, False),
        Im2latexBlock(256, 256, 1, False, True, (1, 2)),
        Im2latexBlock(256, 512, 1, True, True, (2, 1)),
        Im2latexBlock(512, 512, 0, True, False)
    )
    
    return model

@Encoder.register('Im2latex')
class Im2latexEncoder(Encoder):
    # Don't set hidden_size manually
    def __init__(self, encoder: Encoder, hidden_size: int = 512, layers: int = 1, bidirectional: bool = False) -> None:
        super().__init__(pretrained=False)
       
        self._hidden_size = hidden_size
        self._layers = layers
        self._bidirectional = bidirectional
        
        self._num_directions = 2 if self._bidirectional else 1
        
        self._encoder = encoder
        
        self._row_encoder = nn.GRU(input_size=self._encoder.get_output_dim(), hidden_size=self._hidden_size, num_layers=self._layers, batch_first=True, 
                                    bidirectional=self._bidirectional)
        
        self._positional_embeddings = nn.Embedding(self._encoder._encoder_height, self._hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encode image
        # Shape: (batch_size, height * width, encoder_dim)
        x = self._encoder(x)
        
        # Shape: (batch_size, encoder_dim, height, width)
        x = x.view(-1, self._encoder._encoder_dim, self._encoder._encoder_height, self._encoder._encoder_width)

        # Shape: (batch_size, hidden_size, height, width)
        encoded_rows = torch.zeros((x.shape[0], self._hidden_size, self._encoder._encoder_height, self._encoder._encoder_width), device=device)
        
        # Go over each row
        for i in range(x.shape[2]):
            # Get row
            # Shape: (batch_size, width, encoder_dim)
            row = x[:, :, i].transpose(1, 2)
            
            # Get positional embeddings for row
            # Shape: (1, hidden_size)
            positional_embedding = self._positional_embeddings(torch.LongTensor([i]).to(device))

            # Duplicate positional embeddings for each element in batch
            # Shape: (layers * num_directions, batch_size, hidden_size)
            positional_embedding = positional_embedding.view(1, 1, self._hidden_size).repeat(self._layers * self._num_directions, x.shape[0], 1)
            
            # Encode row
            # Shape: (batch_size, width, num_directions * hidden_size)
            encoded_row, _ = self._row_encoder(row, positional_embedding)
            
            if self._bidirectional:
                # Shape: (batch_size, width, 2, hidden_size)
                encoded_row = encoded_row.view(-1, encoded_row.shape[1], 2, self._hidden_size)

                # Add bidirectional directions
                # Shape: (batch_size, width, hidden_size)
                encoded_row = encoded_row[:, :, 0, :] + encoded_row[:, :, 1, :]
                # Reverse bidirectional direction
#                 encoded_row = encoded_row[:, :, 0, :] + torch.from_numpy(np.flip(encoded_row[:, :, 1, :].detach().cpu().numpy(), axis=-1).copy()).to(device)

            # Shape: (batch_size, hidden_size, width)
            encoded_rows[:, :, i, :] = encoded_row.transpose(1, 2)

        # Shape: (batch_size, height * with, hidden_size)
        x = encoded_rows.view(-1, self._encoder.get_feature_map_size(), self._hidden_size)
    
        return x

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder':  # type: ignore
        encoder = params.pop("encoder")
        layers = params.pop("layers")
        bidirectional = params.pop("bidirectional")
        
        encoder = Encoder.from_params(vocab=vocab, params=encoder)
        
        return cls(encoder, encoder._encoder_dim, layers, bidirectional)

    def get_output_dim(self) -> int:
        return self._hidden_size
    
    def get_feature_map_size(self) -> int:
        return self._encoder._encoder_height * self._encoder._encoder_width

# from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(int(in_planes), int(out_planes), kernel_size=3, stride=stride,padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(int(in_planes), int(out_planes), kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
    
# From https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                        growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                        kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)

class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)

class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features, pool=True):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        if pool == True:
            self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))

# pretrained is only for compatibility
class MultiscaleDenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0.2, num_classes=1000, pretrained=False):

        super(MultiscaleDenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            pool = True
            if i == len(block_config) - 1:
                pool = False
            
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate

            trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, pool=pool)
                
            self.features.add_module('transition%d' % (i + 1), trans)
            num_features = num_features // 2

            # Add SCSE
            scse = nn.Sequential(SCSE(num_features, 16))
            self.features.add_module('scse%d' % (i + 1), scse)
            
        # Multiscale branch

        self.main_branch = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
            _DenseBlock(num_layers=32, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate),
            SCSE(1356, 16)
        )

        self.multiscale_branch = nn.Sequential(
            _DenseBlock(num_layers=16, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate),
            SCSE(972, 16)
        )

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        main_features = self.main_branch(features)
        multiscale_features = self.multiscale_branch(features)

        return [main_features, multiscale_features]

@Encoder.register('multiscale')
class MultiscaleEncoder(Encoder):
    def __init__(self, encoder_type: str = 'renset18', encoder_height: int = 4, encoder_width: int = 16, pretrained: bool = False) -> None:
        super().__init__(pretrained=pretrained)
        
        self._encoder_type = encoder_type
        
        self._encoder_height = encoder_height
        self._encoder_width = encoder_width
        
        self._backbones = {
            'resnet18': {
                'model': torchvision.models.resnet18,
                'encoder_dim': [512]
            },
            'resnet50': {
                'model': torchvision.models.resnet50,
                'encoder_dim': [2048]
            },
            'densenet': {
                'model': MultiscaleDenseNet,
                'encoder_dim': [1356, 972]
            }
        }
                
        if self._encoder_type == 'resnet18' or self._encoder_type == 'resnet50':
            self._backbone = self._backbones[self._encoder_type]['model'](pretrained=self._pretrained)
            self._encoder_dim = self._backbones[self._encoder_type]['encoder_dim'][0]

            # Common conv blocks
            self._encoder = nn.Sequential(
                *list(self._backbone.children())[:-3]
            )

            # Last conv block
            self._main_branch = nn.Sequential(*list(self._backbone.children())[-3])

            # Uses 1x1 convs to convert identity to correct num of channels
            self._identity_conv = nn.Sequential(
                conv1x1(self._encoder_dim / 2, self._encoder_dim),
                nn.BatchNorm2d(self._encoder_dim),
            )

            # Last conv block without pool and not pretrained
            self._multiscale_branch = nn.Sequential(
                BasicBlock(self._encoder_dim / 2, self._encoder_dim, downsample=self._identity_conv),
                BasicBlock(self._encoder_dim, self._encoder_dim)
            )
        else:
            self._backbone = self._backbones[self._encoder_type]['model'](growth_rate=24, block_config=(32, 32), num_init_features=48)
            self._encoder_dim = self._backbones[self._encoder_type]['encoder_dim'][0]

            self._encoder = self._backbone

    def forward(self, x: torch.Tensor):
        # Encode image through common conv blocks
        # Shape: (batch_size, channels, height * 2, width * 2)
        x = self._encoder(x)

        if self._encoder_type == 'resnet18' or self._encoder_type == 'resnet50':
            # Shape: (batch_size, channels, height, width)
            main_features = self._main_branch(x)

            # Get multiscale features
            # Shape: (batch_size, channels, height * 2, width * 2)
            multiscale_features = self._multiscale_branch(x)            
        else:
            main_features, multiscale_features = x[0], x[1]
            
        # Flatten features
        # Shape: (batch_size, height * width, encoder_dim)
        main_features = main_features.view(main_features.shape[0], -1, main_features.shape[1])
        # Shape: (batch_size, height * 2 * width * 2, encoder_dim)
        multiscale_features = multiscale_features.view(multiscale_features.shape[0], -1, multiscale_features.shape[1])

        return [main_features, multiscale_features]
    
    def get_output_dim(self) -> int:
        return self._encoder_dim
    
    def get_feature_map_size(self) -> int:
        return self._encoder_height * self._encoder_width

#DEPRECATED
@Encoder.register('multiscale-lstm')
class LstmEncoder(Encoder):
    # Don't set hidden_size manually
    def __init__(self, encoder: Encoder, hidden_size: int = 256, layers: int = 1, bidirectional: bool = False) -> None:
        super().__init__(pretrained=False)

        self._encoder = encoder
        
        self._hidden_size = hidden_size
        self._layers = layers
        self._bidirectional = bidirectional
        
        self._lstm = nn.LSTM(input_size=1356, hidden_size=1356, num_layers=self._layers, batch_first=True, 
                             bidirectional=self._bidirectional)
        
        self._lstm2 = nn.LSTM(input_size=972, hidden_size=972, num_layers=self._layers, batch_first=True, 
                             bidirectional=self._bidirectional)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encode image
        # Shape: (batch_size, height * width, encoder_dim)
        x = self._encoder(x)
        
        # Encode encoded main and dense feature map with (bi)lstm
        # Shape: (batch_size, height * width, num_directions * hidden_size)
        x_1, _ = self._lstm(x[0])
        x_2, _ = self._lstm2(x[1])

        if self._bidirectional:
            # Shape: (batch_size, height * width, num_directions, hidden_size)
            x_1 = x_1.view(-1, x_1.shape[1], 2, self._hidden_size)
            x_2 = x_2.view(-1, x_2.shape[1], 2, self._hidden_size)

            # Add directions and reverse bidirectional part
#             x_1 = x_1[:, :, 0, :] + torch.from_numpy(np.flip(x_1[:, :, 1, :].detach().cpu().numpy(), axis=-1).copy()).to(device)
#             x_2 = x_2[:, :, 0, :] + torch.from_numpy(np.flip(x_2[:, :, 1, :].detach().cpu().numpy(), axis=-1).copy()).to(device)

            x_1 = x_1[:, :, 0, :] + x_1[:, :, 1, :]
            x_2 = x_2[:, :, 0, :] + x_2[:, :, 1, :]

        return [x_1, x_2]
    
    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder':  # type: ignore
        encoder = params.pop("encoder")
        layers = params.pop("layers")
        bidirectional = params.pop("bidirectional")
        
        encoder = Encoder.from_params(vocab=vocab, params=encoder)
        
        return cls(encoder, encoder._encoder_dim, layers, bidirectional)
    
    def get_output_dim(self) -> int:
        return self._encoder._encoder_dim
    
    def get_feature_map_size(self) -> int:
        return self._encoder._encoder_height * self._encoder._encoder_width

In [None]:
%%writefile math_handwriting_recognition/attention.py
import os
import random
from typing import Dict, Tuple
from overrides import overrides

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision

import allennlp

from allennlp.common import Registrable, Params
from allennlp.data.vocabulary import Vocabulary

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class CaptioningAttention(nn.Module, Registrable):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError()

    def get_output_dim(self) -> int:
        raise NotImplementedError()

@CaptioningAttention.register('image-captioning')
class ImageCaptioningAttention(CaptioningAttention):
    def __init__(self, encoder_dim: int = 512, decoder_dim: int = 256, attention_dim: int = 256, doubly_stochastic_attention: bool = True) -> None:
        super().__init__()
                
        self._encoder_dim = encoder_dim
        self._decoder_dim = decoder_dim
        self._attention_dim = attention_dim
        
        self._doubly_stochastic_attention = doubly_stochastic_attention
        
        self._encoder_attention = nn.Linear(self._encoder_dim, self._attention_dim)
        self._decoder_attention = nn.Linear(self._decoder_dim, self._attention_dim)
        self._attention = nn.Linear(self._attention_dim, 1)

        if self._doubly_stochastic_attention:
            self._f_beta = nn.Linear(self._decoder_dim, self._encoder_dim)

    @overrides
    def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Shape: (batch_size, height * width, attention_dim)
        encoder_attention = self._encoder_attention(x)
        # Shape: (batch_size, 1, attention_dim)
        decoder_attention = self._decoder_attention(h).unsqueeze(1)

        # Shape: (batch_size, height * width)
        # Can't concat attention since encoder returns h*w and decoder returns 1
        attention = self._attention(torch.tanh(encoder_attention + decoder_attention)).squeeze(2)

        # No need for masked softmax since all encoder pixels are available and hidden state of rnn isn't masked
        # Shape: (batch_size, h * w, 1)
        attention_weights = torch.softmax(attention, dim=1).unsqueeze(2)

        # Shape: (batch_size, encoder_dim)
        attention = (x * attention_weights).sum(dim=1)
        
        if self._doubly_stochastic_attention:     
            # Shape: (batch_size, encoder_dim)
            gate = torch.sigmoid(self._f_beta(h))
            # Shape: (batch_size, encoder_dim)
            attention = gate * attention
        
        return attention, attention_weights
    
    @overrides
    def get_output_dim(self) -> int:
        return self._encoder_dim

@CaptioningAttention.register('WAP')
class WAPAttention(CaptioningAttention):
    def __init__(self, encoder_dim: int = 512, decoder_dim: int = 256, attention_dim: int = 256, kernel_size: int = 5, padding: int=2) -> None:
        super().__init__()
                
        self._encoder_dim = encoder_dim
        self._decoder_dim = decoder_dim
        self._attention_dim = attention_dim
        self._kernel_size = kernel_size
        
        self._encoder_attention = nn.Linear(self._encoder_dim, self._attention_dim)
        self._decoder_attention = nn.Linear(self._decoder_dim, self._attention_dim)
        
        # If kernel size is changed, padding needs to also change
        # Not sure if original uses padding; needed here since need same dimension inputs to attention
        self._coverage = nn.Conv2d(1, self._attention_dim, kernel_size, padding=padding)
        self._coverage_attention = nn.Linear(self._attention_dim, self._attention_dim)
        
        self._attention = nn.Linear(self._attention_dim, 1)

    @overrides
    def forward(self, x: torch.Tensor, h: torch.Tensor, sum_attention_weights: torch.Tensor, height: int = 8) -> Tuple[torch.Tensor, torch.Tensor]:
        # Shape: (batch_size, height * width, attention_dim)
        encoder_attention = self._encoder_attention(x)

        # Shape: (batch_size, 1, attention_dim)
        decoder_attention = self._decoder_attention(h).unsqueeze(1)

        # Get coverage over sum correctly when batch size at timestep isn't local batch size 
        # Need to clone sum_attention_weights since it's modified by an in-place operation 
        # Assumes 4:1 aspect ratio
        # Shape: (batch_size, height * width, attention_dim)
        # Shape: (batch_size, height * width, attention_dim)
        coverage = self._coverage(sum_attention_weights[:encoder_attention.shape[0]].view(-1,1, height, height * 4).clone()).view(-1, height * height * 4, self._attention_dim)
        coverage_attention = self._coverage_attention(coverage)

        # Shape: (batch_size, height * width)
        attention = self._attention(torch.tanh(encoder_attention + decoder_attention + coverage_attention)).squeeze(2)

        # No need for masked softmax since all encoder pixels are available and hidden state of rnn isn't masked
        # Shape: (batch_size, h * w, 1)
        attention_weights = torch.softmax(attention, dim=1).unsqueeze(2)

        # Update sum correctly when batch size at timestep isn't local batch size 
        # Shape: (batch_size, h * w)
        sum_attention_weights[:attention_weights.shape[0]] += attention_weights.view(-1, attention_weights.shape[1])

        # Shape: (batch_size, encoder_dim)
        attention = (x * attention_weights).sum(dim=1)
        
        return attention, attention_weights, sum_attention_weights
    
    @overrides
    def get_output_dim(self) -> int:
        return self._encoder_dim

@CaptioningAttention.register('multiscale')
class MultiscaleAttention(CaptioningAttention):
    def __init__(self, main_attention: CaptioningAttention, multiscale_attention: CaptioningAttention, height_1: int = 4, height_2: int = 8) -> None:
        super().__init__()

        self._main_attention = main_attention
        self._multiscale_attention = multiscale_attention
        
        self._height_1 = height_1
        self._height_2 = height_2

    @overrides
    def forward(self, x: torch.Tensor, h: torch.Tensor, sum_attention_weights_0: torch.Tensor, sum_attention_weights_1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        main_features, multiscale_features = x[0], x[1]
        
        main_attention, main_attention_weights, sum_attention_weights_0 = self._main_attention(main_features, h, sum_attention_weights_0, height=self._height_1)
        multiscale_attention, multiscale_attention_weights, sum_attention_weights_1 = self._multiscale_attention(multiscale_features, h, sum_attention_weights_1, height=self._height_2)
        
        attention = torch.cat([main_attention, multiscale_attention], dim=1)
        
        return attention, (main_attention_weights, multiscale_attention_weights), sum_attention_weights_0, sum_attention_weights_1
    
    @overrides
    def get_output_dim(self) -> int:
        return self._main_attention._encoder_dim + self._multiscale_attention._encoder_dim

In [None]:
%%writefile math_handwriting_recognition/decoder.py
import os
import random
from typing import Dict, Tuple
from overrides import overrides

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision

import allennlp

from allennlp.common import Registrable, Params

from allennlp.data.vocabulary import Vocabulary

from allennlp.modules.token_embedders import Embedding

from math_handwriting_recognition.attention import CaptioningAttention

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class CaptioningDecoder(nn.Module, Registrable):
    def __init__(self, vocab: Vocabulary):
        super(CaptioningDecoder, self).__init__()
        
        self.vocab = vocab
        
    def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor, predicted_indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        raise NotImplementedError()
        
    def get_output_dim(self) -> int:
        raise NotImplementedError()

    # Input dim is dim of h and c
    def get_input_dim(self) -> int:
        raise NotImplementedError()

@CaptioningDecoder.register('image-captioning')
class ImageCaptioningDecoder(CaptioningDecoder):
    def __init__(self, vocab: Vocabulary, attention: CaptioningAttention, embedding_dim:int = 256, decoder_dim:int = 256):
        super(ImageCaptioningDecoder, self).__init__(vocab=vocab)
        
        self._vocab_size = self.vocab.get_vocab_size()
        self._embedding_dim = embedding_dim
        self._decoder_dim = decoder_dim

        self._embedding = Embedding(self._vocab_size, self._embedding_dim)
        self._attention = attention
        self._decoder_cell = nn.LSTMCell(self._embedding.get_output_dim() + self._attention.get_output_dim(), self._decoder_dim)
        self._linear = nn.Linear(self._decoder_dim, self._vocab_size)

    @overrides
    def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor, predicted_indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # Shape: (batch_size, embedding_dim)
        embedding = self._embedding(predicted_indices).float().view(-1, self._embedding_dim)
        
        # Shape: (batch_size, encoder_dim) (batch_size, h * w, 1)
        attention, attention_weights = self._attention(x, h)

        ## Change to not use teacher forcing all the time
        # Shape: (batch_size, decoder_dim) (batch_size, decoder_dim)
        h, c = self._decoder_cell(torch.cat([attention, embedding], dim=1), (h, c))
        
        # Get output predictions (one per character in vocab)
        # Shape: (batch_size, vocab_size)
        preds = self._linear(h)

        return h, c, preds, attention_weights
    
    @overrides
    def get_output_dim(self) -> int:
        return self._vocab_size
    
    @overrides
    def get_input_dim(self) -> int:
        return self._decoder_dim

@CaptioningDecoder.register('WAP')
class WAPDecoder(CaptioningDecoder):
    def __init__(self, vocab: Vocabulary, attention: CaptioningAttention, embedding_dim:int = 256, decoder_dim:int = 256):
        super(WAPDecoder, self).__init__(vocab=vocab)
        
        self._vocab_size = self.vocab.get_vocab_size()
        self._embedding_dim = embedding_dim
        self._decoder_dim = decoder_dim

        self._embedding = Embedding(self._vocab_size, self._embedding_dim)
        self._attention = attention
        self._decoder_cell = nn.GRUCell(self._embedding.get_output_dim() + self._attention.get_output_dim(), self._decoder_dim)
        self._linear = nn.Linear(self._decoder_dim, self._vocab_size)

    @overrides
    def forward(self, x: torch.Tensor, h: torch.Tensor, predicted_indices: torch.Tensor, sum_attention_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # Shape: (batch_size, embedding_dim)
        embedding = self._embedding(predicted_indices).float().view(-1, self._embedding_dim)
        
        # Shape: (batch_size, encoder_dim) (batch_size, h * w, 1) (batch_size, h * w)
        attention, attention_weights, sum_attention_weights = self._attention(x, h, sum_attention_weights)

        ## Change to not use teacher forcing all the time
        # Shape: (batch_size, decoder_dim) (batch_size, decoder_dim)
        h = self._decoder_cell(torch.cat([attention, embedding], dim=1), h)
        
        # Get output predictions (one per character in vocab)
        # Shape: (batch_size, vocab_size)
        preds = self._linear(h)

        return h, preds, attention_weights, sum_attention_weights
    
    @overrides
    def get_output_dim(self) -> int:
        return self._vocab_size
    
    @overrides
    def get_input_dim(self) -> int:
        return self._decoder_dim

@CaptioningDecoder.register('multiscale')
class MultiscaleDecoder(CaptioningDecoder):
    def __init__(self, vocab: Vocabulary, attention: CaptioningAttention, embedding_dim: int = 256, decoder_dim:int = 256):
        super(MultiscaleDecoder, self).__init__(vocab=vocab)

        self._vocab_size = self.vocab.get_vocab_size()
        self._embedding_dim = embedding_dim
        self._decoder_dim = decoder_dim
                
        self._embedding = Embedding(self._vocab_size, self._embedding_dim)
        self._dropout = nn.Dropout(0.1)
        # Output size of state cell must be decoder dim since state is transformed by the state cell
        self._state_cell = nn.GRUCell(self._embedding.get_output_dim(), self._decoder_dim)

        self._attention = attention
        self._decoder_cell = nn.GRUCell(self._attention.get_output_dim(), self._decoder_dim)

        self._linear = nn.Linear(self._decoder_dim, self._vocab_size)

    @overrides
    def forward(self, x: torch.Tensor, h: torch.Tensor, predicted_indices: torch.Tensor, sum_attention_weights_0: torch.Tensor, sum_attention_weights_1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # Shape: (batch_size, embedding_dim)
        embedding = self._embedding(predicted_indices).float().view(-1, self._embedding_dim)
        embedding = self._dropout(embedding)
        
        # Shape: (batch_size, decoder_dim)
        h = self._state_cell(embedding, h)

        # Shape: (batch_size, encoder_dim) (batch_size, h * w, 1)
        attention, attention_weights, sum_attention_weights_0, sum_attention_weights_1 = self._attention(x, h, sum_attention_weights_0, sum_attention_weights_1)

        ## Change to not use teacher forcing all the time
        # Shape: (batch_size, decoder_dim) (batch_size, decoder_dim)
        h = self._decoder_cell(attention, h)

        # Get output predictions (one per character in vocab)
        # Shape: (batch_size, vocab_size)
        preds = self._linear(h)

        return h, preds, attention_weights, sum_attention_weights_0, sum_attention_weights_1
    
    @overrides
    def get_output_dim(self) -> int:
        return self._vocab_size
    
    @overrides
    def get_input_dim(self) -> int:
        return self._decoder_dim

In [None]:
 %%writefile math_handwriting_recognition/model.py
import os
import random
from typing import Dict, Tuple
from overrides import overrides

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision

import allennlp

from allennlp.common import Registrable, Params
from allennlp.common.util import START_SYMBOL, END_SYMBOL

from allennlp.data.vocabulary import Vocabulary

from allennlp.models import Model

from allennlp.modules.token_embedders import Embedding

from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.nn.beam_search import BeamSearch

from allennlp.training.metrics import F1Measure, BLEU

from math_handwriting_recognition.metrics import Exprate
from math_handwriting_recognition.encoder import Encoder
from math_handwriting_recognition.decoder import CaptioningDecoder

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

@Model.register('image-captioning')
class ImageCaptioning(Model):
    def __init__(self, vocab: Vocabulary, encoder: Encoder, decoder: CaptioningDecoder, max_timesteps: int = 75, teacher_forcing: bool = True, scheduled_sampling_ratio: float = 1, beam_size: int = 10) -> None:
        super().__init__(vocab)

        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        self._pad_index = self.vocab.get_token_index('@@PADDING@@')

        self._max_timesteps = max_timesteps
        self._teacher_forcing = teacher_forcing
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._beam_size = beam_size

        self._encoder = encoder
        self._decoder = decoder

        self._init_h = nn.Linear(self._encoder.get_output_dim(), self._decoder.get_input_dim())
        self._init_c = nn.Linear(self._encoder.get_output_dim(), self._decoder.get_input_dim())

        self.beam_search = BeamSearch(self._end_index, self._max_timesteps, self._beam_size)

        self._bleu = BLEU(exclude_indices={self._start_index, self._end_index, self._pad_index})
        self._exprate = Exprate(self._end_index, self.vocab)

        self._attention_weights = None
        
    def _init_hidden(self, encoder: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mean_encoder = encoder.mean(dim=1)
        
        # Shape: (batch_size, decoder_dim)
        initial_h = self._init_h(mean_encoder)
        # Shape: (batch_size, decoder_dim)
        initial_c = self._init_c(mean_encoder)

        return initial_h, initial_c
    
    def _decode(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Get data from state
        metadata = state['metadata']
        x = state['x']
        h = state['h']
        c = state['c']
        label = state['label']
        mask = state['mask']
        
        # Get actual size of current batch
        local_batch_size = x.shape[0]

        # Sort data to be able to only compute relevent parts of the batch at each timestep
        # Shape: (batch_size)
        lengths = mask.sum(dim=1)
        # Shape: (batch_size) (batch_size)
        sorted_lengths, indices = lengths.sort(dim=0, descending=True)
        # Computing last timestep isn't necessary with labels since last timestep is eos token or pad token 
        timesteps = sorted_lengths[0] - 1

        # Shape: (batch_size, ?)
        # Shape: (batch_size, height * width, encoder_dim)
        # Shape: (batch_size, decoder_dim)
        # Shape: (batch_size, decoder_dim)
        # Shape: (batch_size, timesteps)
        # Shape: (batch_size, timesteps)
        metadata = [metadata[i] for i in indices]
        x = x[indices]
        h = h[indices]
        c = c[indices]
        label = label[indices]        
        mask = mask[indices]
        
        # Shape: (batch_size, 1)
        predicted_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1, 1)
        
        # Shape: (batch_size, timesteps, vocab_size)
        predictions = torch.zeros(local_batch_size, timesteps, self._decoder.get_output_dim(), device=device)
        attention_weights = torch.zeros(local_batch_size, timesteps, self._encoder.get_feature_map_size(), device=device)
        
        for t in range(timesteps):
            # Shape: (batch_offset)
            batch_offset = sum([l > t for l in sorted_lengths.tolist()])

            # Only compute data in valid timesteps
            # Shape: (batch_offset, height * width, encoder_dim)
            # Shape: (batch_offset, decoder_dim)
            # Shape: (batch_offset, decoder_dim)
            # Shape: (batch_offset, 1)
            x_t = x[:batch_offset]
            h_t = h[:batch_offset]
            c_t = c[:batch_offset]
            predicted_indices_t = predicted_indices[:batch_offset]
            
            # Decode timestep
            # Shape: (batch_size, decoder_dim) (batch_size, decoder_dim) (batch_size, vocab_size), (batch_size, encoder_dim, 1)
            h, c, preds, attention_weight = self._decoder(x_t, h_t, c_t, predicted_indices_t)
            
            # Get new predicted indices to pass into model at next timestep
            # Use teacher forcing if chosen
            if self._teacher_forcing:
                # Send next timestep's label to next timestep
                # Shape: (batch_size, 1)
                predicted_indices = label[:batch_offset, t + 1].view(-1, 1)
            else:
                # Shape: (batch_size, 1)
                predicted_indices = torch.argmax(preds, dim=1).view(-1, 1)
            
            # Save preds
            predictions[:batch_offset, t, :] = preds
            attention_weights[:batch_offset, t, :] = attention_weight.view(-1, self._encoder.get_feature_map_size())
            
        # Update state and add logits
        state['metadata'] = metadata
        state['x'] = x
        state['h'] = h
        state['c'] = c
        state['label'] = label
        state['mask'] = mask
        state['attention_weights'] = attention_weights
        state['logits'] = predictions
            
        return state
    
    def _beam_search_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # Group_size is batch_size * beam_size except for first decoding timestep where it is batch_size
        # Shape: (group_size, decoder_dim) (group_size, decoder_dim) (group_size, vocab_size)
        h, c, predictions, attention_weights = self._decoder(state['x'], state['h'], state['c'], last_predictions)
        
        if self._attention_weights is not None:
            attention_weights = attention_weights.view(-1, self._beam_size, 1, self._encoder.get_feature_map_size())
            self._attention_weights = torch.cat([self._attention_weights, attention_weights[:, 0, :, :]], dim=1)
        else:
            attention_weights = attention_weights.view(-1, 1, self._encoder.get_feature_map_size())
            self._attention_weights = attention_weights

        # Update state
        # Shape: (group_size, decoder_dim)
        state['h'] = h
        # Shape: (group_size, decoder_dim)
        state['c'] = c

        # Run log_softmax over logit predictions
        # Shape: (group_size, vocab_size)
        log_preds = F.log_softmax(predictions, dim=1)

        return log_preds, state
    
    def _beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Get data from state
        x = state['x']
        h = state['h']
        c = state['c']
        
        # Get actual size of current batch
        local_batch_size = x.shape[0]

        # Beam search wants initial preds of shape: (batch_size)
        # Shape: (batch_size)
        initial_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1)
        
        state = {'x': x, 'h': h, 'c': c}
        
        # Timesteps returned aren't necessarily max_timesteps
        # Shape: (batch_size, beam_size, timesteps), (batch_size, beam_size)
        
        self._attention_weights = None
        
        predictions, log_probabilities = self.beam_search.search(initial_indices, state, self._beam_search_step)

        # Only keep best predictions from beam search
        # Shape: (batch_size, timesteps)
        predictions = predictions[:, 0, :].view(local_batch_size, -1)
        
        return predictions
        
    @overrides
    def forward(self, metadata: object, img: torch.Tensor, label: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        # Encode the image
        # Shape: (batch_size, height * width, encoder_dim)
        x = self._encoder(img)

        state = {'metadata': metadata, 'x': x}
        # Compute loss on train and val
        if label is not None:
            # Initialize h and c
            # Shape: (batch_size, decoder_dim)
            state['h'], state['c'] = self._init_hidden(x)

            # Convert label dict to tensor since label isn't an input to the model and get mask
            # Shape: (batch_size, timesteps)
            state['mask'] = get_text_field_mask(label).to(device)
            # Shape: (batch_size, timesteps)
            state['label'] = label['tokens']

            # Decode encoded image and get loss on train and val
            state = self._decode(state)

            # Loss shouldn't be computed on start token
            state['mask'] = state['mask'][:, 1:].contiguous()
            state['target'] = state['label'][:, 1:].contiguous()

            # Compute cross entropy loss
            state['loss'] = sequence_cross_entropy_with_logits(state['logits'], state['target'], state['mask'])
            # Doubly stochastic regularization
            state['loss'] += ((1 - torch.sum(state['attention_weights'], dim=1)) ** 2).mean()

        # Decode encoded image with beam search on val and test
        if not self.training:
            # (Re)initialize h and c
            state['h'], state['c'] = self._init_hidden(state['x'])
            
            # Run beam search
            state['out'] = self._beam_search(state)

            # Save attention weights
            state['attention_weights'] = self._attention_weights

            # Compute validation scores
            if 'label' in state:
                self._bleu(state['out'], state['target'])
                self._exprate(state['out'], state['target'])
            
        # Set out to logits while training
        else:
            state['out'] = state['logits']
            
        return state
    
    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}

        # Return Bleu score if possible
        if not self.training:
            metrics.update(self._bleu.get_metric(reset))
            metrics.update(self._exprate.get_metric(reset))
            
        return metrics
        
    def _trim_predictions(self, predictions: torch.Tensor) -> torch.Tensor:
        for b in range(predictions.shape[0]):
            # Shape: (timesteps)
            predicted_index = predictions[b]
            # Set last predicted index to eos token in case there are no predicted eos tokens
            predicted_index[-1] = self._end_index

            # Get index of first eos token
            # Shape: (timesteps)
            mask = predicted_index == self._end_index
            # Work around for pytorch not having an easy way to get the first non-zero index
            eos_token_idx = list(mask.cpu().numpy()).index(1)
            
            # Set prediction at eos token's timestep to eos token
            predictions[b, eos_token_idx] = self._end_index
            # Replace all timesteps after first eos token with pad token
            predictions[b, eos_token_idx + 1:] = self._pad_index

        return predictions

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Trim test preds to first eos token
        # Shape: (batch_size, timesteps)
        output_dict['out'] = self._trim_predictions(output_dict['out'])

        return output_dict

@Model.register('WAP')
class WAP(ImageCaptioning):
    def __init__(self, vocab: Vocabulary, encoder: Encoder, decoder: CaptioningDecoder, max_timesteps: int = 75, teacher_forcing: bool = True, scheduled_sampling_ratio: float = 1, beam_size: int = 10) -> None:
        super().__init__(vocab, encoder, decoder, max_timesteps, teacher_forcing, scheduled_sampling_ratio, beam_size)
        
    def _init_hidden(self, encoder: torch.Tensor) -> torch.Tensor:
        mean_encoder = encoder.mean(dim=1)
        
        # Shape: (batch_size, decoder_dim)
        initial_h = self._init_h(mean_encoder)

        return initial_h
    
    def _decode(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Get data from state
        metadata = state['metadata']
        x = state['x']
        h = state['h']
        label = state['label']
        mask = state['mask']
        
        # Get actual size of current batch
        local_batch_size = x.shape[0]

        # Sort data to be able to only compute relevent parts of the batch at each timestep
        # Shape: (batch_size)
        lengths = mask.sum(dim=1)
        # Shape: (batch_size) (batch_size)
        sorted_lengths, indices = lengths.sort(dim=0, descending=True)
        # Computing last timestep isn't necessary with labels since last timestep is eos token or pad token 
        timesteps = sorted_lengths[0] - 1

        # Shape: (batch_size, ?)
        # Shape: (batch_size, height * width, encoder_dim)
        # Shape: (batch_size, decoder_dim)
        # Shape: (batch_size, timesteps)
        # Shape: (batch_size, timesteps)
        metadata = [metadata[i] for i in indices]
        x = x[indices]
        h = h[indices]
        label = label[indices]        
        mask = mask[indices]
        
        # Shape: (batch_size, 1)
        predicted_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1, 1)
        
        # Shape: (batch_size, timesteps, vocab_size)
        predictions = torch.zeros(local_batch_size, timesteps, self._decoder.get_output_dim(), device=device)
        attention_weights = torch.zeros(local_batch_size, timesteps, self._encoder.get_feature_map_size(), device=device)
        sum_attention_weights = torch.zeros(local_batch_size, self._encoder.get_feature_map_size(), device=device)

        for t in range(timesteps):
            # Shape: (batch_offset)
            batch_offset = sum([l > t for l in sorted_lengths.tolist()])

            # Only compute data in valid timesteps
            # Shape: (batch_offset, height * width, encoder_dim)
            # Shape: (batch_offset, decoder_dim)
            # Shape: (batch_offset, 1)
            x_t = x[:batch_offset]
            h_t = h[:batch_offset]
            predicted_indices_t = predicted_indices[:batch_offset]
            
            # Decode timestep
            # Shape: (batch_size, decoder_dim) (batch_size, vocab_size), (batch_size, encoder_dim, 1), (batch_size, height * width)
            h, preds, attention_weight, sum_attention_weights = self._decoder(x_t, h_t, predicted_indices_t, sum_attention_weights)
            
            # Get new predicted indices to pass into model at next timestep
            # Use teacher forcing if chosen
            if self._teacher_forcing:
                # Send next timestep's label to next timestep
                # Shape: (batch_size, 1)
                predicted_indices = label[:batch_offset, t + 1].view(-1, 1)
            else:
                # Shape: (batch_size, 1)
                predicted_indices = torch.argmax(preds, dim=1).view(-1, 1)
            
            # Save preds
            predictions[:batch_offset, t, :] = preds
            attention_weights[:batch_offset, t, :] = attention_weight.view(-1, self._encoder.get_feature_map_size())
            
        # Update state and add logits
        state['metadata'] = metadata
        state['x'] = x
        state['h'] = h
        state['label'] = label
        state['mask'] = mask
        state['attention_weights'] = attention_weights
        state['logits'] = predictions
            
        return state
    
    def _beam_search_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # Group_size is batch_size * beam_size except for first decoding timestep where it is batch_size
        # Shape: (group_size, decoder_dim) (group_size, vocab_size) (?) (group_size, height * width)
        h, predictions, attention_weights, sum_attention_weights = self._decoder(state['x'], state['h'], last_predictions, state['sum_attention_weights'])
        
        if self._attention_weights is not None:
            attention_weights = attention_weights.view(-1, self._beam_size, 1, self._encoder.get_feature_map_size())
            self._attention_weights = torch.cat([self._attention_weights, attention_weights[:, 0, :, :]], dim=1)
        else:
            attention_weights = attention_weights.view(-1, 1, self._encoder.get_feature_map_size())
            self._attention_weights = attention_weights

        # Update state
        # Shape: (group_size, decoder_dim)
        # Shape: (group_size, height * width)
        state['h'] = h
        state['sum_attention_weights'] = sum_attention_weights

        # Run log_softmax over logit predictions
        # Shape: (group_size, vocab_size)
        log_preds = F.log_softmax(predictions, dim=1)

        return log_preds, state
    
    def _beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Get data from state
        x = state['x']
        h = state['h']
        
        # Get actual size of current batch
        local_batch_size = x.shape[0]

        # Beam search wants initial preds of shape: (batch_size)
        # Shape: (batch_size)
        # Shape: (batch_size, height * width)    
        initial_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1)
        sum_attention_weights = torch.zeros(local_batch_size, self._encoder.get_feature_map_size(), device=device)

        state = {'x': x, 'h': h, 'sum_attention_weights': sum_attention_weights}
        
        self._attention_weights = None

        # Timesteps returned aren't necessarily max_timesteps
        # Shape: (batch_size, beam_size, timesteps), (batch_size, beam_size)        
        predictions, log_probabilities = self.beam_search.search(initial_indices, state, self._beam_search_step)

        # Only keep best predictions from beam search
        # Shape: (batch_size, timesteps)
        predictions = predictions[:, 0, :].view(local_batch_size, -1)
        
        return predictions
        
    @overrides
    def forward(self, metadata: object, img: torch.Tensor, label: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        # Encode the image
        # Shape: (batch_size, height * width, encoder_dim)
        x = self._encoder(img)

        state = {'metadata': metadata, 'x': x}
        # Compute loss on train and val
        if label is not None:
            # Initialize h and c
            # Shape: (batch_size, decoder_dim)
            state['h'] = self._init_hidden(x)

            # Convert label dict to tensor since label isn't an input to the model and get mask
            # Shape: (batch_size, timesteps)
            state['mask'] = get_text_field_mask(label).to(device)
            # Shape: (batch_size, timesteps)
            state['label'] = label['tokens']

            # Decode encoded image and get loss on train and val
            state = self._decode(state)

            # Loss shouldn't be computed on start token
            state['mask'] = state['mask'][:, 1:].contiguous()
            state['target'] = state['label'][:, 1:].contiguous()

            # Compute cross entropy loss
            state['loss'] = sequence_cross_entropy_with_logits(state['logits'], state['target'], state['mask'])
            # No doubly stochastic loss in WAP
            # Doubly stochastic regularization
#             state['loss'] += ((1 - torch.sum(state['attention_weights'], dim=1)) ** 2).mean()

        # Decode encoded image with beam search on val and test
        if not self.training:
            # (Re)initialize h
            state['h'] = self._init_hidden(state['x'])
            
            # Run beam search
            state['out'] = self._beam_search(state)

            # Save attention weights
            state['attention_weights'] = self._attention_weights

            # Compute validation scores
            if 'label' in state:
                self._bleu(state['out'], state['target'])
                self._exprate(state['out'], state['target'])
            
        # Set out to logits while training
        else:
            state['out'] = state['logits']
            
        return state
    
@Model.register('multiscale')
class Multiscale(ImageCaptioning):
    def __init__(self, vocab: Vocabulary, encoder: Encoder, decoder: CaptioningDecoder, max_timesteps: int = 75, teacher_forcing: bool = True, scheduled_sampling_ratio: float = 1, beam_size: int = 10) -> None:
        super().__init__(vocab, encoder, decoder, max_timesteps, teacher_forcing, scheduled_sampling_ratio, beam_size)

    @overrides
    def _init_hidden(self, encoder: torch.Tensor) -> torch.Tensor:
        mean_encoder = encoder[0].mean(dim=1)
        
        # Shape: (batch_size, decoder_dim)
        initial_h = self._init_h(mean_encoder)

        return initial_h

    @overrides
    def _decode(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Get data from state
        metadata = state['metadata']
        x = state['x']
        h = state['h']
        label = state['label']
        mask = state['mask']
        
        # Get actual size of current batch
        # Use main features to find current batch_size
        local_batch_size = x[0].shape[0]

        # Sort data to be able to only compute relevent parts of the batch at each timestep
        # Shape: (batch_size)
        lengths = mask.sum(dim=1)
        # Shape: (batch_size) (batch_size)
        sorted_lengths, indices = lengths.sort(dim=0, descending=True)
        # Computing last timestep isn't necessary with labels since last timestep is eos token or pad token 
        timesteps = sorted_lengths[0] - 1

        # Shape: (batch_size, ?)
        # x is a list; Shape: (batch_size, height * width, encoder_dim), (batch_size, height * width, encoder_dim)
        # Shape: (batch_size, decoder_dim)
        # Shape: (batch_size, timesteps)
        # Shape: (batch_size, timesteps)
        metadata = [metadata[i] for i in indices]
        # Sort indices of values in list separately
        x = [x[0][indices], x[1][indices]]
        h = h[indices]
        label = label[indices]        
        mask = mask[indices]
        
        # Shape: (batch_size, 1)
        predicted_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1, 1)
        
        # Shape: (batch_size, timesteps, vocab_size)
        predictions = torch.zeros(local_batch_size, timesteps, self._decoder.get_output_dim(), device=device)
        # Attention weights is a tuple
        attention_weights = (torch.zeros(local_batch_size, timesteps, self._encoder.get_feature_map_size(), device=device), torch.zeros(local_batch_size, timesteps, self._encoder.get_feature_map_size() * 2 * 2, device=device))
        sum_attention_weights_0 = torch.zeros(local_batch_size, self._encoder.get_feature_map_size(), device=device)
        sum_attention_weights_1 = torch.zeros(local_batch_size, self._encoder.get_feature_map_size() * 2 * 2, device=device)

        for t in range(timesteps):
            # Shape: (batch_offset)
            batch_offset = sum([l > t for l in sorted_lengths.tolist()])

            # Only compute data in valid timesteps
            # x_t is a list; Shape: (batch_offset, height * width, encoder_dim), (batch_offset, height * width, encoder_dim)
            # Shape: (batch_offset, decoder_dim)
            # Shape: (batch_offset, decoder_dim)
            # Shape: (batch_offset, 1)
            x_t = [x[0][:batch_offset], x[1][:batch_offset]]
            h_t = h[:batch_offset]
            predicted_indices_t = predicted_indices[:batch_offset]
            
            # Decode timestep
            # Shape: (batch_size, decoder_dim) (batch_size, vocab_size), (batch_size, encoder_dim, 1), (batch_size, height * width)
            h, preds, attention_weight, sum_attention_weights_0, sum_attention_weights_1 = self._decoder(x_t, h_t, predicted_indices_t, sum_attention_weights_0, sum_attention_weights_1)
            
            # Get new predicted indices to pass into model at next timestep
            # Use teacher forcing if chosen
            if self._teacher_forcing and np.random.random() < self._scheduled_sampling_ratio:
                # Send next timestep's label to next timestep
                # Shape: (batch_size, 1)
                predicted_indices = label[:batch_offset, t + 1].view(-1, 1)
            else:
                # Shape: (batch_size, 1)
                predicted_indices = torch.argmax(preds, dim=1).view(-1, 1)
            
            # Save preds
            predictions[:batch_offset, t, :] = preds
            
            # Attention weights is a tuple
            attention_weights[0][:batch_offset, t, :] = attention_weight[0].view(-1, self._encoder.get_feature_map_size())
            attention_weights[1][:batch_offset, t, :] = attention_weight[1].view(-1, self._encoder.get_feature_map_size() * 2 * 2)

        # Update state and add logits
        state['metadata'] = metadata
        state['x'] = x
        state['h'] = h
        state['label'] = label
        state['mask'] = mask
        state['attention_weights'] = attention_weights
        state['logits'] = predictions
            
        return state

    def _beam_search_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # Group_size is batch_size * beam_size except for first decoding timestep where it is batch_size
        # Shape: (group_size, decoder_dim) (group_size, decoder_dim) (group_size, vocab_size)
        
        # Combine main and multiscale features
        x = [state['x_0'], state['x_1']]
        h, predictions, attention_weights, sum_attention_weights_0, sum_attention_weights_1 = self._decoder(x, state['h'], last_predictions, state['sum_attention_weights_0'], state['sum_attention_weights_1'])
    
        # Attention weights is a tuple with main and multiscale features
        if self._attention_weights is not None:
            attention_weights = (attention_weights[0].view(-1, self._beam_size, 1, self._encoder.get_feature_map_size()), attention_weights[1].view(-1, self._beam_size, 1, self._encoder.get_feature_map_size() * 2 * 2))
            self._attention_weights = (torch.cat([self._attention_weights[0], attention_weights[0][:, 0, :, :]], dim=1), torch.cat([self._attention_weights[1], attention_weights[1][:, 0, :, :]], dim=1))
        else:
            attention_weights = (attention_weights[0].view(-1, 1, self._encoder.get_feature_map_size()), attention_weights[1].view(-1, 1, self._encoder.get_feature_map_size() * 2 * 2))
            self._attention_weights = attention_weights

        # Update state
        # Shape: (group_size, decoder_dim)
        state['h'] = h
        
        state['sum_attention_weights_0'] = sum_attention_weights_0
        state['sum_attention_weights_0'] = sum_attention_weights_0

        # Run log_softmax over logit predictions
        # Shape: (group_size, vocab_size)
        log_preds = F.log_softmax(predictions, dim=1)
        state

        return log_preds, state

    @overrides
    def _beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Get data from state
        x = state['x']
        h = state['h']
        
        # x is a list; use main features; Get actual size of current batch
        local_batch_size = x[0].shape[0]

        # Beam search wants initial preds of shape: (batch_size)
        # Shape: (batch_size)
        initial_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1)        
        sum_attention_weights_0 = torch.zeros(local_batch_size, self._encoder.get_feature_map_size(), device=device)
        sum_attention_weights_1 = torch.zeros(local_batch_size, self._encoder.get_feature_map_size() * 2 * 2, device=device)

        # Beam search requires tensors, not lists
        state = {'x_0': x[0], 'x_1': x[1], 'h': h, 'sum_attention_weights_0': sum_attention_weights_0, 'sum_attention_weights_1': sum_attention_weights_1}
        
        # Timesteps returned aren't necessarily max_timesteps
        # Shape: (batch_size, beam_size, timesteps), (batch_size, beam_size)
        
        self._attention_weights = None
        
        predictions, log_probabilities = self.beam_search.search(initial_indices, state, self._beam_search_step)

        # Only keep best predictions from beam search
        # Shape: (batch_size, timesteps)
        predictions = predictions[:, 0, :].view(local_batch_size, -1)
        
        return predictions

    @overrides
    def forward(self, metadata: object, img: torch.Tensor, label: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        # Encode the image
        # Shape: (batch_size, height * width, encoder_dim)
        x = self._encoder(img)

        state = {'metadata': metadata, 'x': x}
        # Compute loss on train and val
        if label is not None:
            # Initialize h and c
            # Shape: (batch_size, decoder_dim)
            state['h'] = self._init_hidden(x)

            # Convert label dict to tensor since label isn't an input to the model and get mask
            # Shape: (batch_size, timesteps)
            state['mask'] = get_text_field_mask(label).to(device)
            # Shape: (batch_size, timesteps)
            state['label'] = label['tokens']

            # Decode encoded image and get loss on train and val
            state = self._decode(state)

            # Loss shouldn't be computed on start token
            state['mask'] = state['mask'][:, 1:].contiguous()
            state['target'] = state['label'][:, 1:].contiguous()

            # Compute cross entropy loss
            state['loss'] = sequence_cross_entropy_with_logits(state['logits'], state['target'], state['mask'])
            # Doubly stochastic regularization
            # Can't use doubly stochastic regularization with multiscale features
            # state['loss'] += ((1 - torch.sum(state['attention_weights'], dim=1)) ** 2).mean()

        # Decode encoded image with beam search on val and test
        if not self.training:
            # (Re)initialize h
            state['h'] = self._init_hidden(state['x'])
            
            # Run beam search
            state['out'] = self._beam_search(state)

            # Save attention weights
            # Predictor needs tensors, not tuple
            state['main_attention_weights'] = self._attention_weights[0]
            state['multiscale_attention_weights'] = self._attention_weights[1]

            # Compute validation scores
            if 'label' in state:
                self._bleu(state['out'], state['target'])
                self._exprate(state['out'], state['target'])
            
        # Set out to logits while training
        else:
            state['out'] = state['logits']
            
        return state

In [None]:
%%writefile math_handwriting_recognition/predictor.py
import os
import random
from typing import Dict, Tuple
from overrides import overrides
import json

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import skimage
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision

import allennlp

from allennlp.common import Registrable, Params
from allennlp.common.util import START_SYMBOL, END_SYMBOL, JsonDict

from allennlp.data import DatasetReader
from allennlp.data.vocabulary import Vocabulary

from allennlp.models import Model

from allennlp.predictors.predictor import Predictor

from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper # MIGHT USE FOR ABSTRACTION

from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.nn.beam_search import BeamSearch

from allennlp.training.metrics import F1Measure, BLEU

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

@Predictor.register('CROHME')
class MathPredictor(Predictor):
    def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
        super().__init__(model, dataset_reader)
        
        self._start_idx = np.random.randint(0, 100)
        self._counter = self._start_idx

    def dump_line(self, outputs: JsonDict) -> str:
        beam_search_preds = [self._model.vocab.get_token_from_index(i) for i in outputs['out']]
        preds = ' '.join(beam_search_preds)
        idx = preds.index('@end@')
        preds = preds[:idx]
#         out = '\n\nPred: ' + preds + '\n'
        out = preds + '\n'

        if 'label' in outputs:
            label = ' '.join([self._model.vocab.get_token_from_index(i) for i in outputs['label']])
            end_idx = label.index('@end@')
            label = label[8:end_idx]
#             out += 'Gold: ' + label + '\n'
            out += label + '\n'

#         if 'logits' in outputs:
#             logits = np.array(outputs['logits'])
#             out += 'Logits: ' + str([self._model.vocab.get_token_from_index(np.argmax(logits[i])) for i in range(logits.shape[0])])

        # Save visualizations for first 10 preds
        if self._counter - self._start_idx < 10:
            img = plt.imread(outputs['metadata']['path'])
            img = cv2.resize(img, (512, 128))
            
            attention_weights = np.array(outputs['attention_weights'])
            timesteps = attention_weights.shape[0]
            
            fig=plt.figure(figsize=(20, 20))
            fig.tight_layout() 
            columns = 8
            rows = 10
            for i in range(1, timesteps + 1):
                ax = fig.add_subplot(rows, columns, i)
                ax.set_title(f'{beam_search_preds[i-1]}')

                plt.imshow(img)

                attention_weight = attention_weights[i-1].reshape(4, 16)
#                 attention_weight = attention_weights[i-1].reshape(8, 32)
                attention_weight = skimage.transform.pyramid_expand(attention_weight, upscale=32, sigma=8)
#                 attention_weight = skimage.transform.pyramid_expand(attention_weight, upscale=16, sigma=8)
                plt.imshow(attention_weight, alpha=0.8)
            
            save_path = 'visualization_' + outputs['metadata']['path'].split('/')[2] + f'_{self._counter}.png'
            fig.savefig(save_path)
            
            self._counter += 1
            
        return out

@Predictor.register('WAP')
class WAPPredictor(Predictor):
    def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
        super().__init__(model, dataset_reader)
        
        self._start_idx = np.random.randint(0, 100)
        self._counter = self._start_idx

    def dump_line(self, outputs: JsonDict) -> str:
        beam_search_preds = [self._model.vocab.get_token_from_index(i) for i in outputs['out']]
        out = '\n\nPred: ' + ' '.join(beam_search_preds) + '\n'

        if 'logits' in outputs:
            logits = np.array(outputs['logits'])
            out += 'Logits: ' + str([self._model.vocab.get_token_from_index(np.argmax(logits[i])) for i in range(logits.shape[0])]) + '\n'
                
        if 'label' in outputs:
            out += 'Gold: ' + ' '.join([self._model.vocab.get_token_from_index(i) for i in outputs['label']])
                
        # Save visualizations for first 10 preds
        if self._counter - self._start_idx < 10:
            img = plt.imread(outputs['metadata']['path'])
            img = cv2.resize(img, (512, 128))
            
            attention_weights = np.array(outputs['attention_weights'])
            timesteps = attention_weights.shape[0]

            fig=plt.figure(figsize=(20, 20))
            fig.tight_layout() 
            columns = 8
            rows = 10
            for i in range(1, timesteps + 1):
                ax = fig.add_subplot(rows, columns, i)
                ax.set_title(f'{beam_search_preds[i-1]}')
                
                plt.imshow(img)
                
                attention_weight = attention_weights[i-1].reshape(8, 32)
                attention_weight = skimage.transform.pyramid_expand(attention_weight, upscale=16, sigma=8)
                plt.imshow(attention_weight, alpha=0.8)
                
            save_path = 'visualization_' + outputs['metadata']['path'].split('/')[2] + f'_{self._counter}.png'
            fig.savefig(save_path)
            
            self._counter += 1
            
        return out

@Predictor.register('multiscale')
class MathPredictor(Predictor):
    def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
        super().__init__(model, dataset_reader)
        
        self._start_idx = np.random.randint(0, 100)
        self._counter = self._start_idx

    def dump_line(self, outputs: JsonDict) -> str:
        beam_search_preds = [self._model.vocab.get_token_from_index(i) for i in outputs['out']]
        out = '\n\nPred: ' + ' '.join(beam_search_preds) + '\n'

        if 'logits' in outputs:
            logits = np.array(outputs['logits'])
            out += 'Logits: ' + str([self._model.vocab.get_token_from_index(np.argmax(logits[i])) for i in range(logits.shape[0])]) + '\n'
                
        if 'label' in outputs:
            out += 'Gold: ' + ' '.join([self._model.vocab.get_token_from_index(i) for i in outputs['label']])
    
        # Save visualizations for first 10 preds
        if self._counter - self._start_idx < 10:
            img = plt.imread(outputs['metadata']['path'])
            img = cv2.resize(img, (512, 128))
            
            attention_weights = np.array(outputs['main_attention_weights'])
            timesteps = attention_weights.shape[0]

            fig=plt.figure(figsize=(20, 20))
            fig.tight_layout() 
            columns = 8
            rows = 10
            for i in range(1, timesteps + 1):
                ax = fig.add_subplot(rows, columns, i)
                ax.set_title(f'{beam_search_preds[i-1]}')
                
                plt.imshow(img)
                
                attention_weight = attention_weights[i-1].reshape(8, 32)
                attention_weight = skimage.transform.pyramid_expand(attention_weight, upscale=16, sigma=8)
                plt.imshow(attention_weight, alpha=0.8)
                
            save_path = 'visualization_' + outputs['metadata']['path'].split('/')[2] + f'_{self._counter}_main_branch.png'
            fig.savefig(save_path)

            attention_weights = np.array(outputs['multiscale_attention_weights'])
            timesteps = attention_weights.shape[0]

            fig=plt.figure(figsize=(20, 20))
            fig.tight_layout() 
            columns = 8
            rows = 10
            for i in range(1, timesteps + 1):
                ax = fig.add_subplot(rows, columns, i)
                ax.set_title(f'{beam_search_preds[i-1]}')
                
                plt.imshow(img)
                
                attention_weight = attention_weights[i-1].reshape(16, 64)
                attention_weight = skimage.transform.pyramid_expand(attention_weight, upscale=8, sigma=8)
                plt.imshow(attention_weight, alpha=0.8)
                
            save_path = 'visualization_' + outputs['metadata']['path'].split('/')[2] + f'_{self._counter}_multiscale_branch.png'
            fig.savefig(save_path)
            
            self._counter += 1
            
        return out

In [None]:
%%writefile config.json
{
    "dataset_reader": {
        "type": "CROHME",
        "root_path": "./",
        "height": 128,
        "width": 512,
        "lazy": true,
        "subset": false,
        "tokenizer": {
            "type": "latex"
        }
    },
    "train_data_path": "crohme-train/train.csv",
    "validation_data_path": "crohme-train/val.csv",
    "model": {
        "type": "image-captioning",
        "encoder": {
            "type": "lstm",
            "encoder": {
                "type": 'backbone',
                "encoder_type": 'resnet18',
                "encoder_height": 4,
                "encoder_width": 16,
                "pretrained": true,
                "custom_in_conv": false
            },
            "layers": 1,
            "bidirectional": false
        },
        "decoder": {
            "type": "image-captioning",
            "attention": {
                "type": 'image-captioning',
                "encoder_dim": 512, # Must be encoder dim of chosen encoder
                "decoder_dim": 256, # Must be same as decoder's decoder_dim
                "attention_dim": 256,
                "doubly_stochastic_attention": true
            },
            "embedding_dim": 256,
            "decoder_dim": 256
        },
        "max_timesteps": 75,
        "beam_size": 10,
        "teacher_forcing": true,
        "scheduled_sampling_ratio": 1,
    },
    "iterator": {
        "type": "bucket",
        "sorting_keys":[["label", "num_tokens"]],
        "batch_size": 16
    },
    "trainer": {
        "num_epochs": 20,
        "cuda_device": 0,
        "optimizer": {
            "type": "sgd",
            "lr": 0.1,
            "momentum": 0.9
        },
        "grad_clipping": 5,
        "validation_metric": "+exprate",
        "learning_rate_scheduler": {
            "type": "reduce_on_plateau",
            "factor": 0.5,
            "patience": 5
        },
        "num_serialized_models_to_keep": 1,
        "summary_interval": 10,
        "histogram_interval": 100,
        "should_log_parameter_statistics": true,
        "should_log_learning_rate": true
    }
}

In [None]:
# # Save vocabulary in advance
# !allennlp make-vocab -s ./ --include-package math_handwriting_recognition config.json

# # Find best learning rate
# !allennlp find-lr -s ./logs --start-lr 0.001 --end-lr 10 --num-batches=100 --include-package math_handwriting_recognition config.json
# x = plt.imread('./logs/lr-losses.png')
# fig, ax = plt.subplots(figsize=(10, 10))
# ax.imshow(x, interpolation='nearest')
# !rm -rf logs/*

# # Use Allennlp's online configuration tool
# get_ipython().system_raw('allennlp configure --include-package math_handwriting_recognition &')
# !ssh -o "StrictHostKeyChecking no" -R 80:localhost:8123 serveo.net

# # Dry run configuration
# !allennlp dry-run -s ./logs --include-package math_handwriting_recognition config.json
# !rm -rf ./logs/*

# # Predict with last checkpoint
# !allennlp predict --output-file ./out.txt --weights-file ./logs/model_state_epoch_9.th --batch-size 64 --silent --cuda-device 0 --use-dataset-reader --predictor math-predictor --include-package math_handwriting_recognition ./logs/model.tar.gz test.csv

In [None]:
!allennlp train config.json -s ./logs --include-package math_handwriting_recognition
# !rm -rf logs/*

In [None]:
!allennlp evaluate --cuda-device 0 --include-package math_handwriting_recognition ./logs/model.tar.gz crohme-train/train.csv
!allennlp evaluate --cuda-device 0 --include-package math_handwriting_recognition ./logs/model.tar.gz crohme-train/val.csv

In [None]:
!allennlp predict --output-file ./out.txt --batch-size 64 --cuda-device 0 --use-dataset-reader --predictor CROHME --include-package math_handwriting_recognition --silent ./logs/model.tar.gz crohme-train/train.csv

In [None]:
!head -10 out.txt

In [None]:
!allennlp predict --output-file ./out.txt --batch-size 64 --cuda-device 0 --use-dataset-reader --predictor CROHME --include-package math_handwriting_recognition --silent ./logs/model.tar.gz crohme-train/val.csv

In [None]:
!head -10 out.txt

In [None]:
from tensorboardX import SummaryWriter

writer = SummaryWriter('./logs')

for img_name in glob.glob('./logs/math-recognition/visualization_*.png'):
    img = torch.from_numpy(plt.imread(img_name)[:, :, :3].transpose(2, 0, 1))

    writer.add_image(img_name, img)

In [None]:
!cat logs/metrics.json

In [None]:
!rm -rf crohme-train
!rm -rf crohme-val

In [None]:
with open('./logs/metrics.json', 'r') as metrics:
    metrics = metrics.read()

notify('Metrics: ', f'{metrics}')