In [27]:
!pip install torch-geometric
!pip install torch-scatter
!pip install torch-cluster



In [2]:
import math
from typing import List, Tuple, Optional

import torch
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
from torch_scatter import scatter
from torch_cluster import random_walk


In [3]:
import pandas as pd
import torch
from torch_geometric.utils import to_undirected
from sklearn.preprocessing import LabelEncoder

In [37]:
!wget http://files.grouplens.org/datasets/movielens/ml-100k.zip
!unzip -q ml-100k.zip


--2025-04-13 20:06:10--  http://files.grouplens.org/datasets/movielens/ml-100k.zip
Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152
Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4924029 (4.7M) [application/zip]
Saving to: ‘ml-100k.zip.2’


2025-04-13 20:06:10 (12.0 MB/s) - ‘ml-100k.zip.2’ saved [4924029/4924029]

replace ml-100k/allbut.pl? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace ml-100k/mku.sh? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace ml-100k/README? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace ml-100k/u.data? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace ml-100k/u.genre? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace ml-100k/u.info? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace ml-100k/u.item? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace ml-100k/u.occupation? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace ml-100k/u.user? [y]es, [n]o, [A]ll, 

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


In [5]:
import pandas as pd
import torch
from torch_geometric.data import HeteroData
from torch_geometric.utils import to_undirected
from sklearn.preprocessing import LabelEncoder

