# This notebook is used to test building the Graph Attention Network

##### Dependencies: Pytorch and LayerType constants from utils

In [2]:
import torch
import torch.nn as nn

# Base model of a GAT layer containing initializations

In [3]:
class GATLayerBase(torch.nn.Module):
    def __init__(self, num_in_features, num_out_features, num_of_heads, concat=True, activation=nn.ELU(),
                 dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False):
        super().__init__()
        self.num_of_heads = num_of_heads
        self.num_out_features = num_out_features
        self.concat = concat  # attention heads aggregation method (concatenation/mean)
        self.add_skip_connection = add_skip_connection

        # These status below are trainable weights including the linear layer, 
        # attention matrices, bias and attention aggregation methods

        # Linear projection: Specifying a linear layer have input data as a matrix containing "num_in_features" rows of features
        # and the weights will be of the matrix of dimension (num_of_heads * num_out_features) * num_in_features
        # (num_of_heads * num_out_features) because of multi-head attention mechanism so we have to multiply the num_of_heads
        self.linear_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)

        # Instead of doing [x, y] (concatenation, x/y are node feature vectors) and dot product with "a" (attention matrix)
        # we instead do a dot product between x and "a_left" and y and "a_right" and we sum them up
        self.scoring_fn_target = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))
        self.scoring_fn_source = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))

        if bias and concat:
            self.bias = nn.Parameter(torch.Tensor(num_of_heads * num_out_features))
        elif bias and not concat:
            self.bias = nn.Parameter(torch.Tensor(num_out_features))
        else:
            self.register_parameter('bias', None)

        if add_skip_connection:
            self.skip_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)
        else:
            self.register_parameter('skip_proj', None)

        self.leakyReLU = nn.LeakyReLU(0.2)  # using 0.2 as in the paper of Veličković P. et al. (2018)
        # Non-linearization for classifcation task: For our specific context, that is graph prediction
        self.softmax = nn.Softmax(dim=-1)  # -1 stands for apply the log-softmax along the last dimension
        self.activation = activation # chosen by user
        self.dropout = nn.Dropout(p=dropout_prob)

        self.log_attention_weights = log_attention_weights  # whether we should log the attention weights
        self.attention_weights = None  # for later visualization purposes, I cache the weights here

        self.init_params()
    
    def init_params(self, layer_type):
        nn.init.xavier_uniform_(self.linear_proj.weight)
        nn.init.xavier_uniform_(self.scoring_fn_target)
        nn.init.xavier_uniform_(self.scoring_fn_source)

        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)
    
    def skip_concat_bias(self, attention_coefficients, in_nodes_features, out_nodes_features):
        if self.log_attention_weights:  # potentially log for later visualization in playground.py
            self.attention_weights = attention_coefficients

        # if the tensor is not contiguously stored in memory we'll get an error after we try to do certain ops like view
        # only imp1 will enter this one
        if not out_nodes_features.is_contiguous():
            out_nodes_features = out_nodes_features.contiguous()

        if self.add_skip_connection:  # add skip or residual connection
            if out_nodes_features.shape[-1] == in_nodes_features.shape[-1]:  # if FIN == FOUT
                out_nodes_features += in_nodes_features.unsqueeze(1)
            else:
                out_nodes_features += self.skip_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)

        if self.concat:
            # shape = (N, NH, FOUT) -> (N, NH*FOUT)
            out_nodes_features = out_nodes_features.view(-1, self.num_of_heads * self.num_out_features)
        else:
            # shape = (N, NH, FOUT) -> (N, FOUT)
            out_nodes_features = out_nodes_features.mean(dim=self.head_dim)

        if self.bias is not None:
            out_nodes_features += self.bias

        return out_nodes_features if self.activation is None else self.activation(out_nodes_features)
        

## Specified GAT Layer referencing implementation of Gordic and theory of Veličković P.

