# Step 1: Data Preparation

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from transformers import VisionEncoderDecoderModel, AutoProcessor


class MNISTDataset(Dataset):
    def __init__(self, dataset, processor, max_target_length=5):
        """
        A PyTorch Dataset for MNIST, compatible with the Hugging Face TrOCR model.

        Args:
            dataset: The torchvision MNIST dataset (train/test split).
            processor: The Hugging Face processor for TrOCR.
            max_target_length: Maximum sequence length for the label text.
        """
        self.dataset = dataset
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get image and label from MNIST dataset
        image, label = self.dataset[idx]

        # Convert image to RGB format (MNIST images are grayscale)
        image = transforms.ToPILImage()(image).convert("RGB")

        # Prepare image: resize and normalize
        pixel_values = self.processor(image, return_tensors="pt").pixel_values

        # Prepare label: encode text representation of the digit
        text = str(label)
        labels = self.processor.tokenizer(
            text,
            padding="max_length",
            max_length=self.max_target_length,
            truncation=True,
        ).input_ids

        # Ensure PAD tokens are ignored by the loss function
        labels = [
            label if label != self.processor.tokenizer.pad_token_id else -100
            for label in labels
        ]

        # Return the encoding
        return {
            "pixel_values": pixel_values.squeeze(),  # Remove extra batch dimension
            "labels": torch.tensor(labels).contiguous(),
        }


# Initialize Processor
processor = AutoProcessor.from_pretrained("microsoft/trocr-base-handwritten")

# Load MNIST Dataset
mnist_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_mnist = datasets.MNIST(
    root="../../data", train=True, transform=mnist_transform, download=True
)
test_mnist = datasets.MNIST(
    root="../../data", train=False, transform=mnist_transform, download=True
)

# Prepare Custom Datasets
train_dataset = MNISTDataset(train_mnist, processor)
test_dataset = MNISTDataset(test_mnist, processor)

In [13]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(test_dataset))

Number of training examples: 60000
Number of validation examples: 10000


In [14]:
encoding = train_dataset[0]
for k, v in encoding.items():
    print(k, v.shape)

pixel_values torch.Size([3, 384, 384])
labels torch.Size([5])


We can also check the original image and decode the labels:

In [15]:
labels = encoding["labels"]
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.decode(labels, skip_special_tokens=True)
print(label_str)

5


# Train a model
Here, we initialize the TrOCR model from its pretrained weights. Note that the weights of the language modeling head are already initialized from pre-training, as the model was already trained to generate text during its pre-training stage. Refer to the paper for details.

In [16]:
from transformers import VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1").to(
    "mps"
)

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.48.0"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder

Importantly, we need to set a couple of attributes, namely:

- the attributes required for creating the decoder_input_ids from the labels (the model will automatically create the decoder_input_ids by shifting the labels one position to the right and prepending the decoder_start_token_id, as well as replacing ids which are -100 by the pad_token_id)
- the vocabulary size of the model (for the language modeling head on top of the decoder)
- beam-search related parameters which are used when generating text.

In [17]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size  #11, 10 digits + <eos> token)

model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.eos_token_id = processor.tokenizer.eos_token_id
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.max_length = 5  # Ensure the decoder output has a fixed sequence length

# set beam search parameters
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4


# Define Training Arguments

We will evaluate the model on the Character Error Rate (CER), which is available in HuggingFace Datasets (see here).



In [18]:
from evaluate import load
cer = load("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

Let's train! We also provide the default_data_collator to the Trainer, which is used to batch together examples.

Note that evaluation takes quite a long time, as we're using beam search for decoding, which requires several forward passes for a given example.

In [19]:
print("Decoder vocab size:", model.decoder.config.vocab_size)
print("Tokenizer vocab size:", len(processor.tokenizer))

Decoder vocab size: 50265
Tokenizer vocab size: 50265


In [20]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./trained_model/",
    eval_strategy="epoch",         # Evaluate after every epoch
    per_device_train_batch_size=8,      # Batch size for training
    per_device_eval_batch_size=8,       # Batch size for evaluation
    learning_rate=5e-5,                  # Learning rate for AdamW optimizer
    weight_decay=0.01,                   # Weight decay for regularization
    num_train_epochs=3,                  # Number of training epochs
    predict_with_generate=True,          # Generate sequences for evaluation
    logging_dir="./logs",                # Directory to save logs
    logging_steps=50,                    # Log every 50 steps
    save_strategy="epoch",               # Save checkpoints at the end of each epoch
    save_total_limit=3,                  # Limit the number of saved checkpoints
    fp16=False,                          # Use mixed precision (float16) if supported by GPU
    gradient_accumulation_steps=4,       # Accumulate gradients for larger effective batch size
    remove_unused_columns=False,         # Retain all dataset columns in dataloader
    report_to="none",                    # Disable reporting to external services
    load_best_model_at_end=True,         # Load the best model when training finishes
    metric_for_best_model="eval_loss",   # Metric to determine the best model
    greater_is_better=False,             # Lower eval loss is better
    save_on_each_node=False
)

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=default_data_collator,
)
trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

# Step 6: 

Save the trained model

In [None]:
model.save_pretrained("src/model/trained_model/")
processor.save_pretrained("src/model/trained_model/")
print(f"Model and processor saved to src/model/trained_model/")