# CausalShapGNN Explanation Demo

This notebook demonstrates the multi-granularity explanations generated by CausalShapGNN.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt

from config import get_default_config
from data import DataPreprocessor, BipartiteGraphProcessor
from models import CausalShapGNN
from explainers import FeatureShapley, PathShapley, UserProfileShapley
from explainers import ExplanationReport, ExplanationVisualizer
from utils import set_seed

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 1. Load Model and Data

In [None]:
# Load data
preprocessor = DataPreprocessor('../data', 'movielens-100k')
graph_data = preprocessor.load_data()

# Setup config
config = get_default_config()
config['n_users'] = graph_data.n_users
config['n_items'] = graph_data.n_items
config['embed_dim'] = 64
config['n_factors'] = 4
config['n_layers'] = 3

# Process graph
graph_processor = BipartiteGraphProcessor(
    graph_data.n_users, graph_data.n_items,
    graph_data.train_interactions, device
)

# Initialize model (in practice, load trained checkpoint)
model = CausalShapGNN(config, device)
model.eval()

## 2. Generate Recommendations

In [None]:
user_id = 42

with torch.no_grad():
    user_emb, item_emb, _ = model(graph_processor.norm_adj, use_causal_only=True)

scores = torch.matmul(user_emb[user_id], item_emb.t())

# Mask training items
train_items = list(graph_processor.train_user_items[user_id])
if train_items:
    scores[train_items] = -float('inf')

_, top_items = torch.topk(scores, 10)
top_items = top_items.cpu().numpy().tolist()

print(f"Top 10 recommendations for User {user_id}:")
for i, item in enumerate(top_items):
    print(f"  {i+1}. Item {item}")

## 3. Feature-Level Explanations

In [None]:
feature_explainer = FeatureShapley(model, device)
feature_explainer._compute_population_means(user_emb, item_emb)

item_idx = top_items[0]
shapley = feature_explainer.compute(user_id, item_idx, user_emb, item_emb)

factor_names = ['Genre', 'Recency', 'Quality', 'Social']

print(f"\nFeature-level explanation for Item {item_idx}:")
for name, value in zip(factor_names, shapley):
    print(f"  {name}: {value:.4f}")

In [None]:
# Visualize
visualizer = ExplanationVisualizer(factor_names)

colors = ['green' if v >= 0 else 'red' for v in shapley]

plt.figure(figsize=(8, 4))
plt.barh(factor_names, shapley, color=colors)
plt.xlabel('Shapley Value')
plt.title(f'Factor Contributions for Item {item_idx}')
plt.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
plt.tight_layout()
plt.show()

## 4. User Profile Analysis

In [None]:
profile_explainer = UserProfileShapley(feature_explainer)
user_profile = profile_explainer.compute(user_id, top_items, user_emb, item_emb)

report_generator = ExplanationReport(model, device, factor_names)
report = report_generator.generate_user_profile_report(user_id, user_profile, top_items)
print(report)

## 5. Compare Explanations Across Items

In [None]:
# Get explanations for top 5 items
explanations = []
for item_idx in top_items[:5]:
    shapley = feature_explainer.compute(user_id, item_idx, user_emb, item_emb)
    explanations.append({'item_idx': item_idx, 'feature_shapley': shapley})

# Create heatmap
shapley_matrix = np.array([e['feature_shapley'] for e in explanations])

plt.figure(figsize=(10, 6))
plt.imshow(shapley_matrix, cmap='RdBu_r', aspect='auto')
plt.colorbar(label='Shapley Value')
plt.xticks(range(len(factor_names)), factor_names)
plt.yticks(range(len(explanations)), [f"Item {e['item_idx']}" for e in explanations])
plt.title('Factor Contributions Across Recommendations')
plt.tight_layout()
plt.show()