In [10]:
!pip install torch torchvision torchaudio -q
!pip install torch-geometric -q
!pip install dgl -q  # generic DGL (CPU/GPU autodetect)
!pip install torchmetrics==1.4.0.post0 scikit-learn pandas numpy tqdm geopy haversine -q

In [11]:
# ============================================================
# Cell 1: Imports
# ============================================================
import pandas as pd
import numpy as np
import ast
from collections import defaultdict

import torch
from torch import nn
import torch.nn.functional as F

from tqdm import tqdm
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import svds
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import pairwise_distances

Data Loading

In [12]:
# ============================================================
# Cell 2: Load CSVs (keep all columns, parse lists/dicts)
# ============================================================
def load_reviews_csv(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)

    def try_literal(x):
        # Safely convert stringified lists/dicts; leave others alone
        try:
            return ast.literal_eval(str(x))
        except Exception:
            return x

    df = df.applymap(try_literal)
    return df

train_df = load_reviews_csv("train_reviews.csv")
val_df = load_reviews_csv("val_reviews.csv")
test_df = load_reviews_csv("test_reviews.csv")

print("train:", train_df.shape)
print("val:",   val_df.shape)
print("test:",  test_df.shape)

  df = pd.read_csv(path)
  df = df.applymap(try_literal)
  df = df.applymap(try_literal)
  df = df.applymap(try_literal)


train: (278388, 36)
val: (34798, 36)
test: (34799, 36)


Feature Building

In [62]:
# ============================================================
# Cell 3: Build global ID mappings (SAFE) and add u/i columns
# ============================================================
all_users = pd.concat(
    [train_df["user_id"], val_df["user_id"], test_df["user_id"]],
    ignore_index=True
)
all_items = pd.concat(
    [train_df["gmap_id"], val_df["gmap_id"], test_df["gmap_id"]],
    ignore_index=True
)

user2idx = {u: idx for idx, u in enumerate(all_users.unique())}
item2idx = {i: idx for idx, i in enumerate(all_items.unique())}

num_users = len(user2idx)
num_items = len(item2idx)
print("num_users:", num_users, "num_items:", num_items)

for df in [train_df, val_df, test_df]:
    df["u"] = df["user_id"].map(user2idx)
    df["i"] = df["gmap_id"].map(item2idx)

# Train-only view for all feature construction (avoid leakage)
full_df_train = train_df.copy()

num_users: 25841 num_items: 2923


In [63]:
# ============================================================
# Cell 4: Hardened, Normalized Train-only Feature Builder
# ============================================================

import numpy as np
import pandas as pd

# ------------------------------------------------------------
# 4.1 Ensure category is always a list
# ------------------------------------------------------------
for df in [train_df, val_df, test_df]:
    df["category"] = df["category"].apply(lambda x: x if isinstance(x, list) else [])

# TRAIN-ONLY METADATA
df_train = full_df_train.copy()


# ============================================================
# 4.2 GEO FEATURES (latitude, longitude)
# ============================================================
df_train["latitude"]  = pd.to_numeric(df_train["latitude"], errors="coerce")
df_train["longitude"] = pd.to_numeric(df_train["longitude"], errors="coerce")

item_geo = (
    df_train.groupby("i")[["latitude", "longitude"]]
    .mean()                           # average per item
    .reindex(range(num_items))        # ensure full coverage
    .fillna(0.0)
)

# Normalize geo features (z-score)
item_geo = (item_geo - item_geo.mean()) / (item_geo.std() + 1e-6)
item_geo = item_geo.fillna(0.0).astype("float32")


# ============================================================
# 4.3 CATEGORY MULTI-HOT
# ============================================================
from sklearn.preprocessing import MultiLabelBinarizer

cat_lists = df_train.groupby("i")["category"].apply(
    lambda x: x.iloc[0] if isinstance(x.iloc[0], list) else []
)

mlb = MultiLabelBinarizer()
cat_matrix = mlb.fit_transform(cat_lists)

item_cat = (
    pd.DataFrame(cat_matrix, columns=[f"cat_{c}" for c in mlb.classes_], index=cat_lists.index)
    .reindex(range(num_items), fill_value=0)
    .astype("float32")
)

# categories are already 0/1; normalization optional but safe:
# (optional) item_cat  = (item_cat - item_cat.mean()) / (item_cat.std() + 1e-6)


# ============================================================
# 4.4 NUMERIC STATS (avg_rating, num_of_reviews)
# ============================================================

stats = (
    df_train.groupby("i")[["avg_rating", "num_of_reviews"]]
    .first()
    .reindex(range(num_items))
)

stats["avg_rating"]     = pd.to_numeric(stats["avg_rating"], errors="coerce").fillna(0)
stats["num_of_reviews"] = pd.to_numeric(stats["num_of_reviews"], errors="coerce").fillna(0)

# Normalize ratings: center around 3, scale by 2 → range ≈ [-1, +1]
stats["avg_rating"] = (stats["avg_rating"] - 3.0) / 2.0

# Normalize review counts: log transform for stability
stats["num_of_reviews"] = np.log1p(stats["num_of_reviews"])

# z-score normalize stats
stats = (stats - stats.mean()) / (stats.std() + 1e-6)
item_stats = stats.fillna(0).astype("float32")


# ============================================================
# 4.5 MISC TAG FEATURES
# ============================================================
misc_cols = [c for c in df_train.columns if c.startswith("MISC")]

# Extract vocabulary
misc_labels = set()
for col in misc_cols:
    for entry in df_train[col]:
        if isinstance(entry, list):
            misc_labels.update(entry)

misc_labels = sorted(misc_labels)
misc_label2idx = {tag: idx for idx, tag in enumerate(misc_labels)}

# Initialize (num_items, num_misc_labels)
item_misc_np = np.zeros((num_items, len(misc_labels)), dtype="float32")

# Fill MISC tags
for idx, row in df_train.iterrows():
    i = row["i"]
    if not (0 <= i < num_items):
        continue
    for col in misc_cols:
        tags = row[col]
        if isinstance(tags, list):
            for tag in tags:
                j = misc_label2idx.get(tag)
                if j is not None:
                    item_misc_np[i, j] = 1.0

