# Physics-Informed Graph Neural Network (PIGNN)

This version includes the following updates: 

- K-fold Cross validation
- SHAP Sensitivity Analysis

#### Imports

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = '1'

import gc
import csv
import random
import shap
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch_geometric.nn as pyg_nn

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import random_split
from torch_geometric.data import Data, Batch
from torch_geometric.utils import dense_to_sparse
from sklearn.metrics import roc_auc_score, roc_curve, auc, confusion_matrix, precision_recall_fscore_support, f1_score
from sklearn.model_selection import KFold

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
    device = "cuda"
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    gc.collect()
    print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))

else:
    print("No GPU available. Training will run on CPU.")
    device = "cpu"

def set_seed(seed):
    random.seed(seed)  
    np.random.seed(seed)  
    torch.manual_seed(seed)  
    torch.cuda.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False  

set_seed(seed=12345)

## 1. Architecture definitions

In [None]:
class GraphSAGELayer(nn.Module):
    
    """
    Implements a single GraphSAGE layer.
    
    - Applies message passing by aggregating neighbor embeddings.
    - Uses different linear transformations for each edge type.
    - Combines neighbor embeddings with self embeddings and applies a ReLU activation.
    
    Args:
        in_dim (int): Input feature dimension.
        out_dim (int): Output feature dimension.
        edge_dim (int): Number of edge types.
    """
    
    def __init__(self, in_dim: int, out_dim: int, edge_dim: int): 
        super().__init__()        
        self.lin_neighbors = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=True) for _ in range(edge_dim)])
        self.lin_self = nn.Linear(in_dim, out_dim, bias=True)
        self.act = nn.ReLU()

    def message_passing(self, x: torch.Tensor, adj_tensor: torch.Tensor):
    
        """
        Performs message passing by aggregating neighbor embeddings using adjacency matrices.
        
        Args:
            x (torch.Tensor): Node feature matrix of shape (batch_size, num_nodes, feature_dim).
            adj_tensor (torch.Tensor): Adjacency tensor with multiple edge types.
        Returns:
            torch.Tensor: Aggregated neighbor embeddings.
        """
        
        batch_size, num_nodes, _ = x.shape
        aggregated_neigh_embeds = []
    
        for i in range(adj_tensor.shape[3]):  
            adj_matrix = adj_tensor[:, :, :, i]  
            neigh_embeds_i = torch.bmm(adj_matrix, x) 
            neigh_embeds_i = self.lin_neighbors[i](neigh_embeds_i)
            aggregated_neigh_embeds.append(neigh_embeds_i)

        neigh_embeds = sum(aggregated_neigh_embeds)  
        return neigh_embeds

    def forward(self, x: torch.Tensor, adj_tensor: torch.Tensor):
       
        """
        Forward pass of the GraphSAGE layer.
        
        Args:
            x (torch.Tensor): Node feature matrix.
            adj_tensor (torch.Tensor): Adjacency tensor.   
        Returns:
            torch.Tensor: Output node representations.
        """
        
        neigh_embeds = self.message_passing(x, adj_tensor)
        x_self = self.lin_self(x)
        out = neigh_embeds + x_self  
        return self.act(out)  

class GraphSAGEModel(nn.Module):

    """
    GraphSAGE model with multiple layers for node representation learning.
    
    - Projects input features to a hidden space.
    - Applies three GraphSAGE layers with ReLU activation and dropout.
    - Outputs final node embeddings.
    
    Args:
        in_features (int): Input feature dimension.
        hidden_size (int): Hidden layer size.
        out_features (int): Output feature dimension.
        dropout (float): Dropout rate.
    """
    
    def __init__(self, in_features: int, hidden_size: int, out_features: int, dropout: float = 0.2):
        super().__init__()

        self.input_proj = nn.Linear(in_features, hidden_size, bias=True)        
        self.conv1 = GraphSAGELayer(in_dim=hidden_size, out_dim=hidden_size, edge_dim=16)
        self.conv2 = GraphSAGELayer(in_dim=hidden_size, out_dim=hidden_size, edge_dim=16)
        self.conv3 = GraphSAGELayer(in_dim=hidden_size, out_dim=hidden_size, edge_dim=16)
        self.act = nn.ReLU()
        self.drop = nn.Dropout(p=dropout)        
        self.lin_out = nn.Linear(hidden_size, hidden_size, bias=True)  

    def forward(self, x: torch.Tensor, adj_tensor: torch.Tensor):
        
        """
        Forward pass of the GraphSAGE model.
        
        Args:
            x (torch.Tensor): Input node features.
            adj_tensor (torch.Tensor): Adjacency tensor.
        Returns:
            torch.Tensor: Node embeddings.
        """
        
        x = self.input_proj(x)
        
        x = self.conv1(x, adj_tensor)  
        x = self.act(x)
        x = self.drop(x)
        x = self.conv2(x, adj_tensor) 
        x = self.act(x)
        x = self.drop(x)
        x = self.lin_out(x)  
        return x 

class DNN(nn.Module):
    
    """
    Deep Neural Network (DNN) for path prediction.
    
    - Consists of five fully connected layers with ReLU activations.
    - Outputs a probability score using a sigmoid activation.
    
    Args:
        in_features (int): Input feature dimension.
        hidden_size (int): Hidden layer size.
        out_features (int): Output feature dimension.
    """
    
    def __init__(self, in_features, hidden_size, out_features):
        super(DNN, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, hidden_size)
        self.fc5 = nn.Linear(hidden_size, out_features)

    def forward(self, x):
    
        """
        Forward pass of the DNN.
        
        Args:
            x (torch.Tensor): Input features.
        
        Returns:
            torch.Tensor: Output probabilities.
        """
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = torch.sigmoid(self.fc5(x))  
        return x

class GraphSAGEWithDNN(nn.Module):
    
    """
    Combines GraphSAGE and DNN for path prediction.
    
    - First extracts node embeddings using GraphSAGE.
    - Then applies DNN to predict paths.
    
    Args:
        in_features (int): Input feature dimension.
        hidden_size (int): Hidden layer size.
        out_features (int): Output feature dimension.
        dropout (float): Dropout rate.
    """
    
    def __init__(self, in_features, hidden_size, out_features, dropout=0):
        super().__init__()
        self.graphsage = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout)
        self.dnn = DNN(hidden_size, hidden_size, out_features)

    def forward(self, x, adj_tensor):
    
        """
        Forward pass of the combined model.
        
        Args:
            x (torch.Tensor): Input node features.
            adj_tensor (torch.Tensor): Adjacency tensor.
        
        Returns:
            torch.Tensor: Predicted paths.
        """
        
        node_embeddings = self.graphsage(x, adj_tensor)
        output = self.dnn(node_embeddings)
        return output

class AE(torch.nn.Module):
    
    """
    Autoencoder (AE) model
    
    Args:
        hidden_size (int): Number of hidden features in the GraphSAGE output.
        out_features (int): Number of features in the input X_matrix.
    """
    
    def __init__(self, hidden_size: int, out_features: int):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 128),                  # Added (_,128) layer
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 9)  
        )
        
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(9, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, out_features)                  # Added (_,128) layer
        )
 
    def forward(self, x):
        
        """
        Forward pass for the autoencoder.
        
        Args:
            x (Tensor): Input tensor of shape (batch_size, num_nodes, hidden_size).  
        Returns:
            Tensor: Reconstructed output of shape (batch_size, num_nodes, out_features).
        """
        
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

