# Handwritten LaTeX OCR Training

Train the unified text spotting model on Google Colab with H100/A100.

**Requirements:**
- Colab Pro/Pro+ for H100/A100 access
- Google Drive for checkpoints

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/YOUR_USERNAME/MobileTeXOCR.git
%cd MobileTeXOCR

In [None]:
!pip install -q torch torchvision pillow numpy pyyaml wandb
!pip install -q -e .

In [None]:
!mkdir -p data

In [None]:
import sys
sys.path.insert(0, '.')

import torch
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
from models import HandwrittenLaTeXOCR, ModelConfig
from models.decoder.tokenizer import LaTeXTokenizer
from data import DatasetConfig, CombinedDataset, get_train_transforms, get_eval_transforms
from training import Trainer, TrainingConfig

tokenizer = LaTeXTokenizer()
print(f'Tokenizer vocab size: {tokenizer.vocab_size}')

In [None]:
ENCODER = 'fastvithd'
ENCODER_SIZE = 'base'

model_config = ModelConfig(
    encoder_type=ENCODER,
    encoder_size=ENCODER_SIZE,
    image_size=384,
    d_model=384,
    num_decoder_layers=6,
    freeze_encoder=True,
)

model = HandwrittenLaTeXOCR(model_config)
print(f'Model parameters: {model.count_parameters():,}')

In [None]:
dataset_config = DatasetConfig(data_dir='./data', image_size=384)

train_transform = get_train_transforms(image_size=384, augment_strength='medium')
valid_transform = get_eval_transforms(image_size=384)

import os
available_datasets = []
if os.path.exists('./data/mathwriting'): available_datasets.append('mathwriting')
if os.path.exists('./data/crohme'): available_datasets.append('crohme')
if os.path.exists('./data/hme100k'): available_datasets.append('hme100k')

print(f'Available datasets: {available_datasets}')

train_dataset = CombinedDataset(dataset_config, split='train', transform=train_transform, tokenizer=tokenizer, datasets=available_datasets)
val_dataset = CombinedDataset(dataset_config, split='val', transform=valid_transform, tokenizer=tokenizer, datasets=available_datasets)

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')

In [None]:
BATCH_SIZE = 48 if torch.cuda.get_device_properties(0).total_memory > 70e9 else 32

training_config = TrainingConfig(
    output_dir='/content/drive/MyDrive/latex_ocr_outputs',
    experiment_name=f'latex_ocr_{ENCODER}',
    num_epochs=20,
    batch_size=BATCH_SIZE,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=2000,
    gradient_accumulation_steps=2,
    use_amp=True,
    amp_dtype='bfloat16' if torch.cuda.get_device_capability()[0] >= 8 else 'float16',
    save_steps=2000,
    validation_steps=1000,
    log_steps=100,
    freeze_encoder_epochs=1,
    early_stopping_patience=5,
)

print(f'Batch size: {BATCH_SIZE}')

In [None]:
train_loader = train_dataset.get_dataloader(batch_size=BATCH_SIZE, num_workers=4, use_weighted_sampling=True)
val_loader = val_dataset.get_dataloader(batch_size=BATCH_SIZE, shuffle=False, num_workers=4, use_weighted_sampling=False)

In [None]:
trainer = Trainer(model=model, train_loader=train_loader, val_loader=val_loader, config=training_config, tokenizer=tokenizer)
best_metric = trainer.train()

In [None]:
model.save_pretrained(training_config.output_dir + '/final_model')
print(f'Saved model to {training_config.output_dir}/final_model')