In [1]:
# import necessary packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.utils
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
from sklearn.model_selection import KFold, train_test_split
from sklearn.cluster import SpectralClustering
from karateclub import Graph2Vec
import networkx as nx
from tqdm import tqdm

from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr
from scipy.spatial.distance import jensenshannon
import torch
import networkx as nx
import numpy as np
from tqdm import tqdm

In [2]:
# set global variables
N_TRAIN = 167
N_TEST = 112
N_LR_NODES = 160
N_HR_NODES = 268
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Helper Functions

## Functions and Class for loading in the data

In [3]:
class MatrixVectorizer:
    """
    A class for transforming between matrices and vector representations.
    
    This class provides methods to convert a symmetric matrix into a vector (vectorize)
    and to reconstruct the matrix from its vector form (anti_vectorize), focusing on 
    vertical (column-based) traversal and handling of elements.
    """

    def __init__(self):
        """
        Initializes the MatrixVectorizer instance.
        
        The constructor currently does not perform any actions but is included for 
        potential future extensions where initialization parameters might be required.
        """
        pass

    @staticmethod
    def vectorize(matrix, include_diagonal=False):
        """
        Converts a matrix into a vector by vertically extracting elements.
        
        This method traverses the matrix column by column, collecting elements from the
        upper triangle, and optionally includes the diagonal elements immediately below
        the main diagonal based on the include_diagonal flag.
        
        Parameters:
        - matrix (numpy.ndarray): The matrix to be vectorized.
        - include_diagonal (bool, optional): Flag to include diagonal elements in the vectorization.
          Defaults to False.
        
        Returns:
        - numpy.ndarray: The vectorized form of the matrix.
        """
        # Determine the size of the matrix based on its first dimension
        matrix_size = matrix.shape[0]

        # Initialize an empty list to accumulate vector elements
        vector_elements = []

        # Iterate over columns and then rows to collect the relevant elements
        for col in range(matrix_size):
            for row in range(matrix_size):
                # Skip diagonal elements if not including them
                if row != col:  
                    if row < col:
                        # Collect upper triangle elements
                        vector_elements.append(matrix[row, col])
                    elif include_diagonal and row == col + 1:
                        # Optionally include the diagonal elements immediately below the diagonal
                        vector_elements.append(matrix[row, col])

        return np.array(vector_elements)

    @staticmethod
    def anti_vectorize(vector, matrix_size, include_diagonal=False):
        """
        Reconstructs a matrix from its vector form, filling it vertically.
        
        The method fills the matrix by reflecting vector elements into the upper triangle
        and optionally including the diagonal elements based on the include_diagonal flag.
        
        Parameters:
        - vector (numpy.ndarray): The vector to be transformed into a matrix.
        - matrix_size (int): The size of the square matrix to be reconstructed.
        - include_diagonal (bool, optional): Flag to include diagonal elements in the reconstruction.
          Defaults to False.
        
        Returns:
        - numpy.ndarray: The reconstructed square matrix.
        """
        # Initialize a square matrix of zeros with the specified size
        matrix = np.zeros((matrix_size, matrix_size))

        # Index to keep track of the current position in the vector
        vector_idx = 0

        # Fill the matrix by iterating over columns and then rows
        for col in range(matrix_size):
            for row in range(matrix_size):
                # Skip diagonal elements if not including them
                if row != col:  
                    if row < col:
                        # Reflect vector elements into the upper triangle and its mirror in the lower triangle
                        matrix[row, col] = vector[vector_idx]
                        matrix[col, row] = vector[vector_idx]
                        vector_idx += 1
                    elif include_diagonal and row == col + 1:
                        # Optionally fill the diagonal elements after completing each column
                        matrix[row, col] = vector[vector_idx]
                        matrix[col, row] = vector[vector_idx]
                        vector_idx += 1

        return matrix


def multi_anti_vectorize(arr, vectorizer, matrix_size): 
    return np.array([vectorizer.anti_vectorize(v, matrix_size) for v in arr])

# added an input so that the function can be adjusted to everyone's path to the dataset
def load_data_tensor(path_to_datasets):
    # import data from .csv file
    hr_train_raw = pd.read_csv(path_to_datasets + '/hr_train.csv')
    lr_train_raw = pd.read_csv(path_to_datasets + '/lr_train.csv')
    lr_test_raw = pd.read_csv(path_to_datasets + '/lr_test.csv')

    # anti-vectorize 
    lr_n = 160
    hr_n = 268
    vectorizer = MatrixVectorizer()
    hr_train = multi_anti_vectorize(hr_train_raw.values, vectorizer, hr_n)
    lr_train = multi_anti_vectorize(lr_train_raw.values, vectorizer, lr_n)
    lr_test = multi_anti_vectorize(lr_test_raw.values, vectorizer, lr_n)

    # NOTE the order of return is low res train, low res test, high res train
    return torch.Tensor(lr_train), torch.Tensor(lr_test), torch.Tensor(hr_train)

