# Handwritten LaTeX OCR Training

Train the unified text spotting model on Google Colab with H100/A100.

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/markm39/MobileTeXOCR.git
%cd MobileTeXOCR

In [None]:
!pip install -q torch torchvision pillow numpy pyyaml

In [None]:
import sys
sys.path.insert(0, '/content/MobileTeXOCR')

## Dataset Setup

Choose ONE option below:
- **Option A**: Create dummy data (for testing pipeline)
- **Option B**: Download real datasets (for actual training)

### Option A: Create Dummy Data (for testing)

In [None]:
# Create dummy dataset for testing the pipeline
import os
import json
from PIL import Image, ImageDraw, ImageFont

def create_dummy_dataset(base_dir, num_train=100, num_val=20):
    """Create dummy handwritten math images for testing."""
    
    expressions = [
        'x^2', 'y^2', 'x+y', 'a-b', '\\frac{1}{2}', '\\sqrt{x}',
        'x^2+y^2', 'a^2-b^2', '\\alpha', '\\beta', '\\gamma',
        '\\sum_{i=1}^{n}', '\\int_0^1', 'e^x', '\\pi r^2',
        '\\frac{a}{b}', 'x_1', 'y_2', 'z^n', '\\theta'
    ]
    
    for split, num_samples in [('train', num_train), ('val', num_val)]:
        img_dir = os.path.join(base_dir, 'hme100k', split, 'images')
        os.makedirs(img_dir, exist_ok=True)
        
        labels = {}
        for i in range(num_samples):
            # Create white image
            img = Image.new('RGB', (384, 384), 'white')
            draw = ImageDraw.Draw(img)
            
            # Draw expression (simplified rendering)
            expr = expressions[i % len(expressions)]
            # Draw some random strokes to simulate handwriting
            import random
            random.seed(i)
            x_start = random.randint(50, 150)
            y_start = random.randint(150, 200)
            
            # Simple text (in real data this would be actual handwriting)
            try:
                font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 40)
            except:
                font = ImageFont.load_default()
            
            # Draw display version
            display_text = expr.replace('\\', '').replace('{', '').replace('}', '').replace('_', '').replace('^', '')
            draw.text((x_start, y_start), display_text, fill='black', font=font)
            
            # Save image
            img_name = f'sample_{i:04d}.png'
            img.save(os.path.join(img_dir, img_name))
            labels[img_name] = expr
        
        # Save labels
        labels_file = os.path.join(base_dir, 'hme100k', split, 'labels.json')
        with open(labels_file, 'w') as f:
            json.dump(labels, f, indent=2)
        
        print(f'Created {num_samples} {split} samples in {img_dir}')

# Create dummy data
create_dummy_dataset('./data', num_train=100, num_val=20)
print('Dummy dataset created!')

### Option B: Download Real Datasets (for actual training)

Uncomment and run the cells below to download real data.

In [None]:
# # Download MathWriting dataset (630K samples) - requires gsutil auth
# !pip install gsutil
# !mkdir -p data/mathwriting
# !gsutil -m cp -r gs://mathwriting_data/train ./data/mathwriting/
# !gsutil -m cp -r gs://mathwriting_data/val ./data/mathwriting/

In [None]:
# # Download CROHME from Kaggle (requires Kaggle API key)
# !pip install kaggle
# !mkdir -p ~/.kaggle
# # Upload your kaggle.json or set credentials
# !kaggle datasets download -d xainano/handwrittenmathsymbols
# !unzip -q handwrittenmathsymbols.zip -d data/crohme/

In [None]:
# # Alternative: Download from HuggingFace (if available)
# !pip install huggingface_hub
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="ybelkada/im2latex-100k", local_dir="./data/hme100k", repo_type="dataset")

## Model Setup

In [None]:
import torch
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
from models import HandwrittenLaTeXOCR, ModelConfig
from models.decoder.tokenizer import LaTeXTokenizer
from data import DatasetConfig, CombinedDataset, get_train_transforms, get_eval_transforms
from training import Trainer, TrainingConfig

