In [None]:
import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
GCNConv._orig_propagate = GCNConv.propagate

import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch_geometric.explain import GNNExplainer, Explainer

model_name = 'old-womprat-13'  # Replace with your model name
# model_name = 'holographic-master-15'  # Replace with your model name
graph_num = 17 
weight_prefix = 'best_loss'
random_seed =  100
bins = [int(i) for i in "400 800 1300 2100 3000 3700 4700 7020 9660".split(' ')] 
dropout_p =  0.5
epochs = 100


bins = torch.tensor(bins, device='cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)


with open(f'../data/graphs/{graph_num}/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

# --- Model Definitions ---
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super().__init__()
        torch.manual_seed(random_seed)

        self.input_layer = GCNConv(data.num_features, hidden_channels, improved=True, cached=True)

        # Create intermediate hidden layers (optional)
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.hidden_layers.append(GCNConv(hidden_channels, hidden_channels, improved=True, cached=True))

        self.output_layer = GCNConv(hidden_channels, len(bins) + 1, cached=True)

    def forward(self, x, edge_index):
        x = self.input_layer(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=dropout_p, training=self.training)

        for layer in self.hidden_layers:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=dropout_p, training=self.training)

        x = self.output_layer(x, edge_index)
        return x

# Load the model with the GCN class
model = torch.load(f'../data/graphs/{graph_num}/models/{model_name}.pt', map_location=device)
model = model.to(device)

model.load_state_dict(torch.load(f'../data/graphs/{graph_num}/models/{model_name}_{weight_prefix}.pt', map_location=device))


In [None]:
import os
import pickle
from torch_geometric.explain import Explainer, GNNExplainer

# === Parameters ===
explainer_epochs = 50
graph_model_path = f'../data/graphs/{graph_num}/models/{model_name}_{weight_prefix}.pt'
explanation_output_path = f'../data/graphs/{graph_num}/explanations/{model_name}/graph_level_explanation.pkl'
os.makedirs(os.path.dirname(explanation_output_path), exist_ok=True)


# === Wrap in Explainer ===
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=explainer_epochs).to(device),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type=None,
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs',
    ),
)

# === Run graph-level explanation ===
print("🔍 Running graph-level explanation...")
explanation = explainer(data.x.to(device), data.edge_index.to(device))

# === Save explanation to file ===
with open(explanation_output_path, 'wb') as f:
    pickle.dump(explanation, f)

print(f"✅ Graph-level explanation saved to: {explanation_output_path}")


In [None]:
with open(explanation_output_path, 'rb') as f:
    explanation = pickle.load(f)

explanation.visualize_feature_importance()  # for features
