In [None]:
import torch
import numpy as np
import sys
import pickle
from pathlib import Path
from datasets import load_dataset

sys.path.insert(0, "models")

DEVICE    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ARTIFACTS = Path("artifacts")
CACHE     = Path("../cache")
BACKEND   = Path("../backend/artifacts")

def normalized_edit_distance(s1, s2):
    if len(s1) == 0 and len(s2) == 0: return 0.0
    if len(s1) == 0 or  len(s2) == 0: return 1.0
    d = [[0] * (len(s2) + 1) for _ in range(len(s1) + 1)]
    for i in range(len(s1) + 1): d[i][0] = i
    for j in range(len(s2) + 1): d[0][j] = j
    for i in range(1, len(s1) + 1):
        for j in range(1, len(s2) + 1):
            cost = 0 if s1[i-1] == s2[j-1] else 1
            d[i][j] = min(d[i-1][j]+1, d[i][j-1]+1, d[i-1][j-1]+cost)
    return d[len(s1)][len(s2)] / max(len(s1), len(s2))

print(f"Device: {DEVICE}")

In [None]:
N = 50
ds = load_dataset("deepcopy/MathWriting-human")
ds_test = ds["val"].select(range(N))
gt_latex = [s["latex"] for s in ds_test]
print(f"Loaded {N} samples from validation split")
print(f"Available splits: {list(ds.keys())}")

## Eval 1 â€” DINOv2 + LSTM (Bahdanau Attention)

In [None]:
from vit_lora_lstm_attn import ViTLatexModelLoRA as LSTMModel

# Tokenizer + special token indices (matches backend/app.py exactly)
with open(BACKEND / "latex_tokenizer256.pkl", "rb") as f:
    lstm_tok = pickle.load(f)
with open(BACKEND / "vocab_size.txt") as f:
    VOCAB_SIZE = int(f.read().strip())
LSTM_START = VOCAB_SIZE - 2  # 64
LSTM_END   = VOCAB_SIZE - 1  # 65
lstm_inv_vocab = {v: k for k, v in lstm_tok.word_index.items()}
print(f"vocab_size={VOCAB_SIZE}, START={LSTM_START}, END={LSTM_END}")

def lstm_decode(seq):
    return "".join(lstm_inv_vocab.get(t, "") for t in seq
                   if t not in (0, LSTM_START, LSTM_END))

# Load model
lstm_model = LSTMModel(vocab_size=VOCAB_SIZE, lora_r=16).to(DEVICE)
ckpt = torch.load(ARTIFACTS / "dinov2_attn_lora256.pt", map_location=DEVICE, weights_only=False)
lstm_model.load_state_dict(ckpt["model"])
lstm_model.eval()
print("LSTM model loaded.")

# Preprocess: RGB 224x224, /255, ImageNet norm (matches training)
MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
STD  = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

lstm_images = []
for s in ds_test:
    img = np.array(s["image"].convert("RGB").resize((224, 224)), dtype=np.uint8)
    lstm_images.append(img)
lstm_images = torch.from_numpy(np.array(lstm_images)).permute(0, 3, 1, 2)  # uint8 (N,3,224,224)
print(f"Preprocessed {N} images.")

In [None]:
exact = 0
total_ed = 0.0
print(f"Evaluating LSTM on {N} test samples...")
print("-" * 60)

for i in range(N):
    img = lstm_images[i:i+1].to(DEVICE, dtype=torch.float32) / 255.0
    img = (img - MEAN) / STD
    pred_tokens  = lstm_model.generate(img, max_len=150, sos_idx=LSTM_START, eos_idx=LSTM_END)
    prediction   = lstm_decode(pred_tokens)
    ground_truth = gt_latex[i]

    is_exact  = prediction == ground_truth
    edit_dist = normalized_edit_distance(prediction, ground_truth)
    if is_exact: exact += 1
    total_ed += edit_dist

    status = "EXACT" if is_exact else f"edit_dist={edit_dist:.4f}"
    print(f"  [{i+1:2d}/{N}] {status}")
    print(f"    GT:   {ground_truth[:80]}")
    print(f"    PRED: {prediction[:80]}")
    print("-" * 40)

print("=" * 60)
print(f"[LSTM] Exact match:         {exact/N:.2%} ({exact}/{N})")
print(f"[LSTM] Avg normalized edit: {total_ed/N:.4f}")

In [None]:
import importlib
if "vit_transformer" in sys.modules:
    importlib.reload(sys.modules["vit_transformer"])
import vit_transformer as vt_mod

# Load vocab
with open(CACHE / "latex_vocab.pkl", "rb") as f:
    tr_vocab = pickle.load(f)
tr_inv_vocab = {v: k for k, v in tr_vocab.items()}
vt_mod.inv_vocab = tr_inv_vocab  # inject module global for brace depth tracking
TR_VOCAB_SIZE = len(tr_vocab)
TR_START, TR_END = 1, 2
print(f"Transformer vocab_size={TR_VOCAB_SIZE}")

def tr_decode(seq):
    return "".join(tr_inv_vocab.get(t, "") for t in seq if t not in (0, TR_START, TR_END))

# Load model
tr_model = vt_mod.ViTLatexModelLoRA(vocab_size=TR_VOCAB_SIZE, decoder_type="latex_transformer").to(DEVICE)
tr_ckpt = torch.load(ARTIFACTS / "dinov2_latex_transformer.pt", map_location=DEVICE, weights_only=False)
tr_state = tr_ckpt["model"] if isinstance(tr_ckpt, dict) and "model" in tr_ckpt else tr_ckpt
tr_model.load_state_dict(tr_state)
tr_model.eval()
print("Transformer model loaded.")

# Preprocess: RGB 224x224, stored as uint8 (norm applied per-image in eval loop)
TR_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
TR_STD  = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

tr_images = []
for s in ds_test:
    img = np.array(s["image"].convert("RGB").resize((224, 224)), dtype=np.uint8)
    tr_images.append(img)
tr_images = torch.from_numpy(np.array(tr_images)).permute(0, 3, 1, 2)  # uint8 (N,3,224,224)
print(f"Preprocessed {N} images.")

In [None]:
tr_exact = 0
tr_total_ed = 0.0
print(f"Evaluating Transformer on {N} test samples...")
print("-" * 60)

for i in range(N):
    img = tr_images[i:i+1].to(DEVICE, dtype=torch.float32) / 255.0
    img = (img - TR_MEAN) / TR_STD
    pred_tokens  = tr_model.generate(img, max_len=128, sos_idx=TR_START, eos_idx=TR_END)
    prediction   = tr_decode(pred_tokens)
    ground_truth = gt_latex[i]

    is_exact  = prediction == ground_truth
    edit_dist = normalized_edit_distance(prediction, ground_truth)
    if is_exact: tr_exact += 1
    tr_total_ed += edit_dist

    status = "EXACT" if is_exact else f"edit_dist={edit_dist:.4f}"
    print(f"  [{i+1:2d}/{N}] {status}")
    print(f"    GT:   {ground_truth[:80]}")
    print(f"    PRED: {prediction[:80]}")
    print("-" * 40)

print("=" * 60)
print(f"[Transformer] Exact match:         {tr_exact/N:.2%} ({tr_exact}/{N})")
print(f"[Transformer] Avg normalized edit: {tr_total_ed/N:.4f}")