# TrOCR Fine-tuning — Transformer Test Reports

Fine-tunes `microsoft/trocr-base-handwritten` on labeled crop images from transformer test reports.

**Before running:**
1. Runtime → Change runtime type → **T4 GPU**
2. Upload your `data/labels.csv` and `data/crops/` folder to Google Drive
3. Update `DRIVE_DATA_PATH` in the Config cell below

## 1. Check GPU

In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Go to Runtime → Change runtime type → T4 GPU")

## 2. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print("Drive mounted at /content/drive")

## 3. Config — Edit paths here

In [None]:
from pathlib import Path

# --- EDIT THESE ---
# Folder in your Google Drive that contains:
#   labels.csv
#   crops/   (subfolder with all .png crop images)
DRIVE_DATA_PATH = Path("/content/drive/MyDrive/text_recognizer_data")

# Where to save the fine-tuned model on Drive
DRIVE_MODEL_OUTPUT = Path("/content/drive/MyDrive/text_recognizer_model/trocr_finetuned")

# Training hyperparameters
EPOCHS    = 30
BATCH_SIZE = 8
LR        = 5e-5
PATIENCE  = 8
VAL_SPLIT = 0.2
# ------------------

LABELS_CSV   = DRIVE_DATA_PATH / "labels.csv"
CROPS_DIR    = DRIVE_DATA_PATH / "crops"
CHECKPOINT   = Path("/content/trocr_finetuned")

print(f"Labels CSV : {LABELS_CSV}")
print(f"Crops dir  : {CROPS_DIR}")
print(f"Output     : {DRIVE_MODEL_OUTPUT}")

assert LABELS_CSV.exists(), f"labels.csv not found at {LABELS_CSV}"
assert CROPS_DIR.exists(),  f"crops/ folder not found at {CROPS_DIR}"
print("All paths OK.")

## 4. Install dependencies

In [None]:
%%capture
!pip install transformers sentencepiece opencv-contrib-python-headless Pillow

## 5. Clone your repo

In [None]:
import os

REPO_DIR = Path("/content/text_recognizer")

if REPO_DIR.exists():
    !git -C {REPO_DIR} pull
else:
    !git clone https://github.com/et41/test_ocr.git {REPO_DIR}

os.chdir(REPO_DIR)
print(f"Working directory: {os.getcwd()}")
!ls

## 6. Preview training data

In [None]:
import csv

samples = []
with open(LABELS_CSV, newline="") as f:
    reader = csv.DictReader(f)
    for row in reader:
        value = row["value"].strip()
        if value:
            # Remap image_path to Drive location
            img_name = Path(row["image_path"]).name
            img_path = str(CROPS_DIR / img_name)
            samples.append((img_path, value))

print(f"Total labeled samples: {len(samples)}")
print("\nFirst 10 samples:")
for path, val in samples[:10]:
    exists = "OK" if Path(path).exists() else "MISSING"
    print(f"  [{exists}] {Path(path).name} → '{val}'")

missing = sum(1 for p, _ in samples if not Path(p).exists())
if missing:
    print(f"\nWARNING: {missing} image(s) not found in {CROPS_DIR}")
else:
    print(f"\nAll {len(samples)} images found.")

## 7. Dataset & DataLoader

In [None]:
import csv
import random
import sys
from pathlib import Path

import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# Always re-resolve the repo path so this cell works even after a kernel restart
REPO_DIR = Path("/content/text_recognizer")
if str(REPO_DIR) not in sys.path:
    sys.path.insert(0, str(REPO_DIR))

from pipeline.dataset import augment_image

ALLOWED_CHARS = set("0123456789.,-+")


class TrOCRDataset(Dataset):
    def __init__(self, samples, processor, augment=False):
        self.samples  = samples
        self.processor = processor
        self.augment  = augment

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, value = self.samples[idx]
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            image = np.zeros((32, 64), dtype=np.uint8)
        if self.augment:
            image = augment_image(image)
        image_rgb  = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        pil_img    = Image.fromarray(image_rgb)
        pixel_vals = self.processor(images=pil_img, return_tensors="pt").pixel_values.squeeze(0)
        labels = self.processor.tokenizer(
            text=value, return_tensors="pt",
            padding="max_length", max_length=20, truncation=True,
        ).input_ids.squeeze(0)
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        return {"pixel_values": pixel_vals, "labels": labels, "text": value}


print("Dataset class defined.")

## 8. Load base TrOCR model

In [None]:
TROCR_BASE = "microsoft/trocr-base-handwritten"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

print(f"\nDownloading {TROCR_BASE} (~400 MB, only on first run)...")
processor = TrOCRProcessor.from_pretrained(TROCR_BASE)
model     = VisionEncoderDecoderModel.from_pretrained(TROCR_BASE)

model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id           = processor.tokenizer.pad_token_id
model.config.eos_token_id           = processor.tokenizer.eos_token_id
model = model.to(device)

print("Model loaded.")

## 9. Build train / val splits

In [None]:
random.shuffle(samples)
val_size     = max(1, int(len(samples) * VAL_SPLIT))
val_samples  = samples[:val_size]
train_samples = samples[val_size:]

