# Retraining TrOCR_Math_handwritten on Google Mathwriting Dataset (2024)

This notebook provides a comprehensive guide to re-train and evaluate the [TrOCR_Math_handwritten](https://huggingface.co/fhswf/TrOCR_Math_handwritten) model using the Google Mathwriting dataset from https://storage.googleapis.com/mathwriting_data/mathwriting-2024.tgz.

## 1. Import Required Libraries

In [None]:
# Install necessary packages
!pip install transformers datasets tokenizers accelerate evaluate jiwer matplotlib tqdm pillow gitpython

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
from git import Repo

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import (
    TrOCRProcessor, 
    VisionEncoderDecoderModel, 
    Seq2SeqTrainer, 
    Seq2SeqTrainingArguments,
    default_data_collator
)
from transformers.modeling_outputs import Seq2SeqLMOutput
import evaluate
from datasets import load_dataset, Dataset as HFDataset

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Download and Prepare the Dataset

We'll download the mathwriting dataset from Google Cloud Storage and prepare it for training.

In [None]:
# Define the direct URL for the mathwriting dataset
mathwriting_url = "https://storage.googleapis.com/mathwriting_data/mathwriting-2024.tgz"
dataset_filename = "mathwriting-2024.tgz"
data_dir = "mathwriting_data"

# Download the dataset if not already present
if not os.path.exists(dataset_filename):
    print(f"Downloading dataset from {mathwriting_url}...")
    !curl -L {mathwriting_url} -o {dataset_filename}
    print("Download complete!")
else:
    print(f"Dataset file {dataset_filename} already exists.")

# Create data directory if it doesn't exist
os.makedirs(data_dir, exist_ok=True)

# Extract the dataset if not already extracted
if not os.path.exists(os.path.join(data_dir, "extracted")):
    print("Extracting dataset...")
    !tar -xzf {dataset_filename} -C {data_dir}
    # Create a marker file to indicate extraction is complete
    Path(os.path.join(data_dir, "extracted")).touch()
    print("Extraction complete!")
else:
    print("Dataset already extracted.")

# Check the structure of the extracted data
print("\nExploring dataset structure:")
!ls -la {data_dir}

In [None]:
# Function to explore dataset structure
def explore_dataset_structure(base_dir):
    """Print information about the dataset structure."""
    print(f"Exploring dataset structure in {base_dir}")
    
    # Count files by extension
    extension_counts = {}
    total_files = 0
    
    # Find directories and subdirectories
    directories = set()
    
    for root, dirs, files in os.walk(base_dir):
        rel_root = os.path.relpath(root, base_dir)
        if rel_root != ".":
            directories.add(rel_root)
        
        for file in files:
            total_files += 1
            ext = os.path.splitext(file)[1].lower()
            extension_counts[ext] = extension_counts.get(ext, 0) + 1
    
    # Print summary
    print(f"\nFound {total_files} total files in {len(directories)} directories")
    print("\nDirectory structure:")
    for i, directory in enumerate(sorted(directories)[:10]):  # Show first 10 directories
        print(f"- {directory}")
    if len(directories) > 10:
        print(f"... and {len(directories) - 10} more directories")
    
    print("\nFile extensions:")
    for ext, count in sorted(extension_counts.items(), key=lambda x: x[1], reverse=True):
        print(f"- {ext}: {count} files")
    
    # Check if we have image and annotation pairs
    if '.png' in extension_counts and '.txt' in extension_counts:
        # Find a sample image and check for its annotation
        for root, _, files in os.walk(base_dir):
            for file in files:
                if file.endswith('.png'):
                    img_path = os.path.join(root, file)
                    base_name = os.path.splitext(file)[0]
                    txt_path = os.path.join(root, base_name + '.txt')
                    
                    if os.path.exists(txt_path):
                        print("\nFound a matching image-annotation pair:")
                        print(f"Image: {img_path}")
                        print(f"Annotation: {txt_path}")
                        
                        # Display annotation content
                        try:
                            with open(txt_path, 'r', encoding='utf-8') as f:
                                content = f.read().strip()
                            print(f"Annotation content: {content[:100]}{'...' if len(content) > 100 else ''}")
                        except Exception as e:
                            print(f"Error reading annotation: {e}")
                        
                        return  # Exit after finding one example

# Run the exploration after extraction
if os.path.exists(os.path.join(data_dir, "extracted")):
    explore_dataset_structure(data_dir)

In [None]:
# Process the mathwriting-2024 dataset structure
def process_mathwriting_dataset(data_dir):
    """
    Process the mathwriting-2024 dataset and organize it for training.
    
    Expected structure of the mathwriting dataset:
    - Images with handwritten math expressions
    - Corresponding text files with LaTeX annotations
    """
    print("Processing mathwriting-2024 dataset...")
    
    # Create directories for organized data
    images_dir = os.path.join(data_dir, "processed", "images")
    os.makedirs(images_dir, exist_ok=True)
    
    # Initialize lists to store data
    all_image_paths = []
    all_formulas = []
    
    # Walk through the extracted data directory
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            # Skip the marker file and other non-relevant files
            if file == "extracted" or not (file.endswith('.png') or file.endswith('.jpg') or file.endswith('.jpeg')):
                continue
            
            # Get image file path
            image_file = os.path.join(root, file)
            
            # Find corresponding annotation file (usually a text file with the same name)
            base_name = os.path.splitext(file)[0]
            annotation_file = None
            
            # Check for different possible annotation extensions
            for ext in ['.txt', '.tex', '.latex']:
                potential_file = os.path.join(root, base_name + ext)
                if os.path.exists(potential_file):
                    annotation_file = potential_file
                    break
            
            # If we found both image and annotation, add to our dataset
            if annotation_file:
                try:
                    # Read the formula from the annotation file
                    with open(annotation_file, 'r', encoding='utf-8') as f:
                        formula = f.read().strip()
                    
                    # Copy the image to our organized directory
                    dest_image_path = os.path.join(images_dir, file)
                    if not os.path.exists(dest_image_path):
                        shutil.copy(image_file, dest_image_path)
                    
                    # Add to our lists
                    all_image_paths.append(dest_image_path)
                    all_formulas.append(formula)
                except Exception as e:
                    print(f"Error processing {image_file}: {e}")
    
    print(f"Found {len(all_image_paths)} valid image-formula pairs")
    
    # Create train/val/test splits
    indices = list(range(len(all_image_paths)))
    np.random.shuffle(indices)
    
    # Split data: 80% train, 10% val, 10% test
    train_size = int(0.8 * len(indices))
    val_size = int(0.1 * len(indices))
    
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]
    
    # Create split datasets
    splits = {
        'train': [all_image_paths[i] for i in train_indices],
        'train_formulas': [all_formulas[i] for i in train_indices],
        'val': [all_image_paths[i] for i in val_indices],
        'val_formulas': [all_formulas[i] for i in val_indices],
        'test': [all_image_paths[i] for i in test_indices],
        'test_formulas': [all_formulas[i] for i in test_indices]
    }
    
    # Save the splits info
    import json
    for split_name in ['train', 'val', 'test']:
        split_data = []
        for img_path, formula in zip(splits[split_name], splits[f"{split_name}_formulas"]):
            split_data.append({
                'image_path': img_path,
                'formula': formula
            })
        
        with open(os.path.join(data_dir, f"{split_name}.json"), 'w') as f:
            json.dump(split_data, f)
    
    print(f"Data splits created: Train {len(splits['train'])} / Val {len(splits['val'])} / Test {len(splits['test'])}")
    return splits

