In [1]:
import torch
import torch.nn as nn
import torch.sparse as sparse

# Implement

In [2]:
class LightGCN(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, n_layers, user_ids, item_ids, interaction_scores=None):
        """
        LightGCN model
        Args:
            n_users (int): Number of users
            n_items (int): Number of items
            embedding_dim (int): Embedding dimension
            n_layers (int): Number of propagation layers
            user_ids (list or tensor): List of user indices
            item_ids (list or tensor): List of item indices
            interaction_scores (list or tensor, optional): Interaction scores (binary or weighted)
        """
        super(LightGCN, self).__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers

        # Initialize user and item embeddings using normal initialization
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)
        nn.init.normal_(self.user_embedding.weight, std=0.1)
        nn.init.normal_(self.item_embedding.weight, std=0.1)

        # Build adjacency matrix from user-item interactions
        self.adj_matrix = self.build_adj_matrix(user_ids, item_ids, interaction_scores)

        # Softplus function for BPR loss
        self.softplus = nn.Softplus()

    def build_adj_matrix(self, user_ids, item_ids, interaction_scores):
        """
        Build the sparse adjacency matrix from user-item interactions.
        Args:
            user_ids (list or tensor): User indices for each interaction
            item_ids (list or tensor): Item indices for each interaction
            interaction_scores (list or tensor, optional): Interaction scores (optional, binary if None)
        Returns:
            adj_matrix (torch.sparse.FloatTensor): Symmetrically normalized sparse adjacency matrix
        """
        if interaction_scores is None:
            interaction_scores = torch.ones(len(user_ids))  # Default to binary interactions if no scores are provided

        # Number of total nodes (users + items)
        n_total_nodes = self.n_users + self.n_items

        # Prepare the adjacency matrix in coordinate (COO) format
        user_tensor = torch.tensor(user_ids, dtype=torch.long)
        item_tensor = torch.tensor(item_ids, dtype=torch.long) + self.n_users  # Shift item indices by n_users
        score_tensor = torch.tensor(interaction_scores, dtype=torch.float32)

        # Create user-item interaction edges (user-to-item and item-to-user)
        indices = torch.cat([user_tensor.unsqueeze(0), item_tensor.unsqueeze(0)], dim=0)
        values = score_tensor

        # Create the sparse user-item adjacency matrix
        adj_matrix = torch.sparse_coo_tensor(indices, values, (n_total_nodes, n_total_nodes))

        # Symmetric normalization
        row_sum = torch.sparse.sum(adj_matrix, dim=1).to_dense()  # Compute row sums (degree for each node)
        d_inv_sqrt = torch.pow(row_sum, -0.5)  # Compute D^-0.5
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0  # Handle divide by zero for isolated nodes

        # Normalize the adjacency matrix: A_hat = D^-0.5 * A * D^-0.5
        d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
        adj_matrix = torch.sparse.mm(d_mat_inv_sqrt, adj_matrix)  # D^-0.5 * A
        adj_matrix = torch.sparse.mm(adj_matrix, d_mat_inv_sqrt)  # (D^-0.5 * A) * D^-0.5

        return adj_matrix

    def propagate(self):
        """
        Perform embedding propagation based on the LightGCN propagation rule.
        """
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight

        all_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)  # Stack user and item embeddings

        all_layer_embeddings = [all_embeddings]

        # Perform K-layer propagation
        for _ in range(self.n_layers):
            all_embeddings = torch.sparse.mm(self.adj_matrix, all_embeddings)
            all_layer_embeddings.append(all_embeddings)

        # Combine embeddings from all layers (mean aggregation)
        final_embeddings = torch.stack(all_layer_embeddings, dim=1).mean(dim=1)

        final_user_embeddings = final_embeddings[:self.n_users]  # First part is users
        final_item_embeddings = final_embeddings[self.n_users:]  # Second part is items

        return final_user_embeddings, final_item_embeddings

    def forward(self, users, items):
        """
        Perform forward pass to get final user and item embeddings.
        Args:
            users (tensor): User indices
            items (tensor): Item indices
        """
        user_embeddings, item_embeddings = self.propagate()

        # Get specific user and item embeddings
        return user_embeddings[users], item_embeddings[items]

    def predict(self, users, items):
        """
        Predict interaction score for given users and items.
        """
        user_emb, item_emb = self.forward(users, items)
        return torch.sum(user_emb * item_emb, dim=1)

    def bpr_loss(self, users, pos, neg):
        """
        Compute the BPR loss for the model.
        Args:
            users (tensor): User indices
            pos (tensor): Positive item indices
            neg (tensor): Negative item indices
        Returns:
            tuple: BPR loss and regularization loss
        """
        # Get embeddings for users, positive items, and negative items
        user_emb, pos_emb = self.forward(users, pos)
        _, neg_emb = self.forward(users, neg)

        # Compute the positive and negative scores (dot product between embeddings)
        pos_scores = torch.sum(user_emb * pos_emb, dim=1)
        neg_scores = torch.sum(user_emb * neg_emb, dim=1)

        # Compute BPR loss using Softplus
        bpr_loss = torch.mean(self.softplus(neg_scores - pos_scores))

        # Compute regularization loss (L2 norm of embeddings)
        reg_loss = (1 / 2) * (user_emb.norm(2).pow(2) + pos_emb.norm(2).pow(2) + neg_emb.norm(2).pow(2)) / float(len(users))

        return bpr_loss, reg_loss