print(f"Train: {len(train_samples)} samples")
print(f"Val  : {len(val_samples)} samples")

train_set = TrOCRDataset(train_samples, processor, augment=True)
val_set   = TrOCRDataset(val_samples,   processor, augment=False)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

## 10. Training loop

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)


def compute_accuracy(preds, targets):
    correct = sum(
        "".join(c for c in p.strip() if c in ALLOWED_CHARS) == t
        for p, t in zip(preds, targets)
    )
    return correct / max(len(targets), 1)


best_val_loss  = float("inf")
epochs_no_impr = 0
history        = []

CHECKPOINT.mkdir(parents=True, exist_ok=True)

print(f"Starting fine-tuning for up to {EPOCHS} epochs...\n")

for epoch in range(1, EPOCHS + 1):
    # --- Train ---
    model.train()
    train_loss = 0.0
    for batch in train_loader:
        pv     = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        loss   = model(pixel_values=pv, labels=labels).loss
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # --- Validate ---
    model.eval()
    val_loss   = 0.0
    all_preds  = []
    all_targets = []
    with torch.no_grad():
        for batch in val_loader:
            pv     = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            val_loss += model(pixel_values=pv, labels=labels).loss.item()
            gen_ids  = model.generate(pv, max_new_tokens=20)
            preds    = processor.batch_decode(gen_ids, skip_special_tokens=True)
            all_preds.extend(preds)
            all_targets.extend(batch["text"])
    val_loss /= len(val_loader)
    acc = compute_accuracy(all_preds, all_targets)

    history.append({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "acc": acc})
    print(f"Epoch {epoch:3d} | Train: {train_loss:.4f} | Val: {val_loss:.4f} | Acc: {acc:.2%}")

    if val_loss < best_val_loss:
        best_val_loss  = val_loss
        epochs_no_impr = 0
        model.save_pretrained(str(CHECKPOINT))
        processor.save_pretrained(str(CHECKPOINT))
        print(f"  -> Saved best model (val loss {best_val_loss:.4f})")
    else:
        epochs_no_impr += 1
        if epochs_no_impr >= PATIENCE:
            print(f"\nEarly stopping after {PATIENCE} epochs without improvement.")
            break

print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}")

## 11. Plot training curves

In [None]:
import matplotlib.pyplot as plt

epochs_x   = [h["epoch"]      for h in history]
train_loss = [h["train_loss"] for h in history]
val_loss   = [h["val_loss"]   for h in history]
acc        = [h["acc"]        for h in history]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(epochs_x, train_loss, label="Train loss")
ax1.plot(epochs_x, val_loss,   label="Val loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Loss curves")
ax1.legend()
ax1.grid(True)

ax2.plot(epochs_x, [a * 100 for a in acc], color="green")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy (%)")
ax2.set_title("Validation accuracy")
ax2.grid(True)

plt.tight_layout()
plt.savefig("/content/training_curves.png", dpi=120)
plt.show()
print("Saved training_curves.png")

## 12. Quick inference test

In [None]:
# Test on a few val samples
model.eval()
print(f"{'Image':<40} {'Ground truth':<15} {'Predicted':<15} {'Match'}")
print("-" * 80)

for img_path, gt in val_samples[:10]:
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        continue
    image_rgb  = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    pil_img    = Image.fromarray(image_rgb)
    pixel_vals = processor(images=pil_img, return_tensors="pt").pixel_values.to(device)

    with torch.no_grad():
        gen_ids = model.generate(pixel_vals, max_new_tokens=20)
    pred = processor.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
    pred = "".join(c for c in pred if c in ALLOWED_CHARS)

    match = "OK" if pred == gt else "FAIL"
    print(f"{Path(img_path).name:<40} {gt:<15} {pred:<15} {match}")

## 13. Save model to Google Drive

In [None]:
import shutil

DRIVE_MODEL_OUTPUT.mkdir(parents=True, exist_ok=True)

# Copy checkpoint to Drive
for f in CHECKPOINT.iterdir():
    shutil.copy2(f, DRIVE_MODEL_OUTPUT / f.name)

# Also copy training curves
shutil.copy2("/content/training_curves.png", DRIVE_MODEL_OUTPUT.parent / "training_curves.png")

print(f"Model saved to Google Drive: {DRIVE_MODEL_OUTPUT}")
print("\nFiles saved:")
for f in sorted(DRIVE_MODEL_OUTPUT.iterdir()):
    size_mb = f.stat().st_size / 1e6
    print(f"  {f.name:<35} {size_mb:.1f} MB")

## 14. Download model to your PC (alternative to Drive)

Run this cell if you prefer to download the model directly instead of using Drive.

In [None]:
import shutil
from google.colab import files

# Zip the checkpoint
zip_path = "/content/trocr_finetuned.zip"
shutil.make_archive("/content/trocr_finetuned", "zip", str(CHECKPOINT))
print(f"Zipped model: {zip_path}")

# Trigger browser download
files.download(zip_path)