In [1]:
# Import required modules
import random
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn import model_selection, metrics, preprocessing

import copy
import torch
from torch import nn, optim, Tensor
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree, structured_negative_sampling
from torch_geometric.data import download_url, extract_zip
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj


In [2]:
# Download the dataset
import os
from torch_geometric.data import download_url, extract_zip
'''
# Define dataset URL and file paths
url = "https://files.grouplens.org/datasets/movielens/ml-latest-small.zip"
dataset_path = "ml-latest-small.zip"
extract_path = "."

# Download and extract dataset
if not os.path.exists(dataset_path):
    download_url(url, extract_path)
extract_zip(dataset_path, extract_path)
'''
# Define file paths
movie_path = "./ml-latest-small/movies.csv"
rating_path = "./ml-latest-small/ratings.csv"
user_path = "./ml-latest-small/users.csv"

# Load dataset
import pandas as pd

rating_df = pd.read_csv(rating_path)

# Display first few rows
print(rating_df.head())

# Display unique movie and user counts
print(f"Number of unique movies: {len(rating_df['movieId'].unique())}")
print(f"Number of unique users: {len(rating_df['userId'].unique())}")




   userId  movieId  rating  timestamp
0       1        1     4.0  964982703
1       1        3     4.0  964981247
2       1        6     4.0  964982224
3       1       47     5.0  964983815
4       1       50     5.0  964982931
Number of unique movies: 9724
Number of unique users: 610


In [3]:
# Display summary statistics
rating_df.describe()

Unnamed: 0,userId,movieId,rating,timestamp
count,100836.0,100836.0,100836.0,100836.0
mean,326.127564,19435.295718,3.501557,1205946000.0
std,182.618491,35530.987199,1.042529,216261000.0
min,1.0,1.0,0.5,828124600.0
25%,177.0,1199.0,3.0,1019124000.0
50%,325.0,2991.0,3.5,1186087000.0
75%,477.0,8122.0,4.0,1435994000.0
max,610.0,193609.0,5.0,1537799000.0


In [4]:
lbl_user = preprocessing.LabelEncoder()
lbl_movie = preprocessing.LabelEncoder()

rating_df.userId = lbl_user.fit_transform(rating_df.userId.values)
rating_df.movieId = lbl_movie.fit_transform(rating_df.movieId.values)

In [5]:
print(rating_df.userId.max())
print(rating_df.movieId.max())

609
9723


In [6]:
rating_df.rating.value_counts()

rating
4.0    26818
3.0    20047
5.0    13211
3.5    13136
4.5     8551
2.0     7551
2.5     5550
1.0     2811
1.5     1791
0.5     1370
Name: count, dtype: int64

In [7]:


def load_edge_csv(df, src_index_col, dst_index_col, link_index_col, rating_threshold=3):
    """
    Loads a CSV containing edges between users and items.

    Args:
        df (pd.DataFrame): DataFrame containing user-item interactions.
        src_index_col (str): Column name for users.
        dst_index_col (str): Column name for items (movies).
        link_index_col (str): Column name for user-item interaction (ratings).
        rating_threshold (int, optional): Threshold to determine positive edges. Defaults to 3.

    Returns:
        list: Edge index (2xN matrix) containing the node IDs of N user-item edges.
    """
    
    print("Constructing COO format edge_index from input rating events...")

    # Get user IDs from rating events in order of occurrence
    src = df[src_index_col].tolist()
    
    # Get movie IDs from rating events in order of occurrence
    dst = df[dst_index_col].tolist()

    # Apply rating threshold to filter interactions
    edge_attr = torch.from_numpy(df[link_index_col].values).view(-1, 1).to(torch.long)
    
    edge_index = [[], []]  # COO format edge index (two lists for source and destination nodes)

    for i in range(edge_attr.shape[0]):
        if edge_attr[i] >= rating_threshold:
            edge_index[0].append(src[i])
            edge_index[1].append(dst[i])

    return edge_index


In [8]:

edge_index= load_edge_csv(
rating_df,
src_index_col='userId', dst_index_col='movieId', link_index_col="rating", rating_threshold=3.5,
)
print(f"{len(edge_index)} x {len(edge_index[0])}")

Constructing COO format edge_index from input rating events...
2 x 48580