class GraphSAGEWithAE(nn.Module):
    
    """
    GraphSAGE model combined with an Autoencoder (AE) for feature learning.
    
    Args:
        in_features (int): Number of input node features.
        hidden_size (int): Number of hidden units in GraphSAGE.
        out_features (int): Output feature size (same as X_matrix.shape[2]).
        dropout (float, optional): Dropout rate. Defaults to 0.
    """
    
    def __init__(self, in_features, hidden_size, out_features, dropout=0):
        super().__init__()
        self.graphsage = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout)
        self.autoencoder = AE(hidden_size, out_features)

    def forward(self, x, adj_tensor):
    
        """
        Forward pass through GraphSAGE followed by the autoencoder.
        
        Args:
            x (Tensor): Input feature matrix of shape (batch_size, num_nodes, in_features).
            adj_tensor (Tensor): Adjacency matrix of shape (batch_size, num_nodes, num_nodes).
        Returns:
            Tensor: Reconstructed feature matrix of shape (batch_size, num_nodes, out_features).
        """
        
        node_embeddings = self.graphsage(x, adj_tensor)
        reconstruction = self.autoencoder(node_embeddings)
        return reconstruction

class EncoderWithClassifier(nn.Module):
    
    """
    Encoder model with a classifier for binary node classification.
    
    Args:
        graphsage (nn.Module): Pretrained GraphSAGE model.
        pretrained_encoder (nn.Module): Pretrained encoder (Autoencoder's encoder part).
        latent_dim (int): Size of the latent representation.
        freeze (bool): If True, freezes GraphSAGE and encoder layers during training.
    """
    
    def __init__(self, graphsage: nn.Module, pretrained_encoder: nn.Module, latent_dim: int, freeze: bool):
        super().__init__()
        self.graphsage = graphsage 
        self.encoder = pretrained_encoder

        if freeze:
            for param in self.encoder.parameters():
                param.requires_grad = False
            for param in self.graphsage.parameters():
                param.requires_grad = False

        self.classifier = nn.Sequential(
            torch.nn.Linear(latent_dim, latent_dim),      
            torch.nn.ReLU(),
            torch.nn.Linear(latent_dim, 1),   
            nn.Sigmoid()
        )
    
    def forward(self, x, adj_tensor):
        
        """
        Forward pass through GraphSAGE, encoder, and classifier.
        
        Args:
            x (Tensor): Input feature matrix of shape (batch_size, num_nodes, in_features).
            adj_tensor (Tensor): Adjacency matrix of shape (batch_size, num_nodes, num_nodes).
        
        Returns:
            Tensor: Classification probabilities of shape (batch_size, num_nodes).
        """
        
        batch_size, num_nodes, _ = x.shape  
        node_embeddings = self.graphsage(x, adj_tensor)
        node_embeddings = node_embeddings.view(batch_size * num_nodes, node_embeddings.shape[2])
        latent_repr = self.encoder(node_embeddings)
        classification_output = self.classifier(latent_repr)
        classification_output = classification_output.view(batch_size, num_nodes)
        return classification_output

## 2. Dataload, Loss, Evaluation

In [None]:
class GraphDataset(Dataset):
    
    """
    Custom dataset loader for graph data.
        - Loads adjacency tensor, node features (X_matrix), and target adjacency matrix (Y_matrix).
    """
    
    def __init__(self, data_dir):
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = torch.load(self.files[idx])
        return data["adj_tensor"], data["X_matrix"], data["Y_matrix"]

"""
Initializes dataset and splits it into training and testing sets.
Creates data loaders for efficient batch processing during training.
"""

data_dir = "_data_"
dataset = GraphDataset(data_dir)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, 
                              worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 
                              generator=torch.Generator().manual_seed(42))

test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, 
                             worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 
                             generator=torch.Generator().manual_seed(42))  

def degree_loss(B):
    
    """
    Computes the degree loss to enforce a single path structure in the adjacency matrix B.
    
    Args:
        B (torch.Tensor): Adjacency matrix of shape (B, N, N)
    Returns:
        torch.Tensor: Degree loss value
    """
    
    batch_size, num_nodes, _ = B.shape
    deg_out = B.sum(dim=-1)  
    deg_in = B.sum(dim=-2)  
    active_nodes = (deg_in > 0) | (deg_out > 0)  
    num_start_nodes = (deg_in == 0) & (deg_out > 0)  
    P_start = (num_start_nodes.sum(dim=-1) - 1) ** 2  
    num_end_nodes = (deg_out == 0) & (deg_in > 0) 
    P_end = (num_end_nodes.sum(dim=-1) - 1) ** 2  
    incorrect_intermediate = ((deg_in != 1) | (deg_out != 1)) & active_nodes 
    P_intermediate = incorrect_intermediate.sum(dim=-1) - 2 
    P_intermediate = torch.clamp(P_intermediate, min=0)      
    L_deg = P_start + P_end + P_intermediate
    
    return L_deg.float().mean()

def cycle_loss(B, K=10):
    
    """
    Penalizes cycles in the predicted adjacency matrix B by computing matrix powers up to K.

    Args:
        B (torch.Tensor): Adjacency matrix of shape (B, N, N).
        K (int, optional): Maximum power to compute. Defaults to 10.
    Returns:
        torch.Tensor: Cycle loss value.
    """
    
    batch_size, num_nodes, _ = B.shape
    B = B / (B.sum(dim=-1, keepdim=True) + 1e-6)
    cycle_penalty = torch.zeros(batch_size, device=B.device)
    B_power = torch.eye(num_nodes, device=B.device).unsqueeze(0).expand(batch_size, -1, -1)

    for k in range(1, K + 1):  
        B_power = torch.bmm(B_power, B)  
        diag_sum = torch.diagonal(B_power, dim1=-2, dim2=-1).sum(dim=-1)  
        cycle_penalty += diag_sum / k 

    return cycle_penalty.mean()

def connectivity_loss(B):
    
    """ 
    Encourages the graph to contain a single connected path structure, 
    verified via Laplacian eigenvalues and connected component analysis.
    
    Args:
        B (torch.Tensor): Adjacency matrix of shape (batch_size, N, N).

    Returns:
        torch.Tensor: Path structure loss value.
    """

    batch_size, num_nodes, _ = B.shape
    deg_out = B.sum(dim=-1)  
    D = torch.diag_embed(deg_out)
    L = D - B
    eigvals = torch.linalg.eigvals(L).real  
    zero_threshold = 1e-5  
    num_components = (eigvals.abs() < zero_threshold).sum(dim=-1)
    path_length = (B.sum(dim=(-1, -2)) / 2).long() + 1 
    expected_components = num_nodes - path_length + 1
    component_penalty = (num_components - expected_components) ** 2 / expected_components**2
    max_pl = path_length.max().item()
    max_pl = min(max_pl, num_nodes) 
    batch_path_eigvals = torch.zeros((batch_size, max_pl), device=B.device) 

    for idx, pl in enumerate(path_length):
        pl_val = min(int(pl.item()), num_nodes)  

        eigvals_path = 2 - 2 * torch.cos(torch.arange(pl_val, device=B.device) * torch.pi / pl_val)
        if pl_val > max_pl:
            print(f"Warning: pl_val ({pl_val}) exceeds max_pl ({max_pl}), skipping assignment")
            continue  

        batch_path_eigvals[idx, :pl_val] = eigvals_path

    sorted_eigvals = torch.sort(eigvals, dim=-1).values[:, :max_pl] 
    batch_path_eigvals = batch_path_eigvals[:, :max_pl]
    spectral_penalty = ((sorted_eigvals - batch_path_eigvals) ** 2).mean(dim=-1) 
    
    return (component_penalty + spectral_penalty).mean()

