# ViT Encoder + LSTM Decoder Training
Frozen `facebook/dinov2-base` encoder with a trainable LSTM decoder for handwritten math â†’ LaTeX.

In [None]:
import numpy as np
import pickle
import random
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from transformers import ViTModel
from pathlib import Path
from datasets import load_from_disk, load_dataset
from peft import LoraConfig, get_peft_model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast
from transformers import AutoModel

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## 1. Load & preprocess dataset

In [None]:
print("Loading dataset...")
ds = load_dataset("deepcopy/MathWriting-human")
num_samples = 40000
ds_train = ds["train"].select(range(num_samples))

# 2. Pre-allocate Image Array (Saves RAM by avoiding copies)
print(f"Pre-allocating memory for {num_samples} images...")
images_array = np.zeros((num_samples, 224, 224, 3), dtype=np.uint8) # Uses uint8 instead of float to save RAM
latex_strings = []

# 3. Process Images & Collect Strings
print("Processing images and LaTeX strings...")
for i in range(num_samples):
    sample = ds_train[i]
    # Convert and resize directly into the array
    img = sample["image"].convert("RGB").resize((224, 224))
    images_array[i] = np.array(img, dtype=np.uint8)
    latex_strings.append(sample["latex"])
    
    if (i + 1) % 5000 == 0:
        print(f"Progress: {i + 1}/{num_samples}")

# 4. Setup Tokenizer
print("Fitting tokenizer...")
tokenizer = Tokenizer(char_level=True)
tokenizer.fit_on_texts(latex_strings)

# Add special tokens
tokenizer.word_index["<START>"] = len(tokenizer.word_index) + 1
tokenizer.word_index["<END>"] = len(tokenizer.word_index) + 1
tokenizer.index_word[tokenizer.word_index["<START>"]] = "<START>"
tokenizer.index_word[tokenizer.word_index["<END>"]] = "<END>"

START_ID = tokenizer.word_index["<START>"]
END_ID   = tokenizer.word_index["<END>"]

# 5. Sequence Padding
print("Tokenizing and padding sequences...")
sequences = tokenizer.texts_to_sequences(latex_strings)
sequences = [[START_ID] + seq + [END_ID] for seq in sequences]
padded_sequences = pad_sequences(sequences, padding="post")

# 6. Save Tokenizer and Vocab Info
print("Saving metadata...")
with open("/kaggle/working/latex_tokenizer.pkl", "wb") as f:
    pickle.dump(tokenizer, f)

vocab_size = len(tokenizer.word_index) + 1
with open("/kaggle/working/vocab_size.txt", "w") as f:
    f.write(str(vocab_size))

# 7. Convert to Tensors and Save (Disk usage check: ~13.5GB total)
print("Converting to Tensors...")
# torch.from_numpy avoids a RAM copy
images_tensor = torch.from_numpy(images_array).permute(0, 3, 1, 2) 
tokens_tensor = torch.tensor(padded_sequences, dtype=torch.long)

print("Saving tensors to disk (this takes a minute)...")
torch.save(images_tensor, "/kaggle/working/images_train.pt")
torch.save(tokens_tensor, "/kaggle/working/tokens_train.pt")

print("Done!")
print(f"Final Vocab Size: {vocab_size}")
print(f"Image Tensor Shape: {images_tensor.shape}")


In [None]:
ds = load_dataset("deepcopy/MathWriting-human")

ds_val = ds["val"].select(range(5000))

images, sequences = [], []

def preprocess_image(img, target_size=(224, 224)):
    img = img.convert("RGB")  # convert to grayscale
    img = img.resize(target_size)
    img = np.array(img) / 255.0  # normalize to [0, 1]
    return img

for sample in ds_val:
    img = preprocess_image(sample["image"])
    images.append(img)
    sequences.append(sample["latex"])

images = np.array(images)
with open("/kaggle/working/latex_tokenizer.pkl", "rb") as f:
    tokenizer = pickle.load(f)

START_ID = tokenizer.word_index["<START>"]
END_ID   = tokenizer.word_index["<END>"]

seqs = tokenizer.texts_to_sequences(sequences)
seqs = [[START_ID] + s + [END_ID] for s in seqs]

padded_sequences = pad_sequences(seqs, padding="post")

images_tensor = torch.tensor(images, dtype=torch.float32).permute(0, 3, 1, 2)
tokens_tensor = torch.tensor(padded_sequences, dtype=torch.long)

torch.save(images_tensor, "/kaggle/working/images_val.pt")
torch.save(tokens_tensor, "/kaggle/working/tokens_val.pt")
print("Images:", images_tensor.shape)
print("Tokens:", tokens_tensor.shape)

