In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!mkdir /content/drive/MyDrive/ocr_params -p
!cp /content/drive/MyDrive/datasets/ocr_data.zip .
# this lists checkpoints with the full path, easy to copy the path and paste
!find /content/drive/MyDrive/ocr_params

In [None]:
# Get utilities
import os
import platform

if not os.path.exists("ocr_data.zip"):
    print("upload images")
elif platform.system() == "Linux":
    os.system("git clone https://github.com/n1teshy/sequence-transduction && mv sequence-transduction/core . && rm -rf sequence-transduction")
    os.system("git clone https://github.com/n1teshy/cache && mv cache/ocr/tokenizers . && rm -rf cache")
    os.system("unzip ocr_data.zip -d . > /dev/null && rm ocr_data.zip")
else:
    os.system("git clone https://github.com/n1teshy/sequence-transduction & move sequence-transduction/core . & rd /s /q sequence-transduction")
    os.system("git clone https://github.com/n1teshy/cache & move cache/ocr/tokenizers . & rd /s /q cache")
    os.system("powershell Expand-Archive -Path ocr_data.zip -DestinationPath . > NUL & del ocr_data.zip")

In [None]:
import torch
import torch.nn.functional as F
from core.models import OCR, resnet18
from core.datasets.image import OCRDataset
from core.tokenizers.regex import get_tokenizer
from torch.utils.data import DataLoader
from core.utils import get_param_count, kaiming_init
from core.config import device

In [2]:
TRAIN_FOLDER, BATCH_SIZE = "data/train", 8
TEST_FOLDER = "data/test"
tokenizer = get_tokenizer("_.txt", 384, "tokenizers/en")
train_dataset = OCRDataset(
    TRAIN_FOLDER, mapping_file="meta/labels.txt", tokenizer=tokenizer
)
test_dataset = OCRDataset(
    TEST_FOLDER, mapping_file="meta/labels.txt", tokenizer=tokenizer
)

train_dataloader = DataLoader(
    train_dataset, collate_fn=train_dataset.collate, batch_size=BATCH_SIZE, shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, collate_fn=test_dataset.collate, batch_size=BATCH_SIZE, shuffle=True
)

def get_test_loss(model):
    with torch.no_grad():
        for pixels, tokens in test_dataloader:
            logits = model(pixels, tokens[:, :-1])
            B, T, C = logits.shape
            logits, tokens = logits.reshape((B * T, C)), tokens[:, 1:].reshape(-1)
            return F.cross_entropy(logits, tokens)

In [3]:
EPOCHS = 100
LEARNING_RATE = 0.001
EMBEDDING_SIZE = 256
VOCAB_SZE = tokenizer.size
MAX_LEN = 100
DEC_LAYERS = 5
DEC_HEADS = 4
PADDING_ID = train_dataset.pad_id
ENCODER = resnet18(num_classes=EMBEDDING_SIZE)
MIN_PROGRESS = 0.1

In [None]:
model = OCR.spawn(
    encoder= ENCODER.to(device),
    out_vocab_size=VOCAB_SZE,
    embedding_size=EMBEDDING_SIZE,
    max_len=MAX_LEN,
    dec_layers=DEC_LAYERS,
    dec_heads=DEC_HEADS,
    tgt_pad_id=PADDING_ID,
)
# model.load_state_dict(torch.load("", map_location=device))
kaiming_init(model)
print(f"{get_param_count(model)/1e6} mn params")

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [6]:
mean_train_loss, mean_test_loss, last_saved_at = None, None, None

In [None]:
def save_model(folder):
    model_filename = "ocr_%.4f_%.4f_class_%d_lr_%.4f.pth" % (
        mean_train_loss,
        mean_test_loss,
        EMBEDDING_SIZE,
        LEARNING_RATE,
    )
    torch.save(model.state_dict(), os.path.join(folder, f"{model_filename}.pth"))

In [None]:
for epoch in range(1, EPOCHS + 1):
    for batch, (pixels, tokens) in enumerate(train_dataloader, start=1):
        logits = model(pixels, tokens[:, :-1])
        B, T, C = logits.shape
        logits, tokens = logits.reshape((B * T, C)), tokens[:, 1:].reshape(-1)
        train_loss = F.cross_entropy(logits, tokens)
        mean_train_loss = (
            mean_train_loss or train_loss.item()
        ) * 0.99 + train_loss.item() * 0.01
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        test_loss = get_test_loss(model)
        mean_test_loss = (
            mean_test_loss or test_loss.item()
        ) * 0.99 + test_loss.item() * 0.01
        print(
            "%d:%d -> %.4f(mean:%.4f), %.4f(mean:%.4f)"
            % (
                epoch,
                batch,
                train_loss.item(),
                mean_train_loss,
                test_loss.item(),
                mean_test_loss,
            )
        )
        if last_saved_at - mean_train_loss >= MIN_PROGRESS:
            save_model("drive/MyDrive/ocr_params")
            print(f"saved model at train loss {mean_train_loss}")
            last_saved_at = mean_train_loss

In [None]:
def predict(image, bos_id, eos_id):
    context = torch.tensor([[bos_id]], device=device)
    while True:
        logits = model(image, context)
        probs = F.softmax(logits, dim=-1)
        probs = probs.view(-1, probs.shape[-1])
        choices = torch.multinomial(probs, num_samples=1)
        choices = choices[-1, :]
        if choices.item() == eos_id:
            break
        context = torch.cat((context, choices.unsqueeze(0)), dim=1)
    return context[0].tolist()

In [None]:
for images, tgt_tokens in test_dataloader:
    image = images[0].unsqueeze(0)
    pred_tokens = predict(image, test_dataset.bos_id, test_dataset.eos_id)
    print(tokenizer.decode(tgt_tokens[0].tolist()), tokenizer.decode(pred_tokens))