def masked_bce_loss(pred, target):

    """
    Computes a masked Binary Cross-Entropy (BCE) loss with class imbalance handling.
    
    This function applies BCE loss only to a subset of the target values:  
    - All positive (1) values are included.  
    - A small fraction of negative (0) values are randomly sampled to reduce class imbalance.  
    - A positive weight is applied to further adjust for the imbalance.  
    
    Args:
        pred (torch.Tensor): Predicted logits of shape (B, N, N).
        target (torch.Tensor): Ground truth labels of shape (B, N, N).

    Returns:
        torch.Tensor: The mean masked BCE loss value.
    """

    mask = ((target != 0) | (torch.rand_like(target) < 0.001)).float()
    target_w = target.clone()
    target_w[target == 1] = (1/2.6615810451242892e-03)
    target_w[target == 0] = 0.00001
    target_w = target_w.to(device)
    loss = F.binary_cross_entropy(pred, target, reduction='none', weight=target_w*mask)
    loss = loss * mask 
    return loss.sum() / mask.sum() 
    
def evaluate_model(model, test_dataloader, device):
    
    """
    Evaluate the architecture using ROC-AUC Metric
    """
    
    model.eval()  
    all_outputs = []
    all_targets = []
    with torch.no_grad():  
        for adj_tensor, x_matrix, y_matrix in test_dataloader:
            adj_tensor, x_matrix, y_matrix = adj_tensor.to(device), x_matrix.to(device), y_matrix.to(device)
            output = model(x_matrix, adj_tensor)  
            all_outputs.append(output.cpu().numpy().flatten())
            all_targets.append(y_matrix.cpu().numpy().flatten())

    all_outputs = np.concatenate(all_outputs).astype(int)
    all_targets = np.concatenate(all_targets).astype(int)
    
    predictions = (all_outputs > 0.5).astype(int)
    auc_roc = roc_auc_score(all_targets, all_outputs)
    return auc_roc

def compute_roc_auc_with_threshold(model, dataloader, device):
    
    """ 
    Compute ROC-AUC values (FPR, TPR, AUC score) and find the best threshold.
    """
    
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, Y_matrix in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            Y_matrix = Y_matrix.to(device)
            outputs = model(X_matrix, adj_tensor)  
            probabilities = torch.sigmoid(outputs)  
            all_probs.append(probabilities.cpu().numpy().flatten())  
            all_labels.append(Y_matrix.cpu().numpy().flatten())  

    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)
    J_scores = tpr - fpr
    best_idx = J_scores.argmax()
    best_threshold = thresholds[best_idx]

    return fpr, tpr, roc_auc, best_threshold

def compute_roc_auc_AE_with_threshold(model, dataloader, device, target): 
    
    """ 
    Compute ROC-AUC values and find the best threshold for AE classifiers. 
    """
    
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, _ in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            labels = X_matrix[:, :, -target].float() 
            classification_output = model(X_matrix, adj_tensor)  
            all_labels.extend(labels.cpu().numpy().flatten())  
            all_probs.extend(classification_output.cpu().numpy().flatten())  
    
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)
    J_scores = tpr - fpr
    best_idx = J_scores.argmax()
    best_threshold = thresholds[best_idx]

    return fpr, tpr, roc_auc, best_threshold

def compute_f1_score(model, dataloader, device, best_threshold):
   
    """ 
    Compute F1 score using the best threshold from ROC-AUC computation.
    """
    
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, Y_matrix in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            Y_matrix = Y_matrix.to(device)

            outputs = model(X_matrix, adj_tensor)  
            probabilities = torch.sigmoid(outputs)  

            all_probs.append(probabilities.cpu().numpy().flatten())  
            all_labels.append(Y_matrix.cpu().numpy().flatten())  

    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    preds = (all_probs >= best_threshold).astype(int)
    f1 = f1_score(all_labels, preds, average="weighted")
    return f1

def compute_f1_score_AE(model, dataloader, device, target, best_threshold):
    
    """ 
    Compute F1 score for Autoencoder classification models.
    """

    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, _ in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            labels = X_matrix[:, :, -target].float()  
            classification_output = model(X_matrix, adj_tensor)  
            all_labels.extend(labels.cpu().numpy().flatten())  
            all_probs.extend(classification_output.cpu().numpy().flatten())  

    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    preds = (all_probs >= best_threshold).astype(int)
    f1 = f1_score(all_labels, preds, average="weighted")
    return f1

def compute_confusion_matrix(model, dataloader, device, threshold, is_autoencoder=False, target=1):
    
    """ 
    Compute confusion matrix and F1 score for a given model. 
    """
    
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for adj_tensor, X_matrix, Y_matrix in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)

            if is_autoencoder:
                labels = X_matrix[:, :, -target]
                classification_output = model(X_matrix, adj_tensor)
                preds = (classification_output > threshold).float()
            else:
                Y_matrix = Y_matrix.to(device)
                labels = Y_matrix
                classification_output = model(X_matrix, adj_tensor)
                preds = (classification_output > threshold).float()

            all_preds.extend(preds.cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    cm = confusion_matrix(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average="weighted") 
    return cm, f1

def plot_confusion_matrix(cm, f1, title="Confusion Matrix", filename=None):
    
    """ 
    Plot confusion matrix with F1 score in title. 
    """
    
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(4,4))
    sns.heatmap(cm_normalized, cbar=False, annot=True, fmt=".2f", cmap="Blues", 
                xticklabels=["Negative", "Positive"], 
                yticklabels=["Negative", "Positive"], 
                vmin=0, vmax=1)  

    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title(f"{title}, F1={f1:.4f}")  

    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight', transparent=True)  
    plt.show()

def pinn_loss(output, Y_matrix, alpha, beta, zeta):
    
    """ 
    Computes the total loss including masked-BCE loss and physics-inspired penalties. 
    """

    Y_matrix = Y_matrix.float()
    data_loss = masked_bce_loss(output, Y_matrix)    
    L_deg = degree_loss(output)
    L_cyc = cycle_loss(output)  
    L_con = connectivity_loss(output)      
    pinn_loss = alpha * L_deg + zeta*L_cyc + beta*L_con

    total_loss = data_loss + pinn_loss 
    return total_loss

    
def path_pred_k_fold_cross_validation(model_class, 
                            name, 
                            dataset, 
                            pinnloss, 
                            k=5, 
                            device='cuda', 
                            alpha=1, 
                            beta=0.0001,
                            lr=0.001,
                            zeta=1):
    
    """
    Perform K-Fold Cross Validation.
    - dataset: The full dataset used for training.
    - k: Number of folds (K=5).
    - device: Device (CPU/GPU).
    """
    
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    all_auc_scores = []
    all_f1_scores = []
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(dataset)):
  
        print(f"Fold {fold+1}/{k}")

        if pinnloss is True:
            print('[i] Using PINN Loss...')
        else:
            print('[i] Using MWBCE Loss...')
        
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        test_subset = torch.utils.data.Subset(dataset, test_idx)
        
        train_loader = DataLoader(train_subset, 
                                  batch_size=64, 
                                  shuffle=True, 
                                  worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 
                                  generator=torch.Generator().manual_seed(42))

        test_loader = DataLoader(test_subset, 
                                  batch_size=64, 
                                  shuffle=False, 
                                  worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 
                                  generator=torch.Generator().manual_seed(42))
        
        model = model_class(in_features, hidden_size, out_features, dropout)
        model = model.to(device)
        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr) # LR -> 0.0001 for no-pinn, 0.001 for pinn
        scheduler = ExponentialLR(optimizer, gamma=0.97)

        train_losses = []
        log_data = []
        
        for epoch in range(20):  
            model.train()
            total_loss = 0

            for batch_idx, (adj_tensor, X_matrix, Y_matrix) in enumerate(train_loader):
                adj_tensor = adj_tensor.to(device)
                X_matrix = X_matrix.to(device)
                Y_matrix = Y_matrix.to(device)

                optimizer.zero_grad()
                output = model(X_matrix, adj_tensor)

                if pinnloss is True:
                    loss = pinn_loss(output, Y_matrix.float(), alpha, beta, zeta)
                else:
                    loss = masked_bce_loss(output, Y_matrix.float())  
                
                try:
                    loss.backward()
                except torch.linalg.LinAlgError as e:
                    print(f"LinAlgError: {e}")

                optimizer.step()
                total_loss += loss.item()

            avg_train_loss = total_loss / len(train_loader)
            train_losses.append(avg_train_loss)
            scheduler.step()
            log_data.append([epoch + 1, batch_idx + 1, loss.item()])
            
            fpr, tpr, auc_roc, best_threshold = compute_roc_auc_with_threshold(model, test_loader, device)
            f1 = compute_f1_score(model, test_loader, device, best_threshold)  
            print(f"[i] Epoch {epoch+1}, \tTrain Loss: {avg_train_loss:.8f}, \tAUC: {auc_roc:.4f}, \tF1: {f1:.4f}")

        fpr, tpr, auc_roc, best_threshold = compute_roc_auc_with_threshold(model, test_loader, device)
        f1 = compute_f1_score(model, test_loader, device, best_threshold)  
        all_auc_scores.append(auc_roc)
        all_f1_scores.append(f1)

        print(f"Fold {fold+1} - AUC: {auc_roc:.4f}, F1: {f1:.4f}")
        torch.save(model.state_dict(), f"Exports/5_{name}_fold_{fold+1}_weights.pth")

    mean_auc = np.mean(all_auc_scores)
    std_auc = np.std(all_auc_scores)
    mean_f1 = np.mean(all_f1_scores)
    std_f1 = np.std(all_f1_scores)

    print("Final K-Fold Results:")
    print(f"Mean AUC: {mean_auc:.4f} ± {std_auc:.4f}")
    print(f"Mean F1 Score: {mean_f1:.4f} ± {std_f1:.4f}")

    return mean_auc, std_auc, mean_f1, std_f1