## 2. Model definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, encoder_dim=768, nhead=8, num_layers=4, dim_feedforward=512, dropout=0.1, max_len=150):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_len, embed_dim)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True  # makes input (B,T,D)
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        if encoder_dim != embed_dim:
            self.enc_proj = nn.Linear(encoder_dim, embed_dim)
        else:
            self.enc_proj = nn.Identity()
        
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.max_len = max_len

    def forward(self, x, enc_mem, tgt_mask=None, enc_mask=None):
        B, T = x.shape
        positions = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        
        x = self.embed_tokens(x) + self.pos_embedding(positions)
        x = self.dropout(x)
        
        enc_mem = self.enc_proj(enc_mem)
        
        # tgt_mask for causal attention to prevent a token form seeing future tokens only itself and previous 
        if tgt_mask is None:
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(T).to(x.device)
        
        logits = self.transformer_decoder(
            x, 
            enc_mem, 
            tgt_mask=tgt_mask, 
            memory_key_padding_mask=(enc_mask == 0) if enc_mask is not None else None
        )
        
        logits = self.fc_out(logits)
        return logits

In [None]:
class ViTLatexModelLoRA(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, nhead=8, num_layers=6,
                 lora_r=16, lora_alpha=32, lora_dropout=0.05, dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("facebook/dinov2-base")
        target_modules = ["query", "key", "value", "dense", "fc1", "fc2"]
        
        lora_cfg = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            bias="none",
            target_modules=target_modules,
            task_type="FEATURE_EXTRACTION",  # safe default for encoder-only usage
        )
        self.encoder = get_peft_model(self.encoder, lora_cfg)
        encoder_dim = self.encoder.config.hidden_size  # 768

        # Decoder embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = nn.Parameter(torch.zeros(1, 1000, embed_dim))  # max seq length

        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # linear projection to vocab
        self.fc_out = nn.Linear(embed_dim, vocab_size)

        # project encoder dim to decoder embedding
        if encoder_dim != embed_dim:
            self.enc_proj = nn.Linear(encoder_dim, embed_dim)
        else:
            self.enc_proj = nn.Identity()

    def forward(self, images, input_tokens):
        """
        images: (B, 3, H, W)
        input_tokens: (B, T)
        """
        enc = self.encoder(pixel_values=images).last_hidden_state  # (B, N, D)
        enc = self.enc_proj(enc)  # (B, N, E)

        # Decoder embedding + positions
        emb = self.embedding(input_tokens) + self.pos_encoder[:, :input_tokens.size(1), :]  # (B, T, E)

        # Causal mask for decoder
        T = input_tokens.size(1)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(T).to(images.device)

        # Transformer decoder
        dec_out = self.transformer_decoder(
            tgt=emb,
            memory=enc,
            tgt_mask=causal_mask
        )  # (B, T, E)

        # Project to vocab
        logits = self.fc_out(dec_out)  # (B, T, vocab_size)
        return logits

    @torch.no_grad()
    def generate(self, image, max_len=150, sos_idx=1, eos_idx=2):
        self.eval()
        device = image.device
    
        enc = self.encoder(pixel_values=image).last_hidden_state
        enc = self.enc_proj(enc)
    
        tokens = torch.tensor([[sos_idx]], device=device)
        for _ in range(max_len):
            emb = self.embedding(tokens) + self.pos_encoder[:, :tokens.size(1), :]
            causal_mask = nn.Transformer.generate_square_subsequent_mask(tokens.size(1)).to(device)
            out = self.transformer_decoder(tgt=emb, memory=enc, tgt_mask=causal_mask)
            next_token = self.fc_out(out[:, -1, :]).argmax(dim=-1, keepdim=True)
            if next_token.item() == eos_idx:
                break
            tokens = torch.cat([tokens, next_token], dim=1)
    
        return tokens[0, 1:].tolist()

## 3. Training

In [None]:
# Load vocab size
with open("/kaggle/working/vocab_size.txt") as f:
    VOCAB_SIZE = int(f.read().strip())

# Hyperparameters
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE_DEC = 1e-3  # decoder
LEARNING_RATE_ENC = 1e-4  # LoRA adapters in encoder

# Load pre-processed tensors (just load, don't convert yet)
images_tensor = torch.load("/kaggle/working/images_train.pt")  # (40000, 3, 224, 224)
tokens_tensor = torch.load("/kaggle/working/tokens_train.pt")  # (40000, seq_len)

print(f"Images: {images_tensor.shape}, Tokens: {tokens_tensor.shape}, Vocab size: {VOCAB_SIZE}")

# Create dataset and loader
dataset = TensorDataset(images_tensor, tokens_tensor)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model
model = ViTLatexModelLoRA(vocab_size=VOCAB_SIZE).to(DEVICE)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

criterion = nn.CrossEntropyLoss(ignore_index=0)
#optimizer = torch.optim.Adam(model.decoder.parameters(), lr=LEARNING_RATE)

