# CausalShapGNN Training Demo

This notebook demonstrates how to train CausalShapGNN on a benchmark dataset.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from config import get_default_config
from data import DataPreprocessor, BipartiteGraphProcessor, RecommendationDataset, collate_fn
from models import CausalShapGNN
from trainers import Trainer
from utils import set_seed

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Data

In [None]:
preprocessor = DataPreprocessor('../data', 'movielens-100k')
graph_data = preprocessor.load_data()

## 2. Configure Model

In [None]:
config = get_default_config()

config['n_users'] = graph_data.n_users
config['n_items'] = graph_data.n_items
config['embed_dim'] = 64
config['n_factors'] = 4
config['n_layers'] = 3

config['training'] = {
    'lr': 0.001,
    'batch_size': 1024,
    'n_epochs': 50
}

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 3. Initialize Model and Trainer

In [None]:
# Process graph
graph_processor = BipartiteGraphProcessor(
    graph_data.n_users, graph_data.n_items,
    graph_data.train_interactions, device
)

# Create data loader
train_dataset = RecommendationDataset(graph_processor, graph_data.train_interactions)
train_loader = DataLoader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    collate_fn=collate_fn
)

# Initialize model
model = CausalShapGNN(config, device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Initialize trainer
trainer = Trainer(model, graph_processor, config, device)

## 4. Training Loop

In [None]:
train_losses = []
val_recalls = []

n_epochs = config['training']['n_epochs']

for epoch in range(n_epochs):
    # Train
    losses = trainer.train_epoch(train_loader, graph_processor.norm_adj)
    train_losses.append(losses['total'])
    
    # Evaluate every 5 epochs
    if (epoch + 1) % 5 == 0:
        val_metrics = trainer.evaluate(
            graph_processor.norm_adj,
            graph_data.val_interactions
        )
        val_recalls.append(val_metrics['recall@20'])
        
        print(f"Epoch {epoch+1}/{n_epochs}")
        print(f"  Loss: {losses['total']:.4f}")
        print(f"  Val R@20: {val_metrics['recall@20']:.4f}")
        print(f"  Val N@20: {val_metrics['ndcg@20']:.4f}")

## 5. Plot Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(train_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')

axes[1].plot(range(5, n_epochs+1, 5), val_recalls, marker='o')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Recall@20')
axes[1].set_title('Validation Recall')

plt.tight_layout()
plt.show()

## 6. Final Evaluation

In [None]:
test_metrics = trainer.evaluate(
    graph_processor.norm_adj,
    graph_data.test_interactions
)

print("\nTest Set Results:")
for k, v in sorted(test_metrics.items()):
    print(f"  {k}: {v:.4f}")