In [9]:
#Convert to Tensor
edge_index = torch.LongTensor(edge_index)
print(edge_index)
print(edge_index.size())

tensor([[   0,    0,    0,  ...,  609,  609,  609],
        [   0,    2,    5,  ..., 9443, 9444, 9445]])
torch.Size([2, 48580])


In [11]:
# Get total number of unique users and movies (before applying rating threshold)
num_users = rating_df['userId'].nunique()  # More efficient than len(unique())
num_movies = rating_df['movieId'].nunique()

print(f"Total unique users: {num_users}")
print(f"Total unique movies: {num_movies}")

Total unique users: 610
Total unique movies: 9724


In [12]:
# Calculate total interactions
num_interactions = edge_index.size(1)  # More precise than shape[1] for PyTorch tensors

# Split the edges using 80/10/10 train/validation/test split
all_indices = torch.arange(num_interactions)  # More efficient than list comprehension

# First split: 80% train, 20% temp
train_indices, temp_indices = train_test_split(
    all_indices,
    test_size=0.2,
    random_state=1
)

# Second split: 10% val, 10% test (50% of remaining 20%)
val_indices, test_indices = train_test_split(
    temp_indices,
    test_size=0.5,
    random_state=1
)

# Create edge splits
train_edge_index = edge_index[:, train_indices]
val_edge_index = edge_index[:, val_indices]
test_edge_index = edge_index[:, test_indices]

# Print statistics
print(f"Dataset statistics:")
print(f"- Total users: {num_users}")
print(f"- Total movies: {num_movies}")
print(f"- Total interactions: {num_interactions}")
print(f"- Training interactions: {train_edge_index.size(1)}")
print(f"- Validation interactions: {val_edge_index.size(1)}")
print(f"- Test interactions: {test_edge_index.size(1)}")
print(f"- Total nodes (users + movies): {num_users + num_movies}")
print(f"- Unique users in training: {torch.unique(train_edge_index[0]).size(0)}")
print(f"- Unique movies in training: {torch.unique(train_edge_index[1]).size(0)}")

Dataset statistics:
- Total users: 610
- Total movies: 9724
- Total interactions: 48580
- Training interactions: 38864
- Validation interactions: 4858
- Test interactions: 4858
- Total nodes (users + movies): 10334
- Unique users in training: 609
- Unique movies in training: 5676


In [13]:
def convert_r_mat_edge_index_to_adj_mat_edge_index(input_edge_index):
    R = torch.zeros(num_users, num_movies)
    for i in range(input_edge_index.size(1)):
        row_idx = input_edge_index[0][i]
        col_idx = input_edge_index[1][i]
        R[row_idx][col_idx] = 1

    R_transpose = torch.transpose(R, 0, 1)
    adj_mat = torch.zeros(num_users + num_movies, num_users + num_movies)
    adj_mat[:num_users, num_users:] = R.clone()
    adj_mat[num_users:, :num_users] = R_transpose.clone()
    adj_mat_coo = adj_mat.to_sparse_coo()
    adj_mat_coo = adj_mat_coo.indices()
    return adj_mat_coo

def convert_adj_mat_edge_index_to_r_mat_edge_index(input_edge_index):
    sparse_input_edge_index = SparseTensor(
        row=input_edge_index[0],
        col=input_edge_index[1],
        sparse_sizes=(num_users + num_movies, num_users + num_movies)
    )
    adj_mat = sparse_input_edge_index.to_dense()
    interact_mat = adj_mat[:num_users, num_users:]
    r_mat_edge_index = interact_mat.to_sparse_coo().indices()
    return r_mat_edge_index

In [14]:
# convert from r_mat (interaction matrix) edge index to adjacency matrix's edge index
# so we can feed it to model
train_edge_index = convert_r_mat_edge_index_to_adj_mat_edge_index(train_edge_index)
val_edge_index = convert_r_mat_edge_index_to_adj_mat_edge_index(val_edge_index)
test_edge_index = convert_r_mat_edge_index_to_adj_mat_edge_index(test_edge_index)

print("Train edge index:", train_edge_index)
print("Train edge index size:", train_edge_index.size())
print("Validation edge index:", val_edge_index)
print("Validation edge index size:", val_edge_index.size())
print("Test edge index:", test_edge_index)
print("Test edge index size:", test_edge_index.size())

