In [2]:
# Minimal Stage-2 sanity: load graph, rebuild identical model, load weights, tiny forward

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv
import torch_geometric.transforms as T

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("✅ Device:", DEVICE)

# --- load graphs and add reverse edges exactly like train_stage2.py ---
g_train: HeteroData = torch.load("../data/processed/graph_train.pt", weights_only=False)
g_train = T.ToUndirected()(g_train)  # adds rev_* edge types
g_train = g_train.to(DEVICE)

print("Node types:", g_train.node_types)
print("Edge types:", g_train.edge_types)

# --- rebuild the *same* model used in training ---

class FeatEncoder(nn.Module):
    def __init__(self, g: HeteroData, hidden: int):
        super().__init__()
        self.proj = nn.ModuleDict()
        for ntype in g.node_types:
            in_dim = g[ntype].x.size(-1)
            self.proj[ntype] = nn.Linear(in_dim, hidden)

    def forward(self, x_dict):
        return {nt: F.relu(self.proj[nt](x)) for nt, x in x_dict.items()}

class HeteroSAGE(nn.Module):
    def __init__(self, g: HeteroData, hidden: int = 64):
        super().__init__()
        self.encoder = FeatEncoder(g, hidden)
        self.layers = nn.ModuleList()
        # one SAGE per edge type, for 2 layers, same as train_stage2.py
        for _ in range(2):
            conv = HeteroConv({et: SAGEConv((-1, -1), hidden) for et in g.edge_types}, aggr="sum")
            self.layers.append(conv)

    def forward(self, x_dict, edge_index_dict):
        h = self.encoder(x_dict)
        for conv in self.layers:
            h = conv(h, edge_index_dict)
            h = {k: F.relu(v) for k, v in h.items()}
        return h

    def decode(self, h_user, h_item, edge_label_index):
        src, dst = edge_label_index
        return (h_user[src] * h_item[dst]).sum(dim=-1)

# instantiate on THIS graph (edge types must match)
model = HeteroSAGE(g_train, hidden=64).to(DEVICE)

# --- load trained weights (names now match) ---
ckpt = "../data/processed/graphsage_baseline.pt"
state = torch.load(ckpt, map_location=DEVICE)
model.load_state_dict(state)
model.eval()
print("✅ Weights loaded from:", ckpt)

# --- tiny, fast forward on a small slice (no heavy work) ---
from torch_geometric.loader import LinkNeighborLoader

# use the same graph `g_train` you already ToUndirected()'d and moved to DEVICE
et = ("user","interacted","item")

# a single small, consistent mini-subgraph (no negatives; just to test forward)
loader = LinkNeighborLoader(
    g_train,
    num_neighbors={k: [5, 5] for k in g_train.edge_types},  # small, 2 hops
    batch_size=512,
    edge_label_index=(et, g_train[et].edge_index[:, :2048]),  # small seed set
    neg_sampling_ratio=0.0,
    shuffle=False,
)

batch = next(iter(loader))
batch = batch.to(DEVICE)

with torch.no_grad():
    h = model(batch.x_dict, batch.edge_index_dict)

print("Embeddings shapes:", {k: tuple(v.shape) for k, v in h.items()})
print("✅ Quick sanity passed.")



✅ Device: cpu
Node types: ['user', 'item', 'genre']
Edge types: [('user', 'interacted', 'item'), ('item', 'belongs_to', 'genre'), ('item', 'rev_interacted', 'user'), ('genre', 'rev_belongs_to', 'item')]
✅ Weights loaded from: ../data/processed/graphsage_baseline.pt
Embeddings shapes: {'item': (2073, 64), 'genre': (18, 64), 'user': (1515, 64)}
✅ Quick sanity passed.