def load_movielens_100k(path='ml-100k/u.data', user_path='ml-100k/u.user', item_path='ml-100k/u.item'):
    import pandas as pd
    from sklearn.preprocessing import LabelEncoder
    import torch
    from torch_geometric.data import HeteroData

    # ---- Load user features ----
    user_cols = ['user_id', 'age', 'gender', 'occupation', 'zip']
    user_df = pd.read_csv(user_path, sep='|', names=user_cols, encoding='latin-1')
    # Convert categorical features if needed (e.g., gender)
    le = LabelEncoder()
    user_df['gender'] = le.fit_transform(user_df['gender'])
    # For example, use age and gender as features
    user_features = user_df[['age', 'gender']]
    x_user = torch.tensor(user_features.values, dtype=torch.float)

    # ---- Load item features ----
    item_cols = ['item_id', 'title', 'release_date', 'video_release_date', 'IMDb_URL',
                 'unknown', 'Action', 'Adventure', 'Animation', 'Children', 'Comedy',
                 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical',
                 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western']
    item_df = pd.read_csv(item_path, sep='|', names=item_cols, encoding='latin-1')
    genre_cols = item_cols[5:]  # use genre flags as features
    x_item = torch.tensor(item_df[genre_cols].values, dtype=torch.float)

    # ---- Create edge_index from ratings ----
    ratings_df = pd.read_csv(path, sep='\t', names=['user', 'item', 'rating', 'timestamp'])
    # Subtract 1 to convert from 1-indexed to 0-indexed
    user_tensor = torch.tensor(ratings_df['user'].values, dtype=torch.long) - 1
    item_tensor = torch.tensor(ratings_df['item'].values, dtype=torch.long) - 1 + x_user.size(0)
    edge_index = torch.stack([user_tensor, item_tensor], dim=0)

    # Number of users and items
    num_users = x_user.size(0)
    num_items = x_item.size(0)

    # ---------- make feature dims equal ----------
    feat_dim = max(x_user.size(1), x_item.size(1))
    if x_user.size(1) < feat_dim:
        x_user = torch.nn.functional.pad(x_user, (0, feat_dim - x_user.size(1)), "constant", 0)
    if x_item.size(1) < feat_dim:
        x_item = torch.nn.functional.pad(x_item, (0, feat_dim - x_item.size(1)), "constant", 0)

    # Create the HeteroData object
    data = HeteroData()
    data["user"].x = x_user
    data["item"].x = x_item
    data["user", "rates", "item"].edge_index = edge_index

    return data, num_users, num_items


In [6]:
######################################################################
# Utility: importance‑based neighbor sampling via random walks       #
######################################################################

def importance_sampling(
    edge_index: Tensor,
    num_nodes: int,
    seed_nodes: Tensor,
    walk_length: int = 2,
    num_walks: int = 200,
    top_k: int = 50,
    device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
    """Improved importance sampling with robust neighbor handling"""
    device = device or seed_nodes.device

    # Convert to undirected graph for better connectivity
    undir_edge_index = to_undirected(edge_index)

    # Launch random walks
    start = seed_nodes.repeat_interleave(num_walks)
    walks = random_walk(undir_edge_index[0], undir_edge_index[1], start, walk_length, num_nodes=num_nodes)

    # Get first-step neighbors and remove invalid ones
    nbrs = walks[:, 1]
    valid_mask = nbrs != -1
    nbrs = nbrs[valid_mask]

    # Fallback 1: If no neighbors found at all, use direct neighbors from edge_index
    if len(nbrs) == 0:
        # Get direct neighbors from original edge_index
        rows, cols = [], []
        for i, seed in enumerate(seed_nodes):
            mask = edge_index[0] == seed
            neighbors = edge_index[1][mask]
            if len(neighbors) > 0:
                rows.append(torch.full_like(neighbors, i))
                cols.append(neighbors)
            else:
                # Add self-loop if no neighbors found
                rows.append(torch.tensor([i], device=device))
                cols.append(torch.tensor([seed], device=device))

        sub_edge_index = torch.stack([torch.cat(rows), torch.cat(cols)], dim=0)
        edge_weight = torch.ones(sub_edge_index.size(1), device=device)
        return sub_edge_index, edge_weight

    # Count visit frequency
    counts = scatter(torch.ones_like(nbrs, dtype=torch.float), nbrs, dim=0,
                    dim_size=num_nodes, reduce='sum')
    counts = counts / counts.sum().clamp(min=1e-6)

    # Build neighbor lists with fallbacks
    rows, cols, weights = [], [], []
    for i, seed in enumerate(seed_nodes):
        # Get all possible neighbors
        mask = undir_edge_index[0] == seed
        neighbors = undir_edge_index[1][mask]

        # Fallback to direct neighbors if random walk found none
        if len(neighbors) == 0:
            mask = edge_index[0] == seed
            neighbors = edge_index[1][mask]

        # Final fallback to self-loop
        if len(neighbors) == 0:
            rows.append(torch.tensor([i], device=device))
            cols.append(torch.tensor([seed], device=device))
            weights.append(torch.tensor([1.0], device=device))
            continue

        # Get top-k neighbors by visit count
        neighbor_scores = counts[neighbors]
        if len(neighbors) > top_k:
            neighbor_scores, idx = torch.topk(neighbor_scores, top_k)
            neighbors = neighbors[idx]

        rows.append(torch.full_like(neighbors, i))
        cols.append(neighbors)
        weights.append(neighbor_scores)

    # Final assembly
    sub_edge_index = torch.stack([torch.cat(rows), torch.cat(cols)], dim=0)
    edge_weight = torch.cat(weights)

    return sub_edge_index.to(device), edge_weight.to(device)

In [7]:
######################################################################
# PinSage convolution layer                                          #
######################################################################

class PinSageConv(MessagePassing):
    """PinSage convolution layer with importance (weighted mean) pooling.

    Follows the formulation from Ying et al., KDD'18.
    """

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__(aggr='add')  # we will divide by weights later
        self.lin_neighbor = torch.nn.Linear(in_channels, out_channels, bias=True)
        self.lin_root = torch.nn.Linear(in_channels + out_channels, out_channels,
                                        bias=True)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin_neighbor.weight)
        torch.nn.init.zeros_(self.lin_neighbor.bias)
        torch.nn.init.xavier_uniform_(self.lin_root.weight)
        torch.nn.init.zeros_(self.lin_root.bias)

    def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor):
        """Args:
            x: Node features [N, F]. Assumes 0..B‑1 correspond to target nodes.
            edge_index: [2, E] with edges from *target nodes* to their sampled
                neighbors.
            edge_weight: Importance weight per edge, must be positive.
        Returns:
            Updated representations for the first B (target) nodes.
        """
        h = self.lin_neighbor(x)
        # propagate returns weighted sum; we pass weights via 'norm'.
        out = self.propagate(edge_index, x=h, norm=edge_weight)

        # Compute denominator per destination (row in COO).
        denom = scatter(edge_weight, edge_index[0], dim=0, dim_size=out.size(0),
                        reduce='sum').clamp(min=1e-6).unsqueeze(-1)
        out = out / denom  # weighted mean

        # concatenate with root features
        root = x[:out.size(0)]  # first B nodes are roots
        out = torch.cat([root, out], dim=-1)
        out = F.relu(self.lin_root(out))
        return F.normalize(out, p=2.0, dim=-1)

    def message(self, x_j: Tensor, norm: Tensor) -> Tensor:  # noqa: N802
        return norm.view(-1, 1) * x_j