item_misc = pd.DataFrame(item_misc_np, columns=[f"misc_{m}" for m in misc_labels])

# z-score normalize MISC features
# (optional: many models skip normalization for binary one-hot)
if item_misc.values.sum() > 0:
    item_misc = (item_misc - item_misc.mean()) / (item_misc.std() + 1e-6)
item_misc = item_misc.fillna(0).astype("float32")


# ============================================================
# 4.6 Combine ALL item features safely
# ============================================================
item_features_np = np.hstack([
    item_geo.values.astype("float32"),
    item_cat.values.astype("float32"),
    item_stats.values.astype("float32"),
    item_misc.values.astype("float32"),
])

# Final global normalization for stability
item_features_np = (item_features_np - item_features_np.mean(axis=0)) / \
                   (item_features_np.std(axis=0) + 1e-6)

item_features_np = np.nan_to_num(item_features_np, nan=0.0, posinf=0.0, neginf=0.0)

item_features = torch.tensor(item_features_np, dtype=torch.float32)
print("Final item_features shape:", item_features.shape)


Final item_features shape: torch.Size([2923, 538])


In [64]:
# ============================================================
# Cell 5: Train-only user features (home coord, count, avg rating) and normalize
# ============================================================

user_latlon = (
    df_train.groupby("u")[["latitude", "longitude"]]
    .mean()
    .reindex(range(num_users))
    .fillna(0.0)
)

# normalize latitude/longitude
user_latlon = (user_latlon - user_latlon.mean()) / (user_latlon.std() + 1e-6)
user_latlon = user_latlon.fillna(0)

user_review_counts = (
    df_train.groupby("u")["rating"]
    .count()
    .reindex(range(num_users))
    .fillna(0)
)

# log-normalize review counts
user_review_counts = np.log1p(user_review_counts)

# z-score normalize
user_review_counts = (user_review_counts - user_review_counts.mean()) / (user_review_counts.std() + 1e-6)

user_avg_rating = (
    df_train.groupby("u")["rating"]
    .mean()
    .reindex(range(num_users))
    .fillna(3.0)  # neutral rating
)

# normalize: center at 3, divide by 2
user_avg_rating = (user_avg_rating - 3.0) / 2.0

# z-score
user_avg_rating = (user_avg_rating - user_avg_rating.mean()) / (user_avg_rating.std() + 1e-6)

# combine
user_features_np = np.vstack([
    user_latlon["latitude"].values,
    user_latlon["longitude"].values,
    user_review_counts.values,
    user_avg_rating.values
]).T.astype("float32")

user_features_np = np.nan_to_num(user_features_np, nan=0.0, posinf=0.0, neginf=0.0)

user_features = torch.tensor(user_features_np, dtype=torch.float32)
print("Final user_features shape:", user_features.shape)

Final user_features shape: torch.Size([25841, 4])


Diagonstics

In [42]:
print("Any NaNs in item_features:", torch.isnan(item_features).any().item())
print("Any Infs in item_features:", torch.isinf(item_features).any().item())
print("Item feature max/min:", item_features.max(), item_features.min())

print("\nAny NaNs in user_features:", torch.isnan(user_features).any().item())
print("Any Infs in user_features:", torch.isinf(user_features).any().item())
print("User feature max/min:", user_features.max(), user_features.min())

Any NaNs in item_features: False
Any Infs in item_features: False
Item feature max/min: tensor(54.0580) tensor(-4.7510)

Any NaNs in user_features: False
Any Infs in user_features: False
User feature max/min: tensor(4.7317) tensor(-4.7366)


In [43]:
print("GEO NaNs:",    np.isnan(item_geo.values).any())
print("CAT NaNs:",    np.isnan(item_cat.values).any())
print("STATS NaNs:",  np.isnan(item_stats.values).any())
print("MISC NaNs:",   np.isnan(item_misc.values).any())

GEO NaNs: False
CAT NaNs: False
STATS NaNs: False
MISC NaNs: False


In [44]:
print("GEO shape:", item_geo.shape)
print("CAT shape:", item_cat.shape)
print("STATS shape:", item_stats.shape)
print("MISC shape:", item_misc.shape)
print("Expected num_items:", num_items)

GEO shape: (2923, 2)
CAT shape: (2923, 415)
STATS shape: (2923, 2)
MISC shape: (2923, 119)
Expected num_items: 2923


In [45]:
bad_item_mask = torch.isnan(item_features).any(dim=1) | torch.isinf(item_features).any(dim=1)
bad_items = bad_item_mask.nonzero(as_tuple=True)[0]

print("Bad item count:", len(bad_items))
print("Bad item indices:", bad_items[:20])

Bad item count: 0
Bad item indices: tensor([], dtype=torch.int64)


Baseline Model: LightGCN

In [65]:
# ============================================================
# Cell 6: Train graph (user–item bipartite, symmetric) & helpers
# ============================================================
# User nodes: 0..num_users-1
# Item nodes: num_users..num_users+num_items-1

u_train = torch.tensor(train_df["u"].values, dtype=torch.long)
i_train = torch.tensor(train_df["i"].values + num_users, dtype=torch.long)

# make edges bidirectional
edge_index = torch.stack([
    torch.cat([u_train, i_train]),
    torch.cat([i_train, u_train]),
], dim=0)

num_nodes = num_users + num_items
print("edge_index shape:", edge_index.shape, "num_nodes:", num_nodes)

# Positive items per user for TRAIN (for negative sampling)
user_pos_train = full_df_train.groupby("u")["i"].apply(set).to_dict()

def sample_negative(u: int, k: int = 1):
    """Random negative items for user u, using TRAIN positives only."""
    positives = user_pos_train.get(int(u), set())
    negatives = []
    while len(negatives) < k:
        j = np.random.randint(0, num_items)
        if j not in positives:
            negatives.append(j)
    return negatives


edge_index shape: torch.Size([2, 556776]) num_nodes: 28764


In [66]:
# ============================================================
# DEVICE SETUP
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Move data tensors to correct device
edge_index = edge_index.to(device)
item_features = item_features.to(device)
user_features = user_features.to(device)

