# **Physics-Informed Graph Neural Network (PIGNN)**  

In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = '1'

import gc
import csv
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch_geometric.nn as pyg_nn

from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import random_split
from torch_geometric.data import Data, Batch
from torch_geometric.utils import dense_to_sparse
from sklearn.metrics import roc_auc_score, roc_curve, auc, confusion_matrix

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
    device = "cuda"
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    gc.collect()
    print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))

else:
    print("No GPU available. Training will run on CPU.")
    device = "cpu"

def set_seed(seed):
    random.seed(seed)  
    np.random.seed(seed)  
    torch.manual_seed(seed)  
    torch.cuda.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False  

set_seed(seed=12345)

GPU: NVIDIA GeForce GTX 1660 Ti is available.
torch.cuda.memory_allocated: 0.000000GB
torch.cuda.memory_reserved: 0.000000GB
torch.cuda.max_memory_reserved: 0.000000GB


## 1. SAGEConv Layer and architecture definition

In [2]:
class GraphSAGELayer(nn.Module):
    
    """
    Implements a single GraphSAGE layer.
    
    - Applies message passing by aggregating neighbor embeddings.
    - Uses different linear transformations for each edge type.
    - Combines neighbor embeddings with self embeddings and applies a ReLU activation.
    
    Args:
        in_dim (int): Input feature dimension.
        out_dim (int): Output feature dimension.
        edge_dim (int): Number of edge types.
    """
    
    def __init__(self, in_dim: int, out_dim: int, edge_dim: int): 
        super().__init__()        
        self.lin_neighbors = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=True) for _ in range(edge_dim)])
        self.lin_self = nn.Linear(in_dim, out_dim, bias=True)
        self.act = nn.ReLU()

    def message_passing(self, x: torch.Tensor, adj_tensor: torch.Tensor):
    
        """
        Performs message passing by aggregating neighbor embeddings using adjacency matrices.
        
        Args:
            x (torch.Tensor): Node feature matrix of shape (batch_size, num_nodes, feature_dim).
            adj_tensor (torch.Tensor): Adjacency tensor with multiple edge types.
        Returns:
            torch.Tensor: Aggregated neighbor embeddings.
        """
        
        batch_size, num_nodes, _ = x.shape
        aggregated_neigh_embeds = []
    
        for i in range(adj_tensor.shape[3]):  
            adj_matrix = adj_tensor[:, :, :, i]  
            neigh_embeds_i = torch.bmm(adj_matrix, x) 
            neigh_embeds_i = self.lin_neighbors[i](neigh_embeds_i)
            aggregated_neigh_embeds.append(neigh_embeds_i)

        neigh_embeds = sum(aggregated_neigh_embeds)  
        return neigh_embeds

    def forward(self, x: torch.Tensor, adj_tensor: torch.Tensor):
       
        """
        Forward pass of the GraphSAGE layer.
        
        Args:
            x (torch.Tensor): Node feature matrix.
            adj_tensor (torch.Tensor): Adjacency tensor.   
        Returns:
            torch.Tensor: Output node representations.
        """
        
        neigh_embeds = self.message_passing(x, adj_tensor)
        x_self = self.lin_self(x)
        out = neigh_embeds + x_self  
        return self.act(out)  

class GraphSAGEModel(nn.Module):

    """
    GraphSAGE model with multiple layers for node representation learning.
    
    - Projects input features to a hidden space.
    - Applies three GraphSAGE layers with ReLU activation and dropout.
    - Outputs final node embeddings.
    
    Args:
        in_features (int): Input feature dimension.
        hidden_size (int): Hidden layer size.
        out_features (int): Output feature dimension.
        dropout (float): Dropout rate.
    """
    
    def __init__(self, in_features: int, hidden_size: int, out_features: int, dropout: float = 0.2):
        super().__init__()

        self.input_proj = nn.Linear(in_features, hidden_size, bias=True)        
        self.conv1 = GraphSAGELayer(in_dim=hidden_size, out_dim=hidden_size, edge_dim=16)
        self.conv2 = GraphSAGELayer(in_dim=hidden_size, out_dim=hidden_size, edge_dim=16)
        self.conv3 = GraphSAGELayer(in_dim=hidden_size, out_dim=hidden_size, edge_dim=16)
        self.act = nn.ReLU()
        self.drop = nn.Dropout(p=dropout)        
        self.lin_out = nn.Linear(hidden_size, hidden_size, bias=True)  

    def forward(self, x: torch.Tensor, adj_tensor: torch.Tensor):
        
        """
        Forward pass of the GraphSAGE model.
        
        Args:
            x (torch.Tensor): Input node features.
            adj_tensor (torch.Tensor): Adjacency tensor.
        Returns:
            torch.Tensor: Node embeddings.
        """
        
        x = self.input_proj(x)
        
        x = self.conv1(x, adj_tensor)  
        x = self.act(x)
        x = self.drop(x)
        x = self.conv2(x, adj_tensor) 
        x = self.act(x)
        x = self.drop(x)
        x = self.lin_out(x)  
        return x 

class DNN(nn.Module):
    
    """
    Deep Neural Network (DNN) for path prediction.
    
    - Consists of five fully connected layers with ReLU activations.
    - Outputs a probability score using a sigmoid activation.
    
    Args:
        in_features (int): Input feature dimension.
        hidden_size (int): Hidden layer size.
        out_features (int): Output feature dimension.
    """
    
    def __init__(self, in_features, hidden_size, out_features):
        super(DNN, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, hidden_size)
        self.fc5 = nn.Linear(hidden_size, out_features)

    def forward(self, x):
    
        """
        Forward pass of the DNN.
        
        Args:
            x (torch.Tensor): Input features.
        
        Returns:
            torch.Tensor: Output probabilities.
        """
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = torch.sigmoid(self.fc5(x))  
        return x