In [8]:
######################################################################
# PinSage model                                                      #
######################################################################

class PinSageModel(torch.nn.Module):
    """Two‑layer PinSage network (easily extendable)."""

    def __init__(self, in_channels: int, hidden_channels: int, num_layers: int = 2):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_channels if i == 0 else hidden_channels
            self.convs.append(PinSageConv(in_ch, hidden_channels))
        # final projection as in paper
        self.project = torch.nn.Linear(hidden_channels, hidden_channels)
        self.num_layers = num_layers
        self.reset_parameters()

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        torch.nn.init.xavier_uniform_(self.project.weight)
        torch.nn.init.zeros_(self.project.bias)

    def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,
                batch_size: int) -> Tensor:
        """Computes embeddings for *batch_size* seed nodes.

        x contains features for *batch_size + batch_size*top_k* nodes;
        the first *batch_size* rows correspond to seeds.
        """
        for conv in self.convs:
            x = conv(x, edge_index, edge_weight)
        z = F.relu(self.project(x[:batch_size]))  # only seed embeddings
        return F.normalize(z, p=2.0, dim=-1)

In [9]:
######################################################################
# Mini‑batch training harness                                        #
######################################################################

class PinSageTrainer:
    """Utility class encapsulating training loop with max‑margin loss."""

    def __init__(
        self,
        model: PinSageModel,
        edge_index: Tensor,
        num_nodes: int,
        device: torch.device,
        walk_length: int = 2,
        num_walks: int = 200,
        top_k: int = 50,
        margin: float = 0.1,
    ):
        self.model = model.to(device)
        self.edge_index = edge_index.to(device)
        self.num_nodes = num_nodes
        self.device = device
        self.walk_length = walk_length
        self.num_walks = num_walks
        self.top_k = top_k
        self.margin = margin
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

    def batch(self, seeds: Tensor, x_full: Tensor) -> Tuple[Tensor, Tensor, Tensor, int]:
        """Prepares mini-batch graph and returns (x, edge_index, edge_weight, batch_size)."""
        seeds = seeds.to(self.device)
        batch_size = len(seeds)

        # Sample neighbors
        sub_edge_index, edge_weight = importance_sampling(
            self.edge_index, self.num_nodes, seeds,
            walk_length=self.walk_length, num_walks=self.num_walks,
            top_k=self.top_k, device=self.device)

        # Get all required nodes
        all_nodes = torch.cat([seeds, sub_edge_index[1]]).unique()

        # Map global node IDs to local IDs
        node_map = {int(node): i for i, node in enumerate(all_nodes.tolist())}

        # Remap edge indices to local IDs
        row = torch.tensor([node_map[int(seeds[idx])] for idx in sub_edge_index[0].tolist()],
                          device=self.device)
        col = torch.tensor([node_map[int(idx)] for idx in sub_edge_index[1].tolist()],
                          device=self.device)
        sub_edge_index_local = torch.stack([row, col], dim=0)

        # Get features for batch nodes
        x_batch = x_full[all_nodes]

        return x_batch, sub_edge_index_local, edge_weight, batch_size

    def max_margin_loss(self, z_q: Tensor, z_pos: Tensor, z_neg: Tensor) -> Tensor:
        pos_sim = (z_q * z_pos).sum(dim=-1)
        neg_sim = (z_q.unsqueeze(1) * z_neg).sum(dim=-1)  # [B, Kneg]
        loss = F.relu(neg_sim - pos_sim.unsqueeze(1) + self.margin).mean()
        return loss

    def bpr_loss(self, z_q: Tensor, z_pos: Tensor, z_neg: Tensor) -> Tensor:
        # z_neg: [B, Kneg, d]
        pos_sim = (z_q * z_pos).sum(dim=-1)                     # [B]
        neg_sim = (z_q.unsqueeze(1) * z_neg).sum(dim=-1)        # [B, Kneg]
        return -torch.log(torch.sigmoid(pos_sim.unsqueeze(1) - neg_sim)).mean()

    def train_step(self, seeds: Tensor, pos: Tensor, neg: Tensor,
               x_full: Tensor):
        self.model.train()
        seeds, pos, neg = [t.to(self.device) for t in (seeds, pos, neg)]

        # ------- build mini‑batch sub‑graph -------
        x_batch, sub_edge_index, edge_w, B = self.batch(
            seeds, x_full
        )

        # ------- forward -------
        z_all = self.model(x_batch, sub_edge_index, edge_w, B)  # [B, d]
        z_q   = z_all                                           # queries
        z_pos = self.model(x_full[pos],                         # reuse linear proj
                          torch.empty(2,0, dtype=torch.long, device=self.device),
                          torch.empty(0,  device=self.device),
                          pos.size(0))
        z_neg = self.model(x_full[neg.view(-1)]
                          .view(neg.shape[0], neg.shape[1], -1)
                          .reshape(-1, x_full.size(1)),
                          torch.empty(2,0, dtype=torch.long, device=self.device),
                          torch.empty(0,  device=self.device),
                          neg.numel()).view(neg.shape[0], neg.shape[1], -1)

        # inside train_step, just before backward()
        loss = self.bpr_loss(z_q, z_pos, z_neg)


        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return float(loss)


