# Data Exploration for CausalShapGNN

This notebook explores the benchmark datasets used for evaluating CausalShapGNN.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

from data import DataDownloader, DataPreprocessor

%matplotlib inline
plt.style.use('seaborn-whitegrid')

## 1. Download and Load Data

In [None]:
# Download dataset (e.g., MovieLens-100K for quick exploration)
downloader = DataDownloader('../data')
downloader.download('movielens-100k')

In [None]:
# Load and preprocess
preprocessor = DataPreprocessor('../data', 'movielens-100k')
graph_data = preprocessor.load_data()

print(f"Users: {graph_data.n_users}")
print(f"Items: {graph_data.n_items}")
print(f"Training interactions: {len(graph_data.train_interactions)}")

## 2. Analyze Interaction Distributions

In [None]:
# User degree distribution
user_degrees = defaultdict(int)
item_degrees = defaultdict(int)

for u, i in graph_data.train_interactions:
    user_degrees[u] += 1
    item_degrees[i] += 1

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].hist(list(user_degrees.values()), bins=50, alpha=0.7)
axes[0].set_xlabel('Number of Interactions')
axes[0].set_ylabel('Number of Users')
axes[0].set_title('User Interaction Distribution')
axes[0].set_yscale('log')

axes[1].hist(list(item_degrees.values()), bins=50, alpha=0.7, color='orange')
axes[1].set_xlabel('Number of Interactions')
axes[1].set_ylabel('Number of Items')
axes[1].set_title('Item Popularity Distribution')
axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

## 3. Compute Statistics

In [None]:
stats = preprocessor.compute_statistics(graph_data)

for key, value in stats.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")
    else:
        print(f"{key}: {value}")

## 4. Popularity Bias Analysis

In [None]:
# Lorenz curve for item popularity
item_pops = sorted(item_degrees.values())
cumsum = np.cumsum(item_pops)
cumsum = cumsum / cumsum[-1]

plt.figure(figsize=(8, 8))
plt.plot(np.linspace(0, 1, len(cumsum)), cumsum, label='Lorenz Curve')
plt.plot([0, 1], [0, 1], 'k--', label='Perfect Equality')
plt.xlabel('Cumulative Share of Items')
plt.ylabel('Cumulative Share of Interactions')
plt.title(f'Item Popularity Inequality (Gini = {stats["item_gini"]:.3f})')
plt.legend()
plt.show()