## Functions for data preparation

In [5]:
def adj_hop(adj, hop):
    adj_new = adj
    for i in range(hop-1):
        adj_new = torch.bmm(adj_new, adj)
    stack_adj = adj_new.reshape(-1,adj.shape[1])
    return stack_adj 


def upsampled_data(tensor, repeats, idxs):
    additional_samples = tensor[idxs, :, :]
    upsampled_tensor = torch.repeat_interleave(additional_samples, repeats=repeats, dim=0)
    return torch.cat([tensor, upsampled_tensor], dim=0)

def standardization(matrix, avg=None, std=None):
    if avg is None and std is None:
        avg = torch.mean(matrix)
        std = torch.std(matrix)
        return (matrix - avg)/std, avg, std
    else:
        return (matrix - avg)/std

## Functions for helping define and train the model

In [None]:
def generate_steps(num_steps, low=N_LR_NODES, high=N_HR_NODES):
    step_size = (high - low) / (num_steps - 1)
    steps_list = [round(low + step_size * i) for i in range(num_steps)]
    return steps_list

def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

## Functions for visualisation of results

In [None]:
def generate_histogram(test_pred_tensor, target=None):
    flattened_tensor = test_pred_tensor.flatten() 
    # Create the histogram for predictions in purple
    plt.hist(flattened_tensor, alpha=0.5, bins=50, density=True, color='purple', label='Prediction')  
    if target is not None:
        flattened_target = target.flatten()
        # Create the histogram for ground truth in yellow
        plt.hist(flattened_target, bins=50, alpha=0.5, color='orange', label='Ground Truth', density=True)
    plt.title('Histogram of Tensor Values')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.legend() # This will show the labels in the plot

    # Show the plot
    plt.show()

def generate_heatmap(single_tensor, save_path=None):
    """
    Produces heatmap of single tensor

    input:
    - sinlge_tensor: torch.tensor (single adjacency matrix)
    - save_path: string (optional) (string with file name)
    """
    assert len(single_tensor.shape) == 2, 'tensor dimensionality is greater than 2 - only pass one matrix'
    assert single_tensor.shape[0] == single_tensor.shape[1], 'tensor not square'
    plt.imshow(single_tensor, cmap='hot', interpolation='nearest')
    plt.colorbar()
    if save_path:
        plt.savefig(save_path)
    plt.show()

## Functions for evaluating model

In [None]:
def evaluate_predictions(tensor_pred, tensor_true):

    """ 
    tensor_pred and tensor_true should both be tensors of shape
    (num_val_samples, hr_dim, hr_dim).

    """
    # Initialize lists to store MAEs for each centrality measure
    mae_bc = []
    mae_ec = []
    mae_pc = []

    pred_1d_list = []
    gt_1d_list = []

    # Iterate over each test sample
    for i in tqdm(range(len(tensor_pred)), desc='Evaluating Predictions (Can be Slow)'):

        pred_matrix = tensor_pred[i].cpu().detach().numpy()
        true_matrix = tensor_true[i].cpu().detach().numpy()

        # Convert adjacency matrices to NetworkX graphs
        pred_graph = nx.from_numpy_array(pred_matrix, edge_attr="weight")
        gt_graph = nx.from_numpy_array(true_matrix, edge_attr="weight")

        # Compute centrality measures
        pred_bc = nx.betweenness_centrality(pred_graph, weight="weight")
        pred_ec = nx.eigenvector_centrality(pred_graph, weight="weight")
        pred_pc = nx.pagerank(pred_graph, weight="weight")

        gt_bc = nx.betweenness_centrality(gt_graph, weight="weight")
        gt_ec = nx.eigenvector_centrality(gt_graph, weight="weight")
        gt_pc = nx.pagerank(gt_graph, weight="weight")

        # Convert centrality dictionaries to lists
        pred_bc_values = list(pred_bc.values())
        pred_ec_values = list(pred_ec.values())
        pred_pc_values = list(pred_pc.values())

        gt_bc_values = list(gt_bc.values())
        gt_ec_values = list(gt_ec.values())
        gt_pc_values = list(gt_pc.values())

        # Compute MAEs
        mae_bc.append(mean_absolute_error(pred_bc_values, gt_bc_values))
        mae_ec.append(mean_absolute_error(pred_ec_values, gt_ec_values))
        mae_pc.append(mean_absolute_error(pred_pc_values, gt_pc_values))

        # Vectorize matrices
        pred_1d_list.append(MatrixVectorizer.vectorize(pred_matrix))
        gt_1d_list.append(MatrixVectorizer.vectorize(true_matrix))

    # Compute average MAEs
    avg_mae_bc = sum(mae_bc) / len(mae_bc)
    avg_mae_ec = sum(mae_ec) / len(mae_ec)
    avg_mae_pc = sum(mae_pc) / len(mae_pc)

    # Concatenate flattened matrices
    pred_1d = np.concatenate(pred_1d_list)
    gt_1d = np.concatenate(gt_1d_list)

    # Compute metrics
    mae = mean_absolute_error(pred_1d, gt_1d)
    pcc = pearsonr(pred_1d, gt_1d)[0]
    js_dis = jensenshannon(pred_1d, gt_1d)

    print("MAE: ", mae)
    print("PCC: ", pcc)
    print("Jensen-Shannon Distance: ", js_dis)
    print("Average MAE betweenness centrality:", avg_mae_bc)
    print("Average MAE eigenvector centrality:", avg_mae_ec)
    print("Average MAE PageRank centrality:", avg_mae_pc)
    
    all_metrics = {
        'mae': mae,
        'pcc': pcc,
        'js_dis': js_dis,
        'avg_mae_bc': avg_mae_bc,
        'avg_mae_ec': avg_mae_ec,
        'avg_mae_pc': avg_mae_pc,
    }
    
    return all_metrics

