In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m23.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import os.path as osp
import torch
from tqdm import tqdm
from torch_geometric.datasets import AmazonBook
from torch_geometric.nn import LightGCN
from torch_geometric.utils import degree

# Set device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set up data path and load Amazon Book dataset
path = osp.join("./", 'data', 'Amazon')
dataset = AmazonBook(path)
data = dataset[0]

# Get the number of users and books
num_users, num_books = data['user'].num_nodes, data['book'].num_nodes

# Convert heterogeneous graph to homogeneous and move to device
data = data.to_homogeneous().to(device)

# Training configuration
batch_size = 8192

# Create training edge indices (only using edges where source < target)
mask = data.edge_index[0] < data.edge_index[1]
train_edge_label_index = data.edge_index[:, mask]

# Create data loader for training
train_loader = torch.utils.data.DataLoader(
    range(train_edge_label_index.size(1)),
    shuffle=True,
    batch_size=batch_size,
)

# Initialize LightGCN model
model = LightGCN(
    num_nodes=data.num_nodes,
    embedding_dim=64,  # Dimension of embeddings
    num_layers=2,      # Number of graph convolution layers
).to(device)

# Initialize Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


Downloading https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/master/data/amazon-book/user_list.txt
Downloading https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/master/data/amazon-book/item_list.txt
Downloading https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/master/data/amazon-book/train.txt
Downloading https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/master/data/amazon-book/test.txt
Processing...
Done!


In [3]:
def train():
    """
    Training function for one epoch
    Returns:
        float: Average loss for the epoch
    """
    total_loss = total_examples = 0

    for index in tqdm(train_loader):
        # Generate positive and negative samples
        pos_edge_label_index = train_edge_label_index[:, index]

        # Create negative samples by keeping source nodes and randomly sampling target nodes
        neg_edge_label_index = torch.stack([
            pos_edge_label_index[0],
            torch.randint(num_users, num_users + num_books,
                         (index.numel(), ), device=device)
        ], dim=0)

        # Combine positive and negative samples
        edge_label_index = torch.cat([
            pos_edge_label_index,
            neg_edge_label_index,
        ], dim=1)

        # Forward pass and loss computation
        optimizer.zero_grad()
        pos_rank, neg_rank = model(data.edge_index, edge_label_index).chunk(2)

        loss = model.recommendation_loss(
            pos_rank,
            neg_rank,
            node_id=edge_label_index.unique(),
        )

        # Backward pass
        loss.backward()
        optimizer.step()

        # Update statistics
        total_loss += float(loss) * pos_rank.numel()
        total_examples += pos_rank.numel()

    return total_loss / total_examples

@torch.no_grad()
def test(k: int):
    """
    Testing function that computes precision and recall@k
    Args:
        k (int): Number of top items to consider for metrics
    Returns:
        tuple: (precision@k, recall@k)
    """
    # Get embeddings for users and books
    emb = model.get_embedding(data.edge_index)
    user_emb, book_emb = emb[:num_users], emb[num_users:]

    precision = recall = total_examples = 0

    # Process users in batches
    for start in range(0, num_users, batch_size):
        end = start + batch_size
        # Compute recommendations for current batch
        logits = user_emb[start:end] @ book_emb.t()

        # Remove training edges from recommendations
        mask = ((train_edge_label_index[0] >= start) &
                (train_edge_label_index[0] < end))
        logits[train_edge_label_index[0, mask] - start,
               train_edge_label_index[1, mask] - num_users] = float('-inf')

        # Create ground truth matrix
        ground_truth = torch.zeros_like(logits, dtype=torch.bool)
        mask = ((data.edge_label_index[0] >= start) &
                (data.edge_label_index[0] < end))
        ground_truth[data.edge_label_index[0, mask] - start,
                    data.edge_label_index[1, mask] - num_users] = True

        # Count number of relevant items per user
        node_count = degree(data.edge_label_index[0, mask] - start,
                          num_nodes=logits.size(0))

        # Get top-k recommendations
        topk_index = logits.topk(k, dim=-1).indices
        isin_mat = ground_truth.gather(1, topk_index)

        # Compute metrics
        precision += float((isin_mat.sum(dim=-1) / k).sum())
        recall += float((isin_mat.sum(dim=-1) / node_count.clamp(1e-6)).sum())
        total_examples += int((node_count > 0).sum())

    return precision / total_examples, recall / total_examples

In [4]:
# Training loop
for epoch in range(1, 11):
    loss = train()
    precision, recall = test(k=20)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Precision@20: '
          f'{precision:.4f}, Recall@20: {recall:.4f}')

100%|██████████| 291/291 [00:43<00:00,  6.65it/s]


Epoch: 001, Loss: 0.5135, Precision@20: 0.0051, Recall@20: 0.0104


100%|██████████| 291/291 [00:42<00:00,  6.80it/s]


Epoch: 002, Loss: 0.2993, Precision@20: 0.0062, Recall@20: 0.0129


100%|██████████| 291/291 [00:43<00:00,  6.65it/s]


Epoch: 003, Loss: 0.2434, Precision@20: 0.0070, Recall@20: 0.0150


100%|██████████| 291/291 [00:43<00:00,  6.69it/s]


Epoch: 004, Loss: 0.2115, Precision@20: 0.0077, Recall@20: 0.0166


100%|██████████| 291/291 [00:43<00:00,  6.67it/s]


Epoch: 005, Loss: 0.1902, Precision@20: 0.0082, Recall@20: 0.0178


100%|██████████| 291/291 [00:43<00:00,  6.66it/s]


Epoch: 006, Loss: 0.1744, Precision@20: 0.0086, Recall@20: 0.0186


100%|██████████| 291/291 [00:43<00:00,  6.68it/s]


Epoch: 007, Loss: 0.1622, Precision@20: 0.0087, Recall@20: 0.0191


100%|██████████| 291/291 [00:43<00:00,  6.66it/s]


Epoch: 008, Loss: 0.1524, Precision@20: 0.0090, Recall@20: 0.0197


100%|██████████| 291/291 [00:43<00:00,  6.66it/s]


Epoch: 009, Loss: 0.1444, Precision@20: 0.0092, Recall@20: 0.0201


100%|██████████| 291/291 [00:43<00:00,  6.65it/s]


Epoch: 010, Loss: 0.1375, Precision@20: 0.0093, Recall@20: 0.0205
