In [None]:
# Inference demo that auto-fixes _orig_mod.* keys from torch.compile checkpoints

import random, numpy as np
from pathlib import Path
import torch, torch.nn as nn, torch.nn.functional as F
from torch_geometric.nn import GCNConv
from dataset_loader import load_sample, get_available_ids

# ---- paths ----
ROOT = Path("./dataset/dataset_generation/data")
CKPT = Path("./checkpoints/gcn_facecls.pt")

# ---- model (matches your training code) ----
class GCNBlock(nn.Module):
    def __init__(self, in_ch, out_ch, p_drop=0.3):
        super().__init__()
        self.conv = GCNConv(in_ch, out_ch)
        self.bn   = nn.BatchNorm1d(out_ch)
        self.p    = p_drop
        self.res  = (in_ch == out_ch)
    def forward(self, x, edge_index):
        out = self.conv(x, edge_index)
        out = self.bn(out)
        out = F.relu(out, inplace=True)
        out = F.dropout(out, p=self.p, training=self.training)
        if self.res:
            out = out + x
        return out

class DeepGCN(nn.Module):
    def __init__(self, in_dim, hidden=256, layers=4, out_dim=25, dropout=0.3):
        super().__init__()
        self.in_lin = nn.Linear(in_dim, hidden)
        self.blocks = nn.ModuleList([GCNBlock(hidden, hidden, p_drop=dropout) for _ in range(layers)])
        self.head = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.BatchNorm1d(hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden, out_dim),
        )
    def forward(self, x, edge_index):
        x = F.relu(self.in_lin(x), inplace=True)
        for blk in self.blocks:
            x = blk(x, edge_index)
        return self.head(x)

def _clean_state_dict(sd: dict) -> dict:
    # handle torch.compile and (optionally) DataParallel prefixes
    new_sd = {}
    for k, v in sd.items():
        if k.startswith("_orig_mod."):
            k = k[len("_orig_mod."):]
        if k.startswith("module."):
            k = k[len("module."):]
        new_sd[k] = v
    return new_sd

# ---- load checkpoint ----
ckpt = torch.load(CKPT, map_location="cpu")
in_dim      = int(ckpt["in_dim"])
hidden      = int(ckpt["hidden"])
layers      = int(ckpt["layers"])
dropout     = float(ckpt.get("dropout", 0.3))
num_classes = int(ckpt["num_classes"])
feat_names  = ckpt.get("feat_names", [f"class_{i}" for i in range(num_classes)])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepGCN(in_dim, hidden=hidden, layers=layers, out_dim=num_classes, dropout=dropout).to(device)

state = ckpt["state_dict"]
try:
    model.load_state_dict(state, strict=True)
except RuntimeError as e:
    # retry with cleaned keys (fixes _orig_mod.*)
    state = _clean_state_dict(state)
    model.load_state_dict(state, strict=True)

model.eval()
# ---- pick a random valid sample ----
avail = get_available_ids(ROOT, max_id=20000, strict=True)
assert len(avail), "No usable samples found under ROOT."

In [None]:
pid = random.choice(avail)
print(f"\n--- Demo part id = {pid} ---")

# ---- load sample & run ----
d  = load_sample(ROOT, pid)
x  = d["x"].to(device)
ei = d["edge_index"].to(device)
y  = d["y"].cpu().numpy()

with torch.no_grad():
    if device.type == "cuda":
        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            logits = model(x, ei)
    else:
        logits = model(x, ei)

# ---- decode & report ----
pred_idx   = logits.argmax(dim=1).cpu().numpy()
acc = float((pred_idx == y).mean())
print(f"faces={len(y)}  acc={acc:.3f}")

# show a few mismatches + top3
wrong = np.flatnonzero(pred_idx != y)
if wrong.size:
    probs = logits.softmax(dim=1).cpu().numpy()
    k = min(20, wrong.size)
    print(f"\nFirst {k} mismatches (face_idx: pred (p) -> gt | top3):")
    for i in wrong[:k]:
        top3 = probs[i].argsort()[-3:][::-1]
        print(" ", i, "-> pred", pred_idx[i], "(%.2f)" % probs[i][pred_idx[i]],
              "| top3:", [(int(t), float(probs[i][t])) for t in top3], "| gt", int(y[i]))
else:
    print("No mismatches on this part 🎉")