Train edge index: tensor([[    0,     0,     0,  ..., 10326, 10327, 10333],
        [  610,   612,   653,  ...,   183,   183,   330]])
Train edge index size: torch.Size([2, 77728])
Validation edge index: tensor([[    0,     0,     0,  ..., 10226, 10236, 10240],
        [  615,   794,  2010,  ...,   317,   204,   413]])
Validation edge index size: torch.Size([2, 9716])
Test edge index: tensor([[    0,     0,     0,  ..., 10301, 10302, 10329],
        [  811,  1086,  1095,  ...,   585,   585,   183]])
Test edge index size: torch.Size([2, 9716])


In [15]:
# Helper function for training and computing BPR loss
# Since this is self-supervised learning, we're relying on the graph structure itself
# We don't have labels other than the graph structure, so we need this function
# which randomly samples a mini-batch of positive and negative samples
def sample_mini_batch(batch_size, edge_index):
    """Randomly samples indices of a mini-batch given an adjacency matrix
    
    Args:
        batch_size (int): mini-batch size
        edge_index (torch.Tensor): 2 by N list of edges
        
    Returns:
        tuple: user_indices, positive_item_indices, negative_item_indices
    """
    # structured_negative_sampling is a PyG library
    # Samples a negative edge (i,k) for every positive edge (i,j) in the graph
    # and returns it as a tuple of the form (i,j,k)
    edges = structured_negative_sampling(edge_index)
    
    # Stack the edges into 3 x edge_index_len tensor
    edges = torch.stack(edges, dim=0)
    
    # Randomly sample batch_size indices with replacement
    indices = random.choices(
        [i for i in range(edges.shape[1])], 
        k=batch_size
    )
    
    batch = edges[:, indices]
    user_indices, pos_item_indices, neg_item_indices = batch[0], batch[1], batch[2]
    
    return user_indices, pos_item_indices, neg_item_indices

In [16]:
class LightGCN(MessagePassing):
    """LightGCN Model as proposed in https://arxiv.org/abs/2002.02126"""

    def __init__(self, num_users, num_items, embedding_dim=64, K=3, add_self_loops=False):
        """Initializes LightGCN Model

        Args:
            num_users (int): Number of users
            num_items (int): Number of items
            embedding_dim (int, optional): Dimensionality of embeddings. Defaults to 64.
            K (int, optional): Number of message passing layers. Defaults to 3.
            add_self_loops (bool, optional): Whether to add self loops for message passing. Defaults to False.
        """
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_dim = embedding_dim
        self.K = K
        self.add_self_loops = add_self_loops

        # Define user and item embedding for direct look up
        self.users_emb = nn.Embedding(num_embeddings=self.num_users, embedding_dim=self.embedding_dim)  # e_u^0
        self.items_emb = nn.Embedding(num_embeddings=self.num_items, embedding_dim=self.embedding_dim)  # e_i^0

        # Initialize embeddings as recommended in LightGCN paper
        nn.init.normal_(self.users_emb.weight, std=0.1)
        nn.init.normal_(self.items_emb.weight, std=0.1)

    def forward(self, edge_index: SparseTensor):
        """Forward propagation of LightGCN Model.

        Args:
            edge_index (SparseTensor): adjacency matrix

        Returns:
            tuple: (e_u_k, e_u_0, e_i_k, e_i_0) where:
                - e_u_k: final user embeddings after K layers
                - e_u_0: initial user embeddings
                - e_i_k: final item embeddings after K layers
                - e_i_0: initial item embeddings
        """
        # Normalize adjacency matrix
        edge_index_norm = gcn_norm(edge_index=edge_index, add_self_loops=self.add_self_loops)

        # Concatenate user and item embeddings
        emb_0 = torch.cat([self.users_emb.weight, self.items_emb.weight])  # E^0
        embs = [emb_0]  # Store layer 0 embeddings

        emb_k = emb_0  # Initialize embeddings for propagation

        # Propagate embeddings through K layers
        for _ in range(self.K):
            emb_k = self.propagate(edge_index=edge_index_norm[0], x=emb_k, norm=edge_index_norm[1])
            embs.append(emb_k)

        # Stack and average embeddings as per LightGCN paper
        embs = torch.stack(embs, dim=1)
        emb_final = torch.mean(embs, dim=1)  # E^K

        # Split into user and item embeddings
        users_emb_final, items_emb_final = torch.split(emb_final, [self.num_users, self.num_items])

        return users_emb_final, self.users_emb.weight, items_emb_final, self.items_emb.weight

    def message(self, x_j: Tensor, norm: Tensor) -> Tensor:
        """Message passing operation
        
        Args:
            x_j (Tensor): Embeddings of neighboring nodes (shape: [edge_index_len, embedding_dim])
            norm (Tensor): Normalization values (shape: [edge_index_len])
            
        Returns:
            Tensor: Normalized messages to be propagated
        """
        return norm.view(-1, 1) * x_j