Using device: cuda


In [50]:
# ============================================================
# LightGCN (Device-Safe, Feature-Safe)
# ============================================================

class LightGCN(nn.Module):
    def __init__(self, num_users, num_items, user_feat_dim, item_feat_dim,
                 dim=64, n_layers=3):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.dim = dim
        self.n_layers = n_layers

        # Learnable embeddings
        self.user_emb = nn.Embedding(num_users, dim)
        self.item_emb = nn.Embedding(num_items, dim)

        # Feature projections (all trained)
        self.user_feat_proj = nn.Linear(user_feat_dim, dim, bias=False)
        self.item_feat_proj = nn.Linear(item_feat_dim, dim, bias=False)

        nn.init.xavier_uniform_(self.user_emb.weight)
        nn.init.xavier_uniform_(self.item_emb.weight)

    def propagate(self, x, edge_index):
        row, col = edge_index
        deg = torch.bincount(row, minlength=x.size(0)).float()
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        msg = x[col] * norm.unsqueeze(1)

        out = torch.zeros_like(x)
        out.index_add_(0, row, msg)
        return out

    def forward(self, edge_index):
        # Move embeddings + features to correct device
        user_x = self.user_emb.weight + self.user_feat_proj(user_features)
        item_x = self.item_emb.weight + self.item_feat_proj(item_features)

        x = torch.cat([user_x, item_x], dim=0)
        all_layers = [x]

        for _ in range(self.n_layers):
            x = self.propagate(x, edge_index)
            all_layers.append(x)

        x_final = torch.stack(all_layers, dim=1).mean(dim=1)

        user_gcn = x_final[:self.num_users]
        item_gcn = x_final[self.num_users:]
        return user_gcn, item_gcn

    def predict(self, u, i, user_gcn, item_gcn):
        return (user_gcn[u] * item_gcn[i]).sum(dim=-1)

In [51]:
# ============================================================
# Cell 8: BPR loss & training utilities (with tqdm)
# ============================================================
from tqdm.auto import tqdm
import wandb

def bpr_loss(model, edge_index, batch_users, batch_pos_items):
    user_gcn, item_gcn = model(edge_index)

    neg_items = torch.tensor(
        [sample_negative(int(u))[0] for u in batch_users],
        dtype=torch.long,
        device=device,
    )

    pos_scores = model.predict(batch_users, batch_pos_items, user_gcn, item_gcn)
    neg_scores = model.predict(batch_users, neg_items, user_gcn, item_gcn)

    return -torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()


 def train_bpr_model(model, edge_index, train_df, epochs=10, batch_size=1024, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_users = torch.tensor(train_df["u"].values, dtype=torch.long, device=device)
    train_items = torch.tensor(train_df["i"].values, dtype=torch.long, device=device)

    epoch_losses, batch_losses = [], []

    for epoch in range(epochs):
        perm = torch.randperm(len(train_users), device=device)
        epoch_loss = 0.0

        pbar = tqdm(perm.split(batch_size), desc=f"LightGCN Epoch {epoch+1}/{epochs}")

        for step, idx in enumerate(pbar):
            batch_u = train_users[idx]
            batch_i = train_items[idx]

            loss = bpr_loss(model, edge_index, batch_u, batch_i)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            bl = loss.item()
            batch_losses.append(bl)
            epoch_loss += bl

            pbar.set_postfix({"batch_loss": f"{bl:.4f}", "epoch_loss": f"{epoch_loss:.4f}"})

        epoch_losses.append(epoch_loss)
        print(f"[LightGCN] Epoch {epoch+1}: total_loss={epoch_loss:.4f}")

    return model, epoch_losses, batch_losses


In [53]:
import wandb
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mleeij[0m ([33mleeisabella[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [54]:
# ============================================================
# Cell 9: Train baseline LightGCN with W&B tracking
# ============================================================
import wandb

wandb.init(
    project="restaurant-recsys",
    name="LightGCN-normalized",
    config={
        "model": "LightGCN",
        "embedding_dim": dim,
        "layers": layers,
        "epochs": 5,
        "batch_size": 2048,
        "learning_rate": 1e-3,
        "num_users": num_users,
        "num_items": num_items,
        "item_feature_dim": item_features.shape[1],
        "user_feature_dim": user_features.shape[1],
    }
)

lgcn, epoch_losses, batch_losses = train_bpr_model(
    lgcn,
    edge_index,
    train_df,
    epochs=5,
    batch_size=2048,
    lr=1e-3
)

wandb.finish()

Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCN] Epoch 1: total_loss = 28.6878


Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCN] Epoch 2: total_loss = 25.8112


Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCN] Epoch 3: total_loss = 23.8905


Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCN] Epoch 4: total_loss = 22.2806


Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCN] Epoch 5: total_loss = 21.0119


0,1
batch_loss,█▆▆▇▆▅▆▆▅▆▆▅▅▃▆▄▅▅▅▄▄▄▄▅▃▂▃▂▂▁▂▁▂▃▁▃▄▂▁▃
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆█████
epoch_loss,█▅▄▂▁
step,▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇██

0,1
batch_loss,0.17154
epoch,5.0
epoch_loss,21.01192
step,679.0


In [55]:
# ============================================================
# Cell 10: Evaluation metrics and generic evaluator
# ============================================================
def recall_at_k(target, ranked_list, k):
    return int(target in ranked_list[:k])

def ndcg_at_k(target, ranked_list, k):
    if target in ranked_list[:k]:
        rank = ranked_list.index(target) + 1
        return 1.0 / np.log2(rank + 1)
    return 0.0

def mrr_metric(target, ranked_list):
    if target in ranked_list:
        return 1.0 / (ranked_list.index(target) + 1)
    return 0.0

