# Phase 3: Evaluation — 4-Axis Metrics & Attention Maps

Load trained model and validation data. Compute exact match, tree edit similarity,
compilation success, and plot attention heatmaps.

**Works on:** Colab (CUDA), Mac (MPS), or CPU.

In [None]:
# --- Clone repo on Colab ---
import os
if os.path.exists("/content") and not os.path.exists("/content/DeepPTX"):
    !git clone https://github.com/ns-1456/DeepPTX.git /content/DeepPTX
    %cd /content/DeepPTX
    !pip install -q tqdm pyarrow

In [None]:
import sys, os

IN_COLAB = os.path.exists("/content")
REPO_ROOT = "/content/DeepPTX" if IN_COLAB else os.path.abspath("..")
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

import torch
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

from ptx_decompiler.utils import get_device
from ptx_decompiler.model import PTXDecompilerModel
from ptx_decompiler.tokenizer import PTXTokenizer, ASTTokenizer
from ptx_decompiler.data import normalize_ptx, ast_to_cuda
from ptx_decompiler.training.metrics import exact_match_accuracy, compute_tree_edit_distance, compile_success_rate

DEVICE = get_device()
print(f"Device: {DEVICE}")
print("Imports OK")

In [None]:
# ======================= Paths =======================
if IN_COLAB:
    DATA_PATH = os.path.join(REPO_ROOT, "dataset_100k.parquet")
    CKPT_PATH = os.path.join(REPO_ROOT, "checkpoints", "checkpoint_final.pt")
else:
    DATA_PATH = os.path.join(REPO_ROOT, "dataset_100k.parquet")
    CKPT_PATH = os.path.join(REPO_ROOT, "checkpoints", "checkpoint_final.pt")

assert os.path.exists(DATA_PATH), f"Dataset not found: {DATA_PATH}"
assert os.path.exists(CKPT_PATH), f"Checkpoint not found: {CKPT_PATH}"

df = pd.read_parquet(DATA_PATH)
print(f"Loaded {len(df):,} samples")

ptx_tok = PTXTokenizer(max_vocab_size=2000)
ptx_tok.build_vocab(df["ptx_normalized"].tolist())
ast_tok = ASTTokenizer()

ptx_to_ast = torch.full((len(ptx_tok),), -1, dtype=torch.long)
for t, pid in ptx_tok.vocab.items():
    if t in ast_tok.vocab:
        ptx_to_ast[pid] = ast_tok.vocab[t]

model = PTXDecompilerModel(
    ptx_vocab_size=len(ptx_tok),
    ast_vocab_size=len(ast_tok),
    ptx_to_ast_map=ptx_to_ast,
).to(DEVICE)

ckpt = torch.load(CKPT_PATH, map_location=DEVICE, weights_only=True)
model.load_state_dict(ckpt["model"])
model.eval()
print(f"Model loaded from epoch {ckpt.get('epoch', '?')}")

## Evaluate on Validation Set

In [None]:
from ptx_decompiler.data.dataset import PTXASTDataset, collate_pad_batch
from torch.utils.data import DataLoader

val_df = df.sample(frac=0.1, random_state=42)
val_ds = PTXASTDataset(
    ptx_strings=val_df["ptx_normalized"].tolist(),
    ast_strings=val_df["ast_sexp"].tolist(),
    ptx_tokenizer=ptx_tok,
    ast_tokenizer=ast_tok,
    tiers=val_df["tier"].tolist(),
)

BATCH_SIZE = 64 if DEVICE.type == "cuda" else 32

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    collate_fn=lambda b: collate_pad_batch(b, ptx_tok.pad_id, ast_tok.pad_id),
)

em_sum, ted_sum, n = 0.0, 0.0, 0
pbar = tqdm(val_loader, desc="Evaluating", unit="batch",
            bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}] {postfix}")
for batch in pbar:
    ptx = batch["ptx_ids"].to(DEVICE)
    ast_in = batch["ast_input_ids"].to(DEVICE)
    ast_tgt = batch["ast_target_ids"]
    pad_mask = (~batch["ptx_mask"]).to(DEVICE)
    with torch.no_grad():
        logits, _ = model(ptx, ast_in, pad_mask)
    pred = logits.argmax(dim=-1).cpu()
    em_sum += exact_match_accuracy(pred, ast_tgt, ast_tok.pad_id, ast_tok.eos_id) * ptx.size(0)
    ted_sum += compute_tree_edit_distance(pred, ast_tgt, ast_tok.pad_id, ast_tok.eos_id) * ptx.size(0)
    n += ptx.size(0)
    pbar.set_postfix(em=f"{em_sum/n:.4f}", ted=f"{ted_sum/n:.4f}")

print(f"\nExact Match: {em_sum/n:.4f}")
print(f"Tree Edit Sim: {ted_sum/n:.4f}")

In [None]:
# Compilation success rate (requires nvcc — skipped if not available)
import shutil
if shutil.which("nvcc"):
    from ptx_decompiler.data.compiler import compile_cuda_to_ptx_silent
    compile_ok = compile_success_rate(
        val_df["ast_sexp"].head(100).tolist(),
        render_fn=ast_to_cuda,
        compile_fn=lambda cuda: compile_cuda_to_ptx_silent(cuda) is not None,
    )
    print(f"Compilation success (on gold AST): {compile_ok:.4f}")
else:
    print("nvcc not available — skipping compilation success rate.")
    print("(This metric is available on Colab or machines with CUDA toolkit)")

In [None]:
# Per-tier breakdown
per_tier = val_df.groupby("tier").agg({"ast_sexp": "count"}).rename(columns={"ast_sexp": "count"})
print(per_tier)