In [1]:
from functools import reduce
import os
import json

In [2]:
HEIGHT2STRIDE = {
    16: [2] * 4 + [1],
    32: [2] * 5,
    64: [4] + [2] * 4,
    128: [4] * 2 + [2] * 3,
    256: [4] * 3 + [2] * 2,
    512: [4] * 4 + [2] * 1, 
}

for k, v in HEIGHT2STRIDE.items():
    assert reduce(lambda x, y: x*y, v) == k

In [3]:
vocab = open('vocab.txt').read().splitlines()
vocab_size = len(vocab)
vocab_size

61

In [4]:
SIZE = 16

In [5]:
stride_list = HEIGHT2STRIDE[SIZE]
images_path = f'images_size/size={SIZE}/'
with open("size2labels.json", 'rb') as f:
    size2labels = json.load(f)
labels = size2labels[str(SIZE)]
val_labels = {k:v for k, v in labels.items() if int(k) < 100}
train_labels = {k:v for k, v in labels.items() if int(k) >= 100}

In [6]:
len(val_labels), len(train_labels)

(100, 899)

# Dataset

In [7]:
from torch.utils.data import Dataset, DataLoader
import torchvision as tv
from PIL import Image

In [8]:
class SimpleDataset(Dataset):
    def __init__(self, images_path, labels):
        self.images_path = images_path
        self.labels = labels
        self.keys = sorted(list(labels.keys()))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        idx = self.keys[idx]
        image_path = os.path.join(self.images_path, f'{idx}.png')
        image = Image.open(image_path)
        image = image.convert('RGB')
        
        image = tv.transforms.ToTensor()(image)
        
        label = self.labels[str(idx)]
        return image, label

In [9]:
train_dataset = SimpleDataset(images_path, train_labels)
val_dataset = SimpleDataset(images_path, val_labels)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Model

In [11]:
import torch

from model import CNN
from ctc import GreedyCTCDecoder
from metrics import compute_f1, compute_exact

In [12]:
decoder = GreedyCTCDecoder()

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
def prepare_batch(batch):
    images, texts = batch
    images = images.to(device)

    y = [[
        vocab.index(t) for t in txt
    ] for txt in texts]
    y = torch.tensor(y, dtype=torch.long).to(device)

    return images, texts, y

In [15]:
def get_ctc_loss(logits, y):
    logits = logits.permute(1, 0, 2).log_softmax(2)
    input_lengths = torch.full(
        size=(logits.size(1),),
        fill_value=logits.size(0),
        dtype=torch.int32,
    )
    target_lengths = torch.full(
        size=(y.size(0),),
        fill_value=y.size(1),
        dtype=torch.int32,
    )
    criterion = torch.nn.CTCLoss(zero_infinity=True)
    loss = criterion(logits, y, input_lengths, target_lengths)
    return loss

In [16]:
def get_predictions(logits):
    yp = logits.argmax(-1)
    pt = []
    for row in yp:
        predictions = decoder(row, None)
        pt.append(''.join(vocab[p] for p in predictions))
    return pt

In [17]:
model = CNN(stride_list, vocab_size)
_ = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [18]:
gradient_steps = 32
train_losses = []
val_epoch = {
    "loss": [],
    "f1": [],
    "em": [],
}
for epoch in range(100):

    model.train()
    for idx, batch in enumerate(train_loader):
        images, texts, y = prepare_batch(batch)
        logits = model(images)
        loss = get_ctc_loss(logits, y)
        loss.backward()
        if idx % gradient_steps == 0 or idx == len(train_loader) - 1:
            optimizer.step()
            optimizer.zero_grad()
        
        train_losses.append(loss.item())

    model.eval()
    with torch.no_grad():
        val_losses = []
        val_f1 = []
        val_em = []
        for batch in val_loader:
            images, texts, y = prepare_batch(batch)
            logits = model(images)
            pt = get_predictions(logits)
            loss = get_ctc_loss(logits, y)

            val_losses.append(loss.item())
            for t, p in zip(texts, pt):
                val_f1.append(compute_f1(t, p))
                val_em.append(compute_exact(t, p))
    val_epoch["loss"].append(sum(val_losses) / len(val_losses))
    val_epoch["f1"].append(sum(val_f1) / len(val_f1))
    val_epoch["em"].append(sum(val_em) / len(val_em))
    print(f'epoch: {epoch}, train loss: {sum(train_losses) / len(train_losses)}, val loss: {sum(val_losses) / len(val_losses)}, val f1: {sum(val_f1) / len(val_f1)}, val em: {sum(val_em) / len(val_em)}')

torch.save(model.state_dict(), f'cnn_size={SIZE}.pth')
history = val_epoch
history.update({
    "train_losses": train_losses
})
with open(f'cnn_size={SIZE}.json', 'w') as f:
    json.dump(val_epoch, f)

epoch: 0, train loss: 7.624668448334674, val loss: 6.784406604766846, val f1: 0.0, val em: 0.0
epoch: 1, train loss: 6.463653319934849, val loss: 6.827717645168304, val f1: 0.0, val em: 0.0
epoch: 2, train loss: 5.818309278898342, val loss: 5.03540125131607, val f1: 0.0, val em: 0.0
epoch: 3, train loss: 5.41338135847393, val loss: 5.21288158416748, val f1: 0.0, val em: 0.0
epoch: 4, train loss: 5.1282942871628405, val loss: 4.893181304931641, val f1: 0.0, val em: 0.0
epoch: 5, train loss: 4.912418022061172, val loss: 4.50979599237442, val f1: 0.0, val em: 0.0
epoch: 6, train loss: 4.736394246962308, val loss: 4.206970593929291, val f1: 0.0, val em: 0.0
epoch: 7, train loss: 4.584238116879087, val loss: 4.381089975833893, val f1: 0.003333333333333333, val em: 0.0
epoch: 8, train loss: 4.442851871307311, val loss: 4.0338801407814024, val f1: 0.0022222222222222222, val em: 0.0
epoch: 9, train loss: 4.301988635551147, val loss: 3.904453282356262, val f1: 0.012523809523809524, val em: 0.0
