# ViT Encoder + LSTM Decoder Training
Frozen `google/vit-base-patch16-224` encoder with a trainable LSTM decoder for handwritten math â†’ LaTeX.

In [32]:
!pip install -q transformers datasets

^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [11]:
import numpy as np
import pickle
import random
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from transformers import ViTModel
from pathlib import Path
from datasets import load_from_disk
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


## 1. Load & preprocess dataset

In [2]:
for f in sorted(Path("/kaggle/input/datasets").rglob("*")):
    print(f)

/kaggle/input/datasets/dustinp11
/kaggle/input/datasets/dustinp11/mathwriting
/kaggle/input/datasets/dustinp11/mathwriting/dataset_dict.json
/kaggle/input/datasets/dustinp11/mathwriting/images_train.pt
/kaggle/input/datasets/dustinp11/mathwriting/test
/kaggle/input/datasets/dustinp11/mathwriting/test/data-00000-of-00001.arrow
/kaggle/input/datasets/dustinp11/mathwriting/test/dataset_info.json
/kaggle/input/datasets/dustinp11/mathwriting/test/state.json
/kaggle/input/datasets/dustinp11/mathwriting/tokens_train.pt
/kaggle/input/datasets/dustinp11/mathwriting/train
/kaggle/input/datasets/dustinp11/mathwriting/train/data-00000-of-00003.arrow
/kaggle/input/datasets/dustinp11/mathwriting/train/data-00001-of-00003.arrow
/kaggle/input/datasets/dustinp11/mathwriting/train/data-00002-of-00003.arrow
/kaggle/input/datasets/dustinp11/mathwriting/train/dataset_info.json
/kaggle/input/datasets/dustinp11/mathwriting/train/state.json
/kaggle/input/datasets/dustinp11/mathwriting/val
/kaggle/input/datase

In [48]:
print("Loading dataset...")
ds = load_from_disk("/kaggle/input/datasets/dustinp11/mathwriting")
num_samples = 40000
ds_train = ds["train"].select(range(num_samples))

# 2. Pre-allocate Image Array (Saves RAM by avoiding copies)
# 50,000 * 256 * 256 * 4 bytes = ~13.1 GB
print(f"Pre-allocating memory for {num_samples} images...")
images_array = np.zeros((num_samples, 256, 256), dtype=np.float32)
latex_strings = []

# 3. Process Images & Collect Strings
print("Processing images and LaTeX strings...")
for i in range(num_samples):
    sample = ds_train[i]
    # Convert and resize directly into the array
    img = sample["image"].convert("L").resize((256, 256))
    images_array[i] = np.array(img, dtype=np.float32) / 255.0
    latex_strings.append(sample["latex"])
    
    if (i + 1) % 5000 == 0:
        print(f"Progress: {i + 1}/{num_samples}")

# 4. Setup Tokenizer
print("Fitting tokenizer...")
tokenizer = Tokenizer(char_level=True)
tokenizer.fit_on_texts(latex_strings)

# Add special tokens
tokenizer.word_index["<START>"] = len(tokenizer.word_index) + 1
tokenizer.word_index["<END>"] = len(tokenizer.word_index) + 1
tokenizer.index_word[tokenizer.word_index["<START>"]] = "<START>"
tokenizer.index_word[tokenizer.word_index["<END>"]] = "<END>"

START_ID = tokenizer.word_index["<START>"]
END_ID   = tokenizer.word_index["<END>"]

# 5. Sequence Padding
print("Tokenizing and padding sequences...")
sequences = tokenizer.texts_to_sequences(latex_strings)
sequences = [[START_ID] + seq + [END_ID] for seq in sequences]
padded_sequences = pad_sequences(sequences, padding="post")

# 6. Save Tokenizer and Vocab Info
print("Saving metadata...")
with open("/kaggle/working/latex_tokenizer256.pkl", "wb") as f:
    pickle.dump(tokenizer, f)

vocab_size = len(tokenizer.word_index) + 1
with open("/kaggle/working/vocab_size.txt", "w") as f:
    f.write(str(vocab_size))

# 7. Convert to Tensors and Save (Disk usage check: ~13.5GB total)
print("Converting to Tensors...")
# torch.from_numpy avoids a RAM copy
images_tensor = torch.from_numpy(images_array).unsqueeze(1) 
tokens_tensor = torch.tensor(padded_sequences, dtype=torch.long)