# Data Preparation

In [None]:
lr_train, lr_test, hr_train = load_data_tensor("dgl-icl") # load in data from designated file path
train_idx, val_idx = train_test_split(list(range(len(lr_train))), test_size=0.2, shuffle=True, random_state=0) # split data into training and validation set

## Learning initial node embeddings from adjacency matrix using VAE

### Prepare data for VAE models

In [None]:
# get the 3 hop adjacency matrices
lr_train3 = adj_hop(lr_train, 3)
hr_train3 = adj_hop(hr_train, 3)
lr_test3 = adj_hop(lr_test, 3)

In [None]:
# stack the matrices so they are by rows
lr_stack = lr_train.reshape(-1, N_LR_NODES)
hr_stack = hr_train.reshape(-1, N_HR_NODES)
lr_test = lr_test.reshape(-1, N_LR_NODES)

lr_stack3 = lr_train3.reshape(-1, N_LR_NODES)
hr_stack3 = hr_train3.reshape(-1, N_HR_NODES)
lr_test3 = lr_test3.reshape(-1, N_LR_NODES)

In [None]:
# standardise the values
lr_stack_norm, lr_avg, lr_std = standardization(lr_stack)
hr_stack_norm, temp1, temp2  = standardization(hr_stack)
lr_test_stack_norm = standardization(lr_test, lr_avg, lr_std)

lr_stack3_norm, lr_avg3, lr_std3 = standardization(lr_stack3)
hr_stack3_norm, temp1, temp2 = standardization(hr_stack3)
lr_test_stack3_norm = standardization(lr_test3, lr_avg3, lr_std3)

### Define VAE Models for both low and high resolution adjanceny matrices