In [24]:
# ── NEW: get_all_embeddings ─────────────────────────────────────
@torch.no_grad()
def get_all_embeddings(model, x_full, edge_index, batch_size=4096):
    model.eval()
    num_nodes = x_full.size(0)
    z = torch.empty(num_nodes, model.project.out_features,
                    device=x_full.device)
    ptr = 0
    while ptr < num_nodes:
        batch_nodes = torch.arange(ptr, min(ptr+batch_size, num_nodes),
                                   device=x_full.device)
        # 1‑hop importance sampling (cheap) for inference
        sub_ei, ew = importance_sampling(edge_index, num_nodes,
                                         batch_nodes, top_k=50,
                                         num_walks=50, device=x_full.device)
        x_batch = x_full[torch.cat([batch_nodes, sub_ei[1]]).unique()]
        # remap as in trainer.batch()
        node_map = {int(n): i for i, n in enumerate(x_batch.unique().tolist())}
        row = torch.tensor([node_map[int(n)] for n in sub_ei[0].tolist()],
                           device=x_full.device)
        col = torch.tensor([node_map[int(n)] for n in sub_ei[1].tolist()],
                           device=x_full.device)
        z_batch = model(x_batch,
                        torch.stack([row, col], 0),
                        ew, batch_nodes.size(0))
        z[batch_nodes] = z_batch
        ptr += batch_size
    return F.normalize(z, dim=-1)


In [10]:
# data, num_users, num_items = load_movielens_100k()
# x_full = torch.cat([data["user"].x, data["item"].x], dim=0).to(device)
# edge_index = data["user", "rates", "item"].edge_index.to(device)

# evaluate_model(model, x_full, edge_index, num_users, num_items)


