# Phase 3: Evaluation â€” 4-Axis Metrics & Attention Maps

Load trained model and validation data. Compute exact match, tree edit similarity, compilation success, and plot attention heatmaps.

In [None]:
import sys, os
sys.path.insert(0, os.path.abspath("/content/Neural PTX Decompiler" if os.path.exists("/content/Neural PTX Decompiler") else ".."))

import torch
import pandas as pd
from pathlib import Path

from ptx_decompiler.model import PTXDecompilerModel
from ptx_decompiler.tokenizer import PTXTokenizer, ASTTokenizer
from ptx_decompiler.data import normalize_ptx, ast_to_cuda
from ptx_decompiler.data.compiler import compile_cuda_to_ptx_silent
from ptx_decompiler.training.metrics import exact_match_accuracy, compute_tree_edit_distance, compile_success_rate

In [None]:
DATA_PATH = "/content/drive/MyDrive/NeuralPTX/dataset_100k.parquet"
CKPT_PATH = "/content/drive/MyDrive/NeuralPTX/checkpoints/checkpoint_epoch_29.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_parquet(DATA_PATH)
ptx_tok = PTXTokenizer(max_vocab_size=2000)
ptx_tok.build_vocab(df["ptx_normalized"].tolist())
ast_tok = ASTTokenizer()

ptx_to_ast = torch.full((len(ptx_tok),), -1, dtype=torch.long)
for t, pid in ptx_tok.vocab.items():
    if t in ast_tok.vocab:
        ptx_to_ast[pid] = ast_tok.vocab[t]

model = PTXDecompilerModel(
    ptx_vocab_size=len(ptx_tok),
    ast_vocab_size=len(ast_tok),
    ptx_to_ast_map=ptx_to_ast,
).to(DEVICE)
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model"])
model.eval()

In [None]:
from ptx_decompiler.data.dataset import PTXASTDataset, collate_pad_batch
from torch.utils.data import DataLoader

val_df = df.sample(frac=0.1, random_state=42)
val_ds = PTXASTDataset(
    ptx_strings=val_df["ptx_normalized"].tolist(),
    ast_strings=val_df["ast_sexp"].tolist(),
    ptx_tokenizer=ptx_tok,
    ast_tokenizer=ast_tok,
    tiers=val_df["tier"].tolist(),
)
val_loader = DataLoader(
    val_ds,
    batch_size=64,
    collate_fn=lambda b: collate_pad_batch(b, ptx_tok.pad_id, ast_tok.pad_id),
)

em_sum, ted_sum, n = 0.0, 0.0, 0
for batch in val_loader:
    ptx = batch["ptx_ids"].to(DEVICE)
    ast_in = batch["ast_input_ids"].to(DEVICE)
    ast_tgt = batch["ast_target_ids"]
    pad_mask = (~batch["ptx_mask"]).to(DEVICE)
    with torch.no_grad():
        logits, _ = model(ptx, ast_in, pad_mask)
    pred = logits.argmax(dim=-1)
    em_sum += exact_match_accuracy(pred, ast_tgt, ast_tok.pad_id, ast_tok.eos_id) * ptx.size(0)
    ted_sum += compute_tree_edit_distance(pred, ast_tgt, ast_tok.pad_id, ast_tok.eos_id) * ptx.size(0)
    n += ptx.size(0)
print(f"Exact Match: {em_sum/n:.4f}")
print(f"Tree Edit Sim: {ted_sum/n:.4f}")

In [None]:
compile_ok = compile_success_rate(
    val_df["ast_sexp"].head(100).tolist(),
    render_fn=ast_to_cuda,
    compile_fn=lambda cuda: compile_cuda_to_ptx_silent(cuda) is not None,
)
print(f"Compilation success (on gold AST): {compile_ok:.4f}")

In [None]:
per_tier = val_df.groupby("tier").agg({"ast_sexp": "count"}).rename(columns={"ast_sexp": "count"})
print(per_tier)