class GraphSAGEWithDNN(nn.Module):
    
    """
    Combines GraphSAGE and DNN for path prediction.
    
    - First extracts node embeddings using GraphSAGE.
    - Then applies DNN to predict paths.
    
    Args:
        in_features (int): Input feature dimension.
        hidden_size (int): Hidden layer size.
        out_features (int): Output feature dimension.
        dropout (float): Dropout rate.
    """
    
    def __init__(self, in_features, hidden_size, out_features, dropout=0):
        super().__init__()
        self.graphsage = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout)
        self.dnn = DNN(hidden_size, hidden_size, out_features)

    def forward(self, x, adj_tensor):
    
        """
        Forward pass of the combined model.
        
        Args:
            x (torch.Tensor): Input node features.
            adj_tensor (torch.Tensor): Adjacency tensor.
        
        Returns:
            torch.Tensor: Predicted paths.
        """
        
        node_embeddings = self.graphsage(x, adj_tensor)
        output = self.dnn(node_embeddings)
        return output

## 2. Dataloader, Loss Functions, Evaluation Function

In [None]:
class GraphDataset(Dataset):
    
    """
    Custom dataset loader for graph data.
        - Loads adjacency tensor, node features (X_matrix), and target adjacency matrix (Y_matrix).
    """
    
    def __init__(self, data_dir):
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = torch.load(self.files[idx])
        return data["adj_tensor"], data["X_matrix"], data["Y_matrix"]

"""
Initializes dataset and splits it into training and testing sets.
Creates data loaders for efficient batch processing during training.
"""

data_dir = "_data_"
dataset = GraphDataset(data_dir)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, 
                              worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 
                              generator=torch.Generator().manual_seed(42))

test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, 
                             worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 
                             generator=torch.Generator().manual_seed(42))  

def degree_loss(B):
    
    """
    Computes the degree loss to enforce a single path structure in the adjacency matrix B.
    
    Args:
        B (torch.Tensor): Adjacency matrix of shape (B, N, N)
    Returns:
        torch.Tensor: Degree loss value
    """
    
    batch_size, num_nodes, _ = B.shape
    deg_out = B.sum(dim=-1)  
    deg_in = B.sum(dim=-2)  
    active_nodes = (deg_in > 0) | (deg_out > 0)  
    num_start_nodes = (deg_in == 0) & (deg_out > 0)  
    P_start = (num_start_nodes.sum(dim=-1) - 1) ** 2  
    num_end_nodes = (deg_out == 0) & (deg_in > 0) 
    P_end = (num_end_nodes.sum(dim=-1) - 1) ** 2  
    incorrect_intermediate = ((deg_in != 1) | (deg_out != 1)) & active_nodes 
    P_intermediate = incorrect_intermediate.sum(dim=-1) - 2 
    P_intermediate = torch.clamp(P_intermediate, min=0)      
    L_deg = P_start + P_end + P_intermediate
    
    return L_deg.float().mean()

def cycle_loss(B, K=10):
    
    """
    Penalizes cycles in the predicted adjacency matrix B by computing matrix powers up to K.

    Args:
        B (torch.Tensor): Adjacency matrix of shape (B, N, N).
        K (int, optional): Maximum power to compute. Defaults to 10.
    Returns:
        torch.Tensor: Cycle loss value.
    """
    
    batch_size, num_nodes, _ = B.shape
    B = B / (B.sum(dim=-1, keepdim=True) + 1e-6)
    cycle_penalty = torch.zeros(batch_size, device=B.device)
    B_power = torch.eye(num_nodes, device=B.device).unsqueeze(0).expand(batch_size, -1, -1)

    for k in range(1, K + 1):  
        B_power = torch.bmm(B_power, B)  
        diag_sum = torch.diagonal(B_power, dim1=-2, dim2=-1).sum(dim=-1)  
        cycle_penalty += diag_sum / k 

    return cycle_penalty.mean()

def connectivity_loss(B):
    
    """ 
    Encourages the graph to contain a single connected path structure, 
    verified via Laplacian eigenvalues and connected component analysis.
    
    Args:
        B (torch.Tensor): Adjacency matrix of shape (batch_size, N, N).

    Returns:
        torch.Tensor: Path structure loss value.
    """

    batch_size, num_nodes, _ = B.shape
    deg_out = B.sum(dim=-1)  
    D = torch.diag_embed(deg_out)
    L = D - B
    eigvals = torch.linalg.eigvals(L).real  
    zero_threshold = 1e-5  
    num_components = (eigvals.abs() < zero_threshold).sum(dim=-1)
    path_length = (B.sum(dim=(-1, -2)) / 2).long() + 1 
    expected_components = num_nodes - path_length + 1
    component_penalty = (num_components - expected_components) ** 2 / expected_components**2
    max_pl = path_length.max().item()
    max_pl = min(max_pl, num_nodes) 
    batch_path_eigvals = torch.zeros((batch_size, max_pl), device=B.device) 

    for idx, pl in enumerate(path_length):
        pl_val = min(int(pl.item()), num_nodes)  

        eigvals_path = 2 - 2 * torch.cos(torch.arange(pl_val, device=B.device) * torch.pi / pl_val)
        if pl_val > max_pl:
            print(f"Warning: pl_val ({pl_val}) exceeds max_pl ({max_pl}), skipping assignment")
            continue  

        batch_path_eigvals[idx, :pl_val] = eigvals_path

    sorted_eigvals = torch.sort(eigvals, dim=-1).values[:, :max_pl] 
    batch_path_eigvals = batch_path_eigvals[:, :max_pl]
    spectral_penalty = ((sorted_eigvals - batch_path_eigvals) ** 2).mean(dim=-1) 
    
    return (component_penalty + spectral_penalty).mean()

def masked_bce_loss(pred, target):

    """
    Computes a masked Binary Cross-Entropy (BCE) loss with class imbalance handling.
    
    This function applies BCE loss only to a subset of the target values:  
    - All positive (1) values are included.  
    - A small fraction of negative (0) values are randomly sampled to reduce class imbalance.  
    - A positive weight is applied to further adjust for the imbalance.  
    
    Args:
        pred (torch.Tensor): Predicted logits of shape (B, N, N).
        target (torch.Tensor): Ground truth labels of shape (B, N, N).

    Returns:
        torch.Tensor: The mean masked BCE loss value.
    """

    mask = ((target != 0) | (torch.rand_like(target) < 0.001)).float()
    target_w = target.clone()
    target_w[target == 1] = (1/2.6615810451242892e-03)
    target_w[target == 0] = 0.00001
    target_w = target_w.to(device)
    loss = F.binary_cross_entropy(pred, target, reduction='none', weight=target_w*mask)
    loss = loss * mask 
    return loss.sum() / mask.sum() 