def node_pred_k_fold_cross_validation(dataset, mode, k=5, device='cuda'):
    
    """
    Perform K-Fold Cross Validation.
        - dataset: The full dataset used for training.
        - k: Number of folds (K=5).
        - device: Device (CPU/GPU).
    """
    
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    all_auc_scores = []
    all_f1_scores = []
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(dataset)):
  
        print(f"Fold {fold+1}/{k}")
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        test_subset = torch.utils.data.Subset(dataset, test_idx)
        
        train_loader = DataLoader(train_subset, 
                                  batch_size=64, 
                                  shuffle=True, 
                                  worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 
                                  generator=torch.Generator().manual_seed(42))
        
        test_loader = DataLoader(test_subset, 
                                 batch_size=64, 
                                 shuffle=False, 
                                 worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 
                                 generator=torch.Generator().manual_seed(42))


        graphsage_model = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout).to(device)
        graphsage_model.load_state_dict(torch.load("Weights/graphsage_weights.pth"))  # Pre-trained
        pretrained_ae = AE(hidden_size, in_features).to(device)  
        pretrained_ae.load_state_dict(torch.load("Weights/autoencoder_weights.pth"))  # Pre-trained
        pretrained_encoder = pretrained_ae.encoder 
        
        latent_dim = 9           
        freeze = False  
        classifier_model = EncoderWithClassifier(graphsage_model, pretrained_encoder, latent_dim, freeze).to(device)
        
        optimizer = torch.optim.Adam(classifier_model.parameters(), lr=0.0001)
        scheduler = ExponentialLR(optimizer, gamma=0.9)
        criterion = torch.nn.BCELoss()
        
        num_epochs = 10
        train_losses = []
        
        for epoch in range(num_epochs):  
            classifier_model.train()
            total_loss = 0

            for batch_idx, (adj_tensor, X_matrix, _) in enumerate(train_loader):
                
                adj_tensor = adj_tensor.to(device)
                labels = X_matrix[:, :, -int(mode)].float()  
                labels = labels.to(device)     
                
                X_matrix_mask = X_matrix.clone()  
                X_matrix_mask[:, :, -int(mode)] = 0 
                X_matrix_mask = X_matrix_mask.to(device)  
                
                optimizer.zero_grad()
                classification_output = classifier_model(X_matrix_mask, adj_tensor)  
                loss = criterion(classification_output, labels)  
        
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            scheduler.step()
            print(f"Epoch {epoch+1}, Classifier Train Loss: {total_loss / len(train_loader):.8f}")
        

        fpr, tpr, auc_roc, best_threshold = compute_roc_auc_AE_with_threshold(classifier_model, test_loader, device, target=int(mode))
        f1 = compute_f1_score_AE(classifier_model, test_loader, device, target=1, best_threshold=best_threshold)
        all_auc_scores.append(auc_roc)
        all_f1_scores.append(f1)
        print(f"Fold {fold+1} - AUC: {auc_roc:.4f}, F1: {f1:.4f}")
        
        print('[+] Saving end-to-end model weights...')
        torch.save(classifier_model.graphsage.state_dict(), f"Exports/fold_{fold+1}_sn_graphsage_ae_classifier_weights.pth")
        torch.save(classifier_model.encoder.state_dict(), f"Exports/fold_{fold+1}_sn_encoder_ae_classifier_weights.pth")
        torch.save(classifier_model.classifier.state_dict(), f"Exports/fold_{fold+1}_sn_classifier_ae_classifier_weights.pth")

    mean_auc = np.mean(all_auc_scores)
    std_auc = np.std(all_auc_scores)
    mean_f1 = np.mean(all_f1_scores)
    std_f1 = np.std(all_f1_scores)

    print("Final K-Fold Results:")
    print(f"Mean AUC: {mean_auc:.4f} ± {std_auc:.4f}")
    print(f"Mean F1 Score: {mean_f1:.4f} ± {std_f1:.4f}")

    return mean_auc, std_auc, mean_f1, std_f1

def plot_roc_auc_comparison(model1, 
                            model2, 
                            model3,
                            threshold3,
                            model4, 
                            threshold4,
                            dataloader, 
                            device, 
                            label1, 
                            label2, 
                            label3,
                            label4):

    
    """ 
    Plot ROC-AUC curves for models on the same graph. 
    """
    
    fpr1, tpr1, auc1, _ = compute_roc_auc_with_threshold(model1, dataloader, device)
    fpr2, tpr2, auc2, _ = compute_roc_auc_with_threshold(model2, dataloader, device)
    fpr3, tpr3, auc3 = compute_roc_auc_AE(model3, dataloader, device, target=1, threshold=threshold3)
    fpr4, tpr4, auc4 = compute_roc_auc_AE(model4, dataloader, device,  target=2, threshold=threshold4)

    plt.figure(figsize=(5,5))
    plt.plot(fpr3, tpr3, color='red', lw=1, label=f"{label3} (AUC = {auc3:.4f})")
    plt.plot(fpr4, tpr4, color='red', lw=1, linestyle='--', label=f"{label4} (AUC = {auc4:.4f})")
    plt.plot(fpr1, tpr1, color='blue', lw=1, label=f"{label1} (AUC = {auc1:.4f})")
    plt.plot(fpr2, tpr2, color='blue', linestyle="--", lw=1, label=f"{label2} (AUC = {auc2:.4f})")
    plt.plot([0, 1], [0, 1], color='black', linestyle="dotted", lw=1, label="Baseline") 

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate",  fontsize=10)
    plt.ylabel("True Positive Rate",  fontsize=10)
    plt.legend(loc='lower right', fontsize=10, frameon=True, framealpha=1)
    plt.grid(color='black', linestyle='-', linewidth=.5, alpha=.3)
    plt.draw()
    plt.box(False)    
    plt.savefig('Exports/all_models_roc_auc_compare.jpeg', dpi=400, bbox_inches='tight', transparent=True)
    plt.tight_layout()
    plt.show()