def evaluate_lgcn_like(model, edge_index, eval_df, k=20):
    model.eval()
    with torch.no_grad():
        user_gcn, item_gcn = model(edge_index)

    recalls, ndcgs, mrrs = [], [], []

    for row in tqdm(eval_df.itertuples(), total=len(eval_df)):
        u = int(row.u)
        true_i = int(row.i)

        scores = (user_gcn[u] * item_gcn).sum(dim=1).cpu().numpy()
        ranked_items = list(np.argsort(-scores))

        recalls.append(recall_at_k(true_i, ranked_items, k))
        ndcgs.append(ndcg_at_k(true_i, ranked_items, k))
        mrrs.append(mrr_metric(true_i, ranked_items))

    return {
        "Recall@{}".format(k): float(np.mean(recalls)),
        "NDCG@{}".format(k): float(np.mean(ndcgs)),
        "MRR": float(np.mean(mrrs)),
    }

print("Baseline LightGCN on validation:")
metrics = evaluate_lgcn_like(lgcn, edge_index, val_df, k=20)
wandb.log(metrics)
metrics


Baseline LightGCN on validation:


  0%|          | 0/34798 [00:00<?, ?it/s]

{'Recall@20': 0.05612391516753836,
 'NDCG@20': 0.020208141149234488,
 'MRR': 0.01511325982326465}

Now we move to extended models!

LightGCL

In [95]:
# ============================================================
# LightGCL: SVD Global View + Contrastive Setup (Train-only)
# ============================================================

from scipy.sparse import coo_matrix
from scipy.sparse.linalg import svds
import torch.nn.functional as F

# --- 1. Build implicit matrix from TRAIN ONLY ---
rows = train_df["u"].values
cols = train_df["i"].values
data = np.ones_like(rows, dtype=np.float32)

M_train = coo_matrix((data, (rows, cols)), shape=(num_users, num_items))
print("M_train shape:", M_train.shape)

# --- 2. SVD rank MUST match LightGCN embedding dimension ---
# dim is the embedding dimension you used for LightGCN
k_svd = dim                       # <<< IMPORTANT: match model dimension >>>
k_svd = min(k_svd, min(num_users, num_items) - 1)

U, s, Vt = svds(M_train, k=k_svd)
idx = np.argsort(-s)
U, s, Vt = U[:, idx], s[idx], Vt[idx, :]

Sigma_sqrt = np.sqrt(s)

P_svd = U * Sigma_sqrt          # (num_users, k_svd)
Q_svd = (Vt.T) * Sigma_sqrt     # (num_items, k_svd)

svd_user_emb = torch.tensor(P_svd, dtype=torch.float32, device=device)
svd_item_emb = torch.tensor(Q_svd, dtype=torch.float32, device=device)

print("svd_user_emb:", svd_user_emb.shape)
print("svd_item_emb:", svd_item_emb.shape)

# --- 3. Normalize & detach SVD embeddings for stability ---
svd_u_norm = F.normalize(svd_user_emb, dim=-1).detach()
svd_i_norm = F.normalize(svd_item_emb, dim=-1).detach()

# --- 4. LightGCL hyperparameters (will be tuned later) ---
tau         = 0.03    # temperature (by tuning)
lambda_con  = 0.1    # contrastive weight (by tuning)
lambda_reg  = 1e-4   # L2 weight (on embeddings)


def contrastive_loss(z, z_tilde, tau=0.2):
    """
    InfoNCE loss: z, z_tilde are (N, d) with aligned rows.
    """
    z = F.normalize(z, dim=-1)
    z_tilde = F.normalize(z_tilde, dim=-1)

    logits = z @ z_tilde.T / tau    # (N, N)
    labels = torch.arange(z.size(0), device=z.device)
    return F.cross_entropy(logits, labels)

M_train shape: (25841, 2923)
svd_user_emb: torch.Size([25841, 64])
svd_item_emb: torch.Size([2923, 64])


In [96]:
# ============================================================
# LightGCL: Step function (BPR + Contrastive)
# ============================================================

def lightgcl_step(model, edge_index, batch_users, batch_pos_items):
    # GCN forward
    user_gcn, item_gcn = model(edge_index)

    # Base BPR loss (exactly as LightGCN)
    loss_bpr = bpr_loss(model, edge_index, batch_users, batch_pos_items)

    # Per-user alignment
    u_embed = user_gcn[batch_users]          # (B, d)
    u_svd   = svd_u_norm[batch_users]        # (B, d)

    # Per-item alignment
    i_embed = item_gcn[batch_pos_items]      # (B, d)
    i_svd   = svd_i_norm[batch_pos_items]    # (B, d)

    L_con_u = contrastive_loss(u_embed, u_svd, tau=tau)
    L_con_i = contrastive_loss(i_embed, i_svd, tau=tau)

    # L2 regularization on learnable embeddings
    reg = 0.5 * (
        model.user_emb.weight.norm(2)**2 +
        model.item_emb.weight.norm(2)**2
    )

    return loss_bpr + lambda_con * (L_con_u + L_con_i) + lambda_reg * reg


# ============================================================
# LightGCL: Training loop (no wandb inside)
# ============================================================