print("Saving tensors to disk (this takes a minute)...")
torch.save(images_tensor, "/kaggle/working/images_train256.pt")
torch.save(tokens_tensor, "/kaggle/working/tokens_train256.pt")

print("Done!")
print(f"Final Vocab Size: {vocab_size}")
print(f"Image Tensor Shape: {images_tensor.shape}")




Loading dataset...
Pre-allocating memory for 40000 images...
Processing images and LaTeX strings...
Progress: 5000/40000
Progress: 10000/40000
Progress: 15000/40000
Progress: 20000/40000
Progress: 25000/40000
Progress: 30000/40000
Progress: 35000/40000
Progress: 40000/40000
Fitting tokenizer...
Tokenizing and padding sequences...
Saving metadata...
Converting to Tensors...
Saving tensors to disk (this takes a minute)...
Done!
Final Vocab Size: 66
Image Tensor Shape: torch.Size([40000, 1, 256, 256])


In [3]:
ds = load_from_disk("/kaggle/input/datasets/dustinp11/mathwriting")

ds_val = ds["val"].select(range(5000))

images, sequences = [], []

def preprocess_image(img, target_size=(256, 256)):
    img = img.convert("L")  # convert to grayscale
    img = img.resize(target_size)
    img = np.array(img) / 255.0  # normalize to [0, 1]
    return img

for sample in ds_val:
    img = preprocess_image(sample["image"])
    images.append(img)
    sequences.append(sample["latex"])

images = np.array(images)
with open("/kaggle/working/latex_tokenizer256.pkl", "rb") as f:
    tokenizer = pickle.load(f)

START_ID = tokenizer.word_index["<START>"]
END_ID   = tokenizer.word_index["<END>"]

seqs = tokenizer.texts_to_sequences(sequences)
seqs = [[START_ID] + s + [END_ID] for s in seqs]

padded_sequences = pad_sequences(seqs, padding="post")
images = images[..., np.newaxis]

images_tensor = torch.tensor(images, dtype=torch.float32).permute(0, 3, 1, 2)
tokens_tensor = torch.tensor(padded_sequences, dtype=torch.long)

torch.save(images_tensor, "/kaggle/working/images_val256.pt")
torch.save(tokens_tensor, "/kaggle/working/tokens_val256.pt")
print(1)

1


## 2. Model definition

In [4]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, encoder_dim=768):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.enc_to_h = nn.Linear(encoder_dim, hidden_dim)

    def forward(self, x, encoder_features=None, hidden_state=None):
        x = self.embedding(x)
        if hidden_state is None:
            if encoder_features is not None:
                h0 = torch.tanh(self.enc_to_h(encoder_features)).unsqueeze(0)
                c0 = torch.zeros_like(h0)
                output, hidden = self.lstm(x, (h0, c0))
            else:
                output, hidden = self.lstm(x)
        else:
            output, hidden = self.lstm(x, hidden_state)
        logits = self.fc(output)
        return logits, hidden

In [5]:
class ViTLatexModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
        super().__init__()
        from transformers import MobileViTModel
        self.encoder = MobileViTModel.from_pretrained("apple/mobilevit-small")
        for param in self.encoder.parameters():
            param.requires_grad = False
        encoder_dim = self.encoder.config.neck_hidden_sizes[-1]  # 640
        self.decoder = Decoder(vocab_size, embed_dim, hidden_dim, encoder_dim)
    
    def forward(self, images, targets):
        encoder_out = self.encoder(images).pooler_output  # Use pooler output instead
        logits, _ = self.decoder(targets, encoder_features=encoder_out)
        return logits
    
    @torch.no_grad()
    def generate(self, image, max_len=100, sos_idx=1, eos_idx=2):
        self.eval()
        encoder_out = self.encoder(image).pooler_output
        token = torch.tensor([[sos_idx]], device=image.device)
        output_tokens = []
        hidden = None
        for i in range(max_len):
            if i == 0:
                logits, hidden = self.decoder(token, encoder_features=encoder_out)
            else:
                logits, hidden = self.decoder(token, hidden_state=hidden)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            if next_token.item() == eos_idx:
                break
            output_tokens.append(next_token.item())
            token = next_token
        return output_tokens