optimizer = torch.optim.AdamW(
    [
        {"params": model.transformer_decoder.parameters(), "lr": LEARNING_RATE_DEC},
        {"params": model.encoder.parameters(), "lr": LEARNING_RATE_ENC},  # LoRA adapters
    ],
    weight_decay=0.01
)

mean = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1,3,1,1)
std  = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1,3,1,1)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for batch_idx, (imgs, seqs) in enumerate(loader):
        # Convert grayscale to RGB
        imgs = imgs.to(DEVICE, dtype=torch.float32) / 255.0
        imgs = (imgs - mean) / std
        
        seqs = seqs.to(DEVICE)
        
        # Teacher forcing
        input_tokens = seqs[:, :-1]   # (B, seq_len-1)
        target_tokens = seqs[:, 1:]   # (B, seq_len-1)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(imgs, input_tokens)
        
        # Compute loss
        loss = criterion(
            logits.reshape(-1, VOCAB_SIZE),  # (B * (seq_len-1), vocab_size)
            target_tokens.reshape(-1)         # (B * (seq_len-1))
        )
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        
        # Print every 100 batches
        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx}/{len(loader)} | Loss: {loss.item():.4f}")
    
    print(f"Epoch {epoch + 1}/{EPOCHS} | Avg Loss: {total_loss / len(loader):.4f}")

# Save model
SAVE_PATH = "/kaggle/working/dinov2_attn_lora.pt"
torch.save({
    "model": model.state_dict()
}, SAVE_PATH)

print(f"Model saved to {SAVE_PATH}")

## 4. Evals

In [None]:
def normalized_edit_distance(s1, s2):
    if len(s1) == 0 and len(s2) == 0:
        return 0.0
    if len(s1) == 0 or len(s2) == 0:
        return 1.0
    
    # Levenshtein distance
    d = [[0] * (len(s2) + 1) for _ in range(len(s1) + 1)]
    for i in range(len(s1) + 1):
        d[i][0] = i
    for j in range(len(s2) + 1):
        d[0][j] = j
    
    for i in range(1, len(s1) + 1):
        for j in range(1, len(s2) + 1):
            cost = 0 if s1[i-1] == s2[j-1] else 1
            d[i][j] = min(d[i-1][j] + 1, d[i][j-1] + 1, d[i-1][j-1] + cost)
    
    return d[len(s1)][len(s2)] / max(len(s1), len(s2))

In [None]:
import torch
from pickle import load

DATA = "/kaggle/input/datasets/martinvu7/vit-transformer2"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open(f"/kaggle/working/vocab_size.txt") as f:
    VOCAB_SIZE = int(f.read().strip())

START_TOKEN = VOCAB_SIZE - 2
END_TOKEN = VOCAB_SIZE - 1
MAX_LEN = 150

# Load model
model = ViTLatexModelLoRA(vocab_size=VOCAB_SIZE).to(DEVICE)
checkpoint = torch.load(f"{DATA}/dinov2_attn_lora.pt", map_location=DEVICE)
model.load_state_dict(checkpoint["model"])
model.eval()

# Load validation data
images = torch.load(f"/kaggle/working/images_val.pt")
tokens = torch.load(f"/kaggle/working/tokens_val.pt")

with open(f"/kaggle/working/latex_tokenizer.pkl", "rb") as f:
    tokenizer = load(f)

inv_vocab = {v: k for k, v in tokenizer.word_index.items()}

def decode(seq):
    # Filter out start and end tokens
    filtered = [t for t in seq if t != START_TOKEN and t != END_TOKEN and t != 0]
    return "".join(inv_vocab.get(t, "") for t in filtered)

# Inference
N = 5000
exact_matches = 0
total_edit_dist = 0.0
print(f"Evaluating on {N} test samples...")
print("-" * 60)

for i in range(N):
    img = images[i:i+1].to(DEVICE)  # (1, 3, 224, 224)
    gt_tokens = tokens[i]
    
    pred_tokens = model.generate(img, max_len=MAX_LEN, sos_idx=START_TOKEN, eos_idx=END_TOKEN)
    
    ground_truth = decode(gt_tokens.tolist())
    prediction = decode(pred_tokens)
    
    is_exact = prediction == ground_truth
    edit_dist = normalized_edit_distance(prediction, ground_truth)
    
    if is_exact:
        exact_matches += 1
    total_edit_dist += edit_dist
    
    # status = "EXACT" if is_exact else f"edit_dist={edit_dist:.4f}"
    # print(f"  [{i+1}/{N}] {status}")
    # print(f"    GT:   {ground_truth[:80]}")
    # print(f"    PRED: {prediction[:80]}")
    # print("-"*40)

accuracy = exact_matches / N
avg_edit_dist = total_edit_dist / N

print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Model:                    DinoV2 + Transformer")
print(f"Samples:                  {N}")
print(f"Exact match accuracy:     {accuracy:.2%} ({exact_matches}/{N})")
print(f"Avg normalized edit dist: {avg_edit_dist:.4f}")