def pinn_loss(output, Y_matrix, alpha, beta, zeta):
    
    """ 
    Computes the total loss including masked-BCE loss and physics-inspired penalties. 
    """

    Y_matrix = Y_matrix.float()
    data_loss = masked_bce_loss(output, Y_matrix)    
    L_deg = degree_loss(output)
    L_cyc = cycle_loss(output)  
    L_con = connectivity_loss(output)      
    pinn_loss = alpha * L_deg + zeta*L_cyc + beta*L_con

    total_loss = data_loss + 0*pinn_loss # Psi = 0
    return total_loss
    
def evaluate_model(model, test_dataloader, device):
    
    """
    Evaluate the architecture using ROC-AUC Metric
    """
    
    model.eval()  
    all_outputs = []
    all_targets = []
    with torch.no_grad():  
        for adj_tensor, x_matrix, y_matrix in test_dataloader:
            adj_tensor, x_matrix, y_matrix = adj_tensor.to(device), x_matrix.to(device), y_matrix.to(device)
            output = model(x_matrix, adj_tensor)  
            all_outputs.append(output.cpu().numpy().flatten())
            all_targets.append(y_matrix.cpu().numpy().flatten())

    all_outputs = np.concatenate(all_outputs).astype(int)
    all_targets = np.concatenate(all_targets).astype(int)
    
    predictions = (all_outputs > 0.5).astype(int)
    auc_roc = roc_auc_score(all_targets, all_outputs)
    return auc_roc

## 3. Train path predictor without physics

In [6]:
"""
Defines model hyperparameters such as input/output dimensions, hidden layer size,
and dropout rate.
"""

in_features = dataset[0][1].shape[1]      
out_features = dataset[0][2].shape[1]    
hidden_size = 512
dropout = 0.1

"""
Sets the weights for the physics-inspired loss functions.
These control the influence of different physics constraints on training.
"""

alpha, beta, zeta = 0.1,0.1,0.1

"""
Creates an instance of the GraphSAGE model with a DNN and moves it to the specified device (CPU/GPU).
"""

model = GraphSAGEWithDNN(in_features, hidden_size, out_features, dropout)
model = model.to(device)

"""
Sets up the Adam optimizer with a learning rate of 0.01.
Uses an exponential learning rate scheduler to decay the learning rate over epochs.
"""

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 
scheduler = ExponentialLR(optimizer, gamma=0.97)

"""
Trains the model for a specified number of epochs.
Iterates through the dataset in batches, computes loss, performs backpropagation,
and updates model weights using the optimizer.
Handles potential linear algebra errors during loss computation.
Tracks and prints average training loss per epoch.
"""

num_epochs = 10
train_losses = []

for epoch in range(num_epochs):
    model.train()  
    total_loss = 0

    for adj_tensor, X_matrix, Y_matrix in train_dataloader:
       
        adj_tensor = adj_tensor.to(device)
        X_matrix = X_matrix.to(device)
        Y_matrix = Y_matrix.to(device)

        optimizer.zero_grad()
        output = model(X_matrix, adj_tensor)  
        loss = pinn_loss(output, Y_matrix.float(), alpha, beta, zeta)  
        
        try:
            loss.backward()
        except torch.linalg.LinAlgError as e:
            print(f"LinAlgError: {e}")
        
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    scheduler.step()
    print(f"[i] Epoch {epoch+1}, \tTrain Loss: {avg_train_loss:.8f}")

""" 
Optional: Save model weights for future use.
- Transformer weights (adjacency encoder) are saved separately from the DNN weights.
"""

torch.save(model.graphsage.state_dict(), "Weights/no_pinn_graphsage_weights.pth")  
torch.save(model.dnn.state_dict(), "Weights/no_pinn_pinn_dnn_weights.pth")

"""
Evaluate the architecture using ROC-AUC Metric
"""

auc_roc = evaluate_model(model, test_dataloader, device)
print(f"[i] ROC-AUC: \t{auc_roc:.5f}")

[i] Epoch 1, 	Train Loss: 5.14408466
[i] Epoch 2, 	Train Loss: 4.37899471
[i] Epoch 3, 	Train Loss: 4.20036497
[i] Epoch 4, 	Train Loss: 4.15533141
[i] Epoch 5, 	Train Loss: 4.02720269
[i] Epoch 6, 	Train Loss: 4.01665365
[i] Epoch 7, 	Train Loss: 4.01995330
[i] Epoch 8, 	Train Loss: 4.05554997
[i] Epoch 9, 	Train Loss: 4.04662886
[i] Epoch 10, 	Train Loss: 4.03627675
[i] ROC-AUC: 	0.82979


## 4. Train path predictor with physics

In [14]:
def pinn_loss(output, Y_matrix, alpha, beta, zeta):
    
    """ 
    Computes the total loss including masked-BCE loss and physics-inspired penalties. 
    """

    Y_matrix = Y_matrix.float()
    data_loss = masked_bce_loss(output, Y_matrix)    
    L_deg = corrected_degree_loss(output)
    L_cyc = scaled_corrected_cycle_loss(output)  
    L_con = path_structure_loss(output)      
    pinn_loss = alpha * L_deg + zeta*L_cyc + beta*L_con

    total_loss = data_loss + pinn_loss # Psi = 1
    return total_loss
    
"""
Defines model hyperparameters such as input/output dimensions, hidden layer size,
and dropout rate.
"""

in_features = dataset[0][1].shape[1]      
out_features = dataset[0][2].shape[1]    
hidden_size = 512
dropout = 0.1

"""
Log training loss
"""

csv_filename = "Exports/warmup_training_log.csv"
with open(csv_filename, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["Epoch", "Batch", "Loss"])  

"""
Sets the weights for the physics-inspired loss functions.
These control the influence of different physics constraints on training.
"""

alpha, beta, zeta = 1,0.0001,1 

"""
Creates an instance of the GraphSAGE model with a DNN and moves it to the specified device (CPU/GPU).
"""

model = GraphSAGEWithDNN(in_features, hidden_size, out_features, dropout)
model = model.to(device)

"""
Sets up the Adam optimizer with a learning rate of 0.01.
Uses an exponential learning rate scheduler to decay the learning rate over epochs.
"""

optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 
scheduler = ExponentialLR(optimizer, gamma=0.97)

"""
Trains the model for a specified number of epochs.
Iterates through the dataset in batches, computes loss, performs backpropagation,
and updates model weights using the optimizer.
Handles potential linear algebra errors during loss computation.
Tracks and prints average training loss per epoch.
"""

num_epochs = 20
train_losses = []
log_data = []

for epoch in range(num_epochs):
    model.train()  
    total_loss = 0

    for batch_idx, (adj_tensor, X_matrix, Y_matrix) in enumerate(train_dataloader):
       
        adj_tensor = adj_tensor.to(device)
        X_matrix = X_matrix.to(device)
        Y_matrix = Y_matrix.to(device)

        optimizer.zero_grad()
        output = model(X_matrix, adj_tensor)  
        loss = pinn_loss(output, Y_matrix.float(), alpha, beta, zeta)  
        
        try:
            loss.backward()
        except torch.linalg.LinAlgError as e:
            print(f"LinAlgError: {e}")
        
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    scheduler.step()
    log_data.append([epoch + 1, batch_idx + 1, loss.item()])
    print(f"[i] Epoch {epoch+1}, \tTrain Loss: {avg_train_loss:.8f}")

"""
Save logs
"""
with open(csv_filename, mode="a", newline="") as file:
    writer = csv.writer(file)
    writer.writerows(log_data)

""" 
Optional: Save model weights for future use.
- Transformer weights (adjacency encoder) are saved separately from the DNN weights.
"""

#torch.save(model.graphsage.state_dict(), "Weights/085-pinn_graphsage_weights.pth")  
#torch.save(model.dnn.state_dict(), "Weights/085-pinn_dnn_weights.pth")

"""
Evaluate the architecture using ROC-AUC Metric
"""

auc_roc = evaluate_model(model, test_dataloader, device)
print(f"[i] ROC-AUC: \t{auc_roc:.5f}")

[i] Epoch 1, 	Train Loss: 372.89824911
[i] Epoch 2, 	Train Loss: 368.66578381
[i] Epoch 3, 	Train Loss: 367.44918999
[i] Epoch 4, 	Train Loss: 367.17449130
[i] Epoch 5, 	Train Loss: 366.50590633
[i] Epoch 6, 	Train Loss: 366.07718776
[i] Epoch 7, 	Train Loss: 365.80567815
[i] Epoch 8, 	Train Loss: 365.78322308
[i] Epoch 9, 	Train Loss: 365.58628200
[i] Epoch 10, 	Train Loss: 365.52144623
[i] Epoch 11, 	Train Loss: 365.53299948
[i] Epoch 12, 	Train Loss: 365.02650158
[i] Epoch 13, 	Train Loss: 365.54958109
[i] Epoch 14, 	Train Loss: 364.73675772
[i] Epoch 15, 	Train Loss: 364.56448247
[i] Epoch 16, 	Train Loss: 364.57081252
[i] Epoch 17, 	Train Loss: 365.59575301
[i] Epoch 18, 	Train Loss: 365.14550546
[i] Epoch 19, 	Train Loss: 367.72257526
[i] Epoch 20, 	Train Loss: 372.68691078
[i] ROC-AUC: 	0.84181


## 5. Evaluate both models and plot ROC-AUC

In [None]:
model_nopinn = GraphSAGEWithDNN(in_features, hidden_size, out_features, dropout)
model_nopinn.graphsage.load_state_dict(torch.load("Weights/no_pinn_graphsage_weights.pth"))
model_nopinn.dnn.load_state_dict(torch.load("Weights/no_pinn_pinn_dnn_weights.pth"))
model_nopinn.to(device)
model_nopinn.eval()

model = GraphSAGEWithDNN(in_features, hidden_size, out_features, dropout)
model.graphsage.load_state_dict(torch.load("Weights/086-pinn_graphsage_weights.pth"))
model.dnn.load_state_dict(torch.load("Weights/086-pinn_dnn_weights.pth"))
model.to(device)
model.eval()

def compute_roc_auc(model, dataloader, device):
    """ Compute ROC-AUC values (FPR, TPR, AUC score) for a given model. """
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, Y_matrix in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            Y_matrix = Y_matrix.to(device)

            outputs = model(X_matrix, adj_tensor)  
            probabilities = torch.sigmoid(outputs)  

            all_probs.append(probabilities.cpu().numpy().flatten())  
            all_labels.append(Y_matrix.cpu().numpy().flatten())  

    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)

    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)

    return fpr, tpr, roc_auc


def plot_roc_auc_comparison(model1, model2, dataloader, device, label1="Model", label2="Model No PINN"):
    """ Plot ROC-AUC curves for two models on the same graph. """
    
    fpr1, tpr1, auc1 = compute_roc_auc(model1, dataloader, device)
    fpr2, tpr2, auc2 = compute_roc_auc(model2, dataloader, device)

    plt.figure(figsize=(8,8))
    plt.plot(fpr1, tpr1, color='blue', lw=1, label=f"{label1} (AUC = {auc1:.4f})")
    plt.plot(fpr2, tpr2, color='red', lw=1, label=f"{label2} (AUC = {auc2:.4f})")
    plt.plot([0, 1], [0, 1], color='black', linestyle="dotted", lw=1, label="Baseline")  

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate",  fontsize=10)
    plt.ylabel("True Positive Rate",  fontsize=10)
    plt.title("ROC Curve Comparison",  fontsize=12)
    plt.legend(loc='lower right', fontsize=10, frameon=True, framealpha=1)
    plt.grid(color='black', linestyle='-', linewidth=.5, alpha=.3)
    plt.draw()
    plt.box(False)    
    plt.savefig('Exports/model_roc_auc_compare.jpeg', dpi=400, bbox_inches='tight', transparent=True)
    plt.tight_layout()
    plt.show()

plot_roc_auc_comparison(model, model_nopinn, test_dataloader, device, label1=r'$\Psi = 1$'+f"\t", label2=r'$\Psi = 0$'+f"\t")

## 7. Encoder Architecture for start/end node prediction