def compute_roc_auc_with_threshold(model, dataloader, device):
    
    """ 
    Compute ROC-AUC values (FPR, TPR, AUC score) and find the best threshold.
    """
    
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, Y_matrix in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            Y_matrix = Y_matrix.to(device)
            outputs = model(X_matrix, adj_tensor)  
            probabilities = torch.sigmoid(outputs)  
            all_probs.append(probabilities.cpu().numpy().flatten())  
            all_labels.append(Y_matrix.cpu().numpy().flatten())  

    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)
    J_scores = tpr - fpr
    best_idx = J_scores.argmax()
    best_threshold = thresholds[best_idx]

    return fpr, tpr, roc_auc, best_threshold

def compute_roc_auc_AE_with_threshold(model, dataloader, device, target): 
    
    """ 
    Compute ROC-AUC values and find the best threshold for AE classifiers. 
    """
    
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, _ in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            labels = X_matrix[:, :, -target].float() 
            classification_output = model(X_matrix, adj_tensor)  
            all_labels.extend(labels.cpu().numpy().flatten())  
            all_probs.extend(classification_output.cpu().numpy().flatten())  
    
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)
    J_scores = tpr - fpr
    best_idx = J_scores.argmax()
    best_threshold = thresholds[best_idx]

    return fpr, tpr, roc_auc, best_threshold

def compute_roc_auc_AE(model, dataloader, device, target, threshold): 
    
    """ 
    Compute ROC-AUC values (FPR, TPR, AUC score) for a given model. 
    """
    
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, _ in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            labels = X_matrix[:, :, -target].float() 
            classification_output = model(X_matrix, adj_tensor)  
            all_labels.extend(labels.cpu().numpy().flatten())  
            all_probs.extend(classification_output.cpu().numpy().flatten())  
    
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    preds = (all_probs > threshold).astype(float)
    fpr, tpr, _ = roc_curve(all_labels, preds)
    roc_auc = auc(fpr, tpr)
    
    return fpr, tpr, roc_auc

## PIGNN K-Fold Training & Cross Validation

In [None]:
"""
First we train the PIGNN (using L_pinn) then we train the regular GNN
Only the learning rates defer, each is optimal for their architecture. 
Other HPs are similar.
"""

# PIGNN
mean_auc1, std_auc1, mean_f1, std_f1 = path_pred_k_fold_cross_validation(GraphSAGEWithDNN,
                                                                         'GraphSAGEWithDNNPinn',
                                                                         dataset,
                                                                         True,
                                                                         k=5, 
                                                                         device=device, 
                                                                         alpha=1, 
                                                                         beta=0.0001,
                                                                         lr=0.001,
                                                                         zeta=1)

# Regular GNN
mean_auc2, std_auc2, mean_f2, std_f2 = path_pred_k_fold_cross_validation(GraphSAGEWithDNN, 
                                                                         'GraphSAGEWithDNNNoPinn',
                                                                         dataset, 
                                                                         False,
                                                                         k=5, 
                                                                         device=device, 
                                                                         alpha=1, 
                                                                         beta=0.0001,
                                                                         lr=0.0001,
                                                                         zeta=1)

# Fine-tune Initial Access Classifier
mean_auc3, std_auc3, mean_f3, std_f3 = k_fold_cross_validation(dataset,
                                                             mode=1, # Start node
                                                             k=5, 
                                                             device=device)

# Fine-tune Impact Classifier
mean_auc4, std_auc4, mean_f4, std_f4 = k_fold_cross_validation(dataset,
                                                             mode=2, # Start node
                                                             k=5, 
                                                             device=device)

## Find optimal thresholds using Youden's J index

In [None]:
"""
Start by loading the four models : Path prediction with and without PINN loss, then the 2 classification blocks
Don't forget to load the architectures & data loaders first (Blocks 1,2,7, 13, 14a-b)
"""

in_features = dataset[0][1].shape[1]      
out_features = dataset[0][2].shape[1]    
hidden_size = 512
dropout = 0.1

model = GraphSAGEWithDNN(in_features, hidden_size, out_features, dropout)
model.load_state_dict(torch.load("Exports/pinn_PIGNN.pth"))
model.to(device)
model.eval()

model_nopinn = GraphSAGEWithDNN(in_features, hidden_size, out_features, dropout)
model_nopinn.load_state_dict(torch.load("Exports/nopinn_PIGNN.pth"))
model_nopinn.to(device)
model_nopinn.eval()

graphsage_model_1 = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout)  
pretrained_ae_1 = AE(hidden_size, in_features)
pretrained_encoder_1 = pretrained_ae_1.encoder   
ae_sn = EncoderWithClassifier(graphsage_model_1, pretrained_encoder_1, latent_dim=9, freeze = False)
ae_sn.graphsage.load_state_dict(torch.load("Exports/fold_1_sn_graphsage_ae_classifier_weights.pth"))
ae_sn.encoder.load_state_dict(torch.load("Exports/fold_1_sn_encoder_ae_classifier_weights.pth"))
ae_sn.classifier.load_state_dict(torch.load("Exports/fold_1_sn_classifier_ae_classifier_weights.pth"))
ae_sn.to(device)
ae_sn.eval()

graphsage_model_2 = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout) 
pretrained_ae_2 = AE(hidden_size, in_features)
pretrained_encoder_2 = pretrained_ae_2.encoder   
ae_en = EncoderWithClassifier(graphsage_model_2, pretrained_encoder_2, latent_dim=9, freeze = False)
ae_en.graphsage.load_state_dict(torch.load("Exports/fold_1_en_graphsage_ae_classifier_weights.pth"))  
ae_en.encoder.load_state_dict(torch.load("Exports/fold_1_en_encoder_ae_classifier_weights.pth"))
ae_en.classifier.load_state_dict(torch.load("Exports/fold_1_en_classifier_ae_classifier_weights.pth"))
ae_en.to(device)
ae_en.eval()


fpr1, tpr1, auc1, best_threshold1 = compute_roc_auc_with_threshold(model, test_dataloader, device)
fpr2, tpr2, auc2, best_threshold2 = compute_roc_auc_with_threshold(model_nopinn, test_dataloader, device)
fpr3, tpr3, auc3, best_threshold3 = compute_roc_auc_AE_with_threshold(ae_sn, test_dataloader, device, target=1)
fpr4, tpr4, auc4, best_threshold4 = compute_roc_auc_AE_with_threshold(ae_en, test_dataloader, device, target=2)

print(f"[+] Model 1 Best Threshold: {best_threshold1:.8f} -> {auc1}")
print(f"[+] Model 2 Best Threshold: {best_threshold2:.8f} -> {auc2}")
print(f"[+] Model 3 Best Threshold: {best_threshold3:.8f} -> {auc3}")
print(f"[+] Model 4 Best Threshold: {best_threshold4:.8f} -> {auc4}")

plot_roc_auc_comparison(model, 
                        model_nopinn, 
                        ae_sn, 
                        best_treshold3,
                        ae_en, 
                        best_treshold4,
                        test_dataloader, 
                        device, 
                        label1=r'$\mathcal{M}_1,\quad\Psi = 1$'+f"\t", 
                        label2=r'$\mathcal{M}_1,\quad\Psi = 0$'+f"\t", 
                        label3=r'$\mathcal{M}_2,\quad\Psi = 0$'+f"\t", 
                        label4=r'$\mathcal{M}_3,\quad\Psi = 0$'+f"\t")

# Compute F1 scores using these thresholds
f1_1 = compute_f1_score(model, test_dataloader, device, best_threshold1)
f1_2 = compute_f1_score(model_nopinn, test_dataloader, device, best_threshold2)
f1_3 = compute_f1_score_AE(ae_sn, test_dataloader, device, target=1, best_threshold=best_threshold3)
f1_4 = compute_f1_score_AE(ae_en, test_dataloader, device, target=2, best_threshold=best_threshold4)

# Print F1 scores
print(f"Model 1 (Path Prediction, PINN) - F1 Score: {f1_1:.4f}")
print(f"Model 2 (Path Prediction, No PINN) - F1 Score: {f1_2:.4f}")
print(f"Model 3 (AE Classifier SN) - F1 Score: {f1_3:.4f}")
print(f"Model 4 (AE Classifier EN) - F1 Score: {f1_4:.4f}")

