In [90]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
import json
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TrOCRProcessor
from PIL import Image
import torchvision.transforms as T
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.amp import GradScaler, autocast
from transformers import VisionEncoderDecoderModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Mounted at /content/drive
Using device: cuda


In [84]:
def preprocess_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Image not found/unreadable")
    blurred = cv2.medianBlur(img, 3)
    _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    black_pixels = np.sum(binary == 0)
    white_pixels = np.sum(binary == 255)
    if black_pixels < white_pixels:
        binary = 255 - binary
    rgb_image = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
    return rgb_image

In [85]:
def augment_image(np_image):
    pil_img = Image.fromarray(np_image)
    transform = T.Compose([
        T.RandomRotation(5),
        T.ColorJitter(brightness=0.1, contrast=0.1),
    ])
    augmented = transform(pil_img)
    return np.array(augmented)

In [86]:
class HandwrittenWordsDataset(Dataset):

    def __init__(self, images_dir, labels_path):
        self.images_dir = images_dir
        if labels_path.endswith('.json'):
            with open(labels_path, 'r') as f:
                all_labels = json.load(f)
        else:
            raise ValueError("Labels must be .json.")

        self.labels = {}
        for fname, text in all_labels.items():
            if not fname.strip() or fname.startswith("._"):
                continue
            img_path = os.path.join(images_dir, fname)
            if not os.path.exists(img_path):
                continue
            if cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) is None:
                continue
            self.labels[fname] = text

        try:
            self.image_files = sorted(self.labels.keys(), key=lambda x: int(os.path.splitext(x)[0]))
        except ValueError:
            self.image_files = sorted(self.labels.keys())

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        fname = self.image_files[idx]
        text = self.labels[fname]
        img_path = os.path.join(self.images_dir, fname)
        image_rgb = preprocess_image(img_path)
        if self.images_dir.endswith("training_words"):
            image_rgb = augment_image(image_rgb)
        return {"image": image_rgb, "text": text}

In [87]:
train_images_dir = "/content/drive/MyDrive/dataset/Training/training_words"
train_labels_path = "/content/drive/MyDrive/dataset/Training/training_labels.json"
val_images_dir = "/content/drive/MyDrive/dataset/Validation/validation_words"
val_labels_path = "/content/drive/MyDrive/dataset/Validation/validation_labels.json"

train_dataset = HandwrittenWordsDataset(
    images_dir=train_images_dir,
    labels_path=train_labels_path
)

val_dataset = HandwrittenWordsDataset(
    images_dir=val_images_dir,
    labels_path=val_labels_path
)

test_dataset = HandwrittenWordsDataset(
    images_dir="/content/drive/MyDrive/dataset/Testing/testing_words",
    labels_path="/content/drive/MyDrive/dataset/Testing/testing_labels.json"
)

In [88]:
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')

def collate_batch(batch):
    valid_batch = [item for item in batch if item is not None]
    if len(valid_batch) == 0:
        raise RuntimeError("No valid items in batch")
    for item in valid_batch:
        if not isinstance(item, dict) or "image" not in item or "text" not in item:
            raise RuntimeError(f"Invalid item in batch")
    images = [item["image"] for item in valid_batch]
    texts  = [item["text"]  for item in valid_batch]
    encodings = processor(
        images=images,
        text=texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=16
    )
    labels = encodings.labels
    labels[labels == processor.tokenizer.pad_token_id] = -100
    return {
        "pixel_values": encodings.pixel_values.to(device),
        "labels": labels.to(device)
    }

In [89]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_batch,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_batch,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_batch,
    num_workers=0
)

In [91]:
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
model.to(device)

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.48.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-23): 24 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=False)
              (key): Linear(in_features=1024, out_features=1024, bias=False)
              (value): Linear(in_features=1024, out_features=1024, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dens

In [92]:
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

In [93]:
num_epochs = 10
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader) * num_epochs)
scaler = GradScaler("cuda")  # device cuda to absolve error

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, batch in enumerate(train_loader):
        if i == 0:
            print(f"Epoch {epoch+1}: pixel_values shape: {batch['pixel_values'].shape}, labels shape: {batch['labels'].shape}")
        optimizer.zero_grad()
        with autocast("cuda"):
            outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])
            loss = outputs.loss
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        running_loss += loss.item()

    avg_loss = running_loss / (i + 1)

    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for val_batch in val_loader:
            generated_ids = model.generate(pixel_values=val_batch["pixel_values"], num_beams=5, max_length=16)
            preds = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            label_ids = val_batch["labels"].clone()
            label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
            gt_texts = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            for pred, gt in zip(preds, gt_texts):
                if pred.strip() == gt.strip():
                    total_correct += 1
                total_samples += 1
    val_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.4f} - Validation Accuracy: {val_accuracy:.4f}")

    model.train()

Epoch 1: pixel_values shape: torch.Size([16, 3, 384, 384]), labels shape: torch.Size([16, 6])
Epoch 1/10 - Avg Loss: 0.5483 - Validation Accuracy: 0.9231
Epoch 2: pixel_values shape: torch.Size([16, 3, 384, 384]), labels shape: torch.Size([16, 5])
Epoch 2/10 - Avg Loss: 0.0686 - Validation Accuracy: 0.9308
Epoch 3: pixel_values shape: torch.Size([16, 3, 384, 384]), labels shape: torch.Size([16, 6])
Epoch 3/10 - Avg Loss: 0.0300 - Validation Accuracy: 0.9526
Epoch 4: pixel_values shape: torch.Size([16, 3, 384, 384]), labels shape: torch.Size([16, 6])
Epoch 4/10 - Avg Loss: 0.0150 - Validation Accuracy: 0.9449
Epoch 5: pixel_values shape: torch.Size([16, 3, 384, 384]), labels shape: torch.Size([16, 6])
Epoch 5/10 - Avg Loss: 0.0051 - Validation Accuracy: 0.9551
Epoch 6: pixel_values shape: torch.Size([16, 3, 384, 384]), labels shape: torch.Size([16, 6])
Epoch 6/10 - Avg Loss: 0.0029 - Validation Accuracy: 0.9782
Epoch 7: pixel_values shape: torch.Size([16, 3, 384, 384]), labels shape: to

In [94]:
model.eval()
total_correct = 0
total_samples = 0
all_predictions = []

with torch.no_grad():
    for batch in test_loader:
        generated_ids = model.generate(pixel_values=batch["pixel_values"], num_beams=5, max_length=16)
        preds = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        label_ids = batch["labels"].clone()
        label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
        gt_texts = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        all_predictions.extend(zip(preds, gt_texts))

        for pred, gt in zip(preds, gt_texts):
            if pred.strip() == gt.strip():
                total_correct += 1
            total_samples += 1

test_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
print(test_accuracy)

0.9551282051282052
