In [None]:
import torch, torch_geometric, torch_scatter, torch_sparse
print("CUDA avail:", torch.cuda.is_available(), "torch CUDA:", torch.version.cuda)
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")

In [None]:
# gnn_train_infer.py
import os, sys, json, pickle, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import GradScaler
import matplotlib.pyplot as plt

# ---- PyG ----
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv

# ========= config =========
ROOT = Path("./dataset/dataset_generation/data")
MAX_ID_SCAN = 10000
NUM_CLASSES = 25  # len(feat_names) below
EPOCHS = 60
HIDDEN = 128
BATCH_SIZE = 64
LR = 3e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 8
SEED = 13

feat_names = ['chamfer', 'through_hole', 'triangular_passage', 'rectangular_passage', '6sides_passage',
              'triangular_through_slot', 'rectangular_through_slot', 'circular_through_slot',
              'rectangular_through_step', '2sides_through_step', 'slanted_through_step', 'Oring', 'blind_hole',
              'triangular_pocket', 'rectangular_pocket', '6sides_pocket', 'circular_end_pocket',
              'rectangular_blind_slot', 'v_circular_end_blind_slot', 'h_circular_end_blind_slot',
              'triangular_blind_step', 'circular_blind_step', 'rectangular_blind_step', 'round', 'stock']

# ---- speed knobs for A40 ----
torch.set_float32_matmul_precision("high")
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

rng = np.random.default_rng(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

# ========= data utils (from your dataset_loader, simplified) =========
def get_available_ids(root: Path, max_id=1000):
    root = Path(root)
    have = []
    for i in range(max_id):
        if (root / f"{i}.pkl").exists() and (root / "labels" / f"{i}.json").exists():
            have.append(i)
    return have

def split_ids(ids, seed=13, frac=(0.8, 0.1, 0.1)):
    rng = np.random.default_rng(seed)
    ids = np.array(ids, dtype=int)
    rng.shuffle(ids)
    n = len(ids)
    n_tr = int(frac[0]*n)
    n_va = int(frac[1]*n)
    train = ids[:n_tr].tolist()
    val   = ids[n_tr:n_tr+n_va].tolist()
    test  = ids[n_tr+n_va:].tolist()
    return train, val, test

def _one_hot(idx, n):
    v = np.zeros(n, dtype=np.float32); v[int(idx)] = 1.0; return v

def build_node_features(face_feats: dict, use_type_onehot=False):
    """
    Returns float32 [num_nodes, D] matrix from face_features dict.
    Included: area, adj(deg), loops, centroid(3), convexity(one-hot 3), type(scalar) => D=10
    """
    n = len(face_feats['area'])
    area      = np.asarray(face_feats['area'], dtype=np.float32).reshape(n,1)
    deg       = np.asarray(face_feats['adj'], dtype=np.float32).reshape(n,1)
    loops     = np.asarray(face_feats['loops'], dtype=np.float32).reshape(n,1)
    centroid  = np.asarray(face_feats['centroid'], dtype=np.float32)  # [n,3]
    conv      = np.asarray(face_feats['convexity'], dtype=np.int64)
    conv_oh   = np.stack([_one_hot(c, 3) for c in conv], axis=0)      # [n,3]

    parts = [area, deg, loops, centroid, conv_oh]

    if use_type_onehot:
        stype = np.asarray(face_feats['type'], dtype=np.int64)
        S = int(stype.max()) + 1
        stype_oh = np.stack([_one_hot(t, S) for t in stype], axis=0)
        parts.append(stype_oh)
    else:
        parts.append(np.asarray(face_feats['type'], dtype=np.float32).reshape(n,1))

    x = np.concatenate(parts, axis=1).astype(np.float32)
    # light normalization
    x[:,0] = (x[:,0] - x[:,0].mean()) / (x[:,0].std()+1e-6)  # area z-norm
    x[:,3:6] = x[:,3:6] / (np.linalg.norm(x[:,3:6], axis=1, keepdims=True)+1e-6)  # centroid direction-ish
    return x

def load_sample(root: Path, idx: int):
    root = Path(root)
    pkl_path  = root / f"{idx}.pkl"
    json_path = root / "labels" / f"{idx}.json"
    if not (pkl_path.exists() and json_path.exists()):
        raise FileNotFoundError(idx, pkl_path, json_path)

    with open(pkl_path, "rb") as f:
        G = pickle.load(f)
    with open(json_path, "r") as f:
        J = json.load(f)

    num_nodes = int(G['num_nodes'])
    y = np.array(J['per_face_labels'], dtype=np.int64)
    if len(y) != num_nodes:
        raise ValueError(f"label/face mismatch for {idx}: {len(y)} vs {num_nodes}")

    x = build_node_features(G['face_features'])              # [n, D]
    edge_index = np.asarray(G['edge_index'], dtype=np.int64) # [2, E]

    return {
        "x": torch.from_numpy(x),
        "edge_index": torch.from_numpy(edge_index),
        "y": torch.from_numpy(y),
        "idx": idx
    }

# ========= PyG dataset with lazy cache =========
class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, root: Path, ids, cache_mode="lazy", strict=True):
        from tqdm import tqdm
        self.root = Path(root)
        self.ids = []
        self._cache = {} if cache_mode in ("lazy", "all") else None
        self._lazy = (cache_mode == "lazy")

        for i in tqdm(ids, desc="Indexing graphs"):
            try:
                d = load_sample(self.root, i)
                if strict and (d["y"].numel() != d["x"].shape[0]):
                    continue
                self.ids.append(i)
                if cache_mode == "all" or (self._cache is not None and not self._lazy):
                    self._cache[i] = self._to_data(d)
            except Exception:
                continue
        if not self.ids:
            raise RuntimeError("No usable graphs after filtering.")

    @staticmethod
    def _to_data(d):
        return Data(x=d["x"].float(),
                    edge_index=d["edge_index"].long(),
                    y=d["y"].long())

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        pid = self.ids[idx]
        if self._cache is not None and pid in self._cache:
            return self._cache[pid]
        d = load_sample(self.root, pid)
        g = self._to_data(d)
        if self._cache is not None and self._lazy:
            self._cache[pid] = g
        return g