# Compute confusion matrices & F1 scores
cm1, f1_1 = compute_confusion_matrix(model, test_dataloader, device, best_threshold1)
cm2, f1_2 = compute_confusion_matrix(model_nopinn, test_dataloader, device, best_threshold2)
cm3, f1_3 = compute_confusion_matrix(ae_sn, test_dataloader, device, best_threshold3, is_autoencoder=True, target=1)
cm4, f1_4 = compute_confusion_matrix(ae_en, test_dataloader, device, best_threshold4, is_autoencoder=True, target=2)

# Plot confusion matrices with F1 scores
plot_confusion_matrix(cm1, f1_1, title=r'$\mathcal{M}_1,\quad\Psi=1$', filename='Exports/kf_cm_m1_pinn.png')
plot_confusion_matrix(cm2, f1_2, title=r'$\mathcal{M}_1,\quad\Psi=0$', filename='Exports/kf_cm_m1_nopinn.png')
plot_confusion_matrix(cm3, f1_3, title=r'$\mathcal{M}_2,\quad\Psi=0$', filename='Exports/kf_cm_m3.png')
plot_confusion_matrix(cm4, f1_4, title=r'$\mathcal{M}_3,\quad\Psi=0$', filename='Exports/kf_cm_m4.png')

## Sensitivity Analysis using SHAP

In [None]:
class SimpleGraphSHAPExplainer:
    
    """
    Simplified SHAP explainer for GraphSAGE models.
    """
    
    def __init__(self, model, device=None):
        """
        Initialize the explainer.
        
        Args:
            model: Your GraphSAGEWithDNN model
            device: Device to run on ('cuda' or 'cpu')
        """
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model = model
        self.device = device
        self.model.to(device)
        self.model.eval()
    
    def explain_feature_importance(self, adj_tensor, X_matrix, background_samples=10):
        """
        Explain feature importance.
        
        Args:
            adj_tensor: Adjacency tensor for the graph
            X_matrix: Node feature matrix
            background_samples: Number of background samples to use
            
        Returns:
            numpy.ndarray: Feature importance scores
        """

        adj_tensor = adj_tensor.to(self.device)
        X_matrix = X_matrix.to(self.device)
        batch_size, num_nodes, feature_dim = X_matrix.shape
        
        def model_wrapper(features):
            """
            Wrapper function to handle feature perturbations
            """
            
            features_tensor = torch.tensor(features, dtype=torch.float32).to(self.device)
            batch_results = []
            batch_size_shap = min(8, features_tensor.shape[0])
            
            for i in range(0, features_tensor.shape[0], batch_size_shap):
                batch_features = features_tensor[i:i+batch_size_shap]
                batch_features = batch_features.reshape(-1, num_nodes, feature_dim)                
                batch_adj = adj_tensor[0:1].repeat(batch_features.shape[0], 1, 1, 1)                
                with torch.no_grad():
                    outputs = self.model(batch_features, batch_adj)
                if len(outputs.shape) > 2:
                    outputs = outputs.mean(dim=(1, 2))
                batch_results.append(outputs.cpu().numpy())
            
            return np.concatenate(batch_results, axis=0)
        
        background = []
        base_features = X_matrix[0].cpu().numpy()        
        for _ in range(background_samples):
            noise = np.random.normal(0, 0.1, size=base_features.shape)
            perturbed = base_features + noise
            background.append(perturbed.flatten())
        
        background = np.array(background)
        
        print("Initializing SHAP Kernel explainer...")
        explainer = shap.KernelExplainer(
            lambda x: model_wrapper(x.reshape(-1, num_nodes * feature_dim)), 
            background
        )
        print("Calculating SHAP values...")
        shap_values = explainer.shap_values(base_features.flatten().reshape(1, -1))
        if isinstance(shap_values, list):
            shap_values = shap_values[0].reshape(num_nodes, feature_dim)
        else:
            shap_values = shap_values.reshape(num_nodes, feature_dim)
        
        return shap_values
    
    def explain_edge_importance(self, adj_tensor, X_matrix, background_samples=10):
        
        """
        Explain edge type importance.
        
        Args:
            adj_tensor: Adjacency tensor for the graph
            X_matrix: Node feature matrix
            background_samples: Number of background samples to use
            
        Returns:
            numpy.ndarray: Edge type importance scores
        """
        
        adj_tensor = adj_tensor.to(self.device)
        X_matrix = X_matrix.to(self.device)
        num_edge_types = adj_tensor.shape[3]
        
        def model_wrapper(edge_masks):
            """
            Wrapper function to handle edge type masking
            """
        
            edge_masks = torch.tensor(edge_masks, dtype=torch.float32).to(self.device)
            batch_results = []
            batch_size_shap = min(8, edge_masks.shape[0])
            
            for i in range(0, edge_masks.shape[0], batch_size_shap):
                batch_masks = edge_masks[i:i+batch_size_shap]                
                batch_adjs = []
                for j in range(len(batch_masks)):
                    mask = batch_masks[j].reshape(1, 1, 1, -1)
                    masked_adj = adj_tensor[0:1] * mask
                    batch_adjs.append(masked_adj)
                    
                stacked_adj = torch.cat(batch_adjs, dim=0)
                batch_X = X_matrix[0:1].repeat(len(batch_masks), 1, 1)
                with torch.no_grad():
                    outputs = self.model(batch_X, stacked_adj)    
                if len(outputs.shape) > 2:
                    outputs = outputs.mean(dim=(1, 2))
                
                batch_results.append(outputs.cpu().numpy())
            
            return np.concatenate(batch_results, axis=0)
        
        background = np.random.randint(0, 2, size=(background_samples, num_edge_types)).astype(float)
        background[0] = np.ones(num_edge_types)        
        print("Initializing SHAP Kernel explainer for edge types...")
        explainer = shap.KernelExplainer(model_wrapper, background)        
        print("Calculating SHAP values for edge types...")
        edge_type_masks = np.eye(num_edge_types)
        shap_values = explainer.shap_values(edge_type_masks)        
        if isinstance(shap_values, list):
            shap_values = shap_values[0]
        
        return shap_values
    
    def visualize_results(self, feature_shap, edge_shap):
        
        """
        Visualize the SHAP results.
        
        Args:
            feature_shap: Feature importance scores
            edge_shap: Edge type importance scores
        """
        
        plt.figure(figsize=(10, 6))        
        feature_importance = np.abs(feature_shap).mean(axis=0)
        feature_names = [f"Feature {i}" for i in range(len(feature_importance))]
        idx = np.argsort(-feature_importance)
        feature_importance = feature_importance[idx]
        feature_names = [feature_names[i] for i in idx]        
        plt.barh(feature_names[:10], feature_importance[:10])
        plt.xlabel('Average |SHAP Value|')
        plt.ylabel('Feature')
        plt.title('Top 10 Node Features by Importance')
        plt.tight_layout()
        plt.savefig('node_feature_importance.png')
        plt.close()        
        plt.figure(figsize=(10, 6))
        
        if len(edge_shap.shape) > 1:
            edge_importance = np.abs(edge_shap).mean(axis=0)
        else:
            edge_importance = np.abs(edge_shap)
        
        edge_names = [f"Edge Type {i}" for i in range(len(edge_importance))]        
        idx = np.argsort(-edge_importance)
        edge_importance = edge_importance[idx]
        edge_names = [edge_names[i] for i in idx]
        plt.barh(edge_names, edge_importance)
        plt.xlabel('Average |SHAP Value|')
        plt.ylabel('Edge Type')
        plt.title('Edge Types by Importance')
        plt.tight_layout()
        plt.savefig('edge_type_importance.png')
        plt.close()
        
        print("Visualizations saved to 'node_feature_importance.png' and 'edge_type_importance.png'")

