<a href="https://colab.research.google.com/github/n1teshy/colab_notebooks/blob/main/OCR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!unzip train.zip -d .
!unzip eval.zip -d .
!mkdir samples
!mv train eval samples
!rm train.zip eval.zip

In [None]:
!pip install torch

In [2]:
import os
import torch
import random

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, TrOCRProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig

In [4]:
def get_labels(path: str) -> dict:
    lines = open(path).read().splitlines()
    splits = [line.split(":", 2) for line in lines]
    mappings = {split[0]: split[1] for split in splits if len(split) == 2}
    return mappings


class ImageDataset(Dataset):
    def __init__(self, directory: str, use_label_file: bool = True) -> None:
        self.directory = directory
        _, __, self.files = next(os.walk(directory))
        if use_label_file:
            labels = get_labels(os.path.join(self.directory, f"meta/labels.txt"))
        else:
            labels = {file: os.path.splitext(file)[0] for file in self.files}
        if len(labels) != len(self.files):
            print(f"number of files does not equal number of label")
        max_label_tokens = max(
            len(tokenizer(label)["input_ids"]) for label in labels.values()
        )
        self.tokenized_labels = {
            file: tokenizer(label, padding="max_length", max_length=max_label_tokens)[
                "input_ids"
            ]
            for file, label in labels.items()
        }
        self.lengths = {
            file: len(label) - label.count(tokenizer.pad_token_id)
            for file, label in self.tokenized_labels.items()
        }

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]
        image = Image.open(os.path.join(self.directory, file)).convert("RGB")
        image = processor(images=image, return_tensors="pt").pixel_values.squeeze(0).to(DEVICE)
        return (
            image,
            torch.tensor(
                [self.tokenized_labels[file]],
                device=DEVICE
            ),
            torch.tensor(self.lengths[file], device=DEVICE),
        )

In [5]:
MODEL_NAME = "microsoft/trocr-large-handwritten"
IMAGE_DIR = "samples/train"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
model_config = VisionEncoderDecoderConfig.from_pretrained(MODEL_NAME)
model_config.decoder_start_token_id = tokenizer.cls_token_id
model_config.pad_token_id = tokenizer.pad_token_id
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME, config=model_config)
model = model.to(DEVICE)

In [8]:
EPOCHS = 10
BATCH_SIZE = 8
LEARNING_RATE = 3e-3
PROGRESS_CHECK_ITERVAL = 10
losses = []
average_loss = 0

In [9]:
train_dataset = ImageDataset("samples/train")
eval_dataset = ImageDataset("samples/eval")
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [10]:
def get_val_loss():
  with torch.no_grad():
    for images, labels, lengths in eval_dataloader:
      out, loss = model(pixel_values=images, labels=labels)
      return loss.item()

In [None]:
for epoch in range(EPOCHS):
  batch_iters = 0
  for images, labels, lengths in train_dataloader:
    optimizer.zero_grad()
    out, loss = model(pixel_values=images, labels=labels)
    cur_loss = loss.item()
    losses.append(cur_loss)
    avg_loss = (average_loss + cur_loss)/(epoch+1)
    print(f"current: {cur_loss}, average: {avg_loss}")
    loss.backward()
    optimizer.step()
    batch_iters += 1
    if batch_iters % PROGRESS_CHECK_ITERVAL == 0:
      print(f"eval loss: {get_val_loss()}")