In [50]:
class AE(torch.nn.Module):
    
    """
    Autoencoder (AE) model
    
    Args:
        hidden_size (int): Number of hidden features in the GraphSAGE output.
        out_features (int): Number of features in the input X_matrix.
    """
    
    def __init__(self, hidden_size: int, out_features: int):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 128),                  # Added (_,128) layer
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 9)  
        )
        
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(9, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, out_features)                  # Added (_,128) layer
        )
 
    def forward(self, x):
        
        """
        Forward pass for the autoencoder.
        
        Args:
            x (Tensor): Input tensor of shape (batch_size, num_nodes, hidden_size).  
        Returns:
            Tensor: Reconstructed output of shape (batch_size, num_nodes, out_features).
        """
        
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

class GraphSAGEWithAE(nn.Module):
    
    """
    GraphSAGE model combined with an Autoencoder (AE) for feature learning.
    
    Args:
        in_features (int): Number of input node features.
        hidden_size (int): Number of hidden units in GraphSAGE.
        out_features (int): Output feature size (same as X_matrix.shape[2]).
        dropout (float, optional): Dropout rate. Defaults to 0.
    """
    
    def __init__(self, in_features, hidden_size, out_features, dropout=0):
        super().__init__()
        self.graphsage = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout)
        self.autoencoder = AE(hidden_size, out_features)

    def forward(self, x, adj_tensor):
    
        """
        Forward pass through GraphSAGE followed by the autoencoder.
        
        Args:
            x (Tensor): Input feature matrix of shape (batch_size, num_nodes, in_features).
            adj_tensor (Tensor): Adjacency matrix of shape (batch_size, num_nodes, num_nodes).
        Returns:
            Tensor: Reconstructed feature matrix of shape (batch_size, num_nodes, out_features).
        """
        
        node_embeddings = self.graphsage(x, adj_tensor)
        reconstruction = self.autoencoder(node_embeddings)
        return reconstruction

class EncoderWithClassifier(nn.Module):
    
    """
    Encoder model with a classifier for binary node classification.
    
    Args:
        graphsage (nn.Module): Pretrained GraphSAGE model.
        pretrained_encoder (nn.Module): Pretrained encoder (Autoencoder's encoder part).
        latent_dim (int): Size of the latent representation.
        freeze (bool): If True, freezes GraphSAGE and encoder layers during training.
    """
    
    def __init__(self, graphsage: nn.Module, pretrained_encoder: nn.Module, latent_dim: int, freeze: bool):
        super().__init__()
        self.graphsage = graphsage 
        self.encoder = pretrained_encoder

        if freeze:
            for param in self.encoder.parameters():
                param.requires_grad = False
            for param in self.graphsage.parameters():
                param.requires_grad = False

        self.classifier = nn.Sequential(
            torch.nn.Linear(latent_dim, latent_dim),      
            torch.nn.ReLU(),
            torch.nn.Linear(latent_dim, 1),   
            nn.Sigmoid()
        )
    
    def forward(self, x, adj_tensor):
        
        """
        Forward pass through GraphSAGE, encoder, and classifier.
        
        Args:
            x (Tensor): Input feature matrix of shape (batch_size, num_nodes, in_features).
            adj_tensor (Tensor): Adjacency matrix of shape (batch_size, num_nodes, num_nodes).
        
        Returns:
            Tensor: Classification probabilities of shape (batch_size, num_nodes).
        """
        
        batch_size, num_nodes, _ = x.shape  
        node_embeddings = self.graphsage(x, adj_tensor)
        node_embeddings = node_embeddings.view(batch_size * num_nodes, node_embeddings.shape[2])
        latent_repr = self.encoder(node_embeddings)
        classification_output = self.classifier(latent_repr)
        classification_output = classification_output.view(batch_size, num_nodes)
        return classification_output

## 8. Train encoder using self-supervised encoder-decoder

In [23]:
"""
Hyperparameter Initialization

Defines the key hyperparameters for the model, including:
- in_features: The number of input node features, extracted from the dataset.
- hidden_size: The size of the hidden layers in the model.
- dropout: Dropout rate to prevent overfitting.
"""

in_features = dataset[0][1].shape[1]  
hidden_size = 512
dropout = 0.01

"""
Model Initialization

- Initializes the GraphSAGEWithAE model using the defined hyperparameters.
- Moves the model to the appropriate computing device (CPU/GPU).
"""

model = GraphSAGEWithAE(in_features, hidden_size, in_features, dropout)
model = model.to(device)

"""
Optimizer and Loss Function

- Uses the Adam optimizer with a learning rate of 0.001 to train the model.
- Mean Squared Error (MSE) is used as the loss function to measure reconstruction quality.
"""

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

"""
Autoencoder Training Loop

- Trains the autoencoder for a specified number of epochs.
- Iterates over the training dataset, performs forward and backward passes, and updates model weights.
- Computes the reconstruction loss based on the difference between the original and reconstructed node features.
- Outputs the training loss for each epoch.
"""

num_epochs = 10
train_losses = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for adj_tensor, X_matrix, _ in train_dataloader: 
        adj_tensor = adj_tensor.to(device)
        X_matrix = X_matrix.to(device)

        optimizer.zero_grad()
        reconstruction = model(X_matrix, adj_tensor)  

        loss = criterion(reconstruction, X_matrix)  

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Train Loss: {total_loss/len(train_dataloader):.8f}")

"""
Model Weights Saving

- Saves the trained weights of the GraphSAGE and Autoencoder components.
- Weights are stored in the 'Weights' directory for later use in downstream tasks.
"""

torch.save(model.graphsage.state_dict(), "Weights/graphsage_weights.pth")  
torch.save(model.autoencoder.state_dict(), "Weights/autoencoder_weights.pth")

Epoch 1, Train Loss: 0.03863320
Epoch 2, Train Loss: 0.01139502
Epoch 3, Train Loss: 0.00975788
Epoch 4, Train Loss: 0.00734679
Epoch 5, Train Loss: 0.00401579
Epoch 6, Train Loss: 0.00293719
Epoch 7, Train Loss: 0.00708617
Epoch 8, Train Loss: 0.00161704
Epoch 9, Train Loss: 0.00110221
Epoch 10, Train Loss: 0.00071511


## 9. Fine tune SAGEConv+Encoder+Classifier for start node prediction