In [4]:
class Gat():
    src_nodes_dim = 0  # position of source nodes in edge index
    trg_nodes_dim = 1  # position of target nodes in edge index

    nodes_dim = 0      # node dimension/axis
    head_dim = 1       # attention head dimension/axis

    def __init__(self, num_in_features, num_out_features, num_of_heads, concat=True, activation=nn.ELU(),
                 dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False):

        # Delegate initialization to the base class
        super().__init__(num_in_features, num_out_features, num_of_heads, concat, activation, dropout_prob,
                      add_skip_connection, bias, log_attention_weights)

    def fit_forward(self, data):
        # Step 1: data linear projection + data regularization preprocessing

        in_nodes_features, edge_index = data  # unpack data
        num_of_nodes = in_nodes_features.shape[self.nodes_dim]
        assert edge_index.shape[0] == 2, f'Expected edge index with shape=(2,E) got {edge_index.shape}'

        # shape = (N, FIN) where N - number of nodes in the graph, FIN - number of input features per node
        in_nodes_features = self.dropout(in_nodes_features)
        nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)
        nodes_features_proj = self.dropout(nodes_features_proj)  # in the official GAT imp of Veličković P. et al., they did dropout here as well

        # Step 2: Edge attention calculation

        # Apply the scoring function (* represents element-wise (a.k.a. Hadamard) product)
        # shape = (N, NH, FOUT) * (1, NH, FOUT) -> (N, NH, 1) -> (N, NH) because sum squeezes the last dimension
        # Optimization note: torch.sum() is as performant as .sum() in my experiments
        scores_source = (nodes_features_proj * self.scoring_fn_source).sum(dim=-1)
        scores_target = (nodes_features_proj * self.scoring_fn_target).sum(dim=-1)

        # We simply copy (lift) the scores for source/target nodes based on the edge index. Instead of preparing all
        # the possible combinations of scores we just prepare those that will actually be used and those are defined
        # by the edge index.
        # scores shape = (E, NH), nodes_features_proj_lifted shape = (E, NH, FOUT), E - number of edges in the graph
        scores_source_lifted, scores_target_lifted, nodes_features_proj_lifted = self.lift(scores_source, scores_target, nodes_features_proj, edge_index)
        scores_per_edge = self.leakyReLU(scores_source_lifted + scores_target_lifted)

        # shape = (E, NH, 1)
        attentions_per_edge = self.neighborhood_attention_softmax(scores_per_edge, edge_index[self.trg_nodes_dim], num_of_nodes)
        # Add stochasticity to neighborhood aggregation
        attentions_per_edge = self.dropout(attentions_per_edge)

        # Step 3: Neighborhood aggregation

        # Element-wise (aka Hadamard) product. Operator * does the same thing as torch.mul
        # shape = (E, NH, FOUT) * (E, NH, 1) -> (E, NH, FOUT), 1 gets broadcast into FOUT
        nodes_features_proj_lifted_weighted = nodes_features_proj_lifted * attentions_per_edge

        # This part sums up weighted and projected neighborhood feature vectors for every target node
        # shape = (N, NH, FOUT)
        out_nodes_features = self.aggregate_neighbors(nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes)

        #
        # Step 4: Residual/skip connections, concat and bias
        #

        out_nodes_features = self.skip_concat_bias(attentions_per_edge, in_nodes_features, out_nodes_features)
        return (out_nodes_features, edge_index)

    # Helper functions

    def neighborhood_attention_softmax(self, scores_per_edge, trg_index, num_of_nodes):
        scores_per_edge = scores_per_edge - scores_per_edge.max()
        exp_scores_per_edge = scores_per_edge.exp()  # softmax

        neigborhood_attention_denominator = self.sum_edge_scores_neighborhood_attention(exp_scores_per_edge, trg_index, num_of_nodes)
        attentions_per_edge = exp_scores_per_edge / (neigborhood_attention_denominator + 1e-16)
        return attentions_per_edge.unsqueeze(-1)

    def sum_edge_scores_neighborhood_attention(self, exp_scores_per_edge, trg_index, num_of_nodes):
        trg_index_broadcasted = self.explicit_broadcast(trg_index, exp_scores_per_edge)
        size = list(exp_scores_per_edge.shape)  # convert to list otherwise assignment is not possible
        size[self.nodes_dim] = num_of_nodes
        neighborhood_sums = torch.zeros(size, dtype=exp_scores_per_edge.dtype, device=exp_scores_per_edge.device)
        neighborhood_sums.scatter_add_(self.nodes_dim, trg_index_broadcasted, exp_scores_per_edge)
        return neighborhood_sums.index_select(self.nodes_dim, trg_index)


    def aggregate_neighbors(self, nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes):
        size = list(nodes_features_proj_lifted_weighted.shape)
        size[self.nodes_dim] = num_of_nodes
        out_nodes_features = torch.zeros(size, dtype=in_nodes_features.dtype, device=in_nodes_features.device)
        trg_index_broadcasted = self.explicit_broadcast(edge_index[self.trg_nodes_dim], nodes_features_proj_lifted_weighted)
        out_nodes_features.scatter_add_(self.nodes_dim, trg_index_broadcasted, nodes_features_proj_lifted_weighted)
        return out_nodes_features

    def lift(self, scores_source, scores_target, nodes_features_matrix_proj, edge_index):
        src_nodes_index = edge_index[self.src_nodes_dim]
        trg_nodes_index = edge_index[self.trg_nodes_dim]
        scores_source = scores_source.index_select(self.nodes_dim, src_nodes_index)
        scores_target = scores_target.index_select(self.nodes_dim, trg_nodes_index)
        nodes_features_matrix_proj_lifted = nodes_features_matrix_proj.index_select(self.nodes_dim, src_nodes_index)
        return scores_source, scores_target, nodes_features_matrix_proj_lifted

    def explicit_broadcast(self, this, other):
        for _ in range(this.dim(), other.dim()):
            this = this.unsqueeze(-1)
        return this.expand_as(other)