# Initialize model
layers = 3
model = LightGCN(
    num_users=num_users,
    num_items=num_movies,
    K=layers
)

In [17]:
def bpr_loss(users_emb_final, users_emb_0, pos_items_emb_final, pos_items_emb_0, 
             neg_items_emb_final, neg_items_emb_0, lambda_val=1e-5):
    """Bayesian Personalized Ranking Loss as described in https://arxiv.org/abs/1205.2618
    
    Args:
        users_emb_final (torch.Tensor): Final user embeddings after K layers
        users_emb_0 (torch.Tensor): Initial user embeddings
        pos_items_emb_final (torch.Tensor): Final positive item embeddings after K layers
        pos_items_emb_0 (torch.Tensor): Initial positive item embeddings
        neg_items_emb_final (torch.Tensor): Final negative item embeddings after K layers
        neg_items_emb_0 (torch.Tensor): Initial negative item embeddings
        lambda_val (float, optional): Regularization coefficient. Defaults to 1e-5.
        
    Returns:
        torch.Tensor: Scalar BPR loss value
    """
    # Calculate regularization loss
    reg_loss = lambda_val * (
        users_emb_0.norm(2).pow(2) + 
        pos_items_emb_0.norm(2).pow(2) + 
        neg_items_emb_0.norm(2).pow(2)
    )
    
    # Calculate positive and negative scores
    pos_scores = torch.sum(users_emb_final * pos_items_emb_final, dim=-1)
    neg_scores = torch.sum(users_emb_final * neg_items_emb_final, dim=-1)
    
    # Calculate BPR loss
    bpr_loss = -torch.mean(torch.nn.functional.logsigmoid(pos_scores - neg_scores))
    
    # Total loss
    total_loss = bpr_loss + reg_loss
    
    return total_loss

In [18]:
def get_user_positive_items(edge_index):
    """Generate dictionary of positive items for each user
    
    Args:
        edge_index (torch.Tensor): 2 by N list of edges
        
    Returns:
        dict: {user_id: list of positive item_ids}
    """
    user_pos_items = {}
    
    for i in range(edge_index.size(1)):
        user = edge_index[0][i].item()
        item = edge_index[1][i].item()
        
        if user not in user_pos_items:
            user_pos_items[user] = []  # Fixed: using square brackets
            
        user_pos_items[user].append(item)
    
    return user_pos_items

In [19]:
def RecallPrecision_ATM(groundTruth, r, k):
    """Computes recall@k and precision@k

    Args:
        groundTruth (list[list[int]]): List of lists containing ground truth item_ids for each user.
                                      Each sublist contains the true relevant items for that user.
        r (list[list[bool]]): List of lists indicating whether each top-k recommended item
                             is a ground truth item (true relevant) or not.
        k (int): The number of top items to consider for evaluation metrics.

    Returns:
        tuple: (recall@k, precision@k)
    """
    # Convert to tensors if they aren't already
    if not isinstance(r, torch.Tensor):
        r = torch.tensor(r, dtype=torch.float)
    if not isinstance(groundTruth, torch.Tensor):
        user_num_liked = torch.tensor([len(gt) for gt in groundTruth], dtype=torch.float)
    else:
        user_num_liked = torch.tensor([len(groundTruth[i]) for i in range(len(groundTruth))], dtype=torch.float)

    # Number of correctly predicted items per user
    num_correct_pred = torch.sum(r, dim=-1)  # Fixed dim=-1 instead of dim-1

    # Handle case where user has no liked items to avoid division by zero
    user_num_liked[user_num_liked == 0] = 1e-9

    recall = torch.mean(num_correct_pred / user_num_liked)
    precision = torch.mean(num_correct_pred / k)

    return recall.item(), precision.item()