In [41]:
"""
Load Pretrained GraphSAGE + Encoder and Initialize Classifier

- Loads the pretrained GraphSAGE model from saved weights.
- Loads the pretrained autoencoder and extracts the encoder component.
- Initializes the classifier model with the pretrained encoder.
- Allows optional fine-tuning of the GraphSAGE and encoder by setting `freeze` to False.
"""

graphsage_model = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout).to(device)
graphsage_model.load_state_dict(torch.load("Weights/graphsage_weights.pth"))  

pretrained_ae = AE(hidden_size, in_features).to(device)  
pretrained_ae.load_state_dict(torch.load("Weights/autoencoder_weights.pth"))  
pretrained_encoder = pretrained_ae.encoder 

latent_dim = 9           
freeze = False  
classifier_model = EncoderWithClassifier(graphsage_model, pretrained_encoder, latent_dim, freeze).to(device)

"""
Define Optimizer, Scheduler, and Loss Function

- Uses Adam optimizer with an initial learning rate of 0.01.
- Applies an exponential learning rate decay with a gamma value of 0.9.
- Binary Cross-Entropy Loss (BCELoss) is used for classification.
"""

optimizer = torch.optim.Adam(classifier_model.parameters(), lr=0.0001)
scheduler = ExponentialLR(optimizer, gamma=0.9)
criterion = torch.nn.BCELoss()

num_epochs = 10
train_losses = []

"""
Training Loop for Graph Embeddings -> Encoder -> Classifier

- Iterates over the training dataset for a specified number of epochs.
- Extracts graph adjacency tensors and node feature matrices.
- Retrieves start node labels from the last feature in X_matrix.
- Performs forward propagation, computes loss, and updates model parameters.
- Applies a learning rate scheduler for gradual decay.
"""

for epoch in range(num_epochs):
    classifier_model.train()
    total_loss = 0

    for adj_tensor, X_matrix, _ in train_dataloader:  
        adj_tensor = adj_tensor.to(device)
        
        labels = X_matrix[:, :, -1].float()  
        labels = labels.to(device)         
        
        X_matrix_mask = X_matrix.clone()  
        X_matrix_mask[:, :, -1] = 0 
        X_matrix_mask = X_matrix_mask.to(device)  
        
        optimizer.zero_grad()
        classification_output = classifier_model(X_matrix_mask, adj_tensor)  
        loss = criterion(classification_output, labels)  

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    scheduler.step()
    print(f"Epoch {epoch+1}, Classifier Train Loss: {total_loss / len(train_dataloader):.8f}")

"""
Save weights
"""
torch.save(classifier_model.graphsage.state_dict(), "Weights/sn_graphsage_ae_classifier_weights.pth")
torch.save(classifier_model.encoder.state_dict(), "Weights/sn_encoder_ae_classifier_weights.pth")
torch.save(classifier_model.classifier.state_dict(), "Weights/sn_classifier_ae_classifier_weights.pth")

Epoch 1, Classifier Train Loss: 0.51988231
Epoch 2, Classifier Train Loss: 0.07938496
Epoch 3, Classifier Train Loss: 0.02257822
Epoch 4, Classifier Train Loss: 0.01578102
Epoch 5, Classifier Train Loss: 0.01362000
Epoch 6, Classifier Train Loss: 0.01253690
Epoch 7, Classifier Train Loss: 0.01197894
Epoch 8, Classifier Train Loss: 0.01153727
Epoch 9, Classifier Train Loss: 0.01129567
Epoch 10, Classifier Train Loss: 0.01112559


## 10. Find optimal threshold for reconstruction error indicating start node

In [47]:
"""
Find the optimal threshold for classification.

- Evaluate the classifier on the training set without updating weights.
- Store true labels and model predictions.
- Iterate over a range of thresholds to find the one with the highest score.
- Compute final evaluation metrics using the best threshold.
"""

classifier_model.eval()  
all_labels = []
all_outputs = []

with torch.no_grad():
    for adj_tensor, X_matrix, _ in test_dataloader:
        adj_tensor = adj_tensor.to(device)
        X_matrix = X_matrix.to(device)

        labels = X_matrix[:, :, -1].float() 
        classification_output = classifier_model(X_matrix, adj_tensor)  

        all_labels.extend(labels.cpu().numpy().flatten())  
        all_outputs.extend(classification_output.cpu().numpy().flatten())  

all_labels = np.array(all_labels)
all_outputs = np.array(all_outputs)

best_threshold = 0
best_s = 0
thresholds = np.arange(0, 0.9, 0.0001)

for threshold in thresholds:
    preds = (all_outputs > threshold).astype(float)  
    s = roc_auc_score(all_labels, preds) 

    if s > best_s:  
        best_s = s
        best_threshold = threshold

"""
Evaluate the architecture using ROC-AUC Metric
"""

final_preds = (all_outputs > best_threshold).astype(float)  
roc_auc = roc_auc_score(all_labels, final_preds)

print(f"Optimal Threshold  : {best_threshold:.5f}")
print(f"Test ROC AUC       : {roc_auc:.4f}")

Optimal Threshold  : 0.00220
Test ROC AUC       : 0.9571


## 11. Fine tune SAGEConv+Encoder+Classifier for end node prediction

In [48]:
"""
Load Pretrained GraphSAGE + Encoder and Initialize Classifier

- Loads the pretrained GraphSAGE model from saved weights.
- Loads the pretrained autoencoder and extracts the encoder component.
- Initializes the classifier model with the pretrained encoder.
- Allows optional fine-tuning of the GraphSAGE and encoder by setting `freeze` to False.
"""

graphsage_model = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout).to(device)
graphsage_model.load_state_dict(torch.load("Weights/graphsage_weights.pth"))  

pretrained_ae = AE(hidden_size, in_features).to(device)  
pretrained_ae.load_state_dict(torch.load("Weights/autoencoder_weights.pth"))  
pretrained_encoder = pretrained_ae.encoder  

latent_dim = 9           
freeze = False   
classifier_model = EncoderWithClassifier(graphsage_model, pretrained_encoder, latent_dim, freeze).to(device)

"""
Define Optimizer, Scheduler, and Loss Function

- Uses Adam optimizer with an initial learning rate of 0.001.
- Applies an exponential learning rate decay with a gamma value of 0.9.
- Binary Cross-Entropy Loss (BCELoss) is used for classification.
"""