In [11]:
data, num_users, num_items = load_movielens_100k()
data = data.to(device)
num_nodes = num_users + num_items # Calculate num_nodes


In [12]:
# --- Create the full feature matrix ---
# Ensure features are created correctly (similar to the first train_model function)
# Example: Assuming 'data' has user and item features loaded correctly
x_user = data['user'].x
x_item = data['item'].x
# Pad features if necessary (as done in load_movielens_100k)
feat_dim = max(x_user.size(1), x_item.size(1))
if x_user.size(1) < feat_dim:
    x_user = torch.nn.functional.pad(x_user, (0, feat_dim - x_user.size(1)), "constant", 0)
if x_item.size(1) < feat_dim:
    x_item = torch.nn.functional.pad(x_item, (0, feat_dim - x_item.size(1)), "constant", 0)

x_full = torch.cat([x_user, x_item], dim=0).to(device)
print(f"Full feature matrix shape: {x_full.shape}")

Full feature matrix shape: torch.Size([2625, 19])


In [14]:
from torch_geometric.transforms import RandomLinkSplit

In [15]:
# Split edges into train/val/test
transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected=False, # Important if graph is directed (user->item)
    add_negative_train_samples=False, # Set to False if trainer handles neg sampling
    # Neg_sampling_ratio is only relevant if add_negative_train_samples=True
    edge_types=("user", "rates", "item"),
    rev_edge_types=None, # Adjust if you have reverse edges
)
# Ensure data is on the correct device before transform if needed
train_data, val_data, test_data = transform(data)
print("Data split completed.")
print("Train data:", train_data)
print("Validation data:", val_data)
print("Test data:", test_data)

Data split completed.
Train data: HeteroData(
  user={ x=[943, 19] },
  item={ x=[1682, 19] },
  (user, rates, item)={
    edge_index=[2, 80000],
    edge_label=[80000],
    edge_label_index=[2, 80000],
  }
)
Validation data: HeteroData(
  user={ x=[943, 19] },
  item={ x=[1682, 19] },
  (user, rates, item)={
    edge_index=[2, 80000],
    edge_label=[20000],
    edge_label_index=[2, 20000],
  }
)
Test data: HeteroData(
  user={ x=[943, 19] },
  item={ x=[1682, 19] },
  (user, rates, item)={
    edge_index=[2, 90000],
    edge_label=[20000],
    edge_label_index=[2, 20000],
  }
)


In [16]:
# Train model on train_data
# Make sure in_channels matches the feature dimension after padding
model = PinSageModel(in_channels=x_full.size(1), hidden_channels=128).to(device)

# Use the edge_index from train_data for the trainer
trainer = PinSageTrainer(model, train_data["user", "rates", "item"].edge_index,
                         num_nodes, device) # Pass num_nodes


In [17]:
# --- Build interaction dictionary based *only* on training edges ---
train_edge_index = train_data["user", "rates", "item"].edge_index.cpu() # Move to CPU for iteration
user_to_items_train = {}
for i in range(train_edge_index.size(1)):
    user = train_edge_index[0, i].item()
    item = train_edge_index[1, i].item()
    if user < num_users: # Ensure it's a user node
        if user not in user_to_items_train:
            user_to_items_train[user] = set() # Use set for faster lookups
        user_to_items_train[user].add(item)

print(f"Built training interaction dictionary for {len(user_to_items_train)} users.")


Built training interaction dictionary for 943 users.


In [21]:

import time

from sklearn.metrics import recall_score

import numpy as np