# ========= model =========
class GCN(nn.Module):
    def __init__(self, in_dim, hidden=128, out_dim=NUM_CLASSES, dropout=0.2):
        super().__init__()
        self.c1 = GCNConv(in_dim, hidden)
        self.c2 = GCNConv(hidden, hidden)
        self.lin = nn.Linear(hidden, out_dim)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = F.relu(self.c1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.c2(x, edge_index))
        return self.lin(x)

# ========= training helpers =========
def class_weights_from_dataset(ds):
    counts = torch.zeros(NUM_CLASSES, dtype=torch.long)
    for g in ds:
        counts += torch.bincount(g.y, minlength=NUM_CLASSES)
    w = 1.0 / (counts.float() + 1e-6)
    w *= (NUM_CLASSES / w.sum())
    return w

@torch.no_grad()
def evaluate(model, device, loader):
    model.eval()
    total_loss, total_correct, total_labels = 0.0, 0, 0
    for batch in loader:
        batch = batch.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda"), dtype=torch.float16):
            logits = model(batch.x, batch.edge_index)
            loss = F.cross_entropy(logits, batch.y)
        total_loss += float(loss) * batch.y.numel()
        total_correct += (logits.argmax(dim=1) == batch.y).sum().item()
        total_labels += batch.y.numel()
    return total_loss / max(1, total_labels), total_correct / max(1, total_labels)

def train_one_epoch(model, device, loader, opt, scaler, weight):
    model.train()
    total_loss, total_correct, total_labels = 0.0, 0, 0
    for batch in loader:
        batch = batch.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda"), dtype=torch.float16):
            logits = model(batch.x, batch.edge_index)
            w = weight.to(device=device, dtype=logits.dtype) if weight is not None else None
            loss = F.cross_entropy(logits, batch.y, weight=w)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        total_loss += float(loss) * batch.y.numel()
        total_correct += (logits.argmax(dim=1) == batch.y).sum().item()
        total_labels += batch.y.numel()
    return total_loss / max(1, total_labels), total_correct / max(1, total_labels)

# ========= inference utility =========
@torch.no_grad()
def predict_part(model, device, root: Path, part_id: int):
    d = load_sample(root, part_id)
    x = d["x"].to(device)
    ei = d["edge_index"].to(device)
    logits = model(x, ei)
    yhat = logits.argmax(dim=1).cpu().numpy().tolist()
    return [feat_names[i] for i in yhat]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

avail = get_available_ids(ROOT, max_id=MAX_ID_SCAN)
print(f"usable parts on disk: {len(avail)}")
if len(avail) < 50:
    raise RuntimeError("Too few usable samples; generate more.")

ds_tr_ids, ds_va_ids, ds_te_ids = split_ids(avail, seed=SEED)
print(f"split -> train {len(ds_tr_ids)}, val {len(ds_va_ids)}, test {len(ds_te_ids)}")

# build datasets (lazy cache), avoid Windows/Jupyter multiprocessing stalls
IS_WINDOWS = (sys.platform == "win32")
IN_NOTEBOOK = ("ipykernel" in sys.modules)
NUM_WORKERS = 0 if (IS_WINDOWS or IN_NOTEBOOK) else 4
pin = torch.cuda.is_available() and NUM_WORKERS > 0

ds_tr = GraphDataset(ROOT, ds_tr_ids, cache_mode="lazy", strict=True)
ds_va = GraphDataset(ROOT, ds_va_ids, cache_mode="lazy", strict=True)
ds_te = GraphDataset(ROOT, ds_te_ids, cache_mode="none", strict=True)

tr_loader = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=NUM_WORKERS, pin_memory=pin)
va_loader = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=pin)
te_loader = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=pin)

# infer feature dimension from one item
g0 = ds_tr[0]
in_dim = g0.x.shape[1]
print(f"in_dim: {in_dim}")

