In [None]:
# rtg_ws_train.py
import numpy as np, pandas as pd, torch, torch.nn as nn, torch.nn.functional as F
import argparse, random, sys
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
from sklearn.metrics import f1_score, precision_score, recall_score

# ------------------------ #
# CLI Argument Parsing
# ------------------------ #
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cpu", help="cpu or cuda")
parser.add_argument("--ws_keep", type=float, default=1.0, help="World Sync edge keep ratio (0~1)")
args, _ = parser.parse_known_args()
DEVICE = torch.device(args.device)

# ------------------------ #
# Global Settings
# ------------------------ #
SEED, NEG_RATIO = 42, 2
USER_EMB_DIM, HIDDEN_DIM = 8, 48
BLOCK_DIM, BIAS_SCALE = 8, 1.0
DELTA_SCALE0 = 25.0
RANK_MARGIN = 1.0
LR, NUM_EPOCHS = 4e-4, 80
BATCH_NODES, NEI_L1, NEI_L2 = 8000, 10, 5
K_SAMPLES = 3

torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
if DEVICE.type == 'cuda':
    torch.backends.cudnn.deterministic = True   
    torch.backends.cudnn.benchmark     = False  

# ------------------------ #
# 1. Load events
# ------------------------ #
events = pd.read_csv("events.csv", usecols=["visitorid","timestamp","event","itemid"])
events = events.sort_values(["visitorid", "timestamp"]).reset_index(drop=True)
events["block_type"] = events.event.map({"view":1, "addtocart":2, "transaction":3})
gap = events.groupby("visitorid")["timestamp"].diff().fillna(0)
events["session_id"] = (gap > 3600000).astype(int).groupby(events.visitorid).cumsum()
events["items"] = events.itemid
events["t_bucket"] = ((events.timestamp - events.timestamp.min()) // (30*60*1000)).astype("int32")

# ------------------------ #
# 2. Blocks → Edges
# ------------------------ #
def build_blocks(df):
    blks = df.groupby(["visitorid", "session_id", "timestamp", "block_type"], as_index=False)\
             .agg({"items": "first"}).rename(columns={"timestamp": "start_time"})
    blks["next_bt"] = blks.groupby(["visitorid", "session_id"]).block_type.shift(-1)
    blks["label"] = (blks.next_bt == 3).astype(int)
    return blks.dropna(subset=["next_bt"])

blocks = build_blocks(events)
s_start = blocks.groupby(["visitorid","session_id"]).start_time.min()
c1, c2 = s_start.quantile([.70, .85])
split = blocks.set_index(["visitorid","session_id"])
tr_blk = split.loc[s_start <= c1].reset_index()
va_blk = split.loc[(s_start > c1) & (s_start <= c2)].reset_index()
te_blk = split.loc[s_start > c2].reset_index()

def blks2edges(blks):
    blks["prev_bt"] = blks.groupby(["visitorid","session_id"]).block_type.shift(1).fillna(0).astype(int)
    return blks.explode("items", ignore_index=True).rename(columns={"visitorid":"user", "items":"item"})\
               [["user", "item", "block_type", "prev_bt", "label"]]

edge_tr_full = blks2edges(tr_blk)
edge_val = blks2edges(va_blk)
edge_test = blks2edges(te_blk)

# ------------------------ #
# 3. Undersample
# ------------------------ #
pos = edge_tr_full[edge_tr_full.label == 1]
neg = edge_tr_full[edge_tr_full.label == 0].sample(len(pos)*NEG_RATIO, random_state=SEED)
edge_train = pd.concat([pos, neg]).sample(frac=1, random_state=SEED)

# ------------------------ #
# 4. Attr & Mapping
# ------------------------ #
def onehot(x): return np.eye(4)[x]
def attr(df): return np.hstack([onehot(df.prev_bt), onehot(df.block_type)])

tr_attr_base = torch.tensor(attr(edge_train), dtype=torch.float)
va_attr_base = attr(edge_val)
te_attr_base = attr(edge_test)

uid2idx, it2idx = {}, {}
add_u = lambda u: uid2idx.setdefault(u, len(uid2idx))
add_i = lambda i: it2idx.setdefault(i, len(uid2idx) + len(it2idx))
ui_src = edge_train.user.map(add_u).to_numpy(np.int32)
ui_dst = edge_train.item.map(add_i).to_numpy(np.int32)
unk_u = add_u("__UNK__"); unk_it = add_i("__UNK_IT__")

# ------------------------ #
# 5. World Sync Edges
# ------------------------ #
ws_src, ws_dst = [], []
rng = np.random.default_rng(SEED)
for (_, _), users in events.groupby(["t_bucket", "itemid"])["visitorid"].unique().items():
    if len(users) < 2: continue
    for u in users:
        others = [v for v in users if v != u]
        for v in rng.choice(others, min(K_SAMPLES, len(others)), replace=False):
            ws_src.append(add_u(u)); ws_dst.append(add_u(v))
ws_src, ws_dst = np.array(ws_src), np.array(ws_dst)
if args.ws_keep < 1.0:
    keep = rng.choice(len(ws_src), int(len(ws_src)*args.ws_keep), replace=False)
    ws_src, ws_dst = ws_src[keep], ws_dst[keep]

edge_src = np.concatenate([ui_src, ws_src])
edge_dst = np.concatenate([ui_dst, ws_dst])
edge_user = np.concatenate([ui_src, ws_src])
edge_attr = torch.cat([
    tr_attr_base,
    torch.zeros((len(ws_src), 8), dtype=tr_attr_base.dtype, device=tr_attr_base.device)
])
prev_bt = torch.tensor(np.concatenate([edge_train.prev_bt.values, np.zeros(len(ws_src))]), dtype=torch.long)
curr_bt = torch.tensor(np.concatenate([edge_train.block_type.values, np.zeros(len(ws_src))]), dtype=torch.long)

# ------------------------ #
# 6. Graph
# ------------------------ #
num_nodes = max(edge_src.max(), edge_dst.max()) + 1
graph = Data(x=torch.zeros((num_nodes, 1)),
             edge_index=torch.tensor([edge_src, edge_dst]),
             edge_attr=edge_attr,
             edge_user=torch.tensor(edge_user),
             prev_bt=prev_bt,
             curr_bt=curr_bt)

# ------------------------ #
# 7. Model
# ------------------------ #
class EAGAT(MessagePassing):
    def __init__(self, in_c, out_c, e_c):
        super().__init__('add')
        self.lin_n = nn.Linear(in_c, out_c, False)
        self.lin_e = nn.Linear(e_c, out_c, False)
        self.att = nn.Parameter(torch.randn(1, 3*out_c))
    def forward(self, x, ei, ea): return self.propagate(ei, x=self.lin_n(x), edge_attr=self.lin_e(ea))
    def message(self, x_j, x_i, edge_attr, index):
        a = torch.cat([x_i, x_j, edge_attr], 1) @ self.att.T
        a = F.leaky_relu(a.squeeze(-1), 0.2)
        a = softmax(a, index)
        return x_j * a.unsqueeze(-1)

class RTG(nn.Module):
    def __init__(self, num_users):
        super().__init__()
        self.block_emb = nn.Embedding(4, BLOCK_DIM, padding_idx=0)
        self.user_bias = nn.Embedding(num_users, 1)
        self.delta_scale = nn.Parameter(torch.tensor(DELTA_SCALE0))
        self.bias_scale = nn.Parameter(torch.tensor(BIAS_SCALE))
        self.g1 = EAGAT(1, HIDDEN_DIM, 9)
        self.g2 = EAGAT(HIDDEN_DIM, HIDDEN_DIM, 9)
        self.edge_mlp = nn.Sequential(nn.Linear(HIDDEN_DIM*2 + 9, HIDDEN_DIM), nn.ReLU(), nn.Linear(HIDDEN_DIM, 1))
    def _delta_hat(self, prev_bt, curr_bt, user_id):
        e1 = self.block_emb(prev_bt)
        e2 = self.block_emb(curr_bt)
        dist = torch.norm(e1 - e2, p=2, dim=1)
        user_b = self.user_bias(user_id).squeeze(-1)
        return self.delta_scale * dist + self.bias_scale * user_b
    def _edge_full(self, ea, prev_bt, curr_bt, user_id):
        Δ = self._delta_hat(prev_bt, curr_bt, user_id).unsqueeze(-1)
        return torch.cat([ea, Δ], 1)
    def forward(self, d):
        ea = self._edge_full(d.edge_attr, d.prev_bt, d.curr_bt, d.edge_user)
        h = self.g1(d.x, d.edge_index, ea)
        return self.g2(h, d.edge_index, ea)
    def edge_pred(self, h, ei, ea, prev_bt, curr_bt, user_id):
        full_ea = self._edge_full(ea, prev_bt, curr_bt, user_id)
        u, v = ei
        return self.edge_mlp(torch.cat([h[u], h[v], full_ea], 1)).view(-1)

# ------------------------ #
# 8. Ranking Loss
# ------------------------ #
def delta_ranking_loss(model, edge_df, device):
    code = edge_df.prev_bt.values * 4 + edge_df.block_type.values
    code = torch.tensor(code, device=device)
    prev = torch.tensor(edge_df.prev_bt.values, device=device)
    curr = torch.tensor(edge_df.block_type.values, device=device)
    user = torch.tensor(edge_df.user.map(uid2idx).fillna(unk_u).to_numpy(np.int64), device=device)
    delta = model._delta_hat(prev, curr, user)
    loss = 0.0
    for hi, lo in [((1,3),(1,1))]:
        hi_code = hi[0]*4 + hi[1]; lo_code = lo[0]*4 + lo[1]
        mask_hi, mask_lo = (code==hi_code), (code==lo_code)
        if mask_hi.sum()==0 or mask_lo.sum()==0: continue
        Δ_hi = delta[mask_hi].mean(); Δ_lo = delta[mask_lo].mean()
        loss += F.relu(Δ_lo - Δ_hi + RANK_MARGIN)
    return loss

# ------------------------ #
# 9. Train
# ------------------------ #
# 9. Train
model = RTG(len(uid2idx)).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
crit = lambda log, y: F.binary_cross_entropy_with_logits(
    log, y, pos_weight=torch.tensor(np.sqrt(NEG_RATIO), device=DEVICE))

# ── NEW: cache full prev / curr tensors on DEVICE ───────────────
prev_bt_all = torch.tensor(edge_train.prev_bt.values, dtype=torch.long, device=DEVICE)
curr_bt_all = torch.tensor(edge_train.block_type.values, dtype=torch.long, device=DEVICE)
# ────────────────────────────────────────────────────────────────

ui_edges = torch.tensor([ui_src, ui_dst], device=DEVICE)
ui_user  = torch.tensor(ui_src, device=DEVICE)
ui_attr  = tr_attr_base.to(DEVICE)
ui_y     = torch.tensor(edge_train.label.values, dtype=torch.float32, device=DEVICE)

loader = NeighborLoader(graph, [NEI_L1, NEI_L2], batch_size=BATCH_NODES, shuffle=True)
for ep in range(1, NUM_EPOCHS + 1):
    model.train(); tot = 0
    for batch_idx, batch in enumerate(loader):
        batch = batch.to(DEVICE); opt.zero_grad()
        h = model(batch)

        g2l = -torch.ones(num_nodes, dtype=torch.long, device=DEVICE)
        g2l[batch.n_id] = torch.arange(batch.num_nodes, device=DEVICE)
        m = (g2l[ui_edges[0]] >= 0) & (g2l[ui_edges[1]] >= 0)
        if not m.any(): continue
        ei = torch.stack([g2l[ui_edges[0][m]], g2l[ui_edges[1][m]]], 0)

        # ── OLD ─────────────────────────────────────────────
        # prev_bt = torch.tensor(edge_train.prev_bt.values[m], ...)
        # curr_bt = torch.tensor(edge_train.block_type.values[m], ...)
        # ── NEW: just slice the cached tensors ─────────────
        prev_bt = prev_bt_all[m]
        curr_bt = curr_bt_all[m]
        # ──────────────────────────────────────────────────

        log  = model.edge_pred(h, ei, ui_attr[m], prev_bt, curr_bt, ui_user[m])
        loss = crit(log, ui_y[m])
        if batch_idx == 0:          # one global ranking penalty per epoch
            rank_penalty = 0.1 * delta_ranking_loss(model, edge_train, DEVICE)
            loss = loss + rank_penalty
        loss.backward(); opt.step()
        tot += loss.item() * m.sum().item()
    print(f"Ep {ep:3d}  loss={tot/len(ui_src):.4f}")


# ------------------------ #
# 10. Eval
# ------------------------ #
model.eval()
h_all = torch.zeros((num_nodes, HIDDEN_DIM), device=DEVICE)

with torch.no_grad():
    for batch in NeighborLoader(graph, [-1], batch_size=4000):
        batch = batch.to(DEVICE)
        h_all[batch.n_id] = model(batch)


def eval_set(df, attr_np, thresholds):
    u = torch.tensor(df.user.map(uid2idx).fillna(unk_u).to_numpy(np.int64),
                     device=DEVICE)
    i = torch.tensor(df.item.map(it2idx).fillna(unk_it).to_numpy(np.int64),
                     device=DEVICE)
    ei = torch.stack([u, i], 0)
    ea = torch.tensor(attr_np, dtype=torch.float32, device=DEVICE)
    eu = u
    prev_bt = torch.tensor(df.prev_bt.values, dtype=torch.long, device=DEVICE)
    curr_bt = torch.tensor(df.block_type.values, dtype=torch.long, device=DEVICE)

    with torch.no_grad():
        prob = torch.sigmoid(model.edge_pred(h_all, ei, ea, prev_bt, curr_bt, eu)).numpy()
    y = df.label.values
    best = None
    for th in thresholds:
        pred = prob > th
        f1 = f1_score(y, pred, zero_division=0)
        if not best or f1 > best[0]:
            best = (f1, th, precision_score(y, pred), recall_score(y, pred))
    return best

# Threshold search
thresholds = np.linspace(0.05, 0.95, 19)
val = eval_set(edge_val, va_attr_base, thresholds)
test = eval_set(edge_test, te_attr_base, [val[1]])

# Output
print(f"[Val ] F1 {val[0]:.4f}  P {val[2]:.4f} R {val[3]:.4f} thr {val[1]:.2f}")
print(f"[Test] F1 {test[0]:.4f} P {test[2]:.4f} R {test[3]:.4f}")

