# ViT Encoder + LSTM Decoder
Frozen `google/vit-base-patch16-224` encoder with a trainable LSTM decoder for handwritten math to LaTeX.

In [None]:
!pip install -q transformers datasets

In [None]:
import numpy as np
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from transformers import ViTModel
from datasets import load_dataset
from pathlib import Path

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## 1. Load and preprocess dataset

In [None]:
ds = load_dataset("deepcopy/MathWriting-human")
print(f"Train: {len(ds['train'])}  Val: {len(ds['val'])}  Test: {len(ds['test'])}")

In [None]:
# set to None to use ALL training data
NUM_TRAIN = 2000

ds_train = ds["train"] if NUM_TRAIN is None else ds["train"].select(range(NUM_TRAIN))
print(f"Using {len(ds_train)} training samples")

In [None]:
IMG_SIZE = 128

images, latex_strings = [], []
for sample in ds_train:
    img = sample["image"].convert("L").resize((IMG_SIZE, IMG_SIZE))
    images.append(np.array(img) / 255.0)
    latex_strings.append(sample["latex"])

images = np.array(images)
print(f"Images: {images.shape}")

# char-level tokenizer
chars = sorted(set("".join(latex_strings)))
word_index = {ch: i + 1 for i, ch in enumerate(chars)}
idx_to_char = {v: k for k, v in word_index.items()}

sequences = [[word_index[ch] for ch in s] for s in latex_strings]
max_len = max(len(s) for s in sequences)
padded = np.array([s + [0] * (max_len - len(s)) for s in sequences])

VOCAB_SIZE = len(word_index) + 1
print(f"Vocab size: {VOCAB_SIZE}, Max seq len: {max_len}")

images_tensor = torch.tensor(images, dtype=torch.float32).unsqueeze(1)  # (N, 1, H, W)
tokens_tensor = torch.tensor(padded, dtype=torch.long)
print(f"Images tensor: {images_tensor.shape}")
print(f"Tokens tensor: {tokens_tensor.shape}")

## 2. Prepare for ViT (3-channel, 224x224)

In [None]:
resize = transforms.Resize((224, 224))

images_rgb = images_tensor.repeat(1, 3, 1, 1)
images_resized = torch.stack([resize(img) for img in images_rgb])
print(f"ViT input shape: {images_resized.shape}")

dataset = TensorDataset(images_resized, tokens_tensor)

## 3. Model definition

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, encoder_dim=768):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.enc_to_h = nn.Linear(encoder_dim, hidden_dim)

    def forward(self, x, encoder_features=None, hidden_state=None):
        x = self.embedding(x)
        if hidden_state is None:
            if encoder_features is not None:
                h0 = torch.tanh(self.enc_to_h(encoder_features)).unsqueeze(0)
                c0 = torch.zeros_like(h0)
                output, hidden = self.lstm(x, (h0, c0))
            else:
                output, hidden = self.lstm(x)
        else:
            output, hidden = self.lstm(x, hidden_state)
        logits = self.fc(output)
        return logits, hidden


class ViTLatexModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
        super().__init__()
        self.encoder = ViTModel.from_pretrained("google/vit-base-patch16-224")
        for param in self.encoder.parameters():
            param.requires_grad = False
        encoder_dim = self.encoder.config.hidden_size
        self.decoder = Decoder(vocab_size, embed_dim, hidden_dim, encoder_dim)

    def forward(self, images, targets):
        encoder_out = self.encoder(images).last_hidden_state[:, 0, :]
        logits, _ = self.decoder(targets, encoder_features=encoder_out)
        return logits

    @torch.no_grad()
    def generate(self, image, max_len=100, sos_idx=1, eos_idx=2):
        self.eval()
        encoder_out = self.encoder(image).last_hidden_state[:, 0, :]
        token = torch.tensor([[sos_idx]], device=image.device)
        output_tokens = []
        hidden = None
        for i in range(max_len):
            if i == 0:
                logits, hidden = self.decoder(token, encoder_features=encoder_out)
            else:
                logits, hidden = self.decoder(token, hidden_state=hidden)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            if next_token.item() == eos_idx:
                break
            output_tokens.append(next_token.item())
            token = next_token
        return output_tokens