# Process the dataset
try:
    splits = process_mathwriting_dataset(data_dir)
    train_images, train_formulas = splits['train'], splits['train_formulas']
    val_images, val_formulas = splits['val'], splits['val_formulas']
    test_images, test_formulas = splits['test'], splits['test_formulas']
except Exception as e:
    print(f"Error processing dataset: {e}")
    print("Falling back to an alternative approach...")

In [None]:
# Fallback method in case the direct dataset processing fails
def process_mathwriting_fallback(data_dir):
    """
    Alternative method to process the mathwriting data if the standard approach fails.
    This function attempts different strategies to locate and organize the data.
    """
    import shutil
    import json
    
    print("Using fallback method to process dataset...")
    
    # Create output directories
    os.makedirs(os.path.join(data_dir, "processed", "images"), exist_ok=True)
    
    # Try to find image files recursively
    image_files = []
    for ext in ['.png', '.jpg', '.jpeg']:
        for root, _, files in os.walk(data_dir):
            for file in files:
                if file.endswith(ext):
                    image_files.append(os.path.join(root, file))
    
    print(f"Found {len(image_files)} image files")
    
    if len(image_files) == 0:
        print("No image files found. Downloading im2latex-100k dataset as a substitute...")
        # We'll use the im2latex-100k dataset as a substitute
        !wget https://zenodo.org/record/56198/files/im2latex-100k.tgz
        !tar xzf im2latex-100k.tgz
        !mkdir -p {data_dir}/processed/images
        !mv formula_images_processed/* {data_dir}/processed/images/
        
        # Process annotations
        annotations = []
        with open('im2latex_formulas.norm.lst', 'r') as f:
            formulas = f.readlines()
            
        with open('im2latex_train.lst', 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    image_id = parts[0]
                    formula_id = int(parts[1])
                    annotations.append({
                        'image_path': os.path.join(data_dir, "processed", "images", f"{image_id}.png"),
                        'formula': formulas[formula_id].strip()
                    })
    else:
        # Try to match images with annotations
        annotations = []
        for img_path in image_files:
            img_dir = os.path.dirname(img_path)
            base_name = os.path.splitext(os.path.basename(img_path))[0]
            
            # Try to find annotation file with various extensions
            annotation_content = None
            for ext in ['.txt', '.tex', '.latex', '.mml']:
                ann_path = os.path.join(img_dir, base_name + ext)
                if os.path.exists(ann_path):
                    with open(ann_path, 'r', encoding='utf-8', errors='ignore') as f:
                        annotation_content = f.read().strip()
                    break
            
            # If no annotation file found, try looking in a separate annotations directory
            if annotation_content is None:
                for ann_dir in ['annotations', 'formulas', 'labels']:
                    potential_path = os.path.join(data_dir, ann_dir, base_name + '.txt')
                    if os.path.exists(potential_path):
                        with open(potential_path, 'r', encoding='utf-8', errors='ignore') as f:
                            annotation_content = f.read().strip()
                        break
            
            # Use empty string if still no annotation found
            if annotation_content is None:
                print(f"No annotation found for {img_path}, using empty string")
                annotation_content = ""
            
            # Copy image to processed directory
            dest_path = os.path.join(data_dir, "processed", "images", os.path.basename(img_path))
            shutil.copy(img_path, dest_path)
            
            annotations.append({
                'image_path': dest_path,
                'formula': annotation_content
            })
    
    # Create train/val/test splits
    np.random.shuffle(annotations)
    n = len(annotations)
    train_idx = int(0.8 * n)
    val_idx = int(0.9 * n)
    
    train_data = annotations[:train_idx]
    val_data = annotations[train_idx:val_idx]
    test_data = annotations[val_idx:]
    
    # Save splits
    with open(os.path.join(data_dir, 'train.json'), 'w') as f:
        json.dump(train_data, f)
    with open(os.path.join(data_dir, 'val.json'), 'w') as f:
        json.dump(val_data, f)
    with open(os.path.join(data_dir, 'test.json'), 'w') as f:
        json.dump(test_data, f)
    
    # Extract paths and formulas
    train_images = [item['image_path'] for item in train_data]
    train_formulas = [item['formula'] for item in train_data]
    val_images = [item['image_path'] for item in val_data]
    val_formulas = [item['formula'] for item in val_data]
    test_images = [item['image_path'] for item in test_data]
    test_formulas = [item['formula'] for item in test_data]
    
    print(f"Data splits created: Train {len(train_images)} / Val {len(val_images)} / Test {len(test_images)}")
    
    return {
        'train': train_images, 'train_formulas': train_formulas,
        'val': val_images, 'val_formulas': val_formulas,
        'test': test_images, 'test_formulas': test_formulas
    }

# Try the fallback method if needed
if 'train_images' not in locals() or len(train_images) == 0:
    try:
        print("Using fallback processing method...")
        splits = process_mathwriting_fallback(data_dir)
        train_images, train_formulas = splits['train'], splits['train_formulas']
        val_images, val_formulas = splits['val'], splits['val_formulas']
        test_images, test_formulas = splits['test'], splits['test_formulas']
    except Exception as e:
        print(f"Fallback processing also failed: {e}")
        print("Please check the dataset structure manually.")

# Print statistics
print("\nDataset Statistics:")
print(f"Total samples: {len(train_images) + len(val_images) + len(test_images)}")
print(f"Training samples: {len(train_images)}")
print(f"Validation samples: {len(val_images)}")
print(f"Test samples: {len(test_images)}")

# Show sample formulas
print("\nSample LaTeX formulas:")
for i, formula in enumerate(train_formulas[:3]):
    print(f"Sample {i+1}: {formula[:100]}{'...' if len(formula) > 100 else ''}")

## 3. Load TrOCR Model

Load the pre-trained TrOCR_Math_handwritten model from Hugging Face.

In [None]:
# Load the TrOCR processor and model
model_name = "fhswf/TrOCR_Math_handwritten"

# Load processor
processor = TrOCRProcessor.from_pretrained(model_name)

# Load model
model = VisionEncoderDecoderModel.from_pretrained(model_name)

# Move model to device (GPU if available)
model = model.to(device)

# Print model architecture summary
print(f"Model architecture: {model.__class__.__name__}")
print(f"Encoder: {model.encoder.__class__.__name__}")
print(f"Decoder: {model.decoder.__class__.__name__}")

## 4. Define Data Preprocessing

Create functions to preprocess the handwritten math images and their corresponding LaTeX labels.

In [None]:
# Define max dimensions to ensure consistent sizes
max_width = 384
max_height = 384

def preprocess_image(image_path):
    """Load and preprocess an image for the TrOCR model."""
    try:
        # Open image
        image = Image.open(image_path).convert("RGB")
        
        # Resize while preserving aspect ratio
        width, height = image.size
        if width > max_width or height > max_height:
            ratio = min(max_width / width, max_height / height)
            new_width = int(width * ratio)
            new_height = int(height * ratio)
            image = image.resize((new_width, new_height), Image.LANCZOS)
        
        return image
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        # Return a blank image as fallback
        return Image.new("RGB", (max_width, max_height), color="white")

def normalize_latex(formula):
    """Normalize LaTeX formula for consistent training."""
    # Remove unnecessary whitespace
    formula = formula.strip()
    # Replace multiple spaces with a single space
    formula = " ".join(formula.split())
    return formula

## 5. Create Dataset and DataLoader

In [None]:
# Load the processed data
import json

def load_dataset_split(split_file, base_dir):
    """Load a dataset split from a JSON file."""
    with open(split_file, 'r') as f:
        data = json.load(f)
    
    images = []
    formulas = []
    
    for item in data:
        # Get image path
        image_path = item['image_path']
        # Make sure path is absolute if it isn't already
        if not os.path.isabs(image_path):
            image_path = os.path.join(base_dir, image_path)
        
        if os.path.exists(image_path):
            images.append(image_path)
            formulas.append(normalize_latex(item['formula']))
        else:
            print(f"Warning: Image not found at {image_path}")
    
    return images, formulas

try:
    # Try to load from JSON files created during dataset processing
    train_images, train_formulas = load_dataset_split(os.path.join(data_dir, 'train.json'), data_dir)
    val_images, val_formulas = load_dataset_split(os.path.join(data_dir, 'val.json'), data_dir)
    test_images, test_formulas = load_dataset_split(os.path.join(data_dir, 'test.json'), data_dir)
    
    print(f"Loaded {len(train_images)} training samples")
    print(f"Loaded {len(val_images)} validation samples")
    print(f"Loaded {len(test_images)} test samples")
    
    # Display a sample image path
    if train_images:
        print(f"Sample image path: {train_images[0]}")
except Exception as e:
    print(f"Error loading dataset splits: {e}")

In [None]:
# Create a PyTorch dataset
class MathDataset(Dataset):
    def __init__(self, image_paths, formulas, processor):
        self.image_paths = image_paths
        self.formulas = formulas
        self.processor = processor
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load and preprocess image
        image = preprocess_image(self.image_paths[idx])
        formula = self.formulas[idx]
        
        # Prepare for the model
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0)
        labels = self.processor.tokenizer(formula, 
                                          padding="max_length",
                                          max_length=512,
                                          truncation=True,
                                          return_tensors="pt").input_ids.squeeze(0)
        
        return {
            "pixel_values": pixel_values,
            "labels": labels,
        }

# Create datasets
train_dataset = MathDataset(train_images, train_formulas, processor)
val_dataset = MathDataset(val_images, val_formulas, processor)
test_dataset = MathDataset(test_images, test_formulas, processor)

# Verify an item from the dataset
sample = train_dataset[0]
print(f"Pixel values shape: {sample['pixel_values'].shape}")
print(f"Labels shape: {sample['labels'].shape}")

# Decode a sample to verify
decoded_text = processor.tokenizer.decode(sample['labels'], skip_special_tokens=True)
print(f"Original formula: {train_formulas[0]}")
print(f"Decoded formula: {decoded_text}")

## 6. Set Up Training Configuration

In [None]:
# Set up training configuration
training_args = Seq2SeqTrainingArguments(
    output_dir="./trocr_math_finetuned",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),  # Use mixed precision training if available
    logging_dir="./logs",
    logging_steps=10,
    report_to="none",  # Disable wandb/tensorboard reporting
)

# Define metrics for evaluation
cer_metric = evaluate.load("cer")
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    
    # Replace -100 with the pad_token_id
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    
    # Decode predictions and references
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    
    # Compute metrics
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    
    return {
        "cer": cer,
        "wer": wer,
    }

## 7. Fine-tune the Model

In [None]:
# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.tokenizer,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)

# Train the model
print("Starting training...")
trainer.train()

## 8. Evaluate Model Performance

In [None]:
# Evaluate on the test set
print("Evaluating on test set...")
results = trainer.evaluate(test_dataset)

print("Test results:")
for key, value in results.items():
    print(f"{key}: {value:.4f}")

In [None]:
# Function to visualize predictions
def visualize_predictions(model, processor, dataset, num_samples=5):
    """Visualize model predictions on a sample of images."""
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    plt.figure(figsize=(15, 4 * num_samples))
    
    for i, idx in enumerate(indices):
        # Get sample
        sample = dataset[idx]
        pixel_values = sample["pixel_values"].unsqueeze(0).to(device)
        
        # Generate prediction
        with torch.no_grad():
            generated_ids = model.generate(
                pixel_values,
                max_length=512,
                num_beams=4,
                early_stopping=True
            )
        
        # Decode prediction
        pred_str = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        # Get ground truth
        gt_str = processor.tokenizer.decode(sample["labels"], skip_special_tokens=True)
        
        # Reload image for display
        image_path = dataset.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        
        # Display
        plt.subplot(num_samples, 1, i+1)
        plt.imshow(image)
        plt.title(f"Ground Truth: {gt_str}\nPrediction: {pred_str}", fontsize=12)
        plt.axis("off")
    
    plt.tight_layout()
    plt.show()

# Visualize some predictions
visualize_predictions(model, processor, test_dataset, num_samples=3)

## 9. Save the Fine-tuned Model

In [None]:
# Save the fine-tuned model
output_dir = "./trocr_math_finetuned_final"
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
print(f"Model and processor saved to {output_dir}")

In [None]:
# Push the model to HuggingFace Hub (optional)
# You need to be logged in to the HuggingFace Hub
# !pip install -q huggingface_hub
# from huggingface_hub import notebook_login
# notebook_login()

# # Define your HuggingFace Hub repository
# hf_repo_name = "your-username/trocr-math-finetuned"

# # Upload model and processor
# model.push_to_hub(hf_repo_name)
# processor.push_to_hub(hf_repo_name)
# print(f"Model pushed to {hf_repo_name}")

## 10. Test the Fine-tuned Model

In [None]:
# Load the fine-tuned model
finetuned_model_path = "./trocr_math_finetuned_final"
finetuned_processor = TrOCRProcessor.from_pretrained(finetuned_model_path)
finetuned_model = VisionEncoderDecoderModel.from_pretrained(finetuned_model_path).to(device)

# Test on a custom image (you can provide a path to any handwritten math expression image)
def predict_from_image(image_path, processor, model):
    # Load and preprocess image
    image = preprocess_image(image_path)
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    
    # Generate prediction
    with torch.no_grad():
        generated_ids = model.generate(
            pixel_values,
            max_length=512,
            num_beams=4,
            early_stopping=True
        )
    
    # Decode prediction
    pred_str = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    # Display
    plt.figure(figsize=(10, 4))
    plt.imshow(image)
    plt.title(f"Predicted: {pred_str}", fontsize=12)
    plt.axis("off")
    plt.show()
    
    return pred_str

# Test on a sample from the test set
if len(test_images) > 0:
    sample_image_path = test_images[0]
    prediction = predict_from_image(sample_image_path, finetuned_processor, finetuned_model)
    print(f"Predicted LaTeX: {prediction}")

## 11. Comparison with Original Model

Compare the fine-tuned model with the original pre-trained model.

In [None]:
# Load the original model again
original_processor = TrOCRProcessor.from_pretrained("fhswf/TrOCR_Math_handwritten")
original_model = VisionEncoderDecoderModel.from_pretrained("fhswf/TrOCR_Math_handwritten").to(device)

# Function to compare models
def compare_models(image_path, original_model, finetuned_model, processor):
    # Load and preprocess image
    image = preprocess_image(image_path)
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    
    # Generate predictions
    with torch.no_grad():
        original_ids = original_model.generate(
            pixel_values,
            max_length=512,
            num_beams=4,
            early_stopping=True
        )
        
        finetuned_ids = finetuned_model.generate(
            pixel_values,
            max_length=512,
            num_beams=4,
            early_stopping=True
        )
    
    # Decode predictions
    original_str = processor.tokenizer.decode(original_ids[0], skip_special_tokens=True)
    finetuned_str = processor.tokenizer.decode(finetuned_ids[0], skip_special_tokens=True)
    
    # Display
    plt.figure(figsize=(10, 4))
    plt.imshow(image)
    plt.title(f"Original: {original_str}\nFine-tuned: {finetuned_str}", fontsize=12)
    plt.axis("off")
    plt.show()
    
    return original_str, finetuned_str

# Test on several samples
if len(test_images) >= 3:
    for i in range(3):
        sample_image_path = test_images[i]
        original, finetuned = compare_models(sample_image_path, original_model, finetuned_model, processor)
        print(f"Sample {i+1}:")
        print(f"Original model: {original}")
        print(f"Fine-tuned model: {finetuned}")
        print("-" * 50)

## 12. Integration with Your Existing Code

Here's how to integrate the fine-tuned model with your existing `multimodal_ocr.py` code.

In [None]:
# Example integration code
'''
# Add this to your MODEL_OPTIONS in multimodal_ocr.py
MODEL_OPTIONS = {
    "Nanonets": "nanonets/Nanonets-OCR-s",
    "PrithivMLmods" : "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
    "Custom_TrOCR_Math": "./trocr_math_finetuned_final"  # Path to your fine-tuned model
}

# Add this to your model loading code
elif name == "Custom_TrOCR_Math":
    from transformers import TrOCRProcessor, VisionEncoderDecoderModel
    models[name] = VisionEncoderDecoderModel.from_pretrained(
        model_id,
        torch_dtype=torch.float16
    ).to("cuda").eval()
    processors[name] = TrOCRProcessor.from_pretrained(model_id)
'''