def train_lightgcl(model, edge_index, train_df,
                   epochs=5, batch_size=1024, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_users = torch.tensor(train_df["u"].values, dtype=torch.long, device=device)
    train_items = torch.tensor(train_df["i"].values, dtype=torch.long, device=device)

    epoch_losses, batch_losses = [], []

    for epoch in range(epochs):
        perm = torch.randperm(len(train_users), device=device)
        epoch_loss = 0.0

        pbar = tqdm(perm.split(batch_size), desc=f"LightGCL Epoch {epoch+1}/{epochs}")

        for step, idx in enumerate(pbar):
            batch_u = train_users[idx]
            batch_i = train_items[idx]

            loss = lightgcl_step(model, edge_index, batch_u, batch_i)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            bl = loss.item()
            batch_losses.append(bl)
            epoch_loss += bl

            pbar.set_postfix({
                "batch_loss": f"{bl:.4f}",
                "epoch_loss": f"{epoch_loss:.4f}"
            })

        epoch_losses.append(epoch_loss)
        print(f"[LightGCL] Epoch {epoch+1}: total_loss={epoch_loss:.4f}")

    return model, epoch_losses, batch_losses

In [97]:
# ============================================================
# Train LightGCL with W&B Tracking
# ============================================================

wandb.init(
    project="restaurant-recsys",
    name="LightGCL",
    config={
        "model": "LightGCL",
        "epochs": 5,
        "batch_size": 2048,
        "lr": 1e-3,
        "embedding_dim": dim,
        "layers": layers
    }
)

lgcl = LightGCN(
    num_users=num_users,
    num_items=num_items,
    user_feat_dim=user_features.shape[1],
    item_feat_dim=item_features.shape[1],
    dim=dim,
    n_layers=layers,
).to(device)

lgcl, lgcl_epoch_losses, lgcl_batch_losses = train_lightgcl(
    lgcl, edge_index, train_df,
    epochs=5, batch_size=2048, lr=1e-3
)

wandb.finish()

LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=227.9159


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=121.7744


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=98.0652


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=88.0762


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=82.2510


In [94]:
# ============================================================
# LightGCL Hyperparameter Tuning (grid search)
# ============================================================

import pandas as pd
import itertools

# define search grid
tau_grid        = [0.03, 0.05, 0.1]
lambda_con_grid = [0.03, 0.05, 0.1]
lambda_reg_val  = 1e-4

results = []

for tau_val, lambda_con_val in itertools.product(tau_grid, lambda_con_grid):
    print(f"\n=== Training LightGCL with tau={tau_val}, lambda_con={lambda_con_val} ===")

    # set globals used by lightgcl_step
    tau        = tau_val
    lambda_con = lambda_con_val
    lambda_reg = lambda_reg_val

    # (re)initialize model
    model = LightGCN(
        num_users=num_users,
        num_items=num_items,
        user_feat_dim=user_features.shape[1],
        item_feat_dim=item_features.shape[1],
        dim=dim,
        n_layers=layers,
    ).to(device)

    # train model
    model, epoch_losses, batch_losses = train_lightgcl(
        model, edge_index, train_df,
        epochs=5, batch_size=2048, lr=1e-3
    )

    # evaluate on validation set
    metrics_val = evaluate_lgcn_like(model, edge_index, val_df, k=20)

    run_result = {
        "tau": tau_val,
        "lambda_con": lambda_con_val,
        "Recall@20_val": metrics_val["Recall@20"],
        "NDCG@20_val": metrics_val["NDCG@20"],
        "MRR_val": metrics_val["MRR"],
    }
    results.append(run_result)

# build comparison table
lightgcl_tuning_df = pd.DataFrame(results)
display(lightgcl_tuning_df.sort_values("Recall@20_val", ascending=False))


=== Training LightGCL with tau=0.03, lambda_con=0.03 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=109.5646


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=72.2826


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=59.9697


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=53.2687


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=48.7914


  0%|          | 0/34798 [00:00<?, ?it/s]


=== Training LightGCL with tau=0.03, lambda_con=0.05 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=142.0616


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=86.3093


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=71.4835


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=63.7990


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=58.9265


  0%|          | 0/34798 [00:00<?, ?it/s]


=== Training LightGCL with tau=0.03, lambda_con=0.1 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=229.8677


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=122.3808


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=98.3393


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=87.8228


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=81.9738


  0%|          | 0/34798 [00:00<?, ?it/s]


=== Training LightGCL with tau=0.05, lambda_con=0.03 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=102.5536


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=73.2776


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=62.4268


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=56.1279


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=52.1921


  0%|          | 0/34798 [00:00<?, ?it/s]


=== Training LightGCL with tau=0.05, lambda_con=0.05 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=135.3748


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=91.5604


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=77.1370


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=69.4061


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=64.5469


  0%|          | 0/34798 [00:00<?, ?it/s]


=== Training LightGCL with tau=0.05, lambda_con=0.1 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=208.7431


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=127.9420


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=106.8830


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=98.5313


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=93.2357


  0%|          | 0/34798 [00:00<?, ?it/s]


=== Training LightGCL with tau=0.1, lambda_con=0.03 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=104.8561


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=83.0844


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=73.6698


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=68.0919


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=64.1894


  0%|          | 0/34798 [00:00<?, ?it/s]


=== Training LightGCL with tau=0.1, lambda_con=0.05 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=138.2165


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=107.2988


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=95.1586


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=88.2619


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=83.7568


  0%|          | 0/34798 [00:00<?, ?it/s]


=== Training LightGCL with tau=0.1, lambda_con=0.1 ===


LightGCL Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 1: total_loss=216.7168


LightGCL Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 2: total_loss=162.2701


LightGCL Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 3: total_loss=143.4396


LightGCL Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 4: total_loss=134.9018


LightGCL Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL] Epoch 5: total_loss=129.8106


  0%|          | 0/34798 [00:00<?, ?it/s]

Unnamed: 0,tau,lambda_con,Recall@20_val,NDCG@20_val,MRR_val
2,0.03,0.1,0.062475,0.022148,0.016099
5,0.05,0.1,0.059831,0.021619,0.016071
1,0.03,0.05,0.058768,0.020948,0.015437
0,0.03,0.03,0.058653,0.021207,0.015703
8,0.1,0.1,0.0569,0.020328,0.015284
3,0.05,0.03,0.05598,0.020099,0.01498
6,0.1,0.03,0.055118,0.019901,0.014938
4,0.05,0.05,0.055003,0.019647,0.014686
7,0.1,0.05,0.053997,0.019596,0.014872


In [98]:
# ============================================================
# Evaluate LightGCL on validation set + W&B logging
# ============================================================

import wandb

wandb.init(
    project="restaurant-recsys",
    name="LightGCL-eval",
    config={
        "stage": "evaluation",
        "model": "LightGCL"
    }
)

print("Evaluating LightGCL on validation set...\n")

metrics_lgcl_val = evaluate_lgcn_like(lgcl, edge_index, val_df, k=20)

wandb.log({
    "LightGCL/Recall@20_val": metrics_lgcl_val["Recall@20"],
    "LightGCL/NDCG@20_val":  metrics_lgcl_val["NDCG@20"],
    "LightGCL/MRR_val":      metrics_lgcl_val["MRR"],
})

print("LightGCL Validation Results:")
print(metrics_lgcl_val)

wandb.finish()

