In [None]:
# transformers pinned to 4.44.2 - later versions have breaking changes with TrOCR
# See: https://discuss.huggingface.co/t/fine-tune-trocr-model/151014
%pip install accelerate jiwer tensorboard transformers==4.44.2

In [None]:
from dataclasses import dataclass
from pathlib import Path

import evaluate
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from datasets import Dataset
from PIL import Image
from transformers import (
    EarlyStoppingCallback,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrOCRProcessor,
    VisionEncoderDecoderModel,
)

project_root = Path.cwd().parent

from utils import RandomNoise, shift_tokens_right

seed_value = 42

In [None]:
base_model = "microsoft/trocr-base-printed"

image_dir = project_root / "data" / "images"
csv_path = project_root / "data" / "ground_truths.csv"

output_dir = project_root / "output"
output_dir.mkdir(parents=True, exist_ok=True)

if not csv_path.exists():
    raise FileNotFoundError(f"Could not find CSV at: {csv_path}.")

df = pd.read_csv(csv_path)
print(f"Loaded {len(df)} records from CSV")
df.head()

In [None]:
processor = TrOCRProcessor.from_pretrained(base_model)
model = VisionEncoderDecoderModel.from_pretrained(base_model)

# Configure model tokens
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

# Dropout settings
model.decoder.config.dropout = 0.01
model.decoder.config.attention_dropout = 0.01

# Dynamic max length calculation based on dataset
token_lengths = df["text"].apply(lambda x: len(processor.tokenizer(x).input_ids))
max_target_length = int((token_lengths.max() + 7) // 8 * 8)
model.generation_config.max_length = max_target_length

print(f"Max target length set to: {max_target_length}")

In [None]:
# Add pathing to images
df["image_path"] = df["filename"].apply(lambda x: str(Path(image_dir) / x))
dataset = Dataset.from_pandas(df[["image_path", "text"]])
train_test = dataset.train_test_split(test_size=0.10, seed=seed_value)
train_ds = train_test["train"]
eval_ds = train_test["test"]

train_transform = transforms.Compose(
    [
        transforms.RandomRotation(degrees=10, fill=(255, 255, 255)),
        transforms.RandomAffine(degrees=0, shear=5, fill=(255, 255, 255)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        RandomNoise(prob=0.25),
    ]
)


def process_data(batch, is_train=True):
    """
    Prepares batch for TrOCR:
    1. Applies augmentations to images (if training).
    2. Tokenizes text and masks padding tokens (-100) for loss calculation.
    3. Creates shifted decoder_input_ids for auto-regressive generation.
    """
    images = [
        (train_transform(Image.open(p).convert("RGB")) if is_train else Image.open(p).convert("RGB"))
        for p in batch["image_path"]
    ]
    pixel_values = processor(images=images, return_tensors="pt").pixel_values

    encoding = processor.tokenizer(
        batch["text"], padding="max_length", truncation=True, max_length=max_target_length, return_tensors="pt"
    )

    labels = encoding.input_ids.clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100

    # shift_tokens_right imported from utils.py
    decoder_input_ids = shift_tokens_right(
        encoding.input_ids, processor.tokenizer.pad_token_id, model.config.decoder_start_token_id
    )

    return {"pixel_values": pixel_values, "decoder_input_ids": decoder_input_ids, "labels": labels}


# Apply processing to datasets
train_ds = train_ds.map(
    lambda b: process_data(b, True), batched=True, batch_size=8, remove_columns=train_ds.column_names
)
eval_ds = eval_ds.map(lambda b: process_data(b, False), batched=True, batch_size=8, remove_columns=eval_ds.column_names)

train_ds.set_format(type="torch", columns=["pixel_values", "decoder_input_ids", "labels"])
eval_ds.set_format(type="torch", columns=["pixel_values", "decoder_input_ids", "labels"])

In [None]:
import datetime


cer_metric = evaluate.load("cer")
wer_metric = evaluate.load("wer")
exact_match_metric = evaluate.load("exact_match")


def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    exact_match = exact_match_metric.compute(predictions=pred_str, references=label_str)

    error_count = 0

    with open(output_dir / "eval_errors.txt", "a", encoding="utf-8") as f:
        f.write(f"\n{'='*20} EVAL {'='*20}\n")
        for p, l in zip(pred_str, label_str):
            if p != l:
                error_count += 1
                f.write(f"TRUTH: {l}\nPRED : {p}\n{'-'*30}\n")

    if error_count > 0:
        print(f"Logged {len(error_count)} errors to eval_errors.txt")

    return {"cer": cer, "wer": wer, "exact_match": exact_match["exact_match"]}

In [None]:
@dataclass
class DataCollator:
    processor: TrOCRProcessor

    def __call__(self, features):
        return {
            "pixel_values": torch.stack([f["pixel_values"] for f in features]),
            "decoder_input_ids": torch.stack([f["decoder_input_ids"] for f in features]),
            "labels": torch.stack([f["labels"] for f in features]),
        }


training_args = Seq2SeqTrainingArguments(
    output_dir=str(output_dir),
    evaluation_strategy="steps",
    eval_steps=250,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,  # Effective batch size: 4 * 8 = 32
    learning_rate=5e-6,  # Conservative LR for fine-tuning pretrained model
    warmup_steps=200,
    optim="adamw_torch",
    weight_decay=0.03,
    max_grad_norm=1.0,
    fp16=True,
    predict_with_generate=True,
    generation_max_length=max_target_length,
    generation_num_beams=4,
    save_strategy="steps",
    save_steps=250,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="exact_match",
    greater_is_better=True,
    logging_dir=str(output_dir / "logs"),
    logging_steps=25,
    logging_first_step=True,
    report_to="tensorboard",
    disable_tqdm=False,
    seed=seed_value,
    num_train_epochs=10.0,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=DataCollator(processor),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

In [None]:
trainer.train()

# Save final model
trainer.save_model(output_dir / "model")
processor.save_pretrained(output_dir / "model")