In [None]:
# predict_from_step.py
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# your pieces
# from dataset_loader import build_node_features   # uses scalar "type" (D=10)
from convert_step_to_graph import convert_step_to_graph     # <-- adjust import if named differently

# label names (used if not embedded in the checkpoint)
FEAT_NAMES_DEFAULT = ['chamfer', 'through_hole', 'triangular_passage', 'rectangular_passage', '6sides_passage',
                      'triangular_through_slot', 'rectangular_through_slot', 'circular_through_slot',
                      'rectangular_through_step', '2sides_through_step', 'slanted_through_step', 'Oring', 'blind_hole',
                      'triangular_pocket', 'rectangular_pocket', '6sides_pocket', 'circular_end_pocket',
                      'rectangular_blind_slot', 'v_circular_end_blind_slot', 'h_circular_end_blind_slot',
                      'triangular_blind_step', 'circular_blind_step', 'rectangular_blind_step', 'round', 'stock']

NUM_CLASSES = 25

# same tiny GCN as training
class GCN(nn.Module):
    def __init__(self, in_dim, hidden=128, out_dim=NUM_CLASSES, dropout=0.2):
        super().__init__()
        from torch_geometric.nn import GCNConv
        self.c1 = GCNConv(in_dim, hidden)
        self.c2 = GCNConv(hidden, hidden)
        self.lin = nn.Linear(hidden, out_dim)
        self.dropout = dropout
    def forward(self, x, edge_index):
        x = F.relu(self.c1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.c2(x, edge_index))
        return self.lin(x)

def _load_model(ckpt_path: str | Path, device: torch.device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    in_dim       = int(ckpt.get("in_dim", 10))
    hidden       = int(ckpt.get("hidden", 128))
    num_classes  = int(ckpt.get("num_classes", NUM_CLASSES))
    feat_names   = ckpt.get("feat_names", FEAT_NAMES_DEFAULT)
    state_dict   = ckpt["state_dict"]

    model = GCN(in_dim, hidden=hidden, out_dim=num_classes).to(device)
    model.load_state_dict(state_dict)
    model.eval()
    return model, feat_names

@torch.no_grad()
def predict_from_step(
    step_path: str | Path,
    ckpt_path: str | Path = "./checkpoints/gcn_facecls.pt",
    device: str | torch.device | None = None,
    extractor_kwargs: dict | None = None,
    use_amp: bool = True,
):
    """
    Returns:
      {
        'labels_idx':  [num_faces] int list,
        'labels_name': [num_faces] str list,
        'num_faces':   int
      }
    """
    step_path = Path(step_path)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif isinstance(device, str):
        device = torch.device(device)
    extractor_kwargs = extractor_kwargs or {}

    # 1) STEP -> graph (we only need face_features + edge_index)
    # tip: to speed up, you can pass n_samples_proximity=16 or even 0 if you made that safe
    extractor = BRepGraphExtractor(str(step_path), **extractor_kwargs)
    graph = extractor.extract_features()

    # 2) Build node features exactly like training (scalar 'type')
    x_np = build_node_features(graph["face_features"], use_type_onehot=False)  # shape [N,10]
    ei_np = np.asarray(graph["edge_index"], dtype=np.int64)                    # shape [2,E]
    x  = torch.from_numpy(x_np).to(device)
    ei = torch.from_numpy(ei_np).to(device)

    # 3) Load model
    model, feat_names = _load_model(ckpt_path, device)

    # 4) Predict
    if device.type == "cuda" and use_amp:
        with torch.cuda.amp.autocast(dtype=torch.float16):
            logits = model(x, ei)
    else:
        logits = model(x, ei)

    pred_idx = logits.argmax(dim=1).tolist()
    pred_names = [feat_names[i] for i in pred_idx]

    return {"labels_idx": pred_idx, "labels_name": pred_names, "num_faces": len(pred_idx)}

# --- example ---
if __name__ == "__main__":
    res = predict_from_step(
        step_path="./dataset/dataset_generation/data/0.stp",
        ckpt_path="./checkpoints/gcn_facecls.pt",
        device="cuda",
        extractor_kwargs={"n_samples_proximity": 16},  # faster; not used by the model anyway
    )
    print("faces:", res["num_faces"])
    print(res["labels_name"][:20], "...")