def explain_graph_model_simple(model, dataloader):
    
    """
    Run the simplified SHAP explainer.
    Args:
        model: Your GraphSAGEWithDNN model
        dataloader: DataLoader with your graph data
    """
    explainer = SimpleGraphSHAPExplainer(model)
    batch_data = next(iter(dataloader))
    adj_tensor, X_matrix, _ = batch_data
    adj_tensor = adj_tensor[:1]  
    X_matrix = X_matrix[:1]
    print(f"adj_tensor shape: {adj_tensor.shape}")
    print(f"X_matrix shape: {X_matrix.shape}")
    print("Explaining feature importance...")
    feature_shap = explainer.explain_feature_importance(adj_tensor, X_matrix)
    print("Explaining edge type importance...")
    edge_shap = explainer.explain_edge_importance(adj_tensor, X_matrix)
    print("Visualizing results...")
    explainer.visualize_results(feature_shap, edge_shap)
    
    return {
        'feature_shap': feature_shap,
        'edge_shap': edge_shap
    }

class ClassifierGraphSHAPExplainer:
    
    """
    Optimized SHAP explainer for EncoderWithClassifier models.
    """
    
    def __init__(self, model, device=None):
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.model.eval()
    
    def explain_feature_importance(self, adj_tensor, X_matrix, target_nodes=None, background_samples=5, n_samples=50):
        adj_tensor, X_matrix = adj_tensor.to(self.device), X_matrix.to(self.device)
        batch_size, num_nodes, feature_dim = X_matrix.shape

        if target_nodes is None:
            selected_node_indices = np.linspace(0, num_nodes-1, min(num_nodes, 50), dtype=int) # 5->50
        else:
            selected_node_indices = np.array(target_nodes)

        def model_wrapper(features):
            features_tensor = torch.tensor(features, dtype=torch.float32).to(self.device)
            batch_results = []
            
            for i in range(0, features_tensor.shape[0], 4): # 4-> 49
                batch_features = features_tensor[i:i+4] # 4 -> 49
                batch_X = X_matrix[0:1].repeat(batch_features.shape[0], 1, 1)
                for j, node_idx in enumerate(selected_node_indices):
                    batch_X[:, node_idx, :] = batch_features[:, j*feature_dim:(j+1)*feature_dim]
                batch_adj = adj_tensor[0:1].repeat(batch_features.shape[0], 1, 1, 1)
                
                with torch.no_grad():
                    outputs = self.model(batch_X, batch_adj)
                batch_results.append(outputs[:, selected_node_indices].cpu().numpy())
            
            return np.concatenate(batch_results, axis=0)
        
        base_features = X_matrix[0, selected_node_indices].cpu().numpy().reshape(-1)
        background = [base_features + np.random.normal(0, 0.05, base_features.shape) for _ in range(background_samples)]
        
        explainer = shap.KernelExplainer(model_wrapper, np.array(background), link="identity")
        shap_values = explainer.shap_values(base_features.reshape(1, -1), nsamples=n_samples, l1_reg="num_features(10)")
        
        if isinstance(shap_values, list):
            shap_values = shap_values[0]
        
        full_shap_values = np.zeros((num_nodes, feature_dim))
        for i, node_idx in enumerate(selected_node_indices):
            node_shap_values = shap_values[0, i*feature_dim:(i+1)*feature_dim]
            
            if len(node_shap_values.shape) > 1:
                node_shap_values = node_shap_values.mean(axis=-1)
                
            full_shap_values[node_idx, :] = node_shap_values
        
        return full_shap_values
    
    def explain_edge_importance(self, adj_tensor, X_matrix, target_nodes=None, background_samples=5, n_samples=500): # samples 50->500
        adj_tensor, X_matrix = adj_tensor.to(self.device), X_matrix.to(self.device)
        _, num_nodes, _, num_edge_types = adj_tensor.shape

        if target_nodes is None:
            selected_node_indices = np.linspace(0, num_nodes-1, min(num_nodes, 50), dtype=int) # 5 -> 50
        else:
            selected_node_indices = np.array(target_nodes)

        def model_wrapper(edge_masks):
            edge_masks = torch.tensor(edge_masks, dtype=torch.float32).to(self.device)
            batch_results = []
            
            for i in range(0, edge_masks.shape[0], 4): # 4 -> 49
                batch_masks = edge_masks[i:i+4]
                batch_adjs = [adj_tensor[0:1] * mask.reshape(1, 1, 1, -1) for mask in batch_masks]
                stacked_adj = torch.cat(batch_adjs, dim=0)
                batch_X = X_matrix[0:1].repeat(len(batch_masks), 1, 1)
                
                with torch.no_grad():
                    outputs = self.model(batch_X, stacked_adj)
                batch_results.append(outputs[:, selected_node_indices].cpu().numpy())
            
            return np.concatenate(batch_results, axis=0)
        
        background = np.zeros((background_samples, num_edge_types))
        background[0] = np.ones(num_edge_types)
        for i in range(1, background_samples):
            background[i] = np.random.choice([0, 1], size=num_edge_types, p=[0.3, 0.7])
            if np.sum(background[i]) == 0:
                background[i][np.random.randint(0, num_edge_types)] = 1
        
        explainer = shap.KernelExplainer(model_wrapper, background, link="identity")
        shap_values = explainer.shap_values(np.eye(num_edge_types), nsamples=n_samples, l1_reg="num_features(2)")
        
        return shap_values[0] if isinstance(shap_values, list) else shap_values

feature_names = ['Computer', 'OU', 'User', 'Group', 'GPO', 'Domain', "Windows Server 2003", "Windows Server 2008", 
                "Windows 7", "Windows 10", "Windows XP", "Windows Server 2012", "Windows Server 2008", 
                "enabled", "hasspn", "highvalue", "is_vulnerable", "End Node", "Start Node"]

edge_types = ["AdminTo", "AllowedToDelegate", "CanRDP", "Contains", "DCSync", "ExecuteDCOM", 
              "GenericAll", "GetChanges", "GetChangesAll", "GpLink", "HasSession", "MemberOf", 
              "Open", "Owns", "WriteDacl", "WriteOwner"]