In [22]:
# --- Corrected Training loop ---
print("Starting training...")
start_time = time.time()
for epoch in range(1, 21): # Example: 20 epochs
    valid_users = list(user_to_items_train.keys())

    if not valid_users:
        print("No users with training interactions found. Stopping training.")
        break

    # --- Batch Generation ---
    batch_size = min(512, len(valid_users)) # Example batch size
    user_indices = np.random.choice(valid_users, batch_size, replace=False)

    seeds_list, pos_items_list, neg_items_list = [], [], []
    items_set = set(range(num_users, num_nodes)) # All item indices (global)

    for user_idx in user_indices:
        pos_candidates = list(user_to_items_train[user_idx]) # Items user interacted with in train set

        if not pos_candidates:
            continue # Skip user if no positive items in training set

        # Sample one positive item
        pos_item = np.random.choice(pos_candidates)

        # Sample negative items
        user_train_items = user_to_items_train[user_idx]
        neg_candidates = list(items_set - user_train_items) # Items user DID NOT interact with in train set

        if len(neg_candidates) < 5: # Need at least 5 negatives
             # print(f"Warning: User {user_idx} has only {len(neg_candidates)} neg candidates. Skipping.")
             continue

        neg_sample = np.random.choice(neg_candidates, 5, replace=False) # Sample 5 negatives

        # Add to batch lists
        seeds_list.append(user_idx)
        pos_items_list.append(pos_item)
        neg_items_list.append(neg_sample)

    if not seeds_list:
        print(f"Epoch {epoch:02d} | No valid training pairs generated in this batch. Skipping epoch.")
        continue

    # Convert lists to tensors
    seeds = torch.tensor(seeds_list, dtype=torch.long)
    pos = torch.tensor(pos_items_list, dtype=torch.long)
    neg = torch.tensor(neg_items_list, dtype=torch.long)

    # print(f"Epoch {epoch:02d} | Batch shapes: Seeds: {seeds.shape}, Pos: {pos.shape}, Neg: {neg.shape}")

    # --- Call train_step with all arguments ---
    try:
        loss = trainer.train_step(seeds, pos, neg, x_full)
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Time: {epoch_time:.2f}s | Batch Users: {len(seeds_list)}")
    except Exception as e:
         print(f"Error during training step in epoch {epoch}: {e}")
         # Optionally add more detailed error logging or debugging here
         # import traceback
         # traceback.print_exc()
         continue # Continue to next epoch or break if needed

print("Training finished.")


Starting training...


  neg = torch.tensor(neg_items_list, dtype=torch.long)


Epoch 01 | Loss: 0.0978 | Time: 1.94s | Batch Users: 512
Epoch 02 | Loss: 0.0868 | Time: 2.68s | Batch Users: 512
Epoch 03 | Loss: 0.0840 | Time: 3.78s | Batch Users: 512
Epoch 04 | Loss: 0.0788 | Time: 4.87s | Batch Users: 512
Epoch 05 | Loss: 0.0738 | Time: 5.80s | Batch Users: 512
Epoch 06 | Loss: 0.0776 | Time: 6.50s | Batch Users: 512
Epoch 07 | Loss: 0.0781 | Time: 7.21s | Batch Users: 512
Epoch 08 | Loss: 0.0731 | Time: 7.95s | Batch Users: 512
Epoch 09 | Loss: 0.0723 | Time: 8.69s | Batch Users: 512
Epoch 10 | Loss: 0.0719 | Time: 9.38s | Batch Users: 512
Epoch 11 | Loss: 0.0717 | Time: 10.10s | Batch Users: 512
Epoch 12 | Loss: 0.0743 | Time: 10.80s | Batch Users: 512
Epoch 13 | Loss: 0.0720 | Time: 11.50s | Batch Users: 512
Epoch 14 | Loss: 0.0725 | Time: 12.25s | Batch Users: 512
Epoch 15 | Loss: 0.0719 | Time: 13.01s | Batch Users: 512
Epoch 16 | Loss: 0.0739 | Time: 13.74s | Batch Users: 512
Epoch 17 | Loss: 0.0721 | Time: 14.48s | Batch Users: 512
Epoch 18 | Loss: 0.0670 

In [25]:
# --- Corrected Evaluation Function Definition ---
# Note: The original evaluate function had placeholders (...) for model calls.
# You need to implement how embeddings are generated during evaluation.
# This often involves calling the model's forward pass appropriately for users/items.
# The example below assumes the model's forward can handle empty edges for simple projection.

