# Handwritten LaTeX OCR Training

Train the unified text spotting model on Google Colab with H100/A100.

In [1]:
!nvidia-smi

Tue Jan 20 21:28:07 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          Off |   00000000:04:00.0 Off |                    0 |
| N/A   36C    P0            114W /  700W |       0MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!git clone https://github.com/markm39/MobileTeXOCR.git
%cd MobileTeXOCR

fatal: destination path 'MobileTeXOCR' already exists and is not an empty directory.
/content/MobileTeXOCR


In [4]:
!cd /content/MobileTeXOCR && git pull


Already up to date.


In [5]:
!pip install -q torch torchvision pillow numpy pyyaml

In [6]:
%cd /content
!rm -rf MobileTeXOCR
!git clone https://github.com/markm39/MobileTeXOCR.git
%cd MobileTeXOCR

/content
Cloning into 'MobileTeXOCR'...
remote: Enumerating objects: 48108, done.[K
remote: Counting objects: 100% (206/206), done.[K
remote: Compressing objects: 100% (71/71), done.[K
remote: Total 48108 (delta 162), reused 144 (delta 135), pack-reused 47902 (from 3)[K
Receiving objects: 100% (48108/48108), 389.75 MiB | 62.27 MiB/s, done.
Resolving deltas: 100% (33526/33526), done.
/content/MobileTeXOCR


In [7]:
import sys
sys.path.insert(0, '/content/MobileTeXOCR')

## Dataset Setup

Choose ONE option below:
- **Option A**: Create dummy data (for testing pipeline)
- **Option B**: Download real datasets (for actual training)

### Option A: Create Dummy Data (for testing)

In [8]:
# Create dummy dataset for testing the pipeline
import os
import json
from PIL import Image, ImageDraw, ImageFont

def create_dummy_dataset(base_dir, num_train=100, num_val=20):
    """Create dummy handwritten math images for testing."""

    expressions = [
        'x^2', 'y^2', 'x+y', 'a-b', '\\frac{1}{2}', '\\sqrt{x}',
        'x^2+y^2', 'a^2-b^2', '\\alpha', '\\beta', '\\gamma',
        '\\sum_{i=1}^{n}', '\\int_0^1', 'e^x', '\\pi r^2',
        '\\frac{a}{b}', 'x_1', 'y_2', 'z^n', '\\theta'
    ]

    for split, num_samples in [('train', num_train), ('val', num_val)]:
        img_dir = os.path.join(base_dir, 'hme100k', split, 'images')
        os.makedirs(img_dir, exist_ok=True)

        labels = {}
        for i in range(num_samples):
            # Create white image
            img = Image.new('RGB', (384, 384), 'white')
            draw = ImageDraw.Draw(img)

            # Draw expression (simplified rendering)
            expr = expressions[i % len(expressions)]
            # Draw some random strokes to simulate handwriting
            import random
            random.seed(i)
            x_start = random.randint(50, 150)
            y_start = random.randint(150, 200)

            # Simple text (in real data this would be actual handwriting)
            try:
                font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 40)
            except:
                font = ImageFont.load_default()

            # Draw display version
            display_text = expr.replace('\\', '').replace('{', '').replace('}', '').replace('_', '').replace('^', '')
            draw.text((x_start, y_start), display_text, fill='black', font=font)

            # Save image
            img_name = f'sample_{i:04d}.png'
            img.save(os.path.join(img_dir, img_name))
            labels[img_name] = expr

        # Save labels
        labels_file = os.path.join(base_dir, 'hme100k', split, 'labels.json')
        with open(labels_file, 'w') as f:
            json.dump(labels, f, indent=2)

        print(f'Created {num_samples} {split} samples in {img_dir}')

# Create dummy data
create_dummy_dataset('./data', num_train=100, num_val=20)
print('Dummy dataset created!')

Created 100 train samples in ./data/hme100k/train/images
Created 20 val samples in ./data/hme100k/val/images
Dummy dataset created!


### Option B: Download Real Datasets (for actual training)

Run ONE or more of the cells below to download real data. MathWriting is recommended as the primary dataset.

In [9]:
# Download MathWriting dataset (230K human + 400K synthetic samples, 2.9GB)
# This is the largest handwritten math expression dataset
!mkdir -p data/mathwriting
!wget -q --show-progress https://storage.googleapis.com/mathwriting_data/mathwriting-2024.tgz -O mathwriting.tgz
!tar -xzf mathwriting.tgz -C data/
!rm mathwriting.tgz

# Check structure and reorganize if needed
import os
import shutil

# The tarball extracts to mathwriting-2024/, we need mathwriting/
if os.path.exists('data/mathwriting-2024') and not os.path.exists('data/mathwriting/train'):
    # Move contents
    for item in os.listdir('data/mathwriting-2024'):
        src = f'data/mathwriting-2024/{item}'
        dst = f'data/mathwriting/{item}'
        if os.path.exists(dst):
            shutil.rmtree(dst) if os.path.isdir(dst) else os.remove(dst)
        shutil.move(src, dst)
    os.rmdir('data/mathwriting-2024')

print('MathWriting directory structure:')
!ls -la data/mathwriting/ | head -20

MathWriting directory structure:
total 620228
drwxr-xr-x 7 root   root       4096 Jan 20 21:29 .
drwxr-xr-x 4 root   root       4096 Jan 20 21:29 ..
-rw-r----- 1 218859 89939      7780 Jan 31  2024 readme.md
drwxr-x--- 2 218859 89939    270336 Jan 31  2024 symbols
-rw-r----- 1 218859 89939    523063 Jan 31  2024 symbols.jsonl
drwxr-x--- 2 218859 89939  18432000 Jan 31  2024 synthetic
-rw-r----- 1 218859 89939 604019016 Jan 31  2024 synthetic-bboxes.jsonl
drwxr-x--- 2 218859 89939    339968 Jan 31  2024 test
drwxr-x--- 2 218859 89939  10747904 Jan 31  2024 train
drwxr-x--- 2 218859 89939    745472 Jan 31  2024 valid


In [10]:
# # Download CROHME from Kaggle (requires Kaggle API key)
# !pip install kaggle
# !mkdir -p ~/.kaggle
# # Upload your kaggle.json or set credentials
# !kaggle datasets download -d xainano/handwrittenmathsymbols
# !unzip -q handwrittenmathsymbols.zip -d data/crohme/

In [11]:
# # Alternative: Download from HuggingFace (if available)
# !pip install huggingface_hub
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="ybelkada/im2latex-100k", local_dir="./data/hme100k", repo_type="dataset")

## Model Setup

In [12]:
import torch
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

Using device: cuda
GPU: NVIDIA H100 80GB HBM3
Memory: 85.2 GB


In [13]:
from models import HandwrittenLaTeXOCR, ModelConfig
from models.decoder.tokenizer import LaTeXTokenizer
from data import DatasetConfig, CombinedDataset, get_train_transforms, get_eval_transforms
from training import Trainer, TrainingConfig

tokenizer = LaTeXTokenizer()
print(f'Tokenizer vocab size: {tokenizer.vocab_size}')

Tokenizer vocab size: 1294


In [14]:
# Model configuration
# Use 'small' for testing, 'base' for real training
ENCODER = 'fastvithd'
ENCODER_SIZE = 'base'  # Change to 'base' for full training

model_config = ModelConfig(
    encoder_type=ENCODER,
    encoder_size=ENCODER_SIZE,
    image_size=384,
    d_model=256 if ENCODER_SIZE == 'small' else 384,
    num_decoder_layers=4 if ENCODER_SIZE == 'small' else 6,
    freeze_encoder=True,
)

model = HandwrittenLaTeXOCR(model_config)
print(f'Model parameters: {model.count_parameters():,}')

Model parameters: 15,195,406


In [15]:
# Dataset configuration
dataset_config = DatasetConfig(data_dir='./data', image_size=384)

train_transform = get_train_transforms(image_size=384, augment_strength='medium')
valid_transform = get_eval_transforms(image_size=384)

# Check available datasets
import os
available_datasets = []
for ds in ['mathwriting', 'crohme', 'hme100k']:
    if os.path.exists(f'./data/{ds}'):
        available_datasets.append(ds)

print(f'Available datasets: {available_datasets}')

if not available_datasets:
    raise RuntimeError('No datasets found! Run the dataset setup cells above first.')

train_dataset = CombinedDataset(
    dataset_config, split='train', transform=train_transform,
    tokenizer=tokenizer, datasets=available_datasets
)
val_dataset = CombinedDataset(
    dataset_config, split='val', transform=valid_transform,
    tokenizer=tokenizer, datasets=available_datasets
)

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')

Available datasets: ['mathwriting', 'hme100k']
Train samples: 229964
Val samples: 15694


In [16]:
# Training configuration
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory
    BATCH_SIZE = 96 if gpu_memory > 70e9 else (32 if gpu_memory > 40e9 else 16)
else:
    BATCH_SIZE = 4

training_config = TrainingConfig(
    output_dir='/content/drive/MyDrive/latex_ocr_outputs',
    experiment_name=f'latex_ocr_{ENCODER}_{ENCODER_SIZE}',
    num_epochs=5 if len(train_dataset) < 1000 else 20,  # Fewer epochs for dummy data
    batch_size=BATCH_SIZE,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=100 if len(train_dataset) < 1000 else 2000,
    gradient_accumulation_steps=2,
    use_amp=True,
    amp_dtype='bfloat16' if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else 'float16',
    save_steps=500,
    validation_steps=100 if len(train_dataset) < 1000 else 1000,
    log_steps=10 if len(train_dataset) < 1000 else 100,
    freeze_encoder_epochs=1,
    early_stopping_patience=5,
)

print(f'Batch size: {BATCH_SIZE}')
print(f'Epochs: {training_config.num_epochs}')

Batch size: 96
Epochs: 20


In [17]:
# Create dataloaders
train_loader = train_dataset.get_dataloader(
    batch_size=BATCH_SIZE,
    num_workers=2,
    use_weighted_sampling=len(train_dataset) > 0
)
val_loader = val_dataset.get_dataloader(
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    use_weighted_sampling=False
)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

Train batches: 2396
Val batches: 164


## Training

In [None]:
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=training_config,
    tokenizer=tokenizer
)

# Resume from latest checkpoint
import glob
checkpoints = glob.glob('/content/drive/MyDrive/latex_ocr_outputs/latex_ocr_fa
stvithd_base/checkpoints/step_*.pt')
if checkpoints:
    latest = max(checkpoints, key=lambda x:
int(x.split('_')[-1].split('.')[0]))
    print(f'Resuming from {latest}')
    trainer.load_checkpoint(latest)

print('Starting training...')
best_metric = trainer.train()
print(f'Training complete! Best metric: {best_metric:.4f}')

Starting training...
Starting training: 20 epochs, 2396 batches/epoch
Epoch 0: Encoder frozen
Step 100: loss=6.6145, lr=3.48e-06
Step 200: loss=5.9626, lr=5.95e-06
Step 300: loss=5.5109, lr=8.43e-06
Step 400: loss=5.2729, lr=1.09e-05
Step 500: loss=4.9977, lr=1.34e-05
Step 600: loss=4.7054, lr=1.59e-05
Step 700: loss=4.4679, lr=1.83e-05
Step 800: loss=4.3911, lr=2.08e-05
Step 900: loss=4.1543, lr=2.33e-05
Step 1000: loss=4.0816, lr=2.58e-05
Step 1100: loss=4.1243, lr=2.82e-05
Step 1200: loss=3.9988, lr=3.07e-05
Step 1300: loss=3.8649, lr=3.32e-05
Step 1400: loss=3.9985, lr=3.57e-05
Step 1500: loss=3.6639, lr=3.81e-05
Step 1600: loss=3.6628, lr=4.06e-05
Step 1700: loss=3.7225, lr=4.31e-05
Step 1800: loss=3.6716, lr=4.56e-05
Step 1900: loss=3.6063, lr=4.80e-05
Step 2000: loss=3.5979, lr=5.05e-05
Step 2100: loss=3.3311, lr=5.30e-05
Step 2200: loss=3.3845, lr=5.55e-05
Step 2300: loss=3.2461, lr=5.79e-05
Epoch 0 train: {'loss': 4.320017835035149}
Epoch 0 val: {'exp_rate': 0.0002548744743213

In [None]:
# Save final model
save_path = training_config.output_dir + '/final_model'
model.save_pretrained(save_path)
print(f'Saved model to {save_path}')

## Test Inference

In [None]:
# Test on a sample
model.eval()
with torch.no_grad():
    # Get a sample from validation set
    sample = val_dataset[0]
    img = sample.image.unsqueeze(0).to(device)

    output = model(img)

    print(f'Ground truth: {sample.latex}')
    if output.predictions and output.predictions[0]:
        pred_latex = output.predictions[0][0][1] if output.predictions[0][0] else ''
        print(f'Predicted: {pred_latex}')
    else:
        print('No prediction generated')