In [None]:
# === Stage 3 TGAT: self-contained verification cell ===


import os, json, math, random
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Locate repo root (folder that contains 'data/processed') ---
def find_repo_root(start=Path.cwd()):
    d = start
    for _ in range(8):
        if (d/"data/processed").exists():
            return d
        if d.parent == d: break
        d = d.parent
    raise RuntimeError("Could not find repo root containing data/processed/")
ROOT = find_repo_root()
PROCESSED = ROOT/"data/processed"
print("Repo root:", ROOT)

# --- Inline model (must match training) ---
class TimeEncoder(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        self.lin = nn.Linear(dim, dim)
    def forward(self, t: torch.Tensor):           # t: [E,1] float32
        t = t.to(torch.float32)
        d = self.dim
        freq = torch.arange(d, device=t.device, dtype=t.dtype) / float(d)
        scales = 1.0 / (10.0 ** freq)            # [d]
        x = t * scales.unsqueeze(0)              # [E,d]
        return self.lin(torch.sin(x))            # [E,d]

class TemporalTGAT(nn.Module):
    def __init__(self, num_users, num_items, hidden=64, time_dim=32):
        super().__init__()
        from torch_geometric.nn import TransformerConv
        self.user_emb = nn.Embedding(num_users, hidden)
        self.item_emb = nn.Embedding(num_items, hidden)
        nn.init.normal_(self.user_emb.weight, std=0.02)
        nn.init.normal_(self.item_emb.weight, std=0.02)
        self.time_enc = TimeEncoder(time_dim)
        self.conv = TransformerConv(
            in_channels=hidden, out_channels=hidden,
            heads=2, dropout=0.1, edge_dim=time_dim
        )
    def forward(self, u_unique, i_unique, edge_index, t_edge):
        x_user = self.user_emb(u_unique)     # [U,H]
        x_item = self.item_emb(i_unique)     # [I,H]
        h = torch.cat([x_user, x_item], dim=0)     # [U+I,H]
        eattr = self.time_enc(t_edge)              # [E, T]
        out = self.conv(h, edge_index, edge_attr=eattr)
        return out

# --- Load counts + checkpoint ---
counts = json.loads((PROCESSED/"temporal_counts.json").read_text())
NUM_USERS, NUM_ITEMS = int(counts["num_users"]), int(counts["num_items"])

model = TemporalTGAT(NUM_USERS, NUM_ITEMS, hidden=64, time_dim=32).to(DEVICE)
state = torch.load(PROCESSED/"tgat_baseline.pt", map_location=DEVICE)
model.load_state_dict(state)
model.eval()
print("✅ Loaded TGAT checkpoint")

# --- Load validation data (already time-sorted by your prep script) ---
df = pd.read_csv(PROCESSED/"temporal_val.csv")  # columns: user_id, movie_id, label, ts_norm
df["user_id"] = df["user_id"].astype(int)
df["movie_id"] = df["movie_id"].astype(int)
df["label"]    = df["label"].astype(int)
print("Val size:", len(df))

# --- Per-user evaluation with negative sampling ---
def eval_per_user(df, K=10, neg_k=50, sample_users=200):
    users = df["user_id"].unique()
    if sample_users and len(users) > sample_users:
        users = pd.Series(users).sample(sample_users, random_state=42).values

    total_prec, total_rec, user_count = 0.0, 0.0, 0

    for u in users:
        rows = df[df.user_id == u]
        pos_items = rows.loc[rows.label == 1, "movie_id"].tolist()
        if not pos_items:
            continue  # no positives for this user in val

        # candidate set = all positives + neg_k random negatives
        neg_items = set()
        while len(neg_items) < min(neg_k*len(pos_items), NUM_ITEMS - len(pos_items)):
            r = random.randint(0, NUM_ITEMS-1)
            if r not in pos_items:
                neg_items.add(r)
        neg_items = list(neg_items)

        cand_items = pos_items + neg_items
        labels = torch.tensor([1]*len(pos_items) + [0]*len(neg_items), dtype=torch.float32, device=DEVICE)
        ts = torch.tensor(rows["ts_norm"].iloc[:1].values, dtype=torch.float32, device=DEVICE)  # use user’s first ts as proxy
        ts = ts.repeat(len(cand_items), 1)

        # ---- Build local bipartite mini-graph ----
        u_tensor = torch.tensor([u]*len(cand_items), dtype=torch.long, device=DEVICE)
        i_tensor = torch.tensor(cand_items, dtype=torch.long, device=DEVICE)

        # relabel to local ids
        u_unique, u_inv = torch.unique(u_tensor, return_inverse=True)   # size U=1
        i_unique, i_inv = torch.unique(i_tensor, return_inverse=True)
        U = u_unique.size(0)

        edge_index = torch.stack([u_inv, i_inv + U], dim=0)  # [2,E]

        # ---- Forward ----
        with torch.no_grad():
            out = model(u_unique, i_unique, edge_index, ts)  # [U+I,H]
            h_user = out[:U]          # [1,H]
            h_item = out[U:]          # [I,H]
            scores = (h_user[0].unsqueeze(0) * h_item).sum(dim=-1) / math.sqrt(h_user.size(-1))
            probs = torch.sigmoid(scores)

        # ---- Top-K metrics per user ----
        k = min(K, probs.numel())
        topk_idx = torch.topk(probs, k=k).indices
        topk_labels = labels[topk_idx]
        prec_u = topk_labels.mean().item()
        rec_u = topk_labels.sum().item() / max(1, labels.sum().item())

        total_prec += prec_u
        total_rec  += rec_u
        user_count += 1

    if user_count == 0:
        return 0.0, 0.0
    return total_prec / user_count, total_rec / user_count

for K in [5, 10, 20]:
    p, r = eval_per_user(df, K=K, neg_k=50, sample_users=200)  # evaluate on 200 users for speed
    print(f"P@{K}={p:.4f}  R@{K}={r:.4f}")

# --- Show a few sample recommendations for a random user ---
def show_recs(u, topn=5, neg_k=200):
    rows = df[df.user_id == u]
    if rows.empty:
        print(f"User {u} not in validation slice"); return
    pos_items = rows.loc[rows.label == 1, "movie_id"].tolist()
    seen = set(pos_items)
    neg_items = []
    while len(neg_items) < neg_k:
        r = random.randint(0, NUM_ITEMS-1)
        if r not in seen:
            neg_items.append(r)
    cand_items = pos_items + neg_items

    labels = torch.tensor([1]*len(pos_items) + [0]*len(neg_items), dtype=torch.float32, device=DEVICE)
    ts = torch.tensor(rows["ts_norm"].iloc[:1].values, dtype=torch.float32, device=DEVICE).repeat(len(cand_items), 1)

    u_tensor = torch.tensor([u]*len(cand_items), dtype=torch.long, device=DEVICE)
    i_tensor = torch.tensor(cand_items, dtype=torch.long, device=DEVICE)
    u_unique, u_inv = torch.unique(u_tensor, return_inverse=True)
    i_unique, i_inv = torch.unique(i_tensor, return_inverse=True)
    U = u_unique.size(0)
    edge_index = torch.stack([u_inv, i_inv + U], dim=0)

    with torch.no_grad():
        out = model(u_unique, i_unique, edge_index, ts)
        h_user = out[:U]; h_item = out[U:]
        scores = (h_user[0].unsqueeze(0) * h_item).sum(dim=-1) / math.sqrt(h_user.size(-1))
        probs = torch.sigmoid(scores).cpu()

    # top-n
    k = min(topn, probs.numel())
    top_idx = torch.topk(probs, k=k).indices.tolist()
    rec_items = [cand_items[i] for i in top_idx]
    hits = [int(labels.cpu()[i].item()) for i in top_idx]
    print(f"\nUser {u} top-{k} candidates (item_id, is_positive):")
    for it, h in zip(rec_items, hits):
        print(f"  {it:5d}  {h}")

# display a random user's recs
some_user = int(df.sample(1, random_state=7)["user_id"].iloc[0])
show_recs(some_user, topn=5)


Repo root: c:\Users\hp\OneDrive\Desktop\Zine_project1\rec-temporal-gnn
✅ Loaded TGAT checkpoint
Val size: 100021
P@5=0.0163  R@5=0.0265
P@10=0.0245  R@10=0.0584
P@20=0.0360  R@20=0.0843

User 720 top-5 candidates (item_id, is_positive):
    348  1
   2258  1
   2430  1
    593  1
    899  1