# Reference:
  1. {Gordić2020PyTorchGAT,<br>
      author = {Gordić, Aleksa},<br>
      title = {pytorch-GAT},<br>
      year = {2020},<br>
      publisher = {GitHub},<br>
      journal = {GitHub repository},<br>
      howpublished = {https://github.com/gordicaleksa/pytorch-GAT}<br>
    }
  2. Veličković P. et al. (2018) Graph Attention Networks. ICLR 2018. 
    

### Thanks to the research of Causual Attention Learning (Wang X. et al., 2022)
### We try building GAT but with Causual Attention Learning, extracting both causual attention and trivial attention and used causaul attention to predict Stroke

##### CALGAT Implementation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import os
from torch_geometric.nn import GATConv  # Import PyG's GATConv

##### Step 1: Define the GNNEncoder with Pytorch Geometry GATConv and CAL architecture diagram of Wang (2022)

In [2]:
class GNNEncoder(nn.Module):
    """
    GNN encoder using PyG's GATConv
    """
    def __init__(self, num_in_features, num_hidden_features, num_of_heads, dropout_prob, alpha):
        super(GNNEncoder, self).__init__()
        self.gat1 = GATConv(num_in_features, num_hidden_features, heads=num_of_heads, dropout=dropout_prob, negative_slope=alpha)
        self.gat2 = GATConv(num_hidden_features * num_of_heads, num_hidden_features, heads=num_of_heads, dropout=dropout_prob, negative_slope=alpha) # Add another GAT layer if needed

    def forward(self, x, edge_index):
        h = F.dropout(x, p=0.6, training=self.training) # Apply dropout to input features
        h = self.gat1(h, edge_index)
        h = F.elu(h) # Use ELU activation as in original GAT paper (Veličković P. et al., 2018)
        h = F.dropout(h, p=0.6, training=self.training) # Apply dropout after first GAT layer
        h = self.gat2(h, edge_index) # Add second GAT layer
        return h

##### Step 2: Define Node Attention Functions and Edge Attention Functions

In [3]:
class NodeAttention(nn.Module):
    """Node-level attention module to separate causal and trivial features"""
    def __init__(self, input_dim, dropout_prob=0.6):
        super(NodeAttention, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(input_dim // 2, 2)  # 2 outputs: causal and trivial
        )

    def forward(self, node_features):
        # Compute attention scores [batch_size, 2]
        scores = self.mlp(node_features)
        attention = F.softmax(scores, dim=1)
        # Split into causal and trivial attention
        node_c = attention[:, 0] # nodes attention scores causal
        node_t = attention[:, 1] # nodes attention scores trivial
        return node_c, node_t


class EdgeAttention(nn.Module):
    """Edge-level attention for causal and trivial connections"""
    def __init__(self, input_dim, dropout_prob=0.6):
        super(EdgeAttention, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim * 2, input_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(input_dim, 2)  # 2 outputs: causal and trivial
        )

    def forward(self, node_features, edge_index):
        # Get features for each edge
        src_features = node_features[edge_index[0]]  # source features [num_edges, num_dim]
        dst_features = node_features[edge_index[1]]  # destination features [num_edges, num_dim]
        # Concatenate source and destination features
        edge_features = torch.cat([src_features, dst_features], dim=1)  # [num_edges, 2*input_dim]
        # Compute attention scores
        scores = self.mlp(edge_features)  # [num_edges, 2]
        attention = F.softmax(scores, dim=1)
        # Split into causal and trivial attention
        edge_c = attention[:, 0]  # [num_edges]
        edge_t = attention[:, 1]  # [num_edges]
        return edge_c, edge_t

##### Step 3: Define Graph Convolutional Layer with Pytorch Geometry

In [4]:
class GraphConv(nn.Module):
    """Graph convolution layer for processing attended graphs"""
    def __init__(self, in_features, out_features):
        super(GraphConv, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.edge_proj = nn.Linear(1, 1, bias=False)

    def forward(self, x, edge_index, edge_attr=None):
        # Basic graph convolution with optional edge weights
        out = self.linear(x)
        # If using edge weights, apply them during message passing (simplified)
        if edge_attr is not None:
            edge_weights = self.edge_proj(edge_attr.view(-1, 1)).view(-1)
            src, dst = edge_index
            for i in range(len(src)):
                out[dst[i]] += edge_weights[i] * x[src[i]]
        return out

##### Alternative Scatter Mean function for torch_scatter.scatter_mean

In [5]:
def scatter_mean_alternative(gated_x, batch):
    """
    Alternative implementation for torch_scatter.scatter_mean
    without using torch_scatter.

    Args:
        gated_x: Tensor of node features (e.g., [N, features]).
        batch: Batch assignment vector (e.g., [N]).

    Returns:
        Tensor of mean node features per batch (e.g., [num_batches, features]).
    """
    if batch is None:
        return gated_x.mean(dim=0, keepdim=True)  # Return mean over all nodes if no batch info

    unique_batches = torch.unique(batch)
    batch_means = []
    for b_idx in unique_batches:
        mask = (batch == b_idx)
        current_batch_nodes = gated_x[mask]
        batch_mean = current_batch_nodes.mean(dim=0) # Mean across nodes within this batch
        batch_means.append(batch_mean)

    return torch.stack(batch_means)

##### Step 4: Define Readout Function (similar to the official GAT architecture proposed by Petar)

In [6]:
class ReadoutFunction(nn.Module):
    """Readout function for graph-level representations"""
    def __init__(self, input_dim):
        super(ReadoutFunction, self).__init__()
        self.linear = nn.Linear(input_dim, input_dim)
        self.gate = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x, batch=None):
        # Gate mechanism for readout
        gate = self.gate(x)
        gated_x = x * gate
        # If batch is provided, use it to aggregate node features
        if batch is not None:
            return scatter_mean_alternative(gated_x, batch)
        # For a single graph, just mean over all nodes
        return torch.mean(gated_x, dim=0, keepdim=True)

##### Step 5: Define Classifier for graph classification based on CAL paper

In [7]:
class Classifier(nn.Module):
    """Classifier for graph classification"""
    def __init__(self, in_features, num_classes):
        super(Classifier, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features, in_features // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(in_features // 2, num_classes)
        )

    def forward(self, x):
        return self.mlp(x)


class CAL_GAT(nn.Module):
    """Graph Attention Network with Causal Attention Learning"""
    def __init__(self, num_in_features, num_hidden_features, num_out_features, num_of_heads=8,
                 dropout_prob=0.6, alpha=0.2, lambda1=0.1, lambda2=0.1):
        super(CAL_GAT, self).__init__()

        # GNN Encoder with PyG's GATConv
        self.gnn_encoder = GNNEncoder(num_in_features, num_hidden_features, num_of_heads, dropout_prob, alpha)

        # Calculate feature dimension after encoder
        hidden_dim = num_hidden_features # After GNNEncoder, hidden_dim is just num_hidden_features (due to layer design, not num_hidden_features * num_heads as before)

        # Node and edge attention
        self.node_attention = NodeAttention(hidden_dim, dropout_prob)
        self.edge_attention = EdgeAttention(hidden_dim, dropout_prob)

        # GraphConv layers for causal and trivial branches
        self.graph_conv_causal = GraphConv(hidden_dim, hidden_dim)
        self.graph_conv_trivial = GraphConv(hidden_dim, hidden_dim)

        # Readout functions
        self.readout_causal = ReadoutFunction(hidden_dim)
        self.readout_trivial = ReadoutFunction(hidden_dim)

        # Classifiers
        self.classifier_causal = Classifier(hidden_dim, num_out_features)
        self.classifier_trivial = Classifier(hidden_dim, num_out_features)
        self.classifier_combined = Classifier(hidden_dim * 2, num_out_features) # Combined classifier input dim adjusted

        # Hyperparameters for loss functions
        self.lambda1 = lambda1  # Controls disentanglement strength
        self.lambda2 = lambda2  # Controls causal intervention strength

    def forward(self, data, shuffled_h_G_t=None): # shuffled_h_G_t is now an argument for causal intervention
        x, edge_index = data.x, data.edge_index
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        batch = data.batch if hasattr(data, 'batch') else None

        # Step 1: Encode graph with GNN
        H = self.gnn_encoder(x, edge_index)

        # Step 2: Compute node and edge attention scores
        alpha_c, alpha_t = self.node_attention(H)  # Node-level attention scores
        beta_c, beta_t = self.edge_attention(H, edge_index)  # Edge-level attention scores

        # Step 3: Create attended representations
        # Node masking
        H_c = H * alpha_c.unsqueeze(1)  # Causal attended nodes
        H_t = H * alpha_t.unsqueeze(1)  # Trivial attended nodes

        # Edge masking - if you have edge attributes to mask, do it here. If masking adjacency directly, it is more complex in PyG and might be out of scope for this step for now.
        edge_attr_c = edge_attr * beta_c.unsqueeze(1) if edge_attr is not None else None # beta_c as edge weights for causal
        edge_attr_t = edge_attr * beta_t.unsqueeze(1) if edge_attr is not None else None # beta_t as edge weights for trivial

        # Step 4: Process through GraphConv layers
        G_c = self.graph_conv_causal(H_c, edge_index, edge_attr_c)
        G_t = self.graph_conv_trivial(H_t, edge_index, edge_attr_t)

        # Step 5: Readout functions to get graph-level representations
        h_G_c = self.readout_causal(G_c, batch)
        h_G_t = self.readout_trivial(G_t, batch)

        # Step 6: Predictions (causal and trivial branches)
        z_G_c = self.classifier_causal(h_G_c)  # Causal prediction
        z_G_t = self.classifier_trivial(h_G_t)  # Trivial prediction

        # Step 7: Causal intervention and combined prediction
        if shuffled_h_G_t is not None: # During training, shuffled_h_G_t will be provided
            h_G_combined = h_G_c + shuffled_h_G_t # Using addition as in Algorithm 1, step 18. You can change to concatenation if preferred: torch.cat([h_G_c, shuffled_h_G_t], dim=1)
            z_G_prime = self.classifier_combined(h_G_combined) # Classifier on combined representation
        else: # During inference, no shuffled trivial features, so use combined features from same graph (like before)
            z_G_prime = self.classifier_combined(torch.cat([h_G_c, h_G_t], dim=1)) # Fallback to combined from same graph if no shuffled trivial features.

        return z_G_c, z_G_t, z_G_prime, h_G_c, h_G_t

    def compute_losses(self, z_G_c, z_G_t, z_G_prime, labels, num_classes):
        """Compute CAL losses"""
        # Supervised loss for causal branch
        if len(labels.shape) == 1 or labels.shape[1] == 1:
            L_sup = F.cross_entropy(z_G_c, labels.view(-1))
        else:
            L_sup = F.binary_cross_entropy_with_logits(z_G_c, labels.float())

        # Uniform classification loss for trivial branch
        uniform_target = torch.ones(z_G_t.size(0), num_classes).to(z_G_t.device) # Corrected uniform target shape
        L_unif = F.kl_div(F.log_softmax(z_G_t, dim=1), uniform_target, reduction='batchmean')

        # Causal intervention loss
        if len(labels.shape) == 1 or labels.shape[1] == 1:
            L_caus = F.cross_entropy(z_G_prime, labels.view(-1))
        else:
            L_caus = F.binary_cross_entropy_with_logits(z_G_prime, labels.float())

        # Total loss
        total_loss = L_sup + self.lambda1 * L_unif + self.lambda2 * L_caus
        return total_loss, L_sup, L_unif, L_caus

##### End of CALGAT

##### Execution and helper functions

In [8]:
def prepare_graph_data(adjacency_matrix, covariance_features):
    """Convert adjacency matrix and covariance features to PyTorch Geometric format"""
    # Get edges from adjacency matrix (where weight > 0)
    edges = np.where(adjacency_matrix > 0)
    edge_index = torch.tensor(np.vstack((edges[0], edges[1])), dtype=torch.long)
    # Get edge weights from adjacency matrix
    edge_attr = torch.tensor(adjacency_matrix[edges], dtype=torch.float)
    # Process covariance features
    num_nodes = covariance_features.shape[0]
    feature_dim = covariance_features.shape[1]
    # Extract meaningful features from covariance matrices
    node_features = []
    for i in range(num_nodes):
        covar = covariance_features[i]
        # 1. Diagonal elements (variances)
        diag_features = np.diag(covar)
        # 2. Upper triangular part (correlations)
        triu_indices = np.triu_indices(feature_dim, k=1)
        triu_features = covar[triu_indices]
        # 3. Combine features
        node_feat = np.concatenate([diag_features, triu_features])
        node_features.append(node_feat)
    x = torch.tensor(np.stack(node_features), dtype=torch.float)
    # Create PyTorch Geometric Data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data


def prepare_dataset(adjacency_matrices, covariance_features_list, labels):
    """Prepare dataset for multiple graphs"""
    dataset = []
    for i, (adj, cov) in enumerate(zip(adjacency_matrices, covariance_features_list)):
        data = prepare_graph_data(adj, cov)
        # Add graph label
        data.y = torch.tensor([labels[i]], dtype=torch.long)
        dataset.append(data)
    return dataset


def load_local_data(data_path):
    """
    Load adjacency matrices, covariance features, and labels from local files.
    Assumes files are in .npy format and named as:
    - adj_matrix_graph_i.npy: Adjacency matrix for graph i
    - covar_features_graph_i.npy: Covariance features for graph i
    - labels.npy:  Numpy array of labels for all graphs
    Parameters:
    -----------
    data_path : str
        Path to the directory containing the data files.
    Returns:
    --------
    adjacency_matrices : list
        List of adjacency matrices (numpy arrays).
    covariance_features_list : list
        List of covariance features (numpy arrays).
    labels : numpy.ndarray
        Numpy array of graph labels.
    """
    adjacency_matrices = []
    covariance_features_list = []
    # Load labels
    labels_path = os.path.join(data_path, 'labels.npy')
    labels = np.load(labels_path)
    graph_index = 0 # Start index for graph files
    while True: # Try loading files until no more are found
        adj_matrix_file = os.path.join(data_path, f'adj_matrix_graph_{graph_index}.npy')
        covar_features_file = os.path.join(data_path, f'covar_features_graph_{graph_index}.npy')
        if not os.path.exists(adj_matrix_file) or not os.path.exists(covar_features_file):
            break # Stop loading if files for the current graph index are not found
        adj_matrix = np.load(adj_matrix_file)
        covar_features = np.load(covar_features_file)
        adjacency_matrices.append(adj_matrix)
        covariance_features_list.append(covar_features)
        graph_index += 1 # Increment for next graph
    if not adjacency_matrices: # Check if any data was loaded
        raise FileNotFoundError(f"No adjacency or covariance feature files found in: {data_path}. "
                                f"Make sure files are named as 'adj_matrix_graph_i.npy' and "
                                f"'covar_features_graph_i.npy' and 'labels.npy' are in the directory.")
    return adjacency_matrices, covariance_features_list, labels


# ------------------------------------------------------------------------
# Training and evaluation functions (modified for causal intervention)
# ------------------------------------------------------------------------
def train_cal_gat(model, train_loader, optimizer, device, num_classes):
    """Train the CAL_GAT model for one epoch with causal intervention"""
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Step 1-6: Forward pass to get causal and trivial graph representations
        z_G_c_batch, z_G_t_batch, _, h_G_c_batch, h_G_t_batch = model(batch) # No shuffled trivial features yet

        # Step 7 & 8: Causal Intervention - Random combination across the batch
        shuffled_indices = torch.randperm(batch.num_graphs) # Shuffle indices for trivial graph representations in the batch
        shuffled_h_G_t_batch = h_G_t_batch[shuffled_indices] # Shuffle trivial graph representations

        # Step 7 (continued) and Predictions with causal intervention
        z_G_c, z_G_t, z_G_prime, _, _ = model(batch, shuffled_h_G_t=shuffled_h_G_t_batch) # Pass shuffled trivial features for combined prediction

        # Compute losses
        loss, L_sup, L_unif, L_caus = model.compute_losses(
            z_G_c, z_G_t, z_G_prime, batch.y, num_classes
        )

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs

    return total_loss / len(train_loader.dataset)


def evaluate_cal_gat(model, loader, device, num_classes):
    """Evaluate the CAL_GAT model (no causal intervention during evaluation)"""
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)

            # Forward pass (no shuffled trivial features in eval)
            z_G_c, z_G_t, z_G_prime, h_G_c, h_G_t = model(batch) # shuffled_h_G_t=None by default

            # Compute losses
            loss, _, _, _ = model.compute_losses(
                z_G_c, z_G_t, z_G_prime, batch.y, num_classes
            )
            total_loss += loss.item() * batch.num_graphs

            # Get predictions (using causal branch prediction for evaluation)
            preds = z_G_c.argmax(dim=1).cpu().numpy() # Evaluate based on causal branch, or z_G_prime if you want to evaluate combined.
            labels = batch.y.cpu().numpy().flatten()

            all_preds.extend(preds)
            all_labels.extend(labels)

    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    return total_loss / len(loader.dataset), accuracy, all_preds, all_labels


# ------------------------------------------------------------------------
# Main execution
# ------------------------------------------------------------------------

def main():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Parameters
    num_classes = 3
    hidden_dim = 32
    heads = 8 # Increased heads as GATConv is used now
    dropout = 0.6 # Increased dropout as in original GAT paper
    alpha = 0.2 # Alpha for LeakyReLU in GATConv
    learning_rate = 0.005
    epochs = 50
    batch_size = 32

    # --- Load data from local files ---
    data_path = 'path/to/your/local/data'
    print(f"Loading data from: {data_path}")
    try:
        adj_matrices, covar_features, labels = load_local_data(data_path)
    except FileNotFoundError as e:
        print(f"Error loading data: {e}")
        return

    # 2. Prepare dataset
    print("Preparing dataset...")
    dataset = prepare_dataset(adj_matrices, covar_features, labels)

    # Get feature dimensions
    input_dim = dataset[0].x.size(1)
    print(f"Input feature dimension: {input_dim}")

    # 3. Split data into train/val/test
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
    train_dataset, val_dataset = train_test_split(train_dataset, test_size=0.2, random_state=42)

    print(f"Train set: {len(train_dataset)}, Validation set: {len(val_dataset)}, Test set: {len(test_dataset)}")

    # 4. Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # 5. Initialize the model
    print("Initializing CAL_GAT model...")
    model = CAL_GAT(
        num_in_features=input_dim,
        num_hidden_features=hidden_dim,
        num_out_features=num_classes,
        num_of_heads=heads,
        dropout_prob=dropout,
        alpha=alpha,
        lambda1=0.5,  # Weight for uniform loss
        lambda2=1.0  # Weight for causal intervention loss
    ).to(device)

    # 6. Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-5
    )

    # 7. Training loop
    print("Starting training...")
    train_losses = []
    val_losses = []
    val_accuracies = []
    best_val_acc = 0
    best_model = None

    for epoch in range(epochs):
        # Train
        train_loss = train_cal_gat(model, train_loader, optimizer, device, num_classes)
        train_losses.append(train_loss)

        # Validate
        val_loss, val_acc, _, _ = evaluate_cal_gat(model, val_loader, device, num_classes)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        # Update learning rate
        scheduler.step(val_loss)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict().copy()

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # 8. Load best model and evaluate on test set
    print("\nEvaluating on test set...")
    model.load_state_dict(best_model)
    test_loss, test_acc, test_preds, test_labels = evaluate_cal_gat(model, test_loader, device, num_classes)

    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # 9. Confusion matrix
    cm = confusion_matrix(test_labels, test_preds)
    print("\nConfusion Matrix:")
    print(cm)

    # 10. Plot training curves
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Losses')
    plt.legend() # Add legend to the plot
    plt.show()


if __name__ == '__main__':
    main()

Using device: cpu
Loading data from: path/to/your/local/data
Error loading data: [Errno 2] No such file or directory: 'path/to/your/local/data\\labels.npy'