In [20]:
def get_metrics(model, input_edge_index, input_exclude_edge_indices, k):
    """Computes evaluation metrics: recall, precision, and ndcg @ k

    Args:
        model (LightGCN): trained LightGCN model
        input_edge_index (torch.Tensor): 2 by N edge index (adjacency matrix based) for split to evaluate
        input_exclude_edge_indices (list[torch.Tensor]): list of edge indices to exclude from evaluation
        k (int): number of top items to consider for metrics

    Returns:
        tuple: (recall@k, precision@k, ndcg@k)
    """
    # Get embeddings
    user_embedding = model.users_emb.weight
    item_embedding = model.items_emb.weight

    # Convert edge indices to interaction matrix format
    edge_index = convert_adj_mat_edge_index_to_r_mat_edge_index(input_edge_index)
    exclude_edge_indices = [
        convert_adj_mat_edge_index_to_r_mat_edge_index(exclude_edge_index)
        for exclude_edge_index in input_exclude_edge_indices
    ]

    # Generate predicted interaction matrix
    r_mat_rating = torch.matmul(user_embedding, item_embedding.T)
    rating = r_mat_rating.clone()

    # Mask out excluded items
    for exclude_edge_index in exclude_edge_indices:
        user_pos_items = get_user_positive_items(exclude_edge_index)
        
        exclude_users = []
        exclude_items = []
        for user, items in user_pos_items.items():
            exclude_users.extend([user] * len(items))
            exclude_items.extend(items)
        
        # Set excluded entries to very low value
        rating[exclude_users, exclude_items] = -(1 << 10)

    # Get top-k items for each user
    _, top_K_items = torch.topk(rating, k=k)

    # Get ground truth positive items for test users
    users = edge_index[0].unique()
    test_user_pos_items = get_user_positive_items(edge_index)

    # Convert test_user_pos_items dictionary into list of lists
    test_user_pos_items_list = [test_user_pos_items[user.item()] for user in users]

    # Create relevance matrix
    r = []
    for i, user in enumerate(users):
        user = user.item()
        user_true_relevant_items = test_user_pos_items.get(user, [])
        # Create boolean list indicating whether top-k items are relevant
        relevant_items = [item.item() in user_true_relevant_items for item in top_K_items[i]]
        r.append(relevant_items)

    r = torch.tensor(r, dtype=torch.float32)

    # Compute metrics
    recall, precision = RecallPrecision_ATM(test_user_pos_items_list, r, k)
    ndcg = NDCGatK_r(test_user_pos_items_list, r, k)

    return recall, precision, ndcg

In [21]:
def evaluation(model, edge_index, exclude_edge_indices, k, lambda_val):
    """Evaluates model loss and metrics including recall, precision, ndcg @ k

    Args:
        model (LightGCN): LightGCN model
        edge_index (torch.Tensor): 2 by N edge index (adjacency matrix based) for split to evaluate
        exclude_edge_indices (list[torch.Tensor]): list of edge indices to exclude from evaluation
        k (int): determines the top k items to compute metrics on
        lambda_val (float): lambda value for BPR loss regularization

    Returns:
        tuple: (bpr_loss, recall@k, precision@k, ndcg@k)
    """
    # Get embeddings through forward pass
    users_emb_final, users_emb_0, items_emb_final, items_emb_0 = model(edge_index)

    # Convert to interaction matrix format for negative sampling
    r_mat_edge_index = convert_adj_mat_edge_index_to_r_mat_edge_index(edge_index)

    # Generate negative samples
    edges = structured_negative_sampling(
        r_mat_edge_index,
        contains_neg_self_loops=False
    )
    user_indices, pos_item_indices, neg_item_indices = edges

    # Get embeddings for sampled triplets
    users_emb_final = users_emb_final[user_indices]
    users_emb_0 = users_emb_0[user_indices]
    pos_items_emb_final = items_emb_final[pos_item_indices]
    pos_items_emb_0 = items_emb_0[pos_item_indices]
    neg_items_emb_final = items_emb_final[neg_item_indices]
    neg_items_emb_0 = items_emb_0[neg_item_indices]

    # Calculate BPR loss
    loss = bpr_loss(
        users_emb_final,
        users_emb_0,
        pos_items_emb_final,
        pos_items_emb_0,
        neg_items_emb_final,
        neg_items_emb_0,
        lambda_val
    ).item()

    # Calculate metrics
    recall, precision, ndcg = get_metrics(model, edge_index, exclude_edge_indices, k)

    return loss, recall, precision, ndcg

