In [None]:
%reload_ext autoreload

import os
import sys
import random

import torch
import numpy as np
import matplotlib.pyplot as plt

# Append python path - needed to import text_recognizer
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:

import text_recognizer.data
import text_recognizer.models
import text_recognizer.lit_models

from text_recognizer.metadata.iam_paragraphs import IMAGE_HEIGHT, IMAGE_WIDTH


### IAM Paragraphs dataset

In [None]:
iam_paragraphs = text_recognizer.data.IAMParagraphs()

iam_paragraphs.prepare_data()
iam_paragraphs.setup()
x, y = next(iter(iam_paragraphs.val_dataloader()))

iam_paragraphs

In [None]:
def show(y):
    y = y.detach().cpu()  # bring back from accelerator if it's being used
    return "".join(np.array(iam_paragraphs.mapping)[y]).replace("<P>", "")

idx = random.randint(0, len(x))

print(show(y[idx]))
plt.imshow(x[idx].view(IMAGE_HEIGHT, IMAGE_WIDTH), cmap='Greys_r')
plt.axis("off");

### ResNet Transformer

In [None]:
rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())

In [None]:
# Cast to GPU if available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

rnt.to(device); x = x.to(device); y = y.to(device);

In [None]:
# Pass single input through encoder

# ResNet is designed for RGB images, replicate the input across channels 3 times
resnet_embedding, = rnt.resnet(x[idx:idx+1].repeat(1, 3, 1, 1))

resnet_idx = random.randint(0, len(resnet_embedding))  # re-execute to view a different channel
plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap="Greys_r");
plt.axis("off"); plt.colorbar(fraction=0.05);

In [None]:
preds, = rnt(x[idx:idx+1])

In [None]:
# Prediction from untrained model
print(show(preds.cpu()))
plt.imshow(x[idx].view(IMAGE_HEIGHT, IMAGE_WIDTH), cmap='Greys_r')
plt.axis("off");

### TransformerLitModel

In [None]:
import text_recognizer.lit_models

lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)

In [None]:
forcing_outs, = lit_rnt.teacher_forward(x[idx:idx+1], y[idx:idx+1])

In [None]:
forcing_preds = torch.argmax(forcing_outs, dim=0)

print(show(forcing_preds.cpu()))
plt.imshow(x[idx].view(IMAGE_HEIGHT, IMAGE_WIDTH), cmap='Greys_r')
plt.axis("off");

### Run Experiment - ResNet Transformer - IAM Paragraphs

In [None]:
%run ../training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \
  --fast_dev_run True --log_every_n_steps 1 --limit_test_batches 0 --accelerator 'auto' \
  --max_epochs 1 --batch_size 16 --precision 'bf16' \
  --limit_train_batches 1 --limit_val_batches 1 