optimizer = torch.optim.Adam(classifier_model.parameters(), lr=0.001)
scheduler = ExponentialLR(optimizer, gamma=0.9)
criterion = torch.nn.BCELoss()

num_epochs = 10
train_losses = []

"""
Training Loop for Graph Embeddings -> Encoder -> Classifier

- Iterates over the training dataset for a specified number of epochs.
- Extracts graph adjacency tensors and node feature matrices.
- Retrieves end node labels from the second-to-last feature in X_matrix.
- Performs forward propagation, computes loss, and updates model parameters.
- Applies a learning rate scheduler for gradual decay.
"""

for epoch in range(num_epochs):
    classifier_model.train()
    total_loss = 0

    for adj_tensor, X_matrix, _ in train_dataloader:  
        adj_tensor = adj_tensor.to(device)
        
        labels = X_matrix[:, :, -2].float()  
        labels = labels.to(device)         
        
        X_matrix_mask = X_matrix.clone()  
        X_matrix_mask[:, :, -2] = 0 
        X_matrix_mask = X_matrix_mask.to(device)  
        
        optimizer.zero_grad()
        classification_output = classifier_model(X_matrix_mask, adj_tensor)  
        loss = criterion(classification_output, labels)  

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    scheduler.step()
    print(f"Epoch {epoch+1}, Classifier Train Loss: {total_loss / len(train_dataloader):.8f}")

"""
Save Weights
"""
torch.save(classifier_model.graphsage.state_dict(), "Weights/en_graphsage_ae_classifier_weights.pth")
torch.save(classifier_model.encoder.state_dict(), "Weights/en_encoder_ae_classifier_weights.pth")
torch.save(classifier_model.classifier.state_dict(), "Weights/en_classifier_ae_classifier_weights.pth")

Epoch 1, Classifier Train Loss: 0.14209690
Epoch 2, Classifier Train Loss: 0.05622134
Epoch 3, Classifier Train Loss: 0.05223998
Epoch 4, Classifier Train Loss: 0.05210115
Epoch 5, Classifier Train Loss: 0.05192966
Epoch 6, Classifier Train Loss: 0.05175074
Epoch 7, Classifier Train Loss: 0.05223931
Epoch 8, Classifier Train Loss: 0.05189640
Epoch 9, Classifier Train Loss: 0.05188802
Epoch 10, Classifier Train Loss: 0.05164744


## 12. Find optimal threshold for reconstruction error indicating end node

In [49]:
"""
Find the optimal threshold for classification.

- Evaluate the classifier on the training set without updating weights.
- Store true labels and model predictions.
- Iterate over a range of thresholds to find the one with the highest score.
- Compute final evaluation metrics using the best threshold.
"""

classifier_model.eval()  
all_labels = []
all_outputs = []

with torch.no_grad():
    for adj_tensor, X_matrix, _ in test_dataloader:
        adj_tensor = adj_tensor.to(device)
        X_matrix = X_matrix.to(device)

        labels = X_matrix[:, :, -2].float() 
        classification_output = classifier_model(X_matrix, adj_tensor)  

        all_labels.extend(labels.cpu().numpy().flatten())  
        all_outputs.extend(classification_output.cpu().numpy().flatten())  

all_labels = np.array(all_labels)
all_outputs = np.array(all_outputs)

best_threshold = 0
best_s = 0
thresholds = np.arange(0, 0.9, 0.0001)

for threshold in thresholds:
    preds = (all_outputs > threshold).astype(float)  
    s = roc_auc_score(all_labels, preds) 

    if s > best_s:  
        best_s = s
        best_threshold = threshold

"""
Evaluate the architecture using ROC-AUC Metric
"""

final_preds = (all_outputs > best_threshold).astype(float)  
roc_auc = roc_auc_score(all_labels, final_preds)

print(f"Optimal Threshold  : {best_threshold:.5f}")
print(f"Test ROC AUC       : {roc_auc:.4f}")

Optimal Threshold  : 0.00020
Test ROC AUC       : 0.9591


## 13. Plot four ROC-AUC curves and confusion matrix from paper

In [None]:
"""
Start by loading the four models : Path prediction with and without PINN loss, then the 2 classification blocks
"""

model = GraphSAGEWithDNN(in_features, hidden_size, out_features, dropout)
model.graphsage.load_state_dict(torch.load("Weights/086-pinn_graphsage_weights.pth"))
model.dnn.load_state_dict(torch.load("Weights/086-pinn_dnn_weights.pth"))
model.to(device)
model.eval()

model_nopinn = GraphSAGEWithDNN(in_features, hidden_size, out_features, dropout)
model_nopinn.graphsage.load_state_dict(torch.load("Weights/no_pinn_graphsage_weights.pth"))
model_nopinn.dnn.load_state_dict(torch.load("Weights/no_pinn_pinn_dnn_weights.pth"))
model_nopinn.to(device)
model_nopinn.eval()

graphsage_model_1 = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout)  
pretrained_ae_1 = AE(hidden_size, in_features)
pretrained_encoder_1 = pretrained_ae_1.encoder   

ae_sn = EncoderWithClassifier(graphsage_model_1, pretrained_encoder_1, latent_dim=9, freeze = False)
ae_sn.graphsage.load_state_dict(torch.load("Weights/sn_graphsage_ae_classifier_weights.pth"))  
ae_sn.encoder.load_state_dict(torch.load("Weights/sn_encoder_ae_classifier_weights.pth"))
ae_sn.classifier.load_state_dict(torch.load("Weights/sn_classifier_ae_classifier_weights.pth"))
ae_sn.to(device)
ae_sn.eval()

graphsage_model_2 = GraphSAGEModel(in_features, hidden_size, hidden_size, dropout) 
pretrained_ae_2 = AE(hidden_size, in_features)
pretrained_encoder_2 = pretrained_ae_2.encoder   

ae_en = EncoderWithClassifier(graphsage_model_2, pretrained_encoder_2, latent_dim=9, freeze = False)
ae_en.graphsage.load_state_dict(torch.load("Weights/en_graphsage_ae_classifier_weights.pth"))  
ae_en.encoder.load_state_dict(torch.load("Weights/en_encoder_ae_classifier_weights.pth"))
ae_en.classifier.load_state_dict(torch.load("Weights/en_classifier_ae_classifier_weights.pth"))
ae_en.to(device)
ae_en.eval()

