In [None]:
# Get tokenizer files
!git clone https://github.com/n1teshy/cache && mv cache/tokenizers . && rm -rf cache
# Extract data
!unzip ocr_data.zip -d . > /dev/null && rm ocr_data.zip

In [None]:
import torch
import torch.nn.functional as F
from core.models import OCR, resnet18
from core.datasets.image import OCRDataset
from core.tokenizers.regex import get_tokenizer
from torch.utils.data import DataLoader
from core.utils import get_param_count
from core.config import device

In [2]:
TRAIN_FOLDER, BATCH_SIZE = "data/textract_images/train", 32
TEST_FOLDER = "data/textract_images/test"
tokenizer = get_tokenizer("_.txt", 384, "tokenizers/en")
train_dataset = OCRDataset(
    TRAIN_FOLDER, mapping_file="meta/mapping.txt", tokenizer=tokenizer
)
test_dataset = OCRDataset(
    TEST_FOLDER, mapping_file="meta/mapping.txt", tokenizer=tokenizer
)

train_dataloader = DataLoader(
    train_dataset, collate_fn=train_dataset.collate, batch_size=BATCH_SIZE, shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, collate_fn=test_dataset.collate, batch_size=BATCH_SIZE, shuffle=True
)

def get_test_loss(model):
    with torch.no_grad():
        for pixels, tokens in test_dataloader:
            logits = model(pixels, tokens[:, :-1])
            B, T, C = logits.shape
            logits, tokens = logits.reshape((B * T, C)), tokens[:, 1:].reshape(-1)
            return F.cross_entropy(logits, tokens)

In [3]:
EPOCHS = 100
LEARNING_RATE = 0.001
EMBEDDING_SIZE = 256
VOCAB_SZE = tokenizer.size
MAX_LEN = 100
DEC_LAYERS = 5
DEC_HEADS = 4
PADDING_ID = train_dataset.pad_id
ENCODER = resnet18(num_classes=EMBEDDING_SIZE)

In [None]:
model = OCR.spawn(
    encoder= ENCODER.to(device),
    out_vocab_size=VOCAB_SZE,
    embedding_size=EMBEDDING_SIZE,
    max_len=MAX_LEN,
    dec_layers=DEC_LAYERS,
    dec_heads=DEC_HEADS,
    tgt_pad_id=PADDING_ID,
)
print(f"{get_param_count(model)/1e6} mn params")

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [6]:
mean_train_loss, mean_test_loss = None, None

In [None]:
for epoch in range(1, EPOCHS + 1):
    for batch, (pixels, tokens) in enumerate(train_dataloader, start=1):
        logits = model(pixels, tokens[:, :-1])
        B, T, C = logits.shape
        logits, en_batch = logits.reshape((B * T, C)), tokens[:, 1:].reshape(-1)
        train_loss = F.cross_entropy(logits, en_batch)
        mean_train_loss = (
            mean_train_loss or train_loss.item()
        ) * 0.99 + train_loss.item() * 0.01
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        test_loss = get_test_loss(model)
        mean_test_loss = (
            mean_test_loss or test_loss.item()
        ) * 0.99 + test_loss.item() * 0.01
        print(
            "%d:%d -> %.4f(mean:%.4f), %.4f(mean:%.4f)"
            % (
                epoch,
                batch,
                train_loss.item(),
                mean_train_loss,
                test_loss.item(),
                mean_test_loss,
            )
        )