# Sample training loop using the interaction table format
def train_lightgcn(model, train_data, n_items, epochs=10, lr=0.001, reg=1e-4):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for user, pos_item in train_data:
            # Randomly sample a negative item
            neg_item = torch.randint(0, n_items, (1,)).item()

            optimizer.zero_grad()
            
            # Get embeddings
            user_embeddings, pos_item_embeddings = model(torch.tensor([user]), torch.tensor([pos_item]))
            _, neg_item_embeddings = model(torch.tensor([user]), torch.tensor([neg_item]))

            # Compute BPR loss
            loss = bpr_loss(user_embeddings, pos_item_embeddings, neg_item_embeddings, reg)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")

# Test implementation

In [3]:
# Example data (assuming you have 3 users, 4 items, and binary interactions)
user_ids = [0, 0, 1, 2, 2]
item_ids = [0, 1, 2, 3, 1]
n_users = len(set(user_ids))
n_items = len(set(item_ids))
interaction_scores = [1, 1, 1, 1, 1]  # Binary interaction scores

# Instantiate the model
model = LightGCN(n_users=3, n_items=4, embedding_dim=64, n_layers=3, user_ids=user_ids, item_ids=item_ids, interaction_scores=interaction_scores)

# Example forward pass
users = torch.tensor([0, 1, 2])
items = torch.tensor([0, 1, 2])
predictions = model.predict(users, items)
print(predictions)

tensor([-0.0006,  0.0022, -0.0067], grad_fn=<SumBackward1>)


In [4]:
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader

# Fixing random seed for reproducibility
random.seed(42)
np.random.seed(42)

# Constants
n_users = 5
n_items = 10
n_interactions = 20  # Number of interactions
embedding_dim = 8
n_layers = 3
batch_size = 4

# Create mock user-item interactions (randomly generated)
user_ids = np.random.randint(0, n_users, size=n_interactions)
item_ids = np.random.randint(0, n_items, size=n_interactions)
interaction_scores = np.ones(n_interactions)  # All interactions are positive

# Display mock dataset
print("Mock User IDs:", user_ids.tolist())
print("Mock Item IDs:", item_ids.tolist())
print("Interaction Scores:", interaction_scores)


class InteractionDataset(Dataset):
    def __init__(self, user_ids, item_ids, n_items):
        """
        Args:
            user_ids (list or array): List of user indices.
            item_ids (list or array): List of item indices.
            n_items (int): Number of unique items for negative sampling.
        """
        self.user_ids = user_ids
        self.item_ids = item_ids
        self.n_items = n_items
    
    def __len__(self):
        return len(self.user_ids)
    
    def __getitem__(self, idx):
        user = self.user_ids[idx]
        pos_item = self.item_ids[idx]
        
        # Randomly sample a negative item for the user
        neg_item = np.random.randint(0, self.n_items)
        while neg_item == pos_item:
            neg_item = np.random.randint(0, self.n_items)
        
        return user, pos_item, neg_item

# Create dataset
interaction_dataset = InteractionDataset(user_ids, item_ids, n_items)

# Create DataLoader
dataloader = DataLoader(interaction_dataset, batch_size=batch_size, shuffle=True)

# Instantiate LightGCN model
model = LightGCN(n_users=n_users, n_items=n_items, embedding_dim=embedding_dim, n_layers=n_layers,
                 user_ids=user_ids, item_ids=item_ids, interaction_scores=interaction_scores)

# Define optimizer (e.g., Adam)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
n_epochs = 5

for epoch in range(n_epochs):
    total_loss = 0
    total_reg_loss = 0
    
    for batch_idx, (users, pos_items, neg_items) in enumerate(dataloader):
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass: Compute BPR loss and regularization loss
        bpr_loss, reg_loss = model.bpr_loss(users, pos_items, neg_items)
        
        # Total loss = BPR loss + regularization loss
        loss = bpr_loss + reg_loss
        
        # Backward pass
        loss.backward()
        
        # Optimizer step
        optimizer.step()
        
        total_loss += bpr_loss.item()
        total_reg_loss += reg_loss.item()
    
    print(f'Epoch {epoch+1}/{n_epochs}, Loss: {total_loss:.4f}, Reg Loss: {total_reg_loss:.4f}')

Mock User IDs: [3, 4, 2, 4, 4, 1, 2, 2, 2, 4, 3, 2, 4, 1, 3, 1, 3, 4, 0, 3]
Mock Item IDs: [9, 5, 8, 0, 9, 2, 6, 3, 8, 2, 4, 2, 6, 4, 8, 6, 1, 3, 8, 1]
Interaction Scores: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Epoch 1/5, Loss: 3.4653, Reg Loss: 0.0291
Epoch 2/5, Loss: 3.4652, Reg Loss: 0.0278
Epoch 3/5, Loss: 3.4669, Reg Loss: 0.0266
Epoch 4/5, Loss: 3.4651, Reg Loss: 0.0236
Epoch 5/5, Loss: 3.4659, Reg Loss: 0.0231