Evaluating LightGCL on validation set...



  0%|          | 0/34798 [00:00<?, ?it/s]

LightGCL Validation Results:
{'Recall@20': 0.06247485487671705, 'NDCG@20': 0.022290753481703555, 'MRR': 0.016415805437697842}


0,1
LightGCL/MRR_val,▁
LightGCL/NDCG@20_val,▁
LightGCL/Recall@20_val,▁

0,1
LightGCL/MRR_val,0.01642
LightGCL/NDCG@20_val,0.02229
LightGCL/Recall@20_val,0.06247


LightGCL+Geo

In [99]:
# ============================================================
# Geo Precomputation for LightGCL+Geo
# ============================================================

from sklearn.metrics import pairwise_distances

# Raw coords from TRAIN (not normalized features)
df_train_geo = full_df_train.copy()
df_train_geo["latitude"]  = pd.to_numeric(df_train_geo["latitude"],  errors="coerce")
df_train_geo["longitude"] = pd.to_numeric(df_train_geo["longitude"], errors="coerce")

# Item coordinates (num_items, 2)
item_coord_df = (
    df_train_geo.groupby("i")[["latitude", "longitude"]]
    .mean()
    .reindex(range(num_items))
    .fillna(0.0)
)
item_coord = item_coord_df.values

# User "home" coordinates (num_users, 2)
user_coord_df = (
    df_train_geo.groupby("u")[["latitude", "longitude"]]
    .mean()
    .reindex(range(num_users))
    .fillna(0.0)
)
user_coord = user_coord_df.values

def geo_distance(u, i):
    du = user_coord[int(u)]
    di = item_coord[int(i)]
    return float(np.linalg.norm(du - di))

# Radius-aware negatives
R = 0.1        # tune as needed
alpha = 0.75   # popularity exponent

item_pop = (
    full_df_train["i"].value_counts()
    .reindex(range(num_items))
    .fillna(0)
    .values
)

user_geo_candidates = {}
for u in range(num_users):
    dists = np.linalg.norm(item_coord - user_coord[u], axis=1)
    mask = dists <= R
    cand = np.where(mask)[0].tolist()
    cand = [j for j in cand if j not in user_pos_train.get(u, [])]
    if not cand:
        cand = [j for j in range(num_items) if j not in user_pos_train.get(u, [])]
    user_geo_candidates[u] = np.array(cand, dtype=int)

def sample_geo_negative(u, k=1):
    cand = user_geo_candidates[int(u)]
    weights = item_pop[cand] ** alpha
    if weights.sum() > 0:
        probs = weights / weights.sum()
    else:
        probs = None
    return np.random.choice(cand, size=k, p=probs, replace=True)

# Geo-smooth neighbors
k_nn = 10
rho_r = 0.1

D_items = pairwise_distances(item_coord)
item_neighbors = []
item_neighbor_weights = []

for i in range(num_items):
    idx = np.argsort(D_items[i])[:k_nn + 1]
    idx = idx[1:]
    item_neighbors.append(idx)
    w = np.exp(-D_items[i, idx] / rho_r)
    item_neighbor_weights.append(w)

def geo_smooth_loss(item_gcn):
    loss = 0.0
    for i in range(num_items):
        neigh = item_neighbors[i]
        if len(neigh) == 0:
            continue
        qi = item_gcn[i]
        qj = item_gcn[neigh]
        w = torch.tensor(
            item_neighbor_weights[i], dtype=torch.float32, device=item_gcn.device
        )
        diff = qi.unsqueeze(0) - qj
        loss += (w * (diff ** 2).sum(dim=1)).sum()
    return loss / num_items

beta = 0.1
rho = 0.1  # decay scale for distance

def bpr_geo_terms(model, edge_index, batch_users, batch_pos_items):
    user_gcn, item_gcn = model(edge_index)

    # radius-aware negatives
    neg_items_np = np.concatenate(
        [sample_geo_negative(u.item(), 1) for u in batch_users]
    )
    neg_items = torch.tensor(
        neg_items_np, dtype=torch.long, device=batch_users.device
    )

    pos_base = (user_gcn[batch_users] * item_gcn[batch_pos_items]).sum(dim=-1)
    neg_base = (user_gcn[batch_users] * item_gcn[neg_items]).sum(dim=-1)

    pos_d = torch.tensor(
        [geo_distance(u, i) for u, i in zip(batch_users, batch_pos_items)],
        dtype=torch.float32,
        device=pos_base.device,
    )
    neg_d = torch.tensor(
        [geo_distance(u, j) for u, j in zip(batch_users, neg_items)],
        dtype=torch.float32,
        device=neg_base.device,
    )

    pos_score = pos_base + beta * torch.exp(-pos_d / rho)
    neg_score = neg_base + beta * torch.exp(-neg_d / rho)

    loss_bpr_geo = -torch.log(torch.sigmoid(pos_score - neg_score)).mean()
    loss_geo_smooth = geo_smooth_loss(item_gcn)

    return loss_bpr_geo, loss_geo_smooth, user_gcn, item_gcn

In [100]:
# ============================================================
# LightGCL+Geo Step (BPR_geo + Contrastive + Smoothness)
# ============================================================

def lightgcl_geo_step(model, edge_index, batch_users, batch_pos_items):
    # Geo-aware BPR part
    loss_bpr_geo, loss_geo_smooth, user_gcn, item_gcn = bpr_geo_terms(
        model, edge_index, batch_users, batch_pos_items
    )

    # Contrastive alignment (same as LightGCL)
    u_embed = user_gcn[batch_users]
    u_svd   = svd_u_norm[batch_users]

    i_embed = item_gcn[batch_pos_items]
    i_svd   = svd_i_norm[batch_pos_items]

    L_con_u = contrastive_loss(u_embed, u_svd, tau=tau)
    L_con_i = contrastive_loss(i_embed, i_svd, tau=tau)

    reg = 0.5 * (
        model.user_emb.weight.norm(2)**2 +
        model.item_emb.weight.norm(2)**2
    )

    return (
        loss_bpr_geo +
        lambda_con*(L_con_u + L_con_i) +
        mu_geo * loss_geo_smooth +
        lambda_reg * reg
    )