def plot_importance_values(importance_matrix, 
                           names, 
                           title="Importance Values", 
                           figsize=(12, 10), 
                           scale_method="normalize", 
                           max_features=20):
    """
    Plot feature or edge importance values.
    """

    if len(importance_matrix.shape) == 1:
        importance_matrix = importance_matrix.reshape(1, -1)
        n_items = importance_matrix.shape[1]
    if len(names) < n_items:
        names = list(names) + [f"Unknown {i}" for i in range(len(names), n_items)]
    
    mean_importance = importance_matrix.mean(axis=0)
    sorted_indices = np.argsort(-mean_importance)  # High values at top    
    top_n = min(max_features, n_items)
    sorted_indices = sorted_indices[:top_n]    
    plot_data = []
    item_names = []
    
    for idx in sorted_indices:
        values = importance_matrix[:, idx]
        plot_data.append(values)
        item_names.append(names[idx] if idx < len(names) else f"Unknown {idx}")
    
    all_values = np.concatenate([values for values in plot_data])
    
    if scale_method == "normalize":
        max_abs = np.max(np.abs(all_values))
        if max_abs > 0:
            scaled_plot_data = [values / max_abs for values in plot_data]
        else:
            scaled_plot_data = plot_data
        x_label = "Importance"

    elif scale_method == "percentile":
        pos_values = all_values[all_values > 0]
        neg_values = all_values[all_values < 0]
        
        def scale_by_percentile(x, values):
            if len(values) == 0:
                return x
            percentiles = np.percentile(values, [25, 50, 75, 95])
            if x <= percentiles[0]:
                return x / percentiles[0] * 0.25
            elif x <= percentiles[1]:
                return 0.25 + (x - percentiles[0]) / (percentiles[1] - percentiles[0]) * 0.25
            elif x <= percentiles[2]:
                return 0.5 + (x - percentiles[1]) / (percentiles[2] - percentiles[1]) * 0.25
            elif x <= percentiles[3]:
                return 0.75 + (x - percentiles[2]) / (percentiles[3] - percentiles[2]) * 0.2
            else:
                return 0.95 + (x - percentiles[3]) / (np.max(values) - percentiles[3]) * 0.05
        
        scaled_plot_data = []
        for values in plot_data:
            scaled_values = np.zeros_like(values, dtype=float)
            for i, v in enumerate(values):
                if v > 0:
                    scaled_values[i] = scale_by_percentile(v, pos_values)
                elif v < 0:
                    scaled_values[i] = -scale_by_percentile(abs(v), np.abs(neg_values))
            scaled_plot_data.append(scaled_values)
        
        x_label = "Percentile-based Importance"
        
    elif scale_method == "robust":
        q1, q3 = np.percentile(np.abs(all_values), [25, 75])
        iqr = q3 - q1
        if iqr == 0:  
            iqr = 1.0
        
        scaled_plot_data = []
        for values in plot_data:
            scaled_values = np.zeros_like(values, dtype=float)
            for i, v in enumerate(values):
                if v != 0:
                    scaled_values[i] = v / iqr * np.sign(v)
            scaled_plot_data.append(scaled_values)
        
        x_label = "Robust Scaled Importance (IQR normalized)"
    
    else:
        scaled_plot_data = plot_data
        x_label = "Raw Importance Value"
    
    fig, ax = plt.subplots(figsize=figsize)    
    positions = np.arange(len(item_names))    
    for i, values in enumerate(scaled_plot_data):
        y_pos = np.ones_like(values) * positions[i]        
        pos_mask = values >= 0
        neg_mask = values < 0        
        if np.any(pos_mask):
            ax.scatter(values[pos_mask], y_pos[pos_mask], color='red', alpha=0.25, s=100)
        if np.any(neg_mask):
            ax.scatter(values[neg_mask], y_pos[neg_mask], color='blue', alpha=0.25, s=100)
    
    for i, values in enumerate(scaled_plot_data):
        mean_val = np.mean(values)
        ax.plot([mean_val, mean_val], [positions[i]-0.4, positions[i]+0.4], color='black', linewidth=1.5)
    
    ax.set_yticks(positions)
    ax.set_yticklabels(item_names, fontsize=20)
    ax.set_xlabel(x_label, fontsize=20)
    ax.set_title(title, fontsize=20)
    
    ax.axvline(x=0, color='gray', linestyle='-', alpha=0.3)
    ax.grid(axis='x', linestyle='--', alpha=0.3)    
    if scale_method == "normalize":
        ax.set_xlim(-1.1, 1.1)
        ax.set_xticks([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1])
        ax.tick_params(axis='x', labelsize=20) 
        
    elif scale_method == "percentile":
        ax.set_xlim(-1.1, 1.1)
        ax.set_xticks([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1])
    
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='red', label='Positive Target', markersize=12),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', label='Negative Target', markersize=12),
        Line2D([0], [0], color='black', label='Mean')
    ]
    ax.legend(handles=legend_elements, loc='lower right', fontsize=12)
    
    plt.tight_layout()
    return fig, ax



In [None]:
"""
Run the explainer for PIGNN - model must be declared first
Then we plot the SHAP values for features and edges
"""

results = explain_graph_model_simple(model, test_dataloader)
feature_importance = results['feature_shap']
edge_importance = results['edge_shap']

feature_avg = np.abs(feature_importance).mean(axis=0)
top_features = np.argsort(-feature_avg)[:5]
print("Top 5 features:", top_features)

edge_avg = np.abs(edge_importance).mean(axis=0) if len(edge_importance.shape) > 1 else np.abs(edge_importance)
top_edges = np.argsort(-edge_avg)[:5]
print("Top 5 edge types:", top_edges)

#Eventually save output
np.save('Exports/shap_nodefeature.npy', feature_importance)
np.save('Exports/shap_edgetype.npy', edge_importance)

fig1, ax1 = plot_importance_values(
    feature_importance, 
    names=feature_names,
    title=r'$\mathcal{M}_1\,(\Psi=1)$'+"  Feature SHAP Values",
    scale_method="normalize"
)
plt.draw()
plt.box(False)    
plt.tight_layout()
plt.savefig("Exports/pinn_feature_shap.png", dpi=400, bbox_inches='tight')
plt.show()

if len(edge_importance.shape) > 1:
    fig2, ax2 = plot_importance_values(
        edge_importance, 
        names=edge_types,
        title=r'$\mathcal{M}_1\,(\Psi=1)$'+"  Edge Type SHAP Values",
        scale_method="normalize")
else:
    edge_imp_reshaped = edge_importance.reshape(1, -1)
    fig2, ax2 = plot_importance_values(
        edge_imp_reshaped,
        names=edge_types,
        title=r'$\mathcal{M}_1\,(\Psi=1)$'+"  Edge Type Importance Distribution",
        scale_method="normalize")
    
plt.draw()
plt.box(False)    
plt.tight_layout()
plt.savefig("Exports/pinn_edge_shap.png", dpi=400, bbox_inches='tight')
plt.show()

"""
Now we compute and plot SHAP values for M2 and M3, features and edges
Start with Impact Node (EN)
"""

explainer_en = ClassifierGraphSHAPExplainer(ae_en, device)
adj_tensor, X_matrix, _ = next(iter(test_dataloader))

feature_importance_en = explainer_en.explain_feature_importance(
    adj_tensor, 
    X_matrix)

edge_importance_en = explainer_en.explain_edge_importance(
    adj_tensor, 
    X_matrix)

# Eventually save output
np.save('feature_importance_en.npy', feature_importance_en)
np.save('edge_importance_en.npy', edge_importance_en)

fig, ax = plot_importance_values(feature_importance_en, names=feature_names, title=r'$\mathcal{M}_3$'+"  Feature SHAP Values")
plt.draw()
plt.box(False)    
plt.savefig('en_feature_shap.jpeg', dpi=400, bbox_inches='tight', transparent=True)
plt.tight_layout()
plt.show()

edge_importance_en_avg = edge_importance_en.mean(axis=2)  
fig, ax = plot_importance_values(edge_importance_en_avg, names=edge_types, title=r'$\mathcal{M}_3$'+"  Edge Type SHAP Values")
plt.draw()
plt.box(False)    
plt.savefig('en_edge_shap.jpeg', dpi=400, bbox_inches='tight', transparent=True)
plt.tight_layout()
plt.show()

# Now for Initial Access Model (SN)

explainer_sn = ClassifierGraphSHAPExplainer(ae_sn, device)

feature_importance_sn = explainer_sn.explain_feature_importance(
    adj_tensor, 
    X_matrix)

edge_importance_sn = explainer_sn.explain_edge_importance(
    adj_tensor, 
    X_matrix)

fig, ax = plot_importance_values(feature_importance_sn, names=feature_names, title=r'$\mathcal{M}_2$'+"  Feature SHAP Values")
plt.draw()
plt.box(False)    
plt.savefig('sn_feature_shap.jpeg', dpi=400, bbox_inches='tight', transparent=True)
plt.tight_layout()
plt.show()

edge_importance_sn_avg = edge_importance_sn.mean(axis=2)  # Shape becomes (16, 16)
fig, ax = plot_importance_values(edge_importance_sn_avg, names=edge_types, title=r'$\mathcal{M}_2$'+"  Edge Type SHAP Values")
plt.draw()
plt.box(False)    
plt.savefig('sn_edge_shap.jpeg', dpi=400, bbox_inches='tight', transparent=True)
plt.tight_layout()
plt.show()