In [None]:
class VAELR(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(N_LR_NODES, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.Sigmoid()

        )
        self.fc_m = nn.Linear(64,32)
        self.fc_std = nn.Linear(64,32)
        self.decoder = nn.Sequential(
            
            nn.Linear(32,64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            nn.Linear(64,128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, N_LR_NODES)
        )
    def encode(self, x):
        h1 = self.encoder(x)
        return self.fc_m(h1), self.fc_std(h1)
    
    def reparametrize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu,logvar)
        return self.decode(z), mu, logvar
    

 
class VAEHR(nn.Module):
    def __init__(self):
        super(VAEHR, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(N_HR_NODES, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.Sigmoid()

        )
        self.fc_m = nn.Linear(64,32)
        self.fc_std = nn.Linear(64,32)
        self.decoder = nn.Sequential(
            
            nn.Linear(32,64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            nn.Linear(64,128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, N_HR_NODES)
        )
    def encode(self, x):
        h1 = self.encoder(x)
        return self.fc_m(h1), self.fc_std(h1)
    
    def reparametrize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu,logvar)
        return self.decode(z), mu, logvar
           

In [None]:
def vae_loss(recon_x, x, mu, logvar):
    BCE = F.mse_loss(recon_x, x, reduction="sum")
    KLD = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train(model, adj_matrix, num_epoch=200, lr=0.0001, batch_size=64):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    for step in range(num_epoch):
        n_completed = 0
        while n_completed < len(adj_matrix):
            optimizer.zero_grad()
            batch = adj_matrix[n_completed: n_completed+batch_size]
            e, mu, logvar = model(batch)
            loss = vae_loss(e,batch, mu, logvar)
            loss.backward()
            print(loss)
            optimizer.step()
            n_completed += batch_size
    

### Train VAE models

In [None]:
# train for the low resolution 1-hop adjacency matrix
lr_vae = VAELR()
train(lr_vae, lr_stack_norm)

In [None]:
# train for the high resolution 1-hop adjacency matrix
hr_vae = VAEHR()
train(hr_vae, hr_stack_norm)

In [None]:
# train for the low resolution 3-hop adjacency matrix
lr_vae3 = VAELR()
train(lr_vae, lr_stack3_norm)

In [None]:
# train for the low resolution 3-hop adjacency matrix
hr_vae3 = VAEHR()
train(hr_vae3, hr_stack3_norm)

### Obtain Embeddings

In [None]:
# get training embeddings from VAE based on low resolution 1-hop adjacency matrix
lr_vae.eval()
mu, logvar = lr_vae.encode(lr_stack_norm)
lr_X_dim1 = lr_vae.reparametrize(mu,logvar).reshape(N_TRAIN, N_LR_NODES, 32)

In [None]:
# get training embeddings from VAE based on high resolution 1-hop adjacency matrix
hr_vae.eval()
mu, logvar = hr_vae.encode(hr_stack_norm)
hr_X_dim1 = hr_vae.reparametrize(mu,logvar).reshape(N_TRAIN, N_HR_NODES, 32)

In [None]:
# get training embeddings from VAE based on low resolution 3-hop adjacency matrix
lr_vae3.eval()
mu, logvar = lr_vae3.encode(lr_stack3_norm)
lr_X_dim3 = lr_vae3.reparametrize(mu,logvar).reshape(N_TRAIN, N_LR_NODES, 32)

In [None]:
# get training embeddings from VAE based on high resolution 3-hop adjacency matrix
hr_vae3.eval()
mu, logvar = hr_vae3.encode(hr_stack3_norm)
hr_X_dim3 = hr_vae3.reparametrize(mu,logvar).reshape(N_TRAIN, N_HR_NODES, 32)

In [None]:
# get testing embeddings from VAE based on low resolution 1-hop adjacency matrix
lr_vae.eval()
mu, logvar = lr_vae.encode(lr_test_stack_norm)
lr_X_dim1_test = lr_vae.reparametrize(mu,logvar).reshape(N_TEST, N_LR_NODES, 32)

In [None]:
# get testing embeddings from VAE based on low resolution 3-hop adjacency matrix
lr_vae3.eval()
mu, logvar = lr_vae3.encode(lr_test_stack3_norm)
lr_X_dim3_test = lr_vae3.reparametrize(mu,logvar).reshape(N_TEST, N_LR_NODES, 32)

## Upsampling based on minority class from the HR adjacency matrix

In [None]:
def upsampling_idx(hr_train):
    hr_adj_list = [tensor.numpy() for tensor in hr_train]
    graphs = [nx.from_numpy_matrix(adj_matrix, create_using=nx.DiGraph) for adj_matrix in hr_adj_list]

    # generate embeddings for each graph
    graph2vec = Graph2Vec(dimensions=130)
    graph2vec.fit(graphs)
    graph_embeddings = graph2vec.get_embedding()

    # apply spectral clustering on the embeddings
    num_clusters = 2   # we only partition into majority and minority class
    clustering_model = SpectralClustering(n_clusters=num_clusters, assign_labels="discretize", random_state=0)
    clusters = clustering_model.fit_predict(graph_embeddings)

    return list(np.where(clusters == 1)[0])

In [None]:
idxs = upsampling_idx(hr_train[train_idx])
repeats = 4 # add repeat the sample for 4 times each

# upsample the data
lr_train_up = upsampled_data(lr_train[train_idx], repeats, idxs)
hr_train_up = upsampled_data(hr_train[train_idx], repeats, idxs)
lr_X_dim1_train_up = upsampled_data(lr_X_dim1[train_idx], repeats, idxs)
hr_X_dim1_train_up = upsampled_data(hr_X_dim1[train_idx], repeats, idxs)
lr_X_dim3_train_up = upsampled_data(lr_X_dim3[train_idx], repeats, idxs)
hr_X_dim3_train_up = upsampled_data(hr_X_dim3[train_idx], repeats, idxs)

In [None]:
# load in data into a dataloader for training
trainloader = DataLoader(list(zip(lr_X_dim1_train_up, lr_X_dim3_train_up, lr_train_up, hr_X_dim1_train_up, hr_X_dim3_train_up, hr_train_up)), shuffle=True, batch_size=32)
valloader = DataLoader(list(zip(lr_X_dim1[val_idx], lr_X_dim3[val_idx], lr_train[val_idx], hr_X_dim1[val_idx], hr_X_dim3[val_idx], hr_train[val_idx])), shuffle=True, batch_size=32)
testloader = DataLoader(list(zip(lr_X_dim1_test, lr_X_dim3_test, lr_test)), shuffle=False, batch_size=32)

# Define main model

## Define layers used in the model

In [None]:
# Define module that stacks GCN layeres
class StackedGCN(nn.Module):
    def __init__(self, n_nodes, channel_ls, dropout):
        super().__init__()
        self.n_nodes = n_nodes
        self.gcn_layers, self.batch_norm_layers = self._init_layers(channel_ls)
        self.dropout = dropout
        
    def forward(self, X, A):
        for i in range(len(self.gcn_layers)):
            gcn = self.gcn_layers[i]
            batch_norm = self.batch_norm_layers[i]
            graph_batch = self._create_batch(X, A)
            
            X = F.sigmoid(gcn(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr).reshape(*X.shape[:2], -1))
            torch.cuda.empty_cache()
            X = F.dropout(batch_norm(X), self.dropout, training=self.training)
            torch.cuda.empty_cache()
        return X

    # initialise stacks of GCN layer -> batch normalisation
    def _init_layers(self, channel_ls):
        layers_ls = []
        batch_norm_ls = []
        for i in range(len(channel_ls) - 1):
            layer = GATv2Conv(channel_ls[i], channel_ls[i], heads=2, edge_dim=1)
            layers_ls.append(layer)
            batch_norm_ls.append(torch_geometric.nn.norm.BatchNorm(self.n_nodes, affine=False))
        return nn.ModuleList(layers_ls), nn.ModuleList(batch_norm_ls)

    # generate a Batch object for torch geometric
    def _create_batch(self, X, A):
        data_list = []
        for x, adj in zip(X, A):
            edge_index = adj.nonzero().t()
            edge_weights = adj[edge_index[0], edge_index[1]]
            data = Data(x=x, edge_index=edge_index, edge_attr=edge_weights.view(-1, 1))
            data_list.append(data)
        return torch_geometric.data.Batch().from_data_list(data_list)

In [None]:
# define layer that projects adjacency matrix and node embeddings to higher dimensions
class AdjacencyStep(nn.Module):
    def __init__(self, old_dim, new_dim, channels_ls, dt=1., alpha=0.9, gamma=0.9, dropout=0.5):
        super().__init__()
        self.dt = dt
        self.alpha = alpha
        self.gamma = gamma
        self.gnn = StackedGCN(old_dim, channels_ls, dropout).to(DEVICE)
        self.dropout = dropout
        self.dim_changer = nn.Parameter(torch.randn((new_dim, old_dim), device=DEVICE))

        self.A_dim_changer = nn.Parameter(torch.randn((new_dim, old_dim), device=DEVICE))
        self.A_dim_bias = nn.Parameter(torch.randn((new_dim, 1), device=DEVICE))

        self.Z_dim_changer = nn.Parameter(torch.randn((channels_ls[-1], new_dim), device=DEVICE))
        self.Z_dim_bias = nn.Parameter(torch.randn((new_dim, 1), device=DEVICE))
        self.Z_dim_lower = nn.Conv1d(channels_ls[-1], channels_ls[0], kernel_size=1)

        self.forget_gate = nn.Parameter(torch.randn(new_dim, device=DEVICE))
        self.input_gate = nn.Parameter(torch.randn(new_dim, device=DEVICE))

        self.batchnorm_A = torch_geometric.nn.norm.BatchNorm(new_dim)
        self.batchnorm_X = torch_geometric.nn.norm.BatchNorm(new_dim)
        self.batchnorm_Y = torch_geometric.nn.norm.BatchNorm(new_dim)

    def forward(self, X, Y, A):
        # update node features with gcn
        Z = self.gnn(X, A)
        torch.cuda.empty_cache()

        # forget gate from previous adjacency
        f = F.sigmoid(self.forget_gate)
        forget_A = F.elu(self.A_dim_changer @ A @ self.A_dim_changer.T + self.A_dim_bias)
        forget_A = f[:, None] * forget_A
        torch.cuda.empty_cache()

        # input gate from newly learnt node embeddings
        i = F.sigmoid(self.input_gate)
        input_Z = F.elu(self.dim_changer @ Z @ self.Z_dim_changer + self.Z_dim_bias)
        input_Z = i[:, None] * input_Z
        torch.cuda.empty_cache()

        # get new adjacency matrix
        new_A = forget_A + input_Z
        new_A = self.batchnorm_A(new_A)
        new_A = (new_A + torch.transpose(new_A, -1, -2)) / 2
        new_A = F.hardtanh(F.hardtanh(new_A, min_val=0), min_val=0)
        torch.cuda.empty_cache()

        # update feature embeiddings
        Z = torch.transpose(self.Z_dim_lower(torch.transpose(Z, -1, -2)), -1, -2)
        Y_temp = Y
        Y = self.dim_changer @ (Y + self.dt * (Z - self.alpha * Y - self.gamma * X))
        X = self.dim_changer @ (X + self.dt * Y_temp) 

        # add batch normalisation and dropout at the end
        X = self.batchnorm_X(X)
        Y = self.batchnorm_Y(Y)   
        Y = F.dropout(Y, self.dropout, training=self.training)
        X = F.dropout(X, self.dropout, training=self.training)
        return X, Y, new_A

    

In [None]:
# define module combines all the steps together
class AdjacencyDimChanger(nn.Module):
    def __init__(self, dim_steps, channels_ls):
        super().__init__()        
        self.layers = nn.ModuleList([AdjacencyStep(dim_steps[i], dim_steps[i+1], channels_ls) for i in range(len(dim_steps)-1)])
        
    def forward(self, X, Y, A):
        adj_ls = [A]
        for layer in self.layers:
            X, Y, A = layer(X, Y, A)
            adj_ls.append(A)
        return adj_ls

## Loss functions that will be calculated during training

In [None]:
def up_loss_fn(up_adj_ls, down_adj_ls,  gamma, epoch, a=0.5, b=1, c=0.2):
    total_loss = torch.Tensor([0]).to(DEVICE)
    mse_loss_fn = nn.MSELoss()

    # calculate the loss for the upper traingle of the matrix because thats what we care about
    final_dim = up_adj_ls[-1].shape[-1]
    upper_tri_idx = torch.triu_indices(final_dim, final_dim, offset=1)
    final_upper_tri_adj = up_adj_ls[-1][:,upper_tri_idx.unbind()[0], upper_tri_idx.unbind()[1]]
    final_upper_tri_other_adj = down_adj_ls[0][:,upper_tri_idx.unbind()[0], upper_tri_idx.unbind()[1]]
    final_mse_loss = mse_loss_fn(final_upper_tri_adj, final_upper_tri_other_adj)

    # calculate the loss for the remaining intermediate adj. matrices with larger weights on farther steps
    n = len(up_adj_ls[:-1])
    weights = torch.Tensor([2*(i+1)/(n*(n+1)) for i in range(n)])
    intermediate_mse_loss = torch.Tensor([0]).to(DEVICE)
    for i, (up_adj, down_adj) in enumerate(zip(up_adj_ls[:-1], down_adj_ls[1:][::-1])):
        intermediate_mse_loss = intermediate_mse_loss + weights[i] * mse_loss_fn(up_adj, down_adj)

    # sum up the two components of the loss with a weight alpha that decays over epochs
    alpha = (1 - np.exp(-c * epoch)) * (a - b) + b
    total_loss = total_loss + alpha * final_mse_loss + (1-alpha) * intermediate_mse_loss * (gamma ** 2)
    return total_loss

def down_loss_fn(down_adj_ls, up_adj_ls):
    total_loss = torch.Tensor([0]).to(DEVICE)
    mse_loss_fn = nn.MSELoss()

    # calculate the loss for each of the steps with equal weighting
    n = len(down_adj_ls[:])
    weights = torch.Tensor([1/n for i in range(n)])
    for i, (down_adj, up_adj) in enumerate(zip(down_adj_ls[:], up_adj_ls[::-1])):
        total_loss = total_loss + weights[i] * mse_loss_fn(down_adj, up_adj)
    return total_loss

def l1_regularization_loss(models, l1_lambda):
    l1_reg_loss = torch.Tensor([0]).to(DEVICE)
    for model in models:
        all_params = torch.cat([p.view(-1) for p in model.parameters()])
        l1_reg_loss = l1_reg_loss + l1_lambda * torch.norm(all_params, 1)
    return l1_reg_loss
    
def reconstruction_loss_fn(gt_adj, pred_adj):
    l1_loss_fn = nn.L1Loss()
    return l1_loss_fn(gt_adj, pred_adj)

def end_adj_loss_calc(adj, opp_adj):
    mae_loss_fn = torch.nn.L1Loss()
    n = len(adj)

    # calculate the upper triangular L1 loss for evaluation only
    upper_tri_idx = torch.triu_indices(n, n, offset=1)
    upper_tri_adj = adj.detach()[upper_tri_idx.unbind()]
    upper_tri_opp_adj = opp_adj.detach()[upper_tri_idx.unbind()]
    mae_loss = mae_loss_fn(upper_tri_adj, upper_tri_opp_adj)
    return mae_loss.detach().item()

## Define Training function

In [None]:
def train(epochs, up_changer, down_changer, trainloader, up_optimizer, down_optimizer, reconstruction_optimizer, valloader=None, completed_epochs=0, loss_log=None):

    if loss_log is None:
        loss_log = {'up': [], 'down': [], 'up_end_mae':[], 'down_end_mae':[], 'val_up_end_mae':[]}

    l1_lambda = 0.000001 # for L1 regularization
    for epoch in range(epochs):

        # for logging
        up_losses = []
        up_final_mae_ls = []
        down_final_mae_ls = []
        down_losses = []
        reconstruction_losses = []

        # change to training mode
        up_changer.train()
        down_changer.train()   
            
        for X_lr, Y_lr, adj_lr, X_hr, Y_hr, adj_hr in tqdm(trainloader, desc=f'Epoch {epoch} Train'):
        
            # train down changer
            freeze_model(up_changer)
            unfreeze_model(down_changer)
        
            down_optimizer.zero_grad()

            # forward pass
            up_adj_ls = up_changer(X_lr.to(DEVICE), Y_lr.to(DEVICE), adj_lr.to(DEVICE))
            torch.cuda.empty_cache()
            down_adj_ls = down_changer(X_hr.to(DEVICE), Y_hr.to(DEVICE), adj_hr.to(DEVICE))
            torch.cuda.empty_cache()

            # calculate total loss from down changer
            down_loss = down_loss_fn(down_adj_ls[1:], up_adj_ls[:-1]) + l1_regularization_loss([down_changer], l1_lambda)
        
            # for printing loss only
            down_final_mae_ls.append(end_adj_loss_calc(down_adj_ls[-1].detach(), up_adj_ls[0].detach()))
            torch.cuda.empty_cache()
            
            # backpropgate results
            down_loss.backward()
            down_optimizer.step()

            # log the results
            down_losses.append(down_loss.detach().item())
            del down_loss
            torch.cuda.empty_cache()

            # train up changer
            unfreeze_model(up_changer)
            freeze_model(down_changer)
        
            up_optimizer.zero_grad()

            # forward pass
            up_adj_ls = up_changer(X_lr.to(DEVICE), Y_lr.to(DEVICE), adj_lr.to(DEVICE))
            torch.cuda.empty_cache()
            down_adj_ls = down_changer(X_hr.to(DEVICE), Y_hr.to(DEVICE), adj_hr.to(DEVICE))
            torch.cuda.empty_cache()
        
            # calculate total loss from up changer
            up_loss = up_loss_fn(up_adj_ls[1:], down_adj_ls[:-1], gamma=1, epoch=epoch) + l1_regularization_loss([up_changer], l1_lambda)
            
            # for printing loss only
            up_final_mae_ls.append(end_adj_loss_calc(up_adj_ls[-1].detach(), down_adj_ls[0].detach()))
            torch.cuda.empty_cache()

            # backpropagate results
            up_loss.backward()
            up_optimizer.step()

            # log the results
            up_losses.append(up_loss.detach().item())
            del up_loss
            torch.cuda.empty_cache()

            # train both changer based on down-up reconstruction loss
            unfreeze_model(up_changer)
            unfreeze_model(down_changer)

            reconstruction_optimizer.zero_grad()

            # forward pass using the low dimension projection of high resolution adj as input to the up changer
            down_adj_end = down_changer(X_hr.to(DEVICE), Y_hr.to(DEVICE), adj_hr.to(DEVICE))[-1]
            torch.cuda.empty_cache()
            up_adj_end = up_changer(X_lr.to(DEVICE), Y_lr.to(DEVICE), down_adj_end)[-1]
            torch.cuda.empty_cache()

            # calculate the reconstruction loss and backpropagate the results
            reconstruction_loss = reconstruction_loss_fn(up_adj_end, adj_hr.to(DEVICE)) + l1_regularization_loss([up_changer, down_changer], l1_lambda) / 2
            reconstruction_loss.backward()
            reconstruction_optimizer.step()
            torch.cuda.empty_cache()

            # log the results
            reconstruction_losses.append(reconstruction_loss.detach().item())
            del reconstruction_loss
            torch.cuda.empty_cache()

        # do validation if necessary        
        if valloader is not None:
        
            val_up_final_mae_ls = []
        
            for X_lr, Y_lr, adj_lr, adj_hr in tqdm(valloader, desc=f'Epoch {epoch} Val'):
                # change to eval mode
                up_changer.eval()
                freeze_model(up_changer)


                # evaluate the MAE from a single forward pass in up changer on validation data
                up_adj_ls = up_changer(X_lr.to(DEVICE), Y_lr.to(DEVICE), adj_lr.to(DEVICE))
                torch.cuda.empty_cache()
                val_up_final_mae_ls.append(end_adj_loss_calc(up_adj_ls[-1].detach(), adj_hr.to(DEVICE)))
            
            epoch_val_up_final_mae = np.mean(val_up_final_mae_ls)
            loss_log['val_up_end_mae'].append(epoch_val_up_final_mae)

    
        # for logging the results    
        epoch_up_loss = np.mean(up_losses)
        epoch_down_loss = np.mean(down_losses)
        epoch_reconstruction_loss = np.mean(reconstruction_losses)
        epoch_up_final_mae = np.mean(up_final_mae_ls)
        epoch_down_final_mae = np.mean(down_final_mae_ls)
        
        loss_log['up'].append(epoch_up_loss)
        loss_log['down'].append(epoch_down_loss)
        loss_log['up_end_mae'].append(epoch_up_final_mae)
        loss_log['down_end_mae'].append(epoch_down_final_mae)

        if (epoch + 1) % 5 == 0:
            print(f'ep{epoch}: DOWN L={epoch_down_loss}, UP L={epoch_up_loss}, RC L={epoch_reconstruction_loss}, DOWN MAE={epoch_down_final_mae}, UP MAE={epoch_up_final_mae}, UP VAL MAE={epoch_val_up_final_mae}')


    return up_changer, down_changer, loss_log

# Model training

In [None]:
# define the model
dim_steps = generate_steps(num_steps=8)
channels_ls = [32, 64]

up_changer = AdjacencyDimChanger(dim_steps, channels_ls).to(DEVICE)
down_changer = AdjacencyDimChanger(dim_steps[::-1], channels_ls).to(DEVICE)

# define the optimisers
up_optimizer = torch.optim.AdamW(up_changer.parameters(), lr=0.002)
down_optimizer = torch.optim.AdamW(down_changer.parameters(), lr=0.002)
reconstruction_optimizer = torch.optim.AdamW(list(up_changer.parameters()) + list(down_changer.parameters()), lr=0.002)

# model size
sum(p.numel() for model in [up_changer, down_changer] for p in model.parameters())

In [None]:
# train model
up_changer, down_changer, loss_log = train(300, up_changer, down_changer, trainloader, up_optimizer, down_optimizer, reconstruction_optimizer, valloader=valloader)

In [None]:
# plot loss curve
plt.plot(np.arange(len(loss_log['up_end_mae'])), loss_log['up_end_mae'], label='Training')
plt.plot(np.arange(len(loss_log['up_end_mae'])), loss_log['val_up_end_mae'], label='Validation')
plt.legend()
plt.ylabel('MAE Loss')
plt.xlabel('Epochs')
plt.show()

# Generate predictions from trained model

In [None]:
# to generate test predictions
testloader = DataLoader(list(zip(lr_X_dim1_test, lr_X_dim3_test, lr_test)), shuffle=False, batch_size=16)

up_changer.eval()
test_predictions = []
for X_lr, Y_lr, adj_lr in tqdm(testloader):
    pred = up_changer(X_lr.to(DEVICE), Y_lr.to(DEVICE), adj_lr.to(DEVICE))[-1].detach()
    test_predictions.append(pred)
test_predictions = torch.cat(test_predictions)

In [None]:
# generate heat map for predictions
generate_heatmap(test_predictions[0].cpu())

In [None]:
# compare distribution between predictions and training target data
generate_histogram(test_predictions.cpu(), hr_train)

# Cross Validation

In [None]:
def validation(up_changer, testloader, val_adj_hr):
    print('begin validation')
    up_changer.eval()

    # generate predictions for test data
    val_predictions = []
    for X_lr, Y_lr, adj_lr in tqdm(testloader):
        pred = up_changer(X_lr.to(DEVICE), Y_lr.to(DEVICE), adj_lr.to(DEVICE))[-1].detach()
        val_predictions.append(pred)
    val_predictions = torch.cat(val_predictions)

    # evaluate the performance of prediction
    return evaluate_predictions(val_predictions, val_adj_hr)

def cross_validate(epochs, batch_size, n_fold, X_lr, Y_lr, adj_lr, X_hr, Y_hr, adj_hr):
    kf = KFold(n_fold, shuffle=True, random_state=99)
    runs_results = []
    for train_idx, val_idx in kf.split(X_lr):

        # perform upsampling on training data
        idxs = upsampling_idx(hr_train[train_idx])
        repeats = 4 

        train_adj_lr = upsampled_data(adj_lr[train_idx], repeats, idxs)
        train_adj_hr = upsampled_data(adj_hr[train_idx], repeats, idxs)
        train_X_lr = upsampled_data(X_lr[train_idx], repeats, idxs)
        train_X_hr = upsampled_data(X_hr[train_idx], repeats, idxs)
        train_Y_lr = upsampled_data(Y_lr[train_idx], repeats, idxs)
        train_Y_hr = upsampled_data(Y_hr[train_idx], repeats, idxs)
        
        # split for validation data
        val_X_lr = X_lr[val_idx]
        val_Y_lr = Y_lr[val_idx]
        val_adj_lr = adj_lr[val_idx]
        val_adj_hr = adj_hr[val_idx]

        # load into DataLoader
        trainloader = DataLoader(list(zip(train_X_lr, train_Y_lr, train_adj_lr, train_X_hr, train_Y_hr, train_adj_hr)), shuffle=True, batch_size=batch_size)
        testloader = DataLoader(list(zip(val_X_lr, val_Y_lr, val_adj_lr)), shuffle=False, batch_size=batch_size)

        # define model
        dim_steps = generate_steps(num_steps=8)
        channels_ls = [32, 64]
        up_changer = AdjacencyDimChanger(dim_steps, channels_ls).to(DEVICE)
        down_changer = AdjacencyDimChanger(dim_steps[::-1], channels_ls).to(DEVICE)

        # define the optimisers
        up_optimizer = torch.optim.AdamW(up_changer.parameters(), lr=0.002)
        down_optimizer = torch.optim.AdamW(down_changer.parameters(), lr=0.002)
        reconstruction_optimizer = torch.optim.AdamW(list(up_changer.parameters()) + list(down_changer.parameters()), lr=0.002)

        # train model
        up_changer, down_changer, _ = train(epochs, up_changer, down_changer, trainloader, up_optimizer, down_optimizer, reconstruction_optimizer)

        # evaluate model
        val_metrics = validation(up_changer, testloader, val_adj_hr)
        runs_results.append(val_metrics)

    return runs_results

In [None]:
# perform 3-fold cross validation 
cv_results = cross_validate(200, 32, 3, lr_X_dim1, lr_X_dim3, lr_train, hr_X_dim1, hr_X_dim3, hr_train)