## 3. Training

In [6]:
# Load vocab size
with open("/kaggle/working/vocab_size.txt") as f:
    VOCAB_SIZE = int(f.read().strip())

# Hyperparameters
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 1e-3

# Load pre-processed tensors (just load, don't convert yet)
images_tensor = torch.load("/kaggle/working/images_train256.pt")  # (40000, 1, 256, 256)
tokens_tensor = torch.load("/kaggle/working/tokens_train256.pt")  # (40000, seq_len)

print(f"Images: {images_tensor.shape}, Tokens: {tokens_tensor.shape}, Vocab size: {VOCAB_SIZE}")

# Create dataset and loader
dataset = TensorDataset(images_tensor, tokens_tensor)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model
model = ViTLatexModel(vocab_size=VOCAB_SIZE).to(DEVICE)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=LEARNING_RATE)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for batch_idx, (imgs, seqs) in enumerate(loader):
        # Convert grayscale to RGB
        imgs = imgs.repeat(1, 3, 1, 1)  # (B, 1, 256, 256) -> (B, 3, 256, 256)
        imgs = imgs.to(DEVICE)
        seqs = seqs.to(DEVICE)
        
        # Teacher forcing
        input_tokens = seqs[:, :-1]   # (B, seq_len-1)
        target_tokens = seqs[:, 1:]   # (B, seq_len-1)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(imgs, input_tokens)  # (B, seq_len-1, vocab_size)
        
        # Compute loss
        loss = criterion(
            logits.reshape(-1, VOCAB_SIZE),  # (B * (seq_len-1), vocab_size)
            target_tokens.reshape(-1)         # (B * (seq_len-1))
        )
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Print every 100 batches
        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx}/{len(loader)} | Loss: {loss.item():.4f}")
    
    print(f"Epoch {epoch + 1}/{EPOCHS} | Avg Loss: {total_loss / len(loader):.4f}")

# Save model
SAVE_PATH = "/kaggle/working/mobilevit_model256.pt"
torch.save({
    "model": model.state_dict()
}, SAVE_PATH)

print(f"Model saved to {SAVE_PATH}")

Images: torch.Size([40000, 1, 256, 256]), Tokens: torch.Size([40000, 159]), Vocab size: 66
Trainable params: 1,955,906 / 6,893,538
  Batch 0/2500 | Loss: 4.1920
  Batch 100/2500 | Loss: 1.4481
  Batch 200/2500 | Loss: 1.2579
  Batch 300/2500 | Loss: 1.4113
  Batch 400/2500 | Loss: 1.2067
  Batch 500/2500 | Loss: 1.0980
  Batch 600/2500 | Loss: 1.1920
  Batch 700/2500 | Loss: 1.3405
  Batch 800/2500 | Loss: 1.3882
  Batch 900/2500 | Loss: 1.1913
  Batch 1000/2500 | Loss: 1.3010
  Batch 1100/2500 | Loss: 1.0069
  Batch 1200/2500 | Loss: 1.1402
  Batch 1300/2500 | Loss: 1.0790
  Batch 1400/2500 | Loss: 1.1403
  Batch 1500/2500 | Loss: 1.0927
  Batch 1600/2500 | Loss: 1.0478
  Batch 1700/2500 | Loss: 1.2688
  Batch 1800/2500 | Loss: 1.0919
  Batch 1900/2500 | Loss: 1.1032
  Batch 2000/2500 | Loss: 1.0559
  Batch 2100/2500 | Loss: 0.9818
  Batch 2200/2500 | Loss: 1.1037
  Batch 2300/2500 | Loss: 1.0657
  Batch 2400/2500 | Loss: 1.3239
Epoch 1/5 | Avg Loss: 1.1597
  Batch 0/2500 | Loss: 1.23

## 4. Evals

In [9]:
def normalized_edit_distance(s1, s2):
    if len(s1) == 0 and len(s2) == 0:
        return 0.0
    if len(s1) == 0 or len(s2) == 0:
        return 1.0
    
    # Levenshtein distance
    d = [[0] * (len(s2) + 1) for _ in range(len(s1) + 1)]
    for i in range(len(s1) + 1):
        d[i][0] = i
    for j in range(len(s2) + 1):
        d[0][j] = j
    
    for i in range(1, len(s1) + 1):
        for j in range(1, len(s2) + 1):
            cost = 0 if s1[i-1] == s2[j-1] else 1
            d[i][j] = min(d[i-1][j] + 1, d[i][j-1] + 1, d[i-1][j-1] + cost)
    
    return d[len(s1)][len(s2)] / max(len(s1), len(s2))

