In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.explain import Explainer, GNNExplainer
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import sys

sys.path.append('..')
from src.models.gnn_models import create_model
from src.faithfulness.ablation import NodeAblator, DirectionalAblator
from src.utils.metrics import evaluate_faithfulness
from src.utils.tracking import init_experiment, log_metrics, log_graph_example, finish_experiment

## 1. Setup & Configuration

In [None]:
# Configuration
CONFIG = {
    'dataset': 'MUTAG',
    'model': 'gcn',
    'hidden_channels': 64,
    'num_layers': 3,
    'dropout': 0.5,
    'epochs': 200,
    'lr': 0.01,
    'batch_size': 32,
    'seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

torch.manual_seed(CONFIG['seed'])
print(f"Device: {CONFIG['device']}")

## 2. Load Dataset

In [None]:
# Load MUTAG dataset
dataset = TUDataset(root='../data/raw', name=CONFIG['dataset'])

print(f"Dataset: {CONFIG['dataset']}")
print(f"  Graphs: {len(dataset)}")
print(f"  Features: {dataset.num_features}")
print(f"  Classes: {dataset.num_classes}")
print(f"\nExample graph:")
print(f"  Nodes: {dataset[0].num_nodes}")
print(f"  Edges: {dataset[0].edge_index.size(1)}")
print(f"  Label: {dataset[0].y.item()}")

## 3. Train GNN Model

In [None]:
# Split dataset
dataset = dataset.shuffle()
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))

train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:train_size + val_size]
test_dataset = dataset[train_size + val_size:]

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'])
test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'])

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

In [None]:
# Create model
model = create_model(
    CONFIG['model'],
    in_channels=dataset.num_features,
    hidden_channels=CONFIG['hidden_channels'],
    out_channels=dataset.num_classes,
    num_layers=CONFIG['num_layers'],
    dropout=CONFIG['dropout']
).to(CONFIG['device'])

print(f"Model: {CONFIG['model'].upper()}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training functions
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
        correct += (out.argmax(dim=1) == data.y).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        loss = F.cross_entropy(out, data.y)
        total_loss += loss.item() * data.num_graphs
        correct += (out.argmax(dim=1) == data.y).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [None]:
# Train model
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])

train_losses, val_losses = [], []
train_accs, val_accs = [], []

for epoch in range(1, CONFIG['epochs'] + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, CONFIG['device'])
    val_loss, val_acc = evaluate(model, val_loader, CONFIG['device'])
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch:3d}: train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")

# Test accuracy
test_loss, test_acc = evaluate(model, test_loader, CONFIG['device'])
print(f"\nTest accuracy: {test_acc:.4f}")

## 4. Generate Explanations

In [None]:
# Create explainer
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='raw',
    ),
)

# Get explanation for a test graph
test_data = test_dataset[0].to(CONFIG['device'])
explanation = explainer(test_data.x, test_data.edge_index, target=test_data.y)

print("Explanation generated:")
print(f"  Node importance shape: {explanation.node_mask.shape}")
print(f"  Edge importance shape: {explanation.edge_mask.shape}")
print(f"  Top-5 important nodes: {torch.topk(explanation.node_mask, k=5).indices.tolist()}")

## 5. Faithfulness Testing

In [None]:
# Initialize ablators
node_ablator = NodeAblator(ablation_mode="zero")
dir_ablator = DirectionalAblator(node_ablator)

# Run necessity test
print("Running necessity test...")
original_probs = []
ablated_probs = []

for i, data in enumerate(test_dataset[:20]):
    data = data.to(CONFIG['device'])
    
    # Get original prediction
    out = model(data.x, data.edge_index, data.batch)
    pred = out.argmax(dim=1).item()
    prob_orig = F.softmax(out, dim=1)[0, pred].item()
    original_probs.append(prob_orig)
    
    # Get explanation
    explanation = explainer(data.x, data.edge_index, target=torch.tensor([pred]))
    
    # Ablate top-3 nodes
    top_nodes = torch.topk(explanation.node_mask, k=min(3, len(explanation.node_mask))).indices.tolist()
    data_ablated = node_ablator.ablate(data, top_nodes)
    
    # Get ablated prediction
    out_ablated = model(data_ablated.x, data_ablated.edge_index, data_ablated.batch)
    prob_ablated = F.softmax(out_ablated, dim=1)[0, pred].item()
    ablated_probs.append(prob_ablated)

# Compute necessity scores
necessity_scores = [(o - a) / o for o, a in zip(original_probs, ablated_probs) if o > 0]
print(f"\nNecessity score: {np.mean(necessity_scores):.4f} Â± {np.std(necessity_scores):.4f}")

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Training curves
axes[0].plot(train_accs, label='Train')
axes[0].plot(val_accs, label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Training Progress')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Faithfulness
axes[1].scatter(original_probs, ablated_probs, alpha=0.6)
axes[1].plot([0, 1], [0, 1], 'r--', alpha=0.5, label='y=x')
axes[1].set_xlabel('Original Probability')
axes[1].set_ylabel('Ablated Probability')
axes[1].set_title('Necessity Test (should drop below y=x)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/baseline_results.png', dpi=150, bbox_inches='tight')
plt.show()

## Summary

This notebook demonstrated:
- Training a GNN model on graph classification
- Generating explanations with GNNExplainer  
- Testing faithfulness with necessity tests
- Measuring confidence drops after ablating important nodes

Next steps:
- Implement sufficiency tests
- Add directionality analysis for edges
- Compare across multiple explainer methods
- Test on more complex datasets