# Handwritten LaTeX OCR Training

Train the unified text spotting model on Google Colab with H100/A100.

In [1]:
!nvidia-smi

Thu Jan 22 14:50:56 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   34C    P0             56W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!git clone https://github.com/markm39/MobileTeXOCR.git
%cd MobileTeXOCR

fatal: destination path 'MobileTeXOCR' already exists and is not an empty directory.
/content/MobileTeXOCR


In [4]:
!cd /content/MobileTeXOCR && git pull


remote: Enumerating objects: 13, done.[K
remote: Counting objects:   7% (1/13)[Kremote: Counting objects:  15% (2/13)[Kremote: Counting objects:  23% (3/13)[Kremote: Counting objects:  30% (4/13)[Kremote: Counting objects:  38% (5/13)[Kremote: Counting objects:  46% (6/13)[Kremote: Counting objects:  53% (7/13)[Kremote: Counting objects:  61% (8/13)[Kremote: Counting objects:  69% (9/13)[Kremote: Counting objects:  76% (10/13)[Kremote: Counting objects:  84% (11/13)[Kremote: Counting objects:  92% (12/13)[Kremote: Counting objects: 100% (13/13)[Kremote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects:  50% (1/2)[Kremote: Compressing objects: 100% (2/2)[Kremote: Compressing objects: 100% (2/2), done.[K
remote: Total 8 (delta 5), reused 8 (delta 5), pack-reused 0 (from 0)[K
Unpacking objects:  12% (1/8)Unpacking objects:  25% (2/8)Unpacking objects:  37% (3/8)Unpacking objects:  50% (4/8)Unpacking objects:  62% (5/8)Unpacking obj

In [5]:
!pip install -q torch torchvision pillow numpy pyyaml

In [6]:
import sys
sys.path.insert(0, '/content/MobileTeXOCR')

## Dataset Setup

Choose ONE option below:
- **Option A**: Create dummy data (for testing pipeline)
- **Option B**: Download real datasets (for actual training)

### Option A: Create Dummy Data (for testing)

In [10]:
# Create dummy dataset for testing the pipeline
import os
import json
from PIL import Image, ImageDraw, ImageFont

def create_dummy_dataset(base_dir, num_train=100, num_val=20):
    """Create dummy handwritten math images for testing."""

    expressions = [
        'x^2', 'y^2', 'x+y', 'a-b', '\\frac{1}{2}', '\\sqrt{x}',
        'x^2+y^2', 'a^2-b^2', '\\alpha', '\\beta', '\\gamma',
        '\\sum_{i=1}^{n}', '\\int_0^1', 'e^x', '\\pi r^2',
        '\\frac{a}{b}', 'x_1', 'y_2', 'z^n', '\\theta'
    ]

    for split, num_samples in [('train', num_train), ('val', num_val)]:
        img_dir = os.path.join(base_dir, 'hme100k', split, 'images')
        os.makedirs(img_dir, exist_ok=True)

        labels = {}
        for i in range(num_samples):
            # Create white image
            img = Image.new('RGB', (384, 384), 'white')
            draw = ImageDraw.Draw(img)

            # Draw expression (simplified rendering)
            expr = expressions[i % len(expressions)]
            # Draw some random strokes to simulate handwriting
            import random
            random.seed(i)
            x_start = random.randint(50, 150)
            y_start = random.randint(150, 200)

            # Simple text (in real data this would be actual handwriting)
            try:
                font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 40)
            except:
                font = ImageFont.load_default()

            # Draw display version
            display_text = expr.replace('\\', '').replace('{', '').replace('}', '').replace('_', '').replace('^', '')
            draw.text((x_start, y_start), display_text, fill='black', font=font)

            # Save image
            img_name = f'sample_{i:04d}.png'
            img.save(os.path.join(img_dir, img_name))
            labels[img_name] = expr

        # Save labels
        labels_file = os.path.join(base_dir, 'hme100k', split, 'labels.json')
        with open(labels_file, 'w') as f:
            json.dump(labels, f, indent=2)

        print(f'Created {num_samples} {split} samples in {img_dir}')

# Create dummy data
create_dummy_dataset('./data', num_train=100, num_val=20)
print('Dummy dataset created!')

Created 100 train samples in ./data/hme100k/train/images
Created 20 val samples in ./data/hme100k/val/images
Dummy dataset created!


### Option B: Download Real Datasets (for actual training)

Run ONE or more of the cells below to download real data. MathWriting is recommended as the primary dataset.

In [7]:
# Download MathWriting dataset (230K human + 400K synthetic samples, 2.9GB)
# This is the largest handwritten math expression dataset
!mkdir -p data/mathwriting
!wget -q --show-progress https://storage.googleapis.com/mathwriting_data/mathwriting-2024.tgz -O mathwriting.tgz
!tar -xzf mathwriting.tgz -C data/
!rm mathwriting.tgz

# Check structure and reorganize if needed
import os
import shutil

# The tarball extracts to mathwriting-2024/, we need mathwriting/
if os.path.exists('data/mathwriting-2024') and not os.path.exists('data/mathwriting/train'):
    # Move contents
    for item in os.listdir('data/mathwriting-2024'):
        src = f'data/mathwriting-2024/{item}'
        dst = f'data/mathwriting/{item}'
        if os.path.exists(dst):
            shutil.rmtree(dst) if os.path.isdir(dst) else os.remove(dst)
        shutil.move(src, dst)
    os.rmdir('data/mathwriting-2024')

print('MathWriting directory structure:')
!ls -la data/mathwriting/ | head -20

MathWriting directory structure:
total 620228
drwxr-xr-x 7 root   root       4096 Jan 22 14:09 .
drwxr-xr-x 5 root   root       4096 Jan 22 14:53 ..
-rw-r----- 1 218859 89939      7780 Jan 31  2024 readme.md
drwxr-x--- 2 218859 89939    299008 Jan 31  2024 symbols
-rw-r----- 1 218859 89939    523063 Jan 31  2024 symbols.jsonl
drwxr-x--- 2 218859 89939  18567168 Jan 31  2024 synthetic
-rw-r----- 1 218859 89939 604019016 Jan 31  2024 synthetic-bboxes.jsonl
drwxr-x--- 2 218859 89939    380928 Jan 31  2024 test
drwxr-x--- 2 218859 89939  10543104 Jan 31  2024 train
drwxr-x--- 2 218859 89939    741376 Jan 31  2024 valid


In [12]:
# # Download CROHME from Kaggle (requires Kaggle API key)
# !pip install kaggle
# !mkdir -p ~/.kaggle
# # Upload your kaggle.json or set credentials
# !kaggle datasets download -d xainano/handwrittenmathsymbols
# !unzip -q handwrittenmathsymbols.zip -d data/crohme/

In [13]:
# # Alternative: Download from HuggingFace (if available)
# !pip install huggingface_hub
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="ybelkada/im2latex-100k", local_dir="./data/hme100k", repo_type="dataset")

## Model Setup

In [8]:
import torch
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

Using device: cuda
GPU: NVIDIA A100-SXM4-80GB
Memory: 85.2 GB


In [9]:
from models import HandwrittenLaTeXOCR, ModelConfig
from models.decoder.tokenizer import LaTeXTokenizer
from data import DatasetConfig, CombinedDataset, get_train_transforms, get_eval_transforms
from training import Trainer, TrainingConfig

tokenizer = LaTeXTokenizer()
print(f'Tokenizer vocab size: {tokenizer.vocab_size}')

Tokenizer vocab size: 1294


In [10]:
# Model configuration
# Use 'small' for testing, 'base' for real training
ENCODER = 'fastvithd'
ENCODER_SIZE = 'base'  # Change to 'base' for full training

model_config = ModelConfig(
    encoder_type=ENCODER,
    encoder_size=ENCODER_SIZE,
    image_size=384,
    d_model=256 if ENCODER_SIZE == 'small' else 384,
    num_decoder_layers=4 if ENCODER_SIZE == 'small' else 6,
    freeze_encoder=True,
)

model = HandwrittenLaTeXOCR(model_config)
print(f'Model parameters: {model.count_parameters():,}')

Model parameters: 15,195,406


In [11]:
# Dataset configuration
dataset_config = DatasetConfig(data_dir='./data', image_size=384)

train_transform = get_train_transforms(image_size=384, augment_strength='medium')
valid_transform = get_eval_transforms(image_size=384)

# Check available datasets
import os
available_datasets = []
for ds in ['mathwriting', 'crohme', 'hme100k']:
    if os.path.exists(f'./data/{ds}'):
        available_datasets.append(ds)

print(f'Available datasets: {available_datasets}')

if not available_datasets:
    raise RuntimeError('No datasets found! Run the dataset setup cells above first.')

train_dataset = CombinedDataset(
    dataset_config, split='train', transform=train_transform,
    tokenizer=tokenizer, datasets=available_datasets
)
val_dataset = CombinedDataset(
    dataset_config, split='val', transform=valid_transform,
    tokenizer=tokenizer, datasets=available_datasets
)

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')

Available datasets: ['mathwriting']
Train samples: 229864
Val samples: 15674


In [14]:
# Training configuration
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory
    BATCH_SIZE = 96 if gpu_memory > 70e9 else (32 if gpu_memory > 40e9 else 16)
else:
    BATCH_SIZE = 4

training_config = TrainingConfig(
    output_dir='/content/drive/MyDrive/latex_ocr_outputs',
    experiment_name=f'latex_ocr_{ENCODER}_{ENCODER_SIZE}',
    num_epochs=5 if len(train_dataset) < 1000 else 20,  # Fewer epochs for dummy data
    batch_size=BATCH_SIZE,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=100 if len(train_dataset) < 1000 else 2000,
    gradient_accumulation_steps=2,
    use_amp=True,
    amp_dtype='bfloat16' if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else 'float16',
    save_steps=500,
    validation_steps=100 if len(train_dataset) < 1000 else 1000,
    log_steps=10 if len(train_dataset) < 1000 else 100,
    freeze_encoder_epochs=1,
    early_stopping_patience=5,
)

print(f'Batch size: {BATCH_SIZE}')
print(f'Epochs: {training_config.num_epochs}')

Batch size: 96
Epochs: 20


In [15]:
# Create dataloaders
train_loader = train_dataset.get_dataloader(
    batch_size=BATCH_SIZE,
    num_workers=2,
    use_weighted_sampling=len(train_dataset) > 0
)
val_loader = val_dataset.get_dataloader(
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    use_weighted_sampling=False
)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

Train batches: 2395
Val batches: 164


## Training

In [20]:
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=training_config,
    tokenizer=tokenizer
)

# Resume from latest checkpoint
import glob
checkpoints = glob.glob('/content/drive/MyDrive/latex_ocr_outputs/latex_ocr_fastvithd_base/checkpoints/step_*.pt')
if checkpoints:
    latest = max(checkpoints, key=lambda x:
int(x.split('_')[-1].split('.')[0]))
    print(f'Resuming from {latest}')
    trainer.load_checkpoint(latest)

print('Starting training...')
best_metric = trainer.train()
print(f'Training complete! Best metric: {best_metric:.4f}')

Resuming from /content/drive/MyDrive/latex_ocr_outputs/latex_ocr_fastvithd_base/checkpoints/step_15500.pt
Loaded checkpoint: epoch 5, step 15500, best_metric 0.0022
Starting training...
Starting training: epochs 5-19, 2396 batches/epoch, step 15500
Step 15600: loss=2.2788, lr=9.62e-05
Step 15700: loss=2.1537, lr=9.61e-05
Step 15800: loss=2.2121, lr=9.60e-05
Step 15900: loss=2.0953, lr=9.60e-05
Step 16000: loss=2.0737, lr=9.59e-05
Step 16100: loss=2.2405, lr=9.58e-05
Step 16200: loss=2.2173, lr=9.58e-05
Step 16300: loss=2.0796, lr=9.57e-05
Step 16400: loss=2.1200, lr=9.56e-05
Step 16500: loss=2.2736, lr=9.55e-05
Step 16600: loss=2.1988, lr=9.55e-05
Step 16700: loss=2.0719, lr=9.54e-05
Step 16800: loss=2.1496, lr=9.53e-05
Step 16900: loss=2.2009, lr=9.53e-05
Step 17000: loss=2.1744, lr=9.52e-05
Step 17100: loss=2.2083, lr=9.51e-05
Step 17200: loss=2.2555, lr=9.50e-05
Step 17300: loss=2.1294, lr=9.50e-05
Step 17400: loss=2.1875, lr=9.49e-05
Step 17500: loss=2.2445, lr=9.48e-05
Step 17600:

In [21]:
# Save final model
save_path = training_config.output_dir + '/final_model'
model.save_pretrained(save_path)
print(f'Saved model to {save_path}')

Saved model to /content/drive/MyDrive/latex_ocr_outputs/final_model


## Test Inference

In [22]:
# Test on a sample
model.eval()
with torch.no_grad():
    # Get a sample from validation set
    sample = val_dataset[0]
    img = sample.image.unsqueeze(0).to(device)

    output = model(img)

    print(f'Ground truth: {sample.latex}')
    if output.predictions and output.predictions[0]:
        pred_latex = output.predictions[0][0][1] if output.predictions[0][0] else ''
        print(f'Predicted: {pred_latex}')
    else:
        print('No prediction generated')

Ground truth: r=g^{e}y(\mod p)
Predicted: m(x)=\int_{1}^{x}\frac{1}{t}dt


# Compare Greedy and Beam Search

In [None]:
from tqdm import tqdm
from training.metrics import compute_metrics

model.eval()
device = torch.device('cuda')
model = model.to(device)

def run_validation(model, val_loader, use_beam=False, beam_size=5):
    all_predictions = []
    all_targets = []

    for batch in tqdm(val_loader, desc=f"{'Beam' if use_beam else 'Greedy'} eval"):
        images = batch['images'].to(device)
        targets = batch['latex']

        with torch.no_grad():
            if use_beam:
                output = model.beam_search(images, beam_size=beam_size)
            else:
                output = model(images)  # greedy

        for pred_regions in output.predictions:
            if pred_regions:
                pred_latex = pred_regions[0][1] if pred_regions[0] else ""
            else:
                pred_latex = ""
            all_predictions.append(pred_latex)

        all_targets.extend(targets)

    return compute_metrics(all_predictions, all_targets)

  # Run both
print("Running greedy validation...")
greedy_metrics = run_validation(model, val_loader, use_beam=False)
print(f"Greedy: {greedy_metrics}")

print("\nRunning beam search validation (k=5)...")
beam_metrics = run_validation(model, val_loader, use_beam=True, beam_size=5)
print(f"Beam:   {beam_metrics}")

  # Compare
print("\n=== Comparison ===")
for key in greedy_metrics:
    diff = beam_metrics[key] - greedy_metrics[key]
    print(f"{key}: {greedy_metrics[key]:.4f} -> {beam_metrics[key]:.4f} ({'+'
if diff > 0 else ''}{diff:.4f})")

Running greedy validation...


Greedy eval: 100%|██████████| 164/164 [25:32<00:00,  9.34s/it]


Greedy: {'exp_rate': 0.0, 'symbol_accuracy': 0.021990507751278983, 'bleu': 0.0}

Running beam search validation (k=5)...


Beam eval:   5%|▍         | 8/164 [25:31<8:17:21, 191.29s/it]