## 4. Training

In [None]:
BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE = 1e-3

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = ViTLatexModel(vocab_size=VOCAB_SIZE).to(DEVICE)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=LEARNING_RATE)

In [None]:
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch_idx, (imgs, seqs) in enumerate(loader):
        imgs = imgs.to(DEVICE)
        seqs = seqs.to(DEVICE)

        input_tokens = seqs[:, :-1]
        target_tokens = seqs[:, 1:]

        optimizer.zero_grad()
        logits = model(imgs, input_tokens)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), target_tokens.reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if batch_idx % 50 == 0:
            print(f"  Batch {batch_idx}/{len(loader)} | Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1}/{EPOCHS} | Avg Loss: {avg_loss:.4f}")

    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": avg_loss,
    }, checkpoint_dir / f"vit_epoch_{epoch+1}.pt")

torch.save(model.state_dict(), checkpoint_dir / "vit_final.pt")
print("Training complete!")

## 5. Save tokenizer

In [None]:
class CharTokenizer:
    def __init__(self, word_index):
        self.word_index = word_index

with open("latex_tokenizer.pkl", "wb") as f:
    pickle.dump(CharTokenizer(word_index), f)

with open("vocab_size.txt", "w") as f:
    f.write(str(VOCAB_SIZE))

print(f"Saved tokenizer and vocab_size={VOCAB_SIZE}")

## 6. Evaluation on test set

In [None]:
import random


def normalized_edit_distance(pred, target):
    m, n = len(pred), len(target)
    if m == 0 and n == 0:
        return 0.0
    dp = list(range(n + 1))
    for i in range(1, m + 1):
        prev = dp[0]
        dp[0] = i
        for j in range(1, n + 1):
            temp = dp[j]
            if pred[i - 1] == target[j - 1]:
                dp[j] = prev
            else:
                dp[j] = 1 + min(dp[j], dp[j - 1], prev)
            prev = temp
    return dp[n] / max(m, n)


NUM_EVAL = 50
SEED = 42

test_ds = ds["test"]
random.seed(SEED)
eval_indices = random.sample(range(len(test_ds)), NUM_EVAL)
eval_samples = [test_ds[i] for i in eval_indices]

resize_eval = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

model.eval()
exact_matches = 0
total_edit_dist = 0.0

print(f"Evaluating on {NUM_EVAL} test samples...")
print("-" * 60)

for i, sample in enumerate(eval_samples):
    ground_truth = sample["latex"]
    img = sample["image"].convert("L")
    img_tensor = resize_eval(img).repeat(3, 1, 1).unsqueeze(0).to(DEVICE)

    pred_tokens = model.generate(img_tensor)
    prediction = "".join(idx_to_char.get(t, "?") for t in pred_tokens)

    is_exact = prediction == ground_truth
    edit_dist = normalized_edit_distance(prediction, ground_truth)

    if is_exact:
        exact_matches += 1
    total_edit_dist += edit_dist

    status = "EXACT" if is_exact else f"edit_dist={edit_dist:.4f}"
    print(f"  [{i+1}/{NUM_EVAL}] {status}")
    print(f"    GT:   {ground_truth[:80]}")
    print(f"    Pred: {prediction[:80]}")

accuracy = exact_matches / NUM_EVAL
avg_edit_dist = total_edit_dist / NUM_EVAL

print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Model:                    ViT + LSTM")
print(f"Samples:                  {NUM_EVAL}")
print(f"Exact match accuracy:     {accuracy:.2%} ({exact_matches}/{NUM_EVAL})")
print(f"Avg normalized edit dist: {avg_edit_dist:.4f}")