# ============================================================
# LightGCL+Geo Training Loop
# ============================================================

def train_lightgcl_geo(model, edge_index, train_df, epochs=5, batch_size=1024, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_users = torch.tensor(train_df["u"].values, dtype=torch.long, device=device)
    train_items = torch.tensor(train_df["i"].values, dtype=torch.long, device=device)

    epoch_losses, batch_losses = [], []

    for epoch in range(epochs):

        perm = torch.randperm(len(train_users), device=device)
        epoch_loss = 0.0

        pbar = tqdm(perm.split(batch_size), desc=f"LightGCL+Geo Epoch {epoch+1}/{epochs}")

        for step, idx in enumerate(pbar):
            batch_u = train_users[idx]
            batch_i = train_items[idx]

            loss = lightgcl_geo_step(model, edge_index, batch_u, batch_i)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            bl = loss.item()
            batch_losses.append(bl)
            epoch_loss += bl

            pbar.set_postfix({"batch_loss": f"{bl:.4f}", "epoch_loss": f"{epoch_loss:.4f}"})

            wandb.log({
                "LightGCL_Geo/batch_loss": bl,
                "LightGCL_Geo/epoch": epoch + 1
            })

        epoch_losses.append(epoch_loss)
        wandb.log({"LightGCL_Geo/epoch_loss": epoch_loss})

        print(f"[LightGCL+Geo] Epoch {epoch+1}: total_loss={epoch_loss:.4f}")

    return model, epoch_losses, batch_losses

In [102]:
# ============================================================
# Train LightGCL+Geo with W&B Tracking
# ============================================================

import wandb

wandb.init(
    project="restaurant-recsys",
    name="LightGCL_Geo",
    config={
        "model": "LightGCL_Geo",
        "embedding_dim": dim,
        "layers": layers,
        "epochs": 5,
        "batch_size": 2048,
        "learning_rate": 1e-3,
        "tau": tau,
        "lambda_con": lambda_con,
        "lambda_reg": lambda_reg,
        "mu_geo": mu_geo,
    }
)

# initialize model
lgcl_geo = LightGCN(
    num_users=num_users,
    num_items=num_items,
    user_feat_dim=user_features.shape[1],
    item_feat_dim=item_features.shape[1],
    dim=dim,
    n_layers=layers,
).to(device)

# train model
lgcl_geo, lgcl_geo_epoch_losses, lgcl_geo_batch_losses = train_lightgcl_geo(
    lgcl_geo,
    edge_index,
    train_df,
    epochs=5,
    batch_size=2048,
    lr=1e-3
)

wandb.finish()

LightGCL+Geo Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 1: total_loss=242.4651


LightGCL+Geo Epoch 2/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 2: total_loss=131.9122


LightGCL+Geo Epoch 3/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 3: total_loss=110.5099


LightGCL+Geo Epoch 4/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 4: total_loss=101.6959


LightGCL+Geo Epoch 5/5:   0%|          | 0/136 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 5: total_loss=96.8846


0,1
LightGCL_Geo/batch_loss,█▅▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
LightGCL_Geo/epoch,▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆█████
LightGCL_Geo/epoch_loss,█▃▂▁▁

0,1
LightGCL_Geo/batch_loss,0.68576
LightGCL_Geo/epoch,5.0
LightGCL_Geo/epoch_loss,96.88462


In [104]:
# ============================================================
# Evaluate LightGCL+Geo (lgcl_geo) + W&B logging
# ============================================================

import wandb

wandb.init(
    project="restaurant-recsys",
    name="LightGCLgeo-eval",
    config={
        "stage": "evaluation",
        "model": "LightGCL_Geo"
    }
)

print("Evaluating LightGCL+Geo on validation set...\n")

metrics_lgcl_geo_val = evaluate_lgcn_like(lgcl_geo, edge_index, val_df, k=20)

wandb.log({
    "LightGCL_Geo/Recall@20_val": metrics_lgcl_geo_val["Recall@20"],
    "LightGCL_Geo/NDCG@20_val":  metrics_lgcl_geo_val["NDCG@20"],
    "LightGCL_Geo/MRR_val":      metrics_lgcl_geo_val["MRR"],
})

print("LightGCL+Geo Validation Results:")
print(metrics_lgcl_geo_val)

Evaluating LightGCL+Geo on validation set...



  0%|          | 0/34798 [00:00<?, ?it/s]

LightGCL+Geo Validation Results:
{'Recall@20': 0.057359618368871775, 'NDCG@20': 0.021662990584033513, 'MRR': 0.01639937582505021}


In [None]:
# ============================================================
# Model Comparison Table (Auto-Logged)
# ============================================================

comparison_df = pd.DataFrame({
    "LightGCL":      metrics_lgcl_val,
    "LightGCL+Geo":  metrics_lgcl_geo_val,
}).T  # transpose so rows = models

print("\nModel Comparison Table:")
print(comparison_df)

# Log to W&B as a table
wandb.log({"model_comparison": wandb.Table(dataframe=comparison_df)})

In [None]:
# Hyperparam tuning

In [105]:
# ============================================================
# W&B Sweep: LightGCL+Geo Hyperparameter Search
# ============================================================

import wandb

sweep_config = {
    "method": "bayes",   # or "grid", "random"
    "metric": {
        "name": "LightGCL_Geo/Recall@20_val",
        "goal": "maximize"
    },
    "parameters": {
        "tau": {
            "values": [0.01, 0.05, 0.1, 0.2]
        },
        "lambda_con": {
            "values": [0.05, 0.1, 0.2]
        },
        "mu_geo": {
            "values": [1e-5, 1e-4, 5e-4]
        },
        "beta": {
            "values": [0.05, 0.1, 0.2]
        },
        "rho": {
            "values": [0.05, 0.1, 0.2]
        },
        "learning_rate": {
            "values": [1e-3, 5e-4]
        },
        "batch_size": {
            "values": [1024, 2048]
        }
    }
}

sweep_id = wandb.sweep(sweep_config, project="restaurant-recsys")
print("Sweep ID:", sweep_id)

Create sweep with ID: 194m6z4i
Sweep URL: https://wandb.ai/leeisabella/restaurant-recsys/sweeps/194m6z4i
Sweep ID: 194m6z4i


In [106]:
# ============================================================
# Sweep Training Function for LightGCL+Geo
# ============================================================

def sweep_train_lightgcl_geo():
    wandb.init()
    cfg = wandb.config

    # set global hyperparameters
    global tau, lambda_con, mu_geo, beta, rho
    tau = cfg.tau
    lambda_con = cfg.lambda_con
    mu_geo = cfg.mu_geo
    beta = cfg.beta
    rho = cfg.rho

    # Initialize fresh model
    model = LightGCN(
        num_users=num_users,
        num_items=num_items,
        user_feat_dim=user_features.shape[1],
        item_feat_dim=item_features.shape[1],
        dim=dim,
        n_layers=layers,
    ).to(device)

    # Train model
    model, epoch_losses, batch_losses = train_lightgcl_geo(
        model,
        edge_index,
        train_df,
        epochs=5,
        batch_size=cfg.batch_size,
        lr=cfg.learning_rate
    )

    # Validation evaluation
    metrics_val = evaluate_lgcn_like(model, edge_index, val_df, k=20)

    wandb.log({
        "LightGCL_Geo/Recall@20_val": metrics_val["Recall@20"],
        "LightGCL_Geo/NDCG@20_val":  metrics_val["NDCG@20"],
        "LightGCL_Geo/MRR_val":      metrics_val["MRR"]
    })

    wandb.finish()

In [None]:
wandb.agent(sweep_id, function=sweep_train_lightgcl_geo, count=20)

[34m[1mwandb[0m: Agent Starting Run: ebtty0fq with config:
[34m[1mwandb[0m: 	batch_size: 1024
[34m[1mwandb[0m: 	beta: 0.05
[34m[1mwandb[0m: 	lambda_con: 0.1
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	mu_geo: 0.0005
[34m[1mwandb[0m: 	rho: 0.1
[34m[1mwandb[0m: 	tau: 0.05


LightGCL+Geo Epoch 1/5:   0%|          | 0/272 [00:00<?, ?it/s]

Exception in thread ChkStopThr:
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/wandb/sdk/wandb_run.py", line 309, in check_stop_status
    self._loop_check_status(
  File "/usr/local/lib/python3.12/dist-packages/wandb/sdk/wandb_run.py", line 237, in _loop_check_status
    local_handle = request()
                   ^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/wandb/sdk/interface/interface.py", line 985, in deliver_stop_status
    return self._deliver_stop_status(status)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/wandb/sdk/interface/interface_shared.py", line 480, in _deliver_stop_status
    return self._deliver(record)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packag

[LightGCL+Geo] Epoch 1: total_loss=356.2547


LightGCL+Geo Epoch 2/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 2: total_loss=229.7992


LightGCL+Geo Epoch 3/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 3: total_loss=215.9876


LightGCL+Geo Epoch 4/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 4: total_loss=212.1434


LightGCL+Geo Epoch 5/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 5: total_loss=210.6144


  0%|          | 0/34798 [00:00<?, ?it/s]

0,1
LightGCL_Geo/MRR_val,▁
LightGCL_Geo/NDCG@20_val,▁
LightGCL_Geo/Recall@20_val,▁
LightGCL_Geo/batch_loss,█▇▆▄▄▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
LightGCL_Geo/epoch,▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆█████
LightGCL_Geo/epoch_loss,█▂▁▁▁

0,1
LightGCL_Geo/MRR_val,0.01603
LightGCL_Geo/NDCG@20_val,0.02216
LightGCL_Geo/Recall@20_val,0.06147
LightGCL_Geo/batch_loss,0.75551
LightGCL_Geo/epoch,5.0
LightGCL_Geo/epoch_loss,210.61438


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 055ljwr4 with config:
[34m[1mwandb[0m: 	batch_size: 1024
[34m[1mwandb[0m: 	beta: 0.2
[34m[1mwandb[0m: 	lambda_con: 0.05
[34m[1mwandb[0m: 	learning_rate: 0.0005
[34m[1mwandb[0m: 	mu_geo: 0.0005
[34m[1mwandb[0m: 	rho: 0.2
[34m[1mwandb[0m: 	tau: 0.05


LightGCL+Geo Epoch 1/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 1: total_loss=316.5239


LightGCL+Geo Epoch 2/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 2: total_loss=225.8782


LightGCL+Geo Epoch 3/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 3: total_loss=201.2762


LightGCL+Geo Epoch 4/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 4: total_loss=191.2918


LightGCL+Geo Epoch 5/5:   0%|          | 0/272 [00:00<?, ?it/s]

[LightGCL+Geo] Epoch 5: total_loss=186.2520


  0%|          | 0/34798 [00:00<?, ?it/s]

0,1
LightGCL_Geo/MRR_val,▁
LightGCL_Geo/NDCG@20_val,▁
LightGCL_Geo/Recall@20_val,▁
LightGCL_Geo/batch_loss,█▆▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
LightGCL_Geo/epoch,▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆█████
LightGCL_Geo/epoch_loss,█▃▂▁▁

0,1
LightGCL_Geo/MRR_val,0.01724
LightGCL_Geo/NDCG@20_val,0.02321
LightGCL_Geo/Recall@20_val,0.06242
LightGCL_Geo/batch_loss,0.64399
LightGCL_Geo/epoch,5.0
LightGCL_Geo/epoch_loss,186.25201


[34m[1mwandb[0m: Agent Starting Run: or4kpee3 with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	beta: 0.1
[34m[1mwandb[0m: 	lambda_con: 0.2
[34m[1mwandb[0m: 	learning_rate: 0.0005
[34m[1mwandb[0m: 	mu_geo: 0.0001
[34m[1mwandb[0m: 	rho: 0.05
[34m[1mwandb[0m: 	tau: 0.05


LightGCL+Geo Epoch 1/5:   0%|          | 0/136 [00:00<?, ?it/s]

Contrasting Models