# Training 

In [23]:
ITERATIONS = 1000
EPOCHS = 10
BATCH_SIZE = 1024
LR = 1e-3
ITEMS_PER_EVAL = 200
ITEMS_PER_LR_DECAY = 200
K = 20
LAMBDA = 1e-6

In [24]:
# Setup device and model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move model to device
model = model.to(device)
model.train()  # Set model to training mode

# Initialize optimizer and learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

# Move data tensors to device
edge_index = edge_index.to(device)
train_edge_index = train_edge_index.to(device)
val_edge_index = val_edge_index.to(device)
test_edge_index = test_edge_index.to(device) if 'test_edge_index' in locals() else None

Using device: cpu


In [25]:
def get_embs_for_bpr(model, input_edge_index):
    """Prepares embeddings for BPR loss calculation
    
    Args:
        model (LightGCN): trained LightGCN model
        input_edge_index (torch.Tensor): adjacency matrix based edge index
        
    Returns:
        tuple: (users_emb_final, users_emb_0, 
                pos_items_emb_final, pos_items_emb_0,
                neg_items_emb_final, neg_items_emb_0)
    """
    # Get all embeddings from model
    users_emb_final, users_emb_0, items_emb_final, items_emb_0 = model(input_edge_index)
    
    # Convert to interaction matrix format
    edge_index_to_use = convert_adj_mat_edge_index_to_r_mat_edge_index(input_edge_index)
    
    # Sample mini-batch for BPR loss
    user_indices, pos_item_indices, neg_item_indices = sample_mini_batch(BATCH_SIZE, edge_index_to_use)
    
    # Move indices to correct device
    user_indices = user_indices.to(device)
    pos_item_indices = pos_item_indices.to(device)
    neg_item_indices = neg_item_indices.to(device)
    
    # Get embeddings for sampled triplets
    users_emb_final = users_emb_final[user_indices]  # Fixed: using [] for indexing
    users_emb_0 = users_emb_0[user_indices]
    pos_items_emb_final = items_emb_final[pos_item_indices]
    pos_items_emb_0 = items_emb_0[pos_item_indices]
    neg_items_emb_final = items_emb_final[neg_item_indices]
    neg_items_emb_0 = items_emb_0[neg_item_indices]
    
    return (users_emb_final, users_emb_0, 
            pos_items_emb_final, pos_items_emb_0,
            neg_items_emb_final, neg_items_emb_0)

In [None]:
# Initialize training trackers
train_losses = []
val_losses = []
val_recall_at_ks = []

for iter in tqdm(range(ITERATIONS), desc="Training"):
    # Forward propagation and get embeddings for BPR
    users_emb_final, users_emb_0, pos_items_emb_final, pos_items_emb_0, neg_items_emb_final, neg_items_emb_0 = \
        get_embs_for_bpr(model, train_edge_index)

    # Loss computation
    train_loss = bpr_loss(
        users_emb_final,
        users_emb_0,
        pos_items_emb_final,
        pos_items_emb_0,
        neg_items_emb_final,
        neg_items_emb_0,
        LAMBDA
    )

    # Backpropagation
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    # Store training loss
    train_losses.append(train_loss.item())

    # Validation and logging
    if iter % ITERS_PER_EVAL == 0:
        model.eval()
        
        with torch.no_grad():
            val_loss, recall, precision, ndcg = evaluation(
                model,
                val_edge_index,
                [train_edge_index],  # List of edge indices to exclude
                K,
                LAMBDA
            )
            
            val_losses.append(val_loss)
            val_recall_at_ks.append(round(recall, 5))

        print(f"Iteration {iter}/{ITERATIONS}: "
              f"train_loss: {train_loss.item():.5f}, "
              f"val_loss: {val_loss:.5f}, "
              f"recall@{K}: {recall:.5f}, "
              f"precision@{K}: {precision:.5f}, "
              f"ndcg@{K}: {ndcg:.5f}")
        
        model.train()  # Set back to training mode

    # Learning rate scheduling
    if iter % ITERS_PER_LR_DECAY == 0 and iter != 0:
        scheduler.step()