tokenizer = LaTeXTokenizer()
print(f'Tokenizer vocab size: {tokenizer.vocab_size}')

In [None]:
# Model configuration
# Use 'small' for testing, 'base' for real training
ENCODER = 'fastvithd'
ENCODER_SIZE = 'small'  # Change to 'base' for full training

model_config = ModelConfig(
    encoder_type=ENCODER,
    encoder_size=ENCODER_SIZE,
    image_size=384,
    d_model=256 if ENCODER_SIZE == 'small' else 384,
    num_decoder_layers=4 if ENCODER_SIZE == 'small' else 6,
    freeze_encoder=True,
)

model = HandwrittenLaTeXOCR(model_config)
print(f'Model parameters: {model.count_parameters():,}')

In [None]:
# Dataset configuration
dataset_config = DatasetConfig(data_dir='./data', image_size=384)

train_transform = get_train_transforms(image_size=384, augment_strength='medium')
valid_transform = get_eval_transforms(image_size=384)

# Check available datasets
import os
available_datasets = []
for ds in ['mathwriting', 'crohme', 'hme100k']:
    if os.path.exists(f'./data/{ds}'):
        available_datasets.append(ds)

print(f'Available datasets: {available_datasets}')

if not available_datasets:
    raise RuntimeError('No datasets found! Run the dataset setup cells above first.')

train_dataset = CombinedDataset(
    dataset_config, split='train', transform=train_transform, 
    tokenizer=tokenizer, datasets=available_datasets
)
val_dataset = CombinedDataset(
    dataset_config, split='val', transform=valid_transform, 
    tokenizer=tokenizer, datasets=available_datasets
)

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')

In [None]:
# Training configuration
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory
    BATCH_SIZE = 48 if gpu_memory > 70e9 else (32 if gpu_memory > 40e9 else 16)
else:
    BATCH_SIZE = 4

training_config = TrainingConfig(
    output_dir='/content/drive/MyDrive/latex_ocr_outputs',
    experiment_name=f'latex_ocr_{ENCODER}_{ENCODER_SIZE}',
    num_epochs=5 if len(train_dataset) < 1000 else 20,  # Fewer epochs for dummy data
    batch_size=BATCH_SIZE,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=100 if len(train_dataset) < 1000 else 2000,
    gradient_accumulation_steps=2,
    use_amp=True,
    amp_dtype='bfloat16' if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else 'float16',
    save_steps=500,
    validation_steps=100 if len(train_dataset) < 1000 else 1000,
    log_steps=10 if len(train_dataset) < 1000 else 100,
    freeze_encoder_epochs=1,
    early_stopping_patience=5,
)

print(f'Batch size: {BATCH_SIZE}')
print(f'Epochs: {training_config.num_epochs}')

In [None]:
# Create dataloaders
train_loader = train_dataset.get_dataloader(
    batch_size=BATCH_SIZE, 
    num_workers=2, 
    use_weighted_sampling=len(train_dataset) > 0
)
val_loader = val_dataset.get_dataloader(
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=2, 
    use_weighted_sampling=False
)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

## Training

In [None]:
trainer = Trainer(
    model=model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    config=training_config, 
    tokenizer=tokenizer
)

print('Starting training...')
best_metric = trainer.train()
print(f'Training complete! Best metric: {best_metric:.4f}')

In [None]:
# Save final model
save_path = training_config.output_dir + '/final_model'
model.save_pretrained(save_path)
print(f'Saved model to {save_path}')

## Test Inference

In [None]:
# Test on a sample
model.eval()
with torch.no_grad():
    # Get a sample from validation set
    sample = val_dataset[0]
    img = sample.image.unsqueeze(0).to(device)
    
    output = model(img)
    
    print(f'Ground truth: {sample.latex}')
    if output.predictions and output.predictions[0]:
        pred_latex = output.predictions[0][0][1] if output.predictions[0][0] else ''
        print(f'Predicted: {pred_latex}')
    else:
        print('No prediction generated')