In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import AutoModel
from peft import LoraConfig, get_peft_model
from tensorflow.keras.preprocessing.text import Tokenizer as KerasTokenizer
from datasets import load_dataset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def normalized_edit_distance(s1, s2):
    if len(s1) == 0 and len(s2) == 0: return 0.0
    if len(s1) == 0 or  len(s2) == 0: return 1.0
    d = [[0] * (len(s2) + 1) for _ in range(len(s1) + 1)]
    for i in range(len(s1) + 1): d[i][0] = i
    for j in range(len(s2) + 1): d[0][j] = j
    for i in range(1, len(s1) + 1):
        for j in range(1, len(s2) + 1):
            cost = 0 if s1[i-1] == s2[j-1] else 1
            d[i][j] = min(d[i-1][j] + 1, d[i][j-1] + 1, d[i-1][j-1] + cost)
    return d[len(s1)][len(s2)] / max(len(s1), len(s2))

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim, encoder_dim):
        super().__init__()
        self.W_h = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_e = nn.Linear(encoder_dim, hidden_dim, bias=False)
        self.v   = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, h, enc_mem):
        energy = torch.tanh(self.W_h(h).unsqueeze(1) + self.W_e(enc_mem))
        alpha = F.softmax(self.v(energy).squeeze(-1), dim=-1)
        context = torch.bmm(alpha.unsqueeze(1), enc_mem).squeeze(1)
        return context, alpha

class AttnDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, encoder_dim=768, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attn = BahdanauAttention(hidden_dim, encoder_dim)
        self.lstm_cell = nn.LSTMCell(embed_dim + encoder_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.enc_to_h = nn.Linear(encoder_dim, hidden_dim)
        self.enc_to_c = nn.Linear(encoder_dim, hidden_dim)

    def forward(self, input_ids, enc_mem, hidden_state=None):
        emb = self.dropout(self.embedding(input_ids))
        if hidden_state is None:
            pooled = enc_mem.mean(dim=1)
            h = torch.tanh(self.enc_to_h(pooled))
            c = torch.tanh(self.enc_to_c(pooled))
        else:
            h, c = hidden_state
        logits = []
        for t in range(input_ids.shape[1]):
            context, _ = self.attn(h, enc_mem)
            h, c = self.lstm_cell(torch.cat([emb[:, t, :], context], dim=1), (h, c))
            logits.append(self.fc(h).unsqueeze(1))
        return torch.cat(logits, dim=1), (h, c)

class ViTLatexModelLoRA(nn.Module):
    def __init__(self, vocab_size, lora_r=16, lora_alpha=32, lora_dropout=0.05):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("facebook/dinov2-base")
        lora_cfg = LoraConfig(
            r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none",
            target_modules=["query", "key", "value", "dense", "fc1", "fc2"],
            task_type="FEATURE_EXTRACTION",
        )
        self.encoder = get_peft_model(self.encoder, lora_cfg)
        self.decoder = AttnDecoder(vocab_size, encoder_dim=self.encoder.config.hidden_size)

    def forward(self, images, input_tokens):
        enc_mem = self.encoder(pixel_values=images).last_hidden_state[:, 1:, :]
        logits, _ = self.decoder(input_tokens, enc_mem)
        return logits

    @torch.no_grad()
    def generate(self, image, max_len=150, sos_idx=1, eos_idx=2):
        self.eval()
        enc_mem = self.encoder(pixel_values=image).last_hidden_state[:, 1:, :]
        token = torch.tensor([[sos_idx]], device=image.device)
        output_tokens, hidden = [], None
        for _ in range(max_len):
            logits, hidden = self.decoder(token, enc_mem, hidden_state=hidden)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            if next_token.item() == eos_idx: break
            output_tokens.append(next_token.item())
            token = next_token
        return output_tokens

In [None]:
N = 50
ds = load_dataset("deepcopy/MathWriting-human")
ds_train = ds["train"].select(range(40000))
ds_test  = ds["test"].select(range(N))

# Rebuild the same Keras char-level tokenizer used during training
tokenizer = KerasTokenizer(char_level=True)
tokenizer.fit_on_texts([s["latex"] for s in ds_train])
tokenizer.word_index["<START>"] = len(tokenizer.word_index) + 1
tokenizer.word_index["<END>"]   = len(tokenizer.word_index) + 1
tokenizer.index_word = {v: k for k, v in tokenizer.word_index.items()}

VOCAB_SIZE   = len(tokenizer.word_index) + 1  # 66
START_TOKEN  = tokenizer.word_index["<START>"]  # 64
END_TOKEN    = tokenizer.word_index["<END>"]    # 65
print(f"vocab_size={VOCAB_SIZE}, START={START_TOKEN}, END={END_TOKEN}")

# Preprocess test images: grayscale 256x256, [0,1] float
images, gt_latex = [], []
for sample in ds_test:
    img = np.array(sample["image"].convert("L").resize((256, 256)), dtype=np.float32) / 255.0
    images.append(img)
    gt_latex.append(sample["latex"])
images = torch.tensor(np.array(images)).unsqueeze(1)  # (N, 1, 256, 256)

def decode(seq):
    return "".join(tokenizer.index_word.get(t, "") for t in seq
                   if t not in (0, START_TOKEN, END_TOKEN))

In [None]:
checkpoint = torch.load("dinov2_attn_lora256.pt", map_location=DEVICE)
model = ViTLatexModelLoRA(vocab_size=VOCAB_SIZE, lora_r=16).to(DEVICE)
model.load_state_dict(checkpoint["model"])
model.eval()
print("Model loaded.")

In [None]:
exact_matches = 0
total_edit_dist = 0.0

print(f"Evaluating on {N} test samples...")
print("-" * 60)

for i in range(N):
    # grayscale [0,1] -> repeat to RGB (no ImageNet norm, matches original eval)
    img = images[i:i+1].repeat(1, 3, 1, 1).to(DEVICE)

    pred_tokens  = model.generate(img, max_len=150, sos_idx=START_TOKEN, eos_idx=END_TOKEN)
    ground_truth = gt_latex[i]
    prediction   = decode(pred_tokens)

    is_exact  = prediction == ground_truth
    edit_dist = normalized_edit_distance(prediction, ground_truth)
    if is_exact: exact_matches += 1
    total_edit_dist += edit_dist

    status = "EXACT" if is_exact else f"edit_dist={edit_dist:.4f}"
    print(f"  [{i+1}/{N}] {status}")
    print(f"    GT:   {ground_truth[:80]}")
    print(f"    PRED: {prediction[:80]}")
    print("-" * 40)

print("=" * 60)
print(f"Exact match accuracy:     {exact_matches/N:.2%} ({exact_matches}/{N})")
print(f"Avg normalized edit dist: {total_edit_dist/N:.4f}")