model = GCN(in_dim, hidden=HIDDEN, out_dim=NUM_CLASSES, dropout=0.2).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = GradScaler(enabled=(device.type=="cuda"))
W = class_weights_from_dataset(ds_tr)

# train
hist = {"tr_loss": [], "va_loss": [], "tr_acc": [], "va_acc": []}
best_val, best_state, bad = float("inf"), None, 0

for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_one_epoch(model, device, tr_loader, opt, scaler, W)
    va_loss, va_acc = evaluate(model, device, va_loader)

    hist["tr_loss"].append(tr_loss); hist["va_loss"].append(va_loss)
    hist["tr_acc"].append(tr_acc);   hist["va_acc"].append(va_acc)

    print(f"ep{epoch:03d} | train {tr_loss:.4f}/{tr_acc:.3f} | val {va_loss:.4f}/{va_acc:.3f}")

    if va_loss < best_val - 1e-4:
        best_val = va_loss
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        bad = 0
    else:
        bad += 1
        if bad >= PATIENCE:
            print("Early stopping.")
            break

if best_state is not None:
    model.load_state_dict(best_state)

# save checkpoint
ckpt_dir = Path("./checkpoints"); ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = ckpt_dir / "gcn_facecls.pt"
torch.save({
    "state_dict": model.state_dict(),
    "in_dim": in_dim,
    "hidden": HIDDEN,
    "num_classes": NUM_CLASSES,
    "feat_names": feat_names,
    "train_ids": ds_tr.ids, "val_ids": ds_va.ids, "test_ids": ds_te.ids,
}, ckpt_path)
print(f"Saved checkpoint -> {ckpt_path}")

# test
te_loss, te_acc = evaluate(model, device, te_loader)
print(f"TEST  loss/acc: {te_loss:.4f}/{te_acc:.3f}")

# plot
plt.figure(figsize=(7,4.5))
plt.plot(hist["tr_loss"], label="train")
plt.plot(hist["va_loss"], label="val")
plt.xlabel("epoch"); plt.ylabel("loss")
plt.title("GCN per-face classification")
plt.legend(); plt.tight_layout()
plt.show()

# quick demo inference
demo_id = random.choice(ds_te.ids)
pred_names = predict_part(model, device, ROOT, demo_id)
print(f"\nDemo predictions for part {demo_id}:")
print(pred_names[:min(20, len(pred_names))], f"... (total faces={len(pred_names)})")


In [None]:
# quick demo inference
demo_id = random.choice(ds_te.ids)
pred_names = predict_part(model, device, ROOT, demo_id)
print(f"\nDemo predictions for part {demo_id}:")
print(pred_names[:min(20, len(pred_names))], f"... (total faces={len(pred_names)})")

In [None]:
# --- quick demo inference + flat GT comparison (no functions) ---
demo_id = random.choice(ds_te.ids)
print(f"\n--- Demo part {demo_id} ---")

# load sample + move to device
d  = load_sample(ROOT, demo_id)          # also reads ground-truth JSON
x  = d["x"].to(device)
ei = d["edge_index"].to(device)
y  = d["y"].cpu().numpy()                # ground-truth indices

# forward pass
model.eval()
with torch.no_grad():
    with torch.cuda.amp.autocast(enabled=(device.type == "cuda"), dtype=torch.float16):
        logits = model(x, ei)            # [num_faces, NUM_CLASSES]

# predictions and names
pred_idx   = logits.argmax(dim=1).cpu().numpy()
pred_names = [feat_names[i] for i in pred_idx]
gt_names   = [feat_names[i] for i in y]

# accuracy
correct = int((pred_idx == y).sum())
total   = int(y.size)
acc     = correct / max(1, total)
print(f"faces={total}  acc={acc:.3f}  ({correct}/{total})")

# show a few mismatches with top-3 alternatives
wrong = np.flatnonzero(pred_idx != y)
max_show = 20
if wrong.size:
    probs = logits.softmax(dim=1).cpu().numpy()
    k = min(max_show, wrong.size)
    print(f"First {k} mismatches (face_idx: pred (p) -> gt | top3):")
    for i in wrong[:k]:
        top3 = probs[i].argsort()[-3:][::-1]
        top3_str = ", ".join(f"{feat_names[t]}({probs[i][t]:.2f})" for t in top3)
        print(f"  {i:4d}: {feat_names[pred_idx[i]]} ({probs[i][pred_idx[i]]:.2f})"
              f" -> {feat_names[y[i]]} | {top3_str}")
else:
    print("No mismatches on this part 🎉")

# optional: per-part classification report (requires scikit-learn)
try:
    from sklearn.metrics import classification_report
    print("\nClassification report (this part only):")
    print(classification_report(
        y, pred_idx,
        labels=list(range(NUM_CLASSES)),
        target_names=feat_names,
        digits=3,
        zero_division=0
    ))
except Exception as e:
    print(f"[sklearn report skipped] {e}")

# optional: peek a few names
print("\nPred (first 20):", pred_names[:20])
print("GT   (first 20):", gt_names[:20])