In [10]:
import torch
from pickle import load

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("/kaggle/working/vocab_size.txt") as f:
    VOCAB_SIZE = int(f.read().strip())

START_TOKEN = VOCAB_SIZE - 2
END_TOKEN = VOCAB_SIZE - 1
MAX_LEN = 150

# Load model
model = ViTLatexModel(vocab_size=VOCAB_SIZE).to(DEVICE)
checkpoint = torch.load("/kaggle/working/mobilevit_model256.pt", map_location=DEVICE)
model.load_state_dict(checkpoint["model"])
model.eval()

# Load validation data
images = torch.load("/kaggle/working/images_val256.pt")
tokens = torch.load("/kaggle/working/tokens_val256.pt")

with open("/kaggle/working/latex_tokenizer256.pkl", "rb") as f:
    tokenizer = load(f)

inv_vocab = {v: k for k, v in tokenizer.word_index.items()}

def decode(seq):
    # Filter out start and end tokens
    filtered = [t for t in seq if t != START_TOKEN and t != END_TOKEN and t != 0]
    return "".join(inv_vocab.get(t, "") for t in filtered)

# Inference
N = 50
exact_matches = 0
total_edit_dist = 0.0

print(f"Evaluating on {N} test samples...")
print("-" * 60)

for i in range(N):
    img = images[i:i+1]  # (1, 1, 256, 256)
    img = img.repeat(1, 3, 1, 1).to(DEVICE)  # (1, 3, 256, 256)
    gt_tokens = tokens[i]
    
    pred_tokens = model.generate(img, max_len=MAX_LEN, sos_idx=START_TOKEN, eos_idx=END_TOKEN)
    
    ground_truth = decode(gt_tokens.tolist())
    prediction = decode(pred_tokens)
    
    is_exact = prediction == ground_truth
    edit_dist = normalized_edit_distance(prediction, ground_truth)
    
    if is_exact:
        exact_matches += 1
    total_edit_dist += edit_dist
    
    status = "EXACT" if is_exact else f"edit_dist={edit_dist:.4f}"
    print(f"  [{i+1}/{N}] {status}")
    print(f"    GT:   {ground_truth[:80]}")
    print(f"    PRED: {prediction[:80]}")
    print("-"*40)

accuracy = exact_matches / N
avg_edit_dist = total_edit_dist / N

print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Model:                    MobileViT + LSTM")
print(f"Samples:                  {N}")
print(f"Exact match accuracy:     {accuracy:.2%} ({exact_matches}/{N})")
print(f"Avg normalized edit dist: {avg_edit_dist:.4f}")

Evaluating on 50 test samples...
------------------------------------------------------------
  [1/50] edit_dist=0.2703
    GT:   \frac{\partial\psi}{\partial t}=p\psi
    PRED: \frac{\partial f}{\partial z}=0
----------------------------------------
  [2/50] edit_dist=0.7586
    GT:   c_{0}=d/v_{blood}
    PRED: \frac{d}{dt}n_{2}=-c_{1}n_{1}
----------------------------------------
  [3/50] edit_dist=0.9600
    GT:   e^{2}/\hbar c\approx1/137
    PRED: \{\sqrt{\pi},2\pi\sigma\}
----------------------------------------
  [4/50] edit_dist=0.8214
    GT:   nassoc(s,\overline{s})
    PRED: \tilde{\phi}(x)=\tilde{x}(t)
----------------------------------------
  [5/50] edit_dist=0.8065
    GT:   eq.1\frac{dv}{dx}=w
    PRED: \frac{\partial f}{\partial z}=0
----------------------------------------
  [6/50] edit_dist=0.7692
    GT:   \underline{r}
    PRED: \mathbb{n}
----------------------------------------
  [7/50] edit_dist=0.8000
    GT:   \hat{k}_{g}
    PRED: c_{\tilde{\nu}}
-----------