@torch.no_grad()
def evaluate(model, eval_data, train_data_edges, num_users, num_items, x_full, k=20, device=None):
    model.eval()
    if device is None:
        device = x_full.device

    print("Starting evaluation...")
    # Generate embeddings for all users and items using the *full* feature set (x_full)
    # This might require running inference in batches if the graph is large,
    # but for MovieLens-100k, we can often do it at once.
    # The way you call the model here depends on its forward implementation.
    # Assuming it can generate embeddings from features alone by passing empty edge info:
    empty_edge_index = torch.empty(2, 0, dtype=torch.long, device=device)
    empty_edge_weight = torch.empty(0, device=device)

    # --- Generate User Embeddings ---
    # Need to know how many nodes are in the user batch (num_users)
    # We pass the relevant section of x_full
    try:
        user_emb = model(x_full[:num_users], # User features
                         empty_edge_index,
                         empty_edge_weight,
                         batch_size=num_users) # Indicate the number of nodes being processed
        print(f"Generated user embeddings: {user_emb.shape}")
    except Exception as e:
        print(f"Error generating user embeddings: {e}")
        return None # Or handle error appropriately

    # --- Generate Item Embeddings ---
    # Need to know how many nodes are in the item batch (num_items)
    # We pass the relevant section of x_full
    try:
        item_emb = model(x_full[num_users:], # Item features
                         empty_edge_index,
                         empty_edge_weight,
                         batch_size=num_items) # Indicate the number of nodes being processed
        print(f"Generated item embeddings: {item_emb.shape}")
    except Exception as e:
        print(f"Error generating item embeddings: {e}")
        return None # Or handle error appropriately


    # Compute scores for all user-item pairs
    scores = user_emb @ item_emb.t() # [num_users, num_items]
    print(f"Computed scores matrix: {scores.shape}")

    # --- Masking Logic ---
    test_edge_index = eval_data["user", "rates", "item"].edge_index
    test_mask = torch.zeros(num_users, num_items, dtype=torch.bool, device=device)
    # Adjust item indices to be 0-based relative to the item set
    test_mask[test_edge_index[0], test_edge_index[1] - num_users] = True
    print(f"Test mask created. Positive test interactions: {test_mask.sum().item()}")


    # --- Exclude training items ---
    # Use the edge_index from the *original* training data split
    train_mask = torch.zeros(num_users, num_items, dtype=torch.bool, device=device)
    # Adjust item indices to be 0-based relative to the item set
    train_mask[train_data_edges[0], train_data_edges[1] - num_users] = True
    scores[train_mask] = -float("inf") # Set score to negative infinity for training items
    print(f"Masked out {train_mask.sum().item()} training interactions.")


    # Get top-k predictions
    try:
        _, top_k_indices = scores.topk(k, dim=1) # [num_users, k]
    except RuntimeError as e:
        print(f"Error during topk calculation: {e}")
        print("Check if k is larger than the number of items after masking.")
        return None


    # Calculate recall@k
    # Gather the test mask values corresponding to the top-k predictions
    # Ensure test_mask is expanded or indices are handled correctly if needed
    # For gather: top_k_indices should be LongTensor, test_mask BoolTensor
    hits = test_mask.gather(1, top_k_indices).sum().item()
    total_test_positives = test_mask.sum().item()

    if total_test_positives == 0:
        print("No positive edges in the test set!")
        recall = 0.0
    else:
        recall = hits / total_test_positives

    print(f"Recall@{k}: {recall:.4f} ({hits} hits / {total_test_positives} total test positives)")
    return recall

# --- Call Evaluation ---
# Pass the *training* edges to the evaluate function so it can mask them out
evaluate(model, test_data, train_data["user", "rates", "item"].edge_index, num_users, num_items, x_full, k=20, device=device)


Starting evaluation...
Error generating user embeddings: get_all_embeddings() got an unexpected keyword argument 'device'