def compute_roc_auc(model, dataloader, device):
    
    """ 
    Compute ROC-AUC values (FPR, TPR, AUC score) for a given model. 
    """
    
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, Y_matrix in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            Y_matrix = Y_matrix.to(device)

            outputs = model(X_matrix, adj_tensor)  
            probabilities = torch.sigmoid(outputs)  

            all_probs.append(probabilities.cpu().numpy().flatten())  
            all_labels.append(Y_matrix.cpu().numpy().flatten())  

    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)

    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)

    return fpr, tpr, roc_auc

def compute_roc_auc_AE(model, dataloader, device, target, threshold): 
    
    """ 
    Compute ROC-AUC values (FPR, TPR, AUC score) for a given model. 
    """
    
    model.eval()
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for adj_tensor, X_matrix, _ in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)
            labels = X_matrix[:, :, -target].float() 
            classification_output = model(X_matrix, adj_tensor)  
            all_labels.extend(labels.cpu().numpy().flatten())  
            all_probs.extend(classification_output.cpu().numpy().flatten())  
    
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    preds = (all_probs > threshold).astype(float)
    fpr, tpr, _ = roc_curve(all_labels, preds)
    roc_auc = auc(fpr, tpr)
    
    return fpr, tpr, roc_auc

    
    
def plot_roc_auc_comparison(model1, 
                            model2, 
                            model3, 
                            model4, 
                            dataloader, 
                            device, 
                            label1, 
                            label2, 
                            label3,
                            label4):

    
    """ 
    Plot ROC-AUC curves for two models on the same graph. 
    """
    
    fpr1, tpr1, auc1 = compute_roc_auc(model1, dataloader, device)
    fpr2, tpr2, auc2 = compute_roc_auc(model2, dataloader, device)
    
    fpr3, tpr3, auc3 = compute_roc_auc_AE(model3, dataloader, device, target=1, threshold=0.003)
    fpr4, tpr4, auc4 = compute_roc_auc_AE(model4, dataloader, device,  target=2, threshold=0.00020)

    plt.figure(figsize=(5,5))
    plt.plot(fpr3, tpr3, color='red', lw=1, label=f"{label3} (AUC = {auc3:.4f})")
    plt.plot(fpr4, tpr4, color='red', lw=1, linestyle='--', label=f"{label4} (AUC = {auc4:.4f})")
    plt.plot(fpr1, tpr1, color='blue', lw=1, label=f"{label1} (AUC = {auc1:.4f})")
    plt.plot(fpr2, tpr2, color='blue', linestyle="--", lw=1, label=f"{label2} (AUC = {auc2:.4f})")
    plt.plot([0, 1], [0, 1], color='black', linestyle="dotted", lw=1, label="Baseline")  # Random classifier line

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate",  fontsize=10)
    plt.ylabel("True Positive Rate",  fontsize=10)
    plt.legend(loc='lower right', fontsize=10, frameon=True, framealpha=1)
    plt.grid(color='black', linestyle='-', linewidth=.5, alpha=.3)
    plt.draw()
    plt.box(False)    
    plt.savefig('Exports/ae_4model_roc_auc_compare.jpeg', dpi=400, bbox_inches='tight', transparent=True)
    plt.tight_layout()
    plt.show()

plot_roc_auc_comparison(model, 
                        model_nopinn, 
                        ae_sn, 
                        ae_en, 
                        test_dataloader, 
                        device, 
                        label1=r'$\mathcal{M}_1,\quad\Psi = 1$'+f"\t", 
                        label2=r'$\mathcal{M}_1,\quad\Psi = 0$'+f"\t", 
                        label3=r'$\mathcal{M}_2,\quad\Psi = 0$'+f"\t", 
                        label4=r'$\mathcal{M}_3,\quad\Psi = 0$'+f"\t")

In [None]:
def compute_confusion_matrix(model, dataloader, device, threshold=0.5, is_autoencoder=False, target=1):
    
    """ 
    Compute and print the confusion matrix for a given model. 
    """
    
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for adj_tensor, X_matrix, Y_matrix in dataloader:
            adj_tensor = adj_tensor.to(device)
            X_matrix = X_matrix.to(device)

            if is_autoencoder:
                labels = X_matrix[:, :, -target]
                classification_output = model(X_matrix, adj_tensor)
                preds = (classification_output > threshold).float()
            else:
                Y_matrix = Y_matrix.to(device)
                labels = Y_matrix
                classification_output = model(X_matrix, adj_tensor)
                preds = (classification_output> threshold).float() 

            all_preds.extend(preds.cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    cm = confusion_matrix(all_labels, all_preds)
    return cm

def plot_confusion_matrix(cm, title="Confusion Matrix", filename=None):
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(4,4))
    sns.heatmap(cm_normalized, cbar=False, annot=True, fmt=".2f", cmap="Blues", 
                xticklabels=["Negative", "Positive"], 
                yticklabels=["Negative", "Positive"], 
                vmin=0, vmax=1)  

    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title(title)
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight', transparent=True)  
    plt.show()

cm1 = compute_confusion_matrix(model, test_dataloader, device, threshold=0.5)
cm2 = compute_confusion_matrix(model_nopinn, test_dataloader, device, threshold=0.5)
cm3 = compute_confusion_matrix(ae_sn, test_dataloader, device, threshold=0.003, is_autoencoder=True, target=1)
cm4 = compute_confusion_matrix(ae_en, test_dataloader, device, threshold=0.00020, is_autoencoder=True, target=2)

plot_confusion_matrix(cm1, title=r'$\mathcal{M}_1,\quad\Psi=1$', filename='Exports/cm_m1_pinn.png')
plot_confusion_matrix(cm2, title=r'$\mathcal{M}_1,\quad\Psi=0$', filename='Exports/cm_m1_nopinn.png')
plot_confusion_matrix(cm3, title=r'$\mathcal{M}_2,\quad\Psi=0$', filename='Exports/cm_m3.png')
plot_confusion_matrix(cm4, title=r'$\mathcal{M}_3,\quad\Psi=0$', filename='Exports/cm_m4.png')