In [3]:
import colorsys
import matplotlib.colors as mc
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import pickle
import random
import sys
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim

from gensim.models.word2vec import Word2Vec
from karateclub.estimator import Estimator
from karateclub.utils.walker import BiasedRandomWalker
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr
from scipy.spatial.distance import jensenshannon
from tqdm.notebook import tqdm
from typing import List

## Reproducibility

In [None]:
# Set a fixed random seed for reproducibility across multiple libraries
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

# Check for CUDA (GPU support) and set device accordingly
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
else:
    device = torch.device("cpu")
    print("CUDA not available. Using CPU.")

## Define Node2Vec class for Node Embedding Generation

In [5]:
class Node2Vec(Estimator):
    """Implementation of "Node2Vec" <https://cs.stanford.edu/~jure/pubs/node2vec-kdd16.pdf> from the
       KDD '16 paper "node2vec: Scalable Feature Learning for Networks". The procedure uses biased
       second order random walks to approximate the pointwise mutual information matrix obtained
       by pooling normalized adjacency matrix powers.

    Args:
        walk_number (int): Number of random walks.
        walk_length (int): Length of random walks.
        p (float): Return parameter (1/p transition probability) to move towards from previous node.
        q (float): In-out parameter (1/q transition probability) to move away from previous node.
        dimensions (int): Dimensionality of embedding.
        workers (int): Number of cores.
        window_size (int): Matrix power order.
        epochs (int): Number of epochs.
        learning_rate (float): Learning rate.
        min_count (int): Minimal count of node occurrences.
        seed (int): Random seed value.
    """
    _embedding: List[np.ndarray]

    def __init__(
        self,
        walk_number: int = 2,
        walk_length: int = 10,
        p: float = 0.8,
        q: float = 1.5,
        dimensions: int = 268,
        workers: int = 4,
        window_size: int = 5,
        epochs: int = 1,
        learning_rate: float = 0.05,
        min_count: int = 1,
        seed: int = random_seed,
    ):
        super(Node2Vec, self).__init__()

        self.walk_number = walk_number
        self.walk_length = walk_length
        self.p = p
        self.q = q
        self.dimensions = dimensions
        self.workers = workers
        self.window_size = window_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.min_count = min_count
        self.seed = seed

    def fit(self, graph):
            self._set_seed()
            # Checking if input graph is in expected format
            graph = self._check_graph(graph)
            # Initialize random walker
            walker = BiasedRandomWalker(self.walk_length, self.walk_number, self.p, self.q)
            # Perform random walks on G
            walker.do_walks(graph)

            # Train a Word2Vec model on the random walks
            model = Word2Vec(
                walker.walks,
                hs=1,
                alpha=self.learning_rate,
                epochs=self.epochs,
                vector_size=self.dimensions,
                window=self.window_size,
                min_count=self.min_count,
                workers=self.workers,
                seed=self.seed,
                negative=0
            )

            # Retrieve node embeddings from the trained model
            n_nodes = graph.number_of_nodes()
            self._embedding = [model.wv[str(n)] for n in range(n_nodes)]


    def get_embedding(self) -> np.array:
        return np.array(self._embedding)

## Define GSR (Graph Super Resolution) layer

In [6]:
# Helper method for weight initialization in GSRLayer
def weight_variable_glorot(output_dim):
    input_dim = output_dim
    init_range = np.sqrt(6.0 / (input_dim + output_dim))

     # Sample from uniform distribution within Glorot's range
    initial = np.random.uniform(-init_range, init_range,
                                (input_dim, output_dim))

    return initial

In [7]:
class GSRLayer(nn.Module):
    def __init__(self, hr_dim):
        super(GSRLayer, self).__init__()

        self.weights = torch.from_numpy(
            weight_variable_glorot(hr_dim)).type(torch.FloatTensor).to(device)
        self.weights = torch.nn.Parameter(
            data=self.weights, requires_grad=True)

    def forward(self, A_l, Z_l):
        """
        Takes in low resolution graph's connectivity A_l and node embeddings Z_l found from U-Net
        and returns high resolutioni connectivity A_h and node embeddings X_h.
        """
        lr_dim = A_l.shape[0]
        hr_dim = Z_l.shape[1]

        _, U_l = torch.linalg.eigh(A_l, UPLO='U') # Compute the eigenvectors of A_l

        I = torch.eye(lr_dim, lr_dim).to(device) # Identity matrix
        S_d = torch.cat((I, I), dim=0) # Concatenation of identity matrices for upsampling
        S_d = S_d[:hr_dim] # Only keep hr_dim rows of S_d

        # Super-resolution of adjacency matrix
        A_h = torch.matmul(torch.matmul(torch.matmul(self.weights, S_d), torch.t(U_l)), Z_l) # Apply propagation rule A_h = W * S_d * U_l * Z_l
        A_h = torch.abs(A_h)
        A_h = A_h.fill_diagonal_(1) # Add self-loops to the high resolution graph

        # Super-resolution of node embeddings
        X_h = torch.matmul(A_h, torch.t(A_h))
        X_h = (X_h + torch.t(X_h)) / 2
        X_h = X_h.fill_diagonal_(1)
        X_h = torch.abs(X_h)

        return A_h, X_h

## Define the GCN (Graph Convolutional Network) layer

In [8]:
class GCN(nn.Module):
    """
    Simple GCN layer, similar to implementation done in Kipf et al.'s paper
    "Semi-Supervised Classification with Graph Convolutional Networks"
    (https://arxiv.org/abs/1609.02907)
    """
    def __init__(self, in_features, out_features, dropout, act=F.relu):
        super(GCN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = torch.nn.Parameter(
            torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
         # Apply Xavier uniform initialization to the weights
        init.xavier_uniform_(self.weight)

    def forward(self, input, adj):
        # Apply dropout to input features
        input = F.dropout(input, self.dropout, self.training)
        # Message update rule: Hk+1 = act(A*Hk*W)
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        output = self.act(output)
        return output

## Define the AGSR-Vec (Adversarial Graph Super Resolution-Vec) model

#### Helper function(s) for the model

In [9]:
def normalize_adj_torch(A):
    # Calculate inverse square root of each node's degree
    r_inv_sqrt = torch.pow(A.sum(1), -0.5).flatten()
    # Mask to zero to avoid division by zero issues
    r_inv_sqrt[torch.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = torch.diag(r_inv_sqrt)

    # Apply normalization steps on the adjacency matrix A.
    A = torch.matmul(A, r_mat_inv_sqrt)
    A = torch.transpose(A, 0, 1)
    A = torch.matmul(A, r_mat_inv_sqrt)
    return A

### Define the AGSR-Vec model

In [10]:

class AGSRVec(nn.Module):
    '''
    AGSRVec model for super-resolving low-resolution (LR) input data with an upsampling layer and graph convolutional networks (GCNs).

    Args:
        hr_dim (int): Dimensionality of the high-resolution (HR) embeddings.
        hidden_dim (int): Dimensionality of the hidden layers in GCNs.

    Attributes:
        gsr_layer (GSRLayer): Upsampling layer.
        gcn_1 (GCN): First graph convolutional network.
        gcn_2 (GCN): Second graph convolutional network.

    For detailed explanation refer to README.md.
    '''

    def __init__(self, hr_dim, hidden_dim):
        super(AGSRVec, self).__init__()
        self.gsr_layer = GSRLayer(hr_dim)
        self.gcn_1 = GCN(hr_dim, hidden_dim, 0, act=F.relu)
        self.gcn_2 = GCN(hidden_dim, hr_dim, 0, act=F.relu)

    def forward(self, A_l, X_initial):
        '''
        Forward pass performing the following steps:
            (1) Upscaling initial LR node embeddings
            (2) Passing HR node embeddings through GCN_1
            (3) Passing HR hidden_dim embeddings through GCN_2
            (4) Symmetrizing output by averaging with transpose and addding self-connections

        Args:
            A_l: LR Input adj matrix
            X_initial: Initial LR node embeddings (from Node2Vec)
        '''
        # X_initial.shape = (lr_dim, hr_dim), A_l.shape = (lr_dim, lr_dim)
        A_l = normalize_adj_torch(A_l)
        A_h, Z_h = self.gsr_layer(A_l, X_initial)

        # Refine HR embeddings, Z_h through GCN Layer(s)
        # Z_h.shape = A_h.shape = (hr_dim, hr_dim)
        Z_h = self.gcn_1(Z_h, A_h)
        Z_h = self.gcn_2(Z_h, A_h)

        # Symmetrize the output by averaging with its transpose and ensure self-connections
        Z_h = (Z_h + torch.t(Z_h)) / 2
        Z_h = Z_h.fill_diagonal_(1)
        return torch.abs(Z_h)

## Define the discriminator model

In [11]:
class Dense(nn.Module):
    '''
    Vanilla densely-connected NN layer.

    Args:
        n1 (int): Number of input features.
        n2 (int): Number of output features.
        mean_dense (float): Mean of the normal distribution for weight initialization.
        std_dense (float): Standard deviation of the normal distribution for weight initialization.
    '''
    def __init__(self, n1, n2, mean_dense, std_dense):
        super(Dense, self).__init__()
        self.weights = torch.nn.Parameter(
            torch.FloatTensor(n1, n2), requires_grad=True)
        init.normal_(self.weights, mean=mean_dense, std=std_dense)

    def forward(self, x):
        out = torch.mm(x, self.weights)
        return out

class Discriminator(nn.Module):
    '''
    Discriminator model used as part of an adversarial model to distinguish between whether a
    high-resolution connectome is from a prior ground-truth high-resolution distribution or generated.

    Args:
        hr_dim (int): Dimensionality of the input features.
        mean_dense (float): Mean of the normal distribution for weight initialization.
        std_dense (float): Standard deviation of the normal distribution for weight initialization.
    '''
    def __init__(self, hr_dim, mean_dense, std_dense):
        super(Discriminator, self).__init__()
        self.dense_1 = Dense(hr_dim, hr_dim, mean_dense, std_dense)
        self.relu_1 = nn.ReLU(inplace=False)

        self.dense_2 = Dense(hr_dim, hr_dim, mean_dense, std_dense)
        self.relu_2 = nn.ReLU(inplace=False)

        self.dense_3 = Dense(hr_dim, 1, mean_dense, std_dense)
        self.sigmoid = nn.Sigmoid() # Apply

    def forward(self, inputs):
        dc_den1 = self.relu_1(self.dense_1(inputs))
        dc_den2 = self.relu_2(self.dense_2(dc_den1))
        output = self.sigmoid(self.dense_3(dc_den2))
        return torch.abs(output)

## Matrix Vectorising and Anti-Vectorising functions

In [12]:
class MatrixVectorizer:
    """
    A class for transforming between matrices and vector representations.

    Provides methods to convert a symmetric matrix into a vector (vectorize)
    and to reconstruct the matrix from its vector form (anti_vectorize), focusing on
    vertical (column-based) traversal and handling of elements.
    """

    def __init__(self):
        pass

    @staticmethod
    def vectorize(matrix, include_diagonal=False):
        """
        Converts a matrix into a vector by vertically extracting elements.

        This method traverses the matrix column by column, collecting elements from the
        upper triangle, and optionally includes the diagonal elements immediately below
        the main diagonal based on the include_diagonal flag.

        Arguments:
        - matrix (numpy.ndarray): The matrix to be vectorized.
        - include_diagonal (bool, optional): Flag to include diagonal elements in the vectorization.
          Defaults to False.

        Returns:
        - numpy.ndarray: The vectorized form of the matrix.
        """
        # Determine the size of the matrix based on its first dimension
        matrix_size = matrix.shape[0]

        # Initialize an empty list to accumulate vector elements
        vector_elements = []

        # Iterate over columns and then rows to collect the relevant elements
        for col in range(matrix_size):
            for row in range(matrix_size):
                # Skip diagonal elements if not including them
                if row != col:
                    if row < col:
                        # Collect upper triangle elements
                        vector_elements.append(matrix[row, col])
                    elif include_diagonal and row == col + 1:
                        # Optionally include the diagonal elements immediately below the diagonal
                        vector_elements.append(matrix[row, col])

        return np.array(vector_elements)

    @staticmethod
    def anti_vectorize(vector, matrix_size, include_diagonal=False):
        """
        Reconstructs a matrix from its vector form, filling it vertically.

        The method fills the matrix by reflecting vector elements into the upper triangle
        and optionally including the diagonal elements based on the include_diagonal flag.

        Arguments:
        - vector (numpy.ndarray): The vector to be transformed into a matrix.
        - matrix_size (int): The size of the square matrix to be reconstructed.
        - include_diagonal (bool, optional): Flag to include diagonal elements in the reconstruction.
          Defaults to False.

        Returns:
        - numpy.ndarray: The reconstructed square matrix.
        """
        # Initialize a square matrix of zeros with the specified size
        matrix = np.zeros((matrix_size, matrix_size))

        # Index to keep track of the current position in the vector
        vector_idx = 0

        # Fill the matrix by iterating over columns and then rows
        for col in range(matrix_size):
            for row in range(matrix_size):
                # Skip diagonal elements if not including them
                if row != col:
                    if row < col:
                        # Reflect vector elements into the upper triangle and its mirror in the lower triangle
                        matrix[row, col] = vector[vector_idx]
                        matrix[col, row] = vector[vector_idx]
                        vector_idx += 1
                    elif include_diagonal and row == col + 1:
                        # Optionally fill the diagonal elements after completing each column
                        matrix[row, col] = vector[vector_idx]
                        matrix[col, row] = vector[vector_idx]
                        vector_idx += 1

        return matrix

## Define Train, Predict and Test function

In [13]:
def gaussian_noise_layer(input_layer, mean_gaus, std_gaus):
    '''
    Adds Gaussian noise to the input layer.
    '''
    z = torch.empty_like(input_layer)
    noise = z.normal_(mean=mean_gaus, std=std_gaus)
    z = torch.abs(input_layer + noise)

    z = (z + z.t())/2
    z = z.fill_diagonal_(1)
    return z

In [14]:
def predict(model, A, X_initial):
    '''
    Predicts using the given model.

    Args:
        model (nn.Module): The graph nn model (AGSRVec) to use for prediction.
        A (list): List of adjacency matrices.
        X_initial (numpy.ndarray or torch.Tensor): Initial LR node embeddings.

    Returns:
        tuple: Contains predicted adjacency matrices and correspondin vectorized forms.
    '''
    model.eval()

    # Convert X_initial to torch.Tensor if needed
    if not isinstance(X_initial, torch.Tensor):
        X_initial = torch.from_numpy(X_initial).type(torch.FloatTensor).to(device)

    preds_list_adjacency = []
    preds_list_vector = []
    for lr, X in zip(A, X_initial):
        if not np.any(lr) == False:
            lr = torch.from_numpy(lr).type(torch.FloatTensor).to(device)
            X = X.float().to(device)

            # Generate predictions using the model
            preds = model(lr, X)
            adjacency = preds.cpu().detach().numpy()
            preds_list_adjacency.append(adjacency)
            preds_list_vector.append(MatrixVectorizer.vectorize(adjacency))

    preds_list_adjacency = np.array(preds_list_adjacency)
    preds_list_vector = np.array(preds_list_vector)
    return preds_list_adjacency, preds_list_vector

In [15]:
def evaluate(preds, truths, full_eval=True):
    '''
    Evaluate metrics comparing predicted results against ground truth data.

    Args:
        preds (numpy.ndarray): Predicted HR output
        truths (numpy.ndarray): Ground truth HR output
        full_eval (bool, optional): Flag to perform full evaluation including computing centrality measures

    Returns:
        tuple: Tuple containing evaluation metrics.
            If full_eval is True:
                - Mean absolute error (MAE)
                - Pearson correlation coefficient (PCC)
                - Jensen-Shannon divergence (JS_DIS)
                - Average MAE for betweenness centrality (avg_mae_bc)
                - Average MAE for eigenvector centrality (avg_mae_ec)
                - Average MAE for pagerank centrality (avg_mae_pc)
                - Average MAE for communicability betweenness entrality (avg_mae_cbc)
                - Average MAE for degree centrality (avg_mae_dc)
            If full_eval is False:
                - Mean absolute error (MAE)
                - Pearson correlation coefficient (PCC)
                - Jensen-Shannon divergence (JS_DIS)
    '''
    # Initialize lists to store MAEs for each centrality measure
    mae_bc = []
    mae_ec = []
    mae_pc = []
    mae_cbc = []
    mae_dc = []

    pred_1d_list = []
    gt_1d_list = []
    count = 0
    # Iterate over each test sample
    for pred, truth in zip(preds, truths):
        if full_eval:
            print(count)
        count += 1
        # Convert adjacency matrices to NetworkX graph
        pred_graph = nx.from_numpy_array(pred)
        gt_graph = nx.from_numpy_array(truth)

        if full_eval and nx.is_connected(pred_graph) and nx.is_connected(gt_graph):
            # Compute predicted centrality measures
            pred_bc = nx.betweenness_centrality(pred_graph)
            pred_ec = nx.eigenvector_centrality(pred_graph)
            pred_pc = nx.pagerank(pred_graph)
            pred_cbc = nx.communicability_betweenness_centrality(pred_graph)
            pred_dc = nx.degree_centrality(pred_graph)

            # Compute ground-truth centrality measures
            gt_bc = nx.betweenness_centrality(gt_graph)
            gt_ec = nx.eigenvector_centrality(gt_graph)
            gt_pc = nx.pagerank(gt_graph)
            gt_cbc = nx.current_flow_betweenness_centrality(gt_graph)
            gt_dc = nx.degree_centrality(gt_graph)

            # Convert centrality dictionaries to lists
            pred_bc_values = list(pred_bc.values())
            pred_ec_values = list(pred_ec.values())
            pred_pc_values = list(pred_pc.values())
            pred_cbc_values = list(pred_cbc.values())
            pred_dc_values = list(pred_dc.values())

            gt_bc_values = list(gt_bc.values())
            gt_ec_values = list(gt_ec.values())
            gt_pc_values = list(gt_pc.values())
            gt_cbc_values = list(gt_cbc.values())
            gt_dc_values = list(gt_dc.values())

            # Compute MAEs for centrality measures
            mae_bc.append(mean_absolute_error(pred_bc_values, gt_bc_values))
            mae_ec.append(mean_absolute_error(pred_ec_values, gt_ec_values))
            mae_pc.append(mean_absolute_error(pred_pc_values, gt_pc_values))
            mae_cbc.append(mean_absolute_error(pred_cbc_values, gt_cbc_values))
            mae_dc.append(mean_absolute_error(pred_dc_values, gt_dc_values))

        # Vectorize matrices
        pred_1d_list.append(MatrixVectorizer.vectorize(pred))
        gt_1d_list.append(MatrixVectorizer.vectorize(truth))

    if full_eval:
        # Compute average MAEs
        avg_mae_bc = sum(mae_bc) / len(mae_bc)
        avg_mae_ec = sum(mae_ec) / len(mae_ec)
        avg_mae_pc = sum(mae_pc) / len(mae_pc)
        avg_mae_cbc = sum(mae_cbc) / len(mae_cbc)
        avg_mae_dc = sum(mae_dc) / len(mae_dc)

    # Concatenate flattened matrices
    pred_1d = np.concatenate(pred_1d_list)
    gt_1d = np.concatenate(gt_1d_list)

    mae = mean_absolute_error(pred_1d, gt_1d)
    pcc = pearsonr(pred_1d, gt_1d)[0]
    js_dis = jensenshannon(pred_1d, gt_1d)

    if full_eval:
        return mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc, avg_mae_cbc, avg_mae_dc
    else:
        return mae, pcc, js_dis

In [16]:
import pandas as pd
import numpy as np
import networkx as nx
from scipy.stats import pearsonr
from scipy.spatial.distance import jensenshannon
from sklearn.metrics import mean_absolute_error
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from community import community_louvain
import os

def calculate_centralities(adj_matrix):
    if adj_matrix.shape[0] != adj_matrix.shape[1]:
        raise ValueError(f"Adjacency matrix is not square: shape={adj_matrix.shape}")
    print(f"Processing adjacency matrix of shape: {adj_matrix.shape}")

    G = nx.from_numpy_array(adj_matrix)
    partition = community_louvain.best_partition(G)

    # Calculate the participation coefficient with the partition
    pc_dict = participation_coefficient(G, partition)

    # Calculate averages of centrality measures
    pr = nx.pagerank(G, alpha=0.9)
    ec = nx.eigenvector_centrality_numpy(G, max_iter=100)
    bc = nx.betweenness_centrality(G, normalized=True, endpoints=False)
    ns = np.array(list(nx.degree_centrality(G).values())) * (len(G.nodes()) - 1)
    acc = nx.average_clustering(G, weight=None)

    # Average participation coefficient
    pc_avg = np.mean(list(pc_dict.values()))

    return {
        'pr': np.mean(list(pr.values())),
        'ec': np.mean(list(ec.values())),
        'bc': np.mean(list(bc.values())),
        'ns': ns,
        'pc': pc_avg,
        'acc': acc
    }

def participation_coefficient(G, partition):
    # Initialize dictionary for participation coefficients
    pc_dict = {}

    # Calculate participation coefficient for each node
    for node in G.nodes():
        node_degree = G.degree(node)
        if node_degree == 0:
            pc_dict[node] = 0.0
        else:
            # Count within-module connections
            within_module_degree = sum(1 for neighbor in G[node] if partition[neighbor] == partition[node])
            # Calculate participation coefficient
            pc_dict[node] = 1 - (within_module_degree / node_degree) ** 2

    return pc_dict


def evaluate_all(true_hr_matrices, predicted_hr_matrices, output_path='ID-randomCV.csv'):
    print(true_hr_matrices.shape)
    print(predicted_hr_matrices.shape)

    num_subjects = true_hr_matrices.shape[0]
    results = []

    for i in range(num_subjects):
        true_matrix = true_hr_matrices[i, :, :]
        pred_matrix = predicted_hr_matrices[i, :, :]

        print(f"Evaluating subject {i+1} with matrix shapes: true={true_matrix.shape}, pred={pred_matrix.shape}")

        if true_matrix.shape != pred_matrix.shape or true_matrix.shape[0] != true_matrix.shape[1]:
            print(f"Error: Matrix shape mismatch or not square for subject {i+1}: true={true_matrix.shape}, pred={pred_matrix.shape}")
            continue

        metrics = {
            'ID': i + 1,
            'MAE': mean_absolute_error(true_matrix.flatten(), pred_matrix.flatten()),
            'PCC': pearsonr(true_matrix.flatten(), pred_matrix.flatten())[0],
            'JSD': jensenshannon(true_matrix.flatten(), pred_matrix.flatten()),
        }

        true_metrics = calculate_centralities(true_matrix)
        pred_metrics = calculate_centralities(pred_matrix)

        for key in ['NS', 'PR', 'EC', 'BC', 'PC', 'ACC']:
            metrics[f'MAE in {key}'] = mean_absolute_error([true_metrics[key.lower()]], [pred_metrics[key.lower()]])

        results.append(metrics)

    df = pd.DataFrame(results)
    if not df.empty:
        # Check if the file exists to decide whether to write headers
        file_exists = os.path.isfile(output_path)

        df.to_csv(output_path, mode='a', header=not file_exists, index=False)
        print(f"Results appended to {output_path}.")
    else:
        print("No data to save.")



## Define Jensen-Shannon Divergence and Weights Histogram

In [17]:
def js_divergence(p, q):
    """
    Calculate the Jensen-Shannon Divergence between two probability distributions
    """
    m = 0.5 * (p + q)
    # Use torch.distributions.kl_divergence for stability and avoid manual log
    p_dist = torch.distributions.Categorical(probs=p)
    q_dist = torch.distributions.Categorical(probs=q)
    m_dist = torch.distributions.Categorical(probs=m)
    return 0.5 * (torch.distributions.kl_divergence(p_dist, m_dist) + torch.distributions.kl_divergence(q_dist, m_dist))

def compute_histogram(tensor, bins=10, min=0, max=1, eps=1e-10):
    # Create a histogram of tensor values, assuming tensor values are in [min, max]
    histogram = torch.histc(tensor, bins=bins, min=min, max=max)
    # Add a small epsilon to each bin count to avoid division by zero or log(0)
    histogram += eps
    # Normalize the histogram to get a probability distribution
    prob_distribution = histogram / torch.sum(histogram)
    return prob_distribution

def jsd_loss(Z_h, A_h):
    '''
    Compute the Jensen-Shannon Divergence (JSD) loss between predicted and ground truth distributions

    Args:
        Z_h (torch.Tensor): Predicted distribution tensor
        A_h (torch.Tensor): Ground truth distribution tensor
    '''
    Z_h_flat = Z_h.flatten()
    A_h_flat = A_h.flatten()
    pred_dist = compute_histogram(Z_h_flat, bins=10, min=0, max=1)
    true_dist = compute_histogram(A_h_flat, bins=10, min=0, max=1)
    epsilon = 1e-10  # To prevent log(0)
    jsd_loss = js_divergence(torch.log(pred_dist + epsilon), torch.log(true_dist + epsilon))
    return jsd_loss

## Main Training Loop

In [18]:
def train(model, A_train, GT_train, X_train_initial, num_epochs, lr, hr_dim, mean_dense, std_dense, mean_gaus, std_gaus, lmbda, mu=1, A_test=None, GT_test=None, X_test_initial=None):
    '''
    Train the model using the given arguments.

    Args:
        model (nn.Module): The graph-based leanring model (AGSRVec) to train.
        A_train (list): List of training LR adjacency matrices.
        GT_train (list): List of ground truth HR training output.
        X_train_initial (numpy.ndarray or list): Initial LR data for training.
        num_epochs (int): Number of training epochs.
        lr (float): Learning rate for optimization.
        hr_dim (int): Dimensionality of the high-resolution data.
        mean_dense (float): Mean of the normal distribution for weight initialization in Discrimnator dense layers.
        std_dense (float): Standard deviation of the normal distribution for weight initialization in Discriminator dense layers.
        mean_gaus (float): Mean of the Gaussian noise added to genearted samples.
        std_gaus (float): Standard deviation of the Gaussian noise added to generated samples.
        lmbda (float): Weight for the Jensen-Shannon Divergence loss.
        mu (float, optional): Additional weight for the MSE loss. Default is 1.
        A_test (list, optional): List of test adjacency matrices. Default is None.
        GT_test (list, optional): List of ground truth test data. Default is None.
        X_test_initial (numpy.ndarray or list, optional): Initial input data for testing. Default is None.
    '''

    # Define loss criterions
    bce_loss = nn.BCELoss()
    criterion_mse = nn.MSELoss()
    criterion_mae = nn.L1Loss()

    # Initialize Discriminator model
    modelD = Discriminator(hr_dim, mean_dense, std_dense).to(device)

    # Initialize optimizers for generator and discriminator
    optimizerG = optim.Adam(model.parameters(), lr=lr)
    optimizerD = optim.Adam(modelD.parameters(), lr=lr)

    best_loss = 100
    counter = 0

    # Main training loop
    for epoch in tqdm(range(1, num_epochs+1), desc='Epochs', file=sys.stdout, mininterval=1):
        epoch_loss = []
        epoch_mse_error = []
        epoch_mae_error = []
        epoch_jsd_error = []
        epoch_eigen_error = []

        # Set models to training mode
        model.train()
        modelD.train()

        # Itearte over training data (paired with corresponding initial node embeddings)
        for A_l, A_h, X_initial in zip(A_train, GT_train, X_train_initial):
            optimizerG.zero_grad()
            optimizerD.zero_grad()

            # Convert data to tensors and move to appropriate device
            A_l = torch.from_numpy(A_l).float().to(device)
            A_h = torch.from_numpy(A_h).float().to(device)
            X_initial = torch.from_numpy(X_initial).float().to(device)

            # Compute eigenvectors of the ground truth adjacency matrix
            _, U_h = torch.linalg.eigh(A_h, UPLO='U')

            # Pass input through model
            Z_h = model(A_l, X_initial)

            # Compute Jensen-Shannon Divergence loss
            jsd_epoch_loss = jsd_loss(Z_h, A_h)
            jsd_error = jsd_loss(Z_h, A_h)

            # Compute total AGSRVec loss
            mse_loss = criterion_mse(Z_h, A_h) + mu * criterion_mse(model.gsr_layer.weights, U_h) + lmbda * jsd_epoch_loss

            # Compute individual errors for monitoring
            mse_error = criterion_mse(Z_h, A_h)
            mae_error = criterion_mae(Z_h, A_h)
            eigen_error = criterion_mse(model.gsr_layer.weights, U_h)

            # Generate fake data using Gaussian noise
            real_data = Z_h.detach()
            fake_data = gaussian_noise_layer(A_h, mean_gaus, std_gaus)

            # Train-step for discriminator model
            d_real = modelD(real_data)
            d_fake = modelD(fake_data)
            d_loss = bce_loss(d_real, torch.ones_like(d_real)) + bce_loss(d_fake, torch.zeros_like(d_fake))
            d_loss.backward()
            optimizerD.step()

            d_fake = modelD(gaussian_noise_layer(A_h, mean_gaus, std_gaus))

            # Compute generator loss
            gen_loss = bce_loss(d_fake, torch.ones_like(d_fake)) + mse_loss
            gen_loss.backward()
            optimizerG.step()

            epoch_loss.append(gen_loss.item())
            epoch_mse_error.append(mse_error.item())
            epoch_mae_error.append(mae_error.item())
            epoch_jsd_error.append(jsd_error.item())
            epoch_eigen_error.append(eigen_error.item())

        tqdm.write("Epoch: {}, Loss: {:.4f}, MSE_Error: {:.4f}%, JSD_Error: {:.4f}, Eigen_Error: {:.4f},  MAE_Error: {:.4f}%".format(epoch, np.mean(epoch_loss), np.mean(epoch_mse_error)*100, lmbda * np.mean(epoch_jsd_error), np.mean(epoch_eigen_error), np.mean(epoch_mae_error)*100))

        # Evaluate on test set if available
        if A_test is not None and GT_test is not None and X_test_initial is not None:
            preds_adjacencies, _ = predict(model, A_test, X_test_initial)
            mae, pcc, js_dis = evaluate(preds_adjacencies, GT_test, full_eval=False)
            tqdm.write("Evaluation  mae: {:.4f}%, pcc: {:.4f}, js_dis: {:.4f}".format(mae*100, pcc, js_dis))

            if mae >= best_loss:
              counter += 1
              if counter >= 10 and epoch >= 110:
                break
            else:
              best_loss = mae
              counter = 0
              tqdm.write(f'----best model epoch: {epoch}----')


#### Now that all functions for model initialization and tarining are defined, we define main scripts for loading in data, k-fold cross-validation, and full-model training.

## Load in {lr_train, hr_train, lr_test}
#### We standardize data with 0.5 mean and clip for values < 0 post-standardization

In [None]:
# Initialize StandardScalers for HR and LR separately
scaler_hr = StandardScaler()
scaler_lr = StandardScaler()

# Load the training data
training_1 = pd.read_csv('RandomCV/Fold1/lr_split_1.csv').to_numpy()
training_2 = pd.read_csv('RandomCV/Fold2/lr_split_2.csv').to_numpy()
training_3 = pd.read_csv('RandomCV/Fold3/lr_split_3.csv').to_numpy()
print(training_1.shape)
print(training_2.shape)
print(training_3.shape)
training_A = np.concatenate((training_1, training_2, training_3), axis=0)
A_all = []
for i in range(len(training_A)):
    A = MatrixVectorizer.anti_vectorize(training_A[i], 160, include_diagonal=False)
    A_scaled = scaler_lr.fit_transform(A)  # Scale the LR data
    A_scaled = A_scaled + 0.5
    # Clip negative values to zero
    A_scaled[A_scaled < 0.0] = 0.0
    A_all.append(A_scaled)
training_A = np.array(A_all)

# Load the training truths
training_truths_1 = pd.read_csv('RandomCV/Fold1/hr_split_1.csv').to_numpy()
training_truths_2 = pd.read_csv('RandomCV/Fold2/hr_split_2.csv').to_numpy()
training_truths_3 = pd.read_csv('RandomCV/Fold3/hr_split_3.csv').to_numpy()
print(training_truths_1.shape)
print(training_truths_2.shape)
print(training_truths_3.shape)
training_truths = np.concatenate((training_truths_1, training_truths_2, training_truths_3), axis=0)
A_all = []
for i in range(len(training_truths)):
    A = MatrixVectorizer.anti_vectorize(training_truths[i], 268, include_diagonal=False)
    A_all.append(A)
training_truths = np.array(A_all)


## (Pre-) Compute Node2Vec Embeddings for lr_{train, test}



In [None]:
# Initialize Node2Vec model
node2vec = Node2Vec()
training_initial_embeddings = []

for idx, A_l in enumerate(training_A):
    print(f"Computing Node2Vec initial embeddings for training sample {idx}")
    # Fit Node2Vec model to current adjacency matrix (converted to a nx graph)
    node2vec.fit(nx.from_numpy_array(A_l))
    X_initial = torch.tensor(np.array(node2vec._embedding))
    training_initial_embeddings.append(X_initial.float())

training_initial_embeddings = torch.stack(training_initial_embeddings).numpy()

# Save to file using numpy.save for efficiency
np.save('training_embeddings.npy', training_initial_embeddings)

### Load the initial embeddings produced by Node2Vec
To avoid re-computations, run this cell to load in pre-computed embeddings

In [23]:
training_initial_embeddings = np.load('training_embeddings.npy')

### Perform K-Fold Cross Validation
For K = 3

In [24]:
# Tuned (using Bayesian inference) hyperaprameters used in training
lmbda_best = 0.2
mu_best = 5
lr_best = 0.00015

#### Random dataset
##### If train_all is true, train the model with all the training data. Otherwise, split the train set into train set and validation set and train the model only with train set.

In [None]:
# fold1

# Initialize lists for storing model predictions and ground truths for each fold

train_all = True
fold_predictions = []
fold_truths = []

print("Fold number:", 1)

if train_all:
  # Get the training and testing data for the current fold
  A_train, A_test = training_A[93: ], training_A[: 93]
  print(A_train.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial, X_test_initial = training_initial_embeddings[93: ], training_initial_embeddings[: 93]
  print(X_train_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train, GT_test = training_truths[93: ], training_truths[: 93]
  print(GT_train.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=60, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best)
  torch.save(model.state_dict(), 'checkpoint_fold1_random_all.pt')

else:
  # Get the training and testing data for the current fold
  A_train = np.concatenate((training_A[93: 166], training_A[186: 259]), axis=0)
  A_val = np.concatenate((training_A[166: 186], training_A[259: ]), axis=0)
  A_test = training_A[: 93]
  print(A_train.shape)
  print(A_val.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial = np.concatenate((training_initial_embeddings[93: 166], training_initial_embeddings[186: 259]), axis=0)
  X_val_initial = np.concatenate((training_initial_embeddings[166: 186], training_initial_embeddings[259: ]), axis=0)
  X_test_initial = training_initial_embeddings[: 93]
  print(X_train_initial.shape)
  print(X_val_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train = np.concatenate((training_truths[93: 166], training_truths[186: 259]), axis=0)
  GT_val = np.concatenate((training_truths[166: 186], training_truths[259: ]), axis=0)
  GT_test = training_truths[: 93]
  print(GT_train.shape)
  print(GT_val.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=500, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best, A_test=A_val, GT_test=GT_val, X_test_initial=X_val_initial)

In [None]:
# fold2

# Initialize lists for storing model predictions and ground truths for each fold

train_all = True
fold_predictions = []
fold_truths = []

print("Fold number:", 2)

if train_all:
  # Get the training and testing data for the current fold
  A_train = np.concatenate((training_A[: 93], training_A[186: ]), axis=0)
  A_test = training_A[93: 186]
  print(A_train.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial = np.concatenate((training_initial_embeddings[: 93], training_initial_embeddings[186: ]), axis=0)
  X_test_initial = training_initial_embeddings[93: 186]
  print(X_train_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train = np.concatenate((training_truths[: 93], training_truths[186: ]), axis=0)
  GT_test = training_truths[93: 186]
  print(GT_train.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=54, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best)
  torch.save(model.state_dict(), 'checkpoint_fold2_random_all.pt')

else:
  # Get the training and testing data for the current fold
  A_train = np.concatenate((training_A[: 73], training_A[186: 259]), axis=0)
  A_val = np.concatenate((training_A[73: 93], training_A[259: ]), axis=0)
  A_test = training_A[93: 186]
  print(A_train.shape)
  print(A_val.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial = np.concatenate((training_initial_embeddings[: 73], training_initial_embeddings[186: 259]), axis=0)
  X_val_initial = np.concatenate((training_initial_embeddings[73: 93], training_initial_embeddings[259: ]), axis=0)
  X_test_initial = training_initial_embeddings[93: 186]
  print(X_train_initial.shape)
  print(X_val_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train = np.concatenate((training_truths[: 73], training_truths[186: 259]), axis=0)
  GT_val = np.concatenate((training_truths[73: 93], training_truths[259: ]), axis=0)
  GT_test = training_truths[93: 186]
  print(GT_train.shape)
  print(GT_val.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=500, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best, A_test=A_val, GT_test=GT_val, X_test_initial=X_val_initial)

In [None]:
# fold3

# Initialize lists for storing model predictions and ground truths for each fold

train_all = True
fold_predictions = []
fold_truths = []

print("Fold number:", 3)

if train_all:
  # Get the training and testing data for the current fold
  A_train, A_test = training_A[: 186], training_A[186: ]
  print(A_train.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial, X_test_initial = training_initial_embeddings[: 186], training_initial_embeddings[186: ]
  print(X_train_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train, GT_test = training_truths[: 186], training_truths[186: ]
  print(GT_train.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=63, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best)

  torch.save(model.state_dict(), 'checkpoint_fold3_random_all.pt')

else:
  # Get the training and testing data for the current fold
  A_train = np.concatenate((training_A[: 73], training_A[93: 166]), axis=0)
  A_val = np.concatenate((training_A[73: 93], training_A[166: 186]), axis=0)
  A_test = training_A[186: ]
  print(A_train.shape)
  print(A_val.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial = np.concatenate((training_initial_embeddings[: 73], training_initial_embeddings[93: 166]), axis=0)
  X_val_initial = np.concatenate((training_initial_embeddings[73: 93], training_initial_embeddings[166: 186]), axis=0)
  X_test_initial = training_initial_embeddings[186: ]
  print(X_train_initial.shape)
  print(X_val_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train = np.concatenate((training_truths[: 73], training_truths[93: 166]), axis=0)
  GT_val = np.concatenate((training_truths[73: 93], training_truths[166: 186]), axis=0)
  GT_test = training_truths[186: ]
  print(GT_train.shape)
  print(GT_val.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=500, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best, A_test=A_val, GT_test=GT_val, X_test_initial=X_val_initial)

#### Cluster dataset

In [None]:
# fold1

# Initialize lists for storing model predictions and ground truths for each fold

train_all = True
fold_predictions = []
fold_truths = []

print("Fold number:", 1)

if train_all:
  # Get the training and testing data for the current fold
  A_train, A_test = training_A[102: ], training_A[: 102]
  print(A_train.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial, X_test_initial = training_initial_embeddings[102: ], training_initial_embeddings[: 102]
  print(X_train_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train, GT_test = training_truths[102: ], training_truths[: 102]
  print(GT_train.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=43, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best)
  torch.save(model.state_dict(), 'checkpoint_fold1_cluster_all.pt')

else:
  # Get the training and testing data for the current fold
  A_train = np.concatenate((training_A[102: 184], training_A[204: 259]), axis=0)
  A_val = np.concatenate((training_A[184: 204], training_A[259: ]), axis=0)
  A_test = training_A[: 102]
  print(A_train.shape)
  print(A_val.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial = np.concatenate((training_initial_embeddings[102: 184], training_initial_embeddings[204: 259]), axis=0)
  X_val_initial = np.concatenate((training_initial_embeddings[184: 204], training_initial_embeddings[259: ]), axis=0)
  X_test_initial = training_initial_embeddings[: 102]
  print(X_train_initial.shape)
  print(X_val_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train = np.concatenate((training_truths[102: 184], training_truths[204: 259]), axis=0)
  GT_val = np.concatenate((training_truths[184: 204], training_truths[259: ]), axis=0)
  GT_test = training_truths[: 102]
  print(GT_train.shape)
  print(GT_val.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=500, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best, A_test=A_val, GT_test=GT_val, X_test_initial=X_val_initial)

In [None]:
# fold2

# Initialize lists for storing model predictions and ground truths for each fold

train_all = True
fold_predictions = []
fold_truths = []

print("Fold number:", 2)

if train_all:
  # Get the training and testing data for the current fold
  A_train = np.concatenate((training_A[: 102], training_A[204: ]), axis=0)
  A_test = training_A[102: 204]
  print(A_train.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial = np.concatenate((training_initial_embeddings[: 102], training_initial_embeddings[204: ]), axis=0)
  X_test_initial = training_initial_embeddings[102: 204]
  print(X_train_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train = np.concatenate((training_truths[: 102], training_truths[204: ]), axis=0)
  GT_test = training_truths[102: 204]
  print(GT_train.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=51, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best)
  torch.save(model.state_dict(), 'checkpoint_fold2_cluster_all.pt')

else:
  # Get the training and testing data for the current fold
  A_train = np.concatenate((training_A[: 82], training_A[204: 259]), axis=0)
  A_val = np.concatenate((training_A[82: 102], training_A[259: ]), axis=0)
  A_test = training_A[102: 204]
  print(A_train.shape)
  print(A_val.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial = np.concatenate((training_initial_embeddings[: 82], training_initial_embeddings[204: 259]), axis=0)
  X_val_initial = np.concatenate((training_initial_embeddings[82: 102], training_initial_embeddings[259: ]), axis=0)
  X_test_initial = training_initial_embeddings[102: 204]
  print(X_train_initial.shape)
  print(X_val_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train = np.concatenate((training_truths[: 82], training_truths[204: 259]), axis=0)
  GT_val = np.concatenate((training_truths[82: 102], training_truths[259: ]), axis=0)
  GT_test = training_truths[102: 204]
  print(GT_train.shape)
  print(GT_val.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=500, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best, A_test=A_val, GT_test=GT_val, X_test_initial=X_val_initial)

In [None]:
# fold3

# Initialize lists for storing model predictions and ground truths for each fold

train_all = True
fold_predictions = []
fold_truths = []

print("Fold number:", 3)

if train_all:
  # Get the training and testing data for the current fold
  A_train, A_test = training_A[: 204], training_A[204: ]
  print(A_train.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial, X_test_initial = training_initial_embeddings[: 204], training_initial_embeddings[204: ]
  print(X_train_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train, GT_test = training_truths[: 204], training_truths[204: ]
  print(GT_train.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=158, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best)
  torch.save(model.state_dict(), 'checkpoint_fold3_cluster_all.pt')

else:
  # Get the training and testing data for the current fold
  A_train = np.concatenate((training_A[: 82], training_A[102: 184]), axis=0)
  A_val = np.concatenate((training_A[82: 102], training_A[184: 204]), axis=0)
  A_test = training_A[204: ]
  print(A_train.shape)
  print(A_val.shape)
  print(A_test.shape)
  # Get the initial embeddings for the training and testing data
  X_train_initial = np.concatenate((training_initial_embeddings[: 82], training_initial_embeddings[102: 184]), axis=0)
  X_val_initial = np.concatenate((training_initial_embeddings[82: 102], training_initial_embeddings[184: 204]), axis=0)
  X_test_initial = training_initial_embeddings[204: ]
  print(X_train_initial.shape)
  print(X_val_initial.shape)
  print(X_test_initial.shape)
  # Get the truths for the current fold
  GT_train = np.concatenate((training_truths[: 82], training_truths[102: 184]), axis=0)
  GT_val = np.concatenate((training_truths[82: 102], training_truths[184: 204]), axis=0)
  GT_test = training_truths[204: ]
  print(GT_train.shape)
  print(GT_val.shape)
  print(GT_test.shape)

  # Initialize and train the model
  model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)

  train(model, A_train, GT_train, X_train_initial, num_epochs=500, lr=lr_best, hr_dim=268, mean_dense=0., std_dense=0.01, mean_gaus=0., std_gaus=0.1, lmbda=lmbda_best, mu=mu_best, A_test=A_val, GT_test=GT_val, X_test_initial=X_val_initial)

## Evaluation

### random

In [None]:
# fold1

# Get the training and testing data for the current fold
A_train, A_test = training_A[: 186], training_A[186: ]
print(A_train.shape)
print(A_test.shape)
# Get the initial embeddings for the training and testing data
X_train_initial, X_test_initial = training_initial_embeddings[: 186], training_initial_embeddings[186: ]
print(X_train_initial.shape)
print(X_test_initial.shape)
# Get the truths for the current fold
GT_train, GT_test = training_truths[: 186], training_truths[186: ]
print(GT_train.shape)
print(GT_test.shape)

model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)
state_dict_path = 'checkpoint_fold1_random_all.pt'
model.load_state_dict(torch.load(state_dict_path))

preds_adjacencies, preds_vectors = predict(model, A_test, X_test_initial)
# mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc, avg_mae_cbc, avg_mae_dc = evaluate(preds_adjacencies, GT_test, full_eval=True)
# print(f"Fold 1 MAE: {mae}, PCC: {pcc}, JS Distance: {js_dis}, Avg MAE BC: {avg_mae_bc}, Avg MAE EC: {avg_mae_ec}, Avg MAE PC: {avg_mae_pc}, Avg MAE CBC: {avg_mae_cbc}, Avg MAE DC: {avg_mae_dc}")

evaluate_all(GT_test, preds_adjacencies, output_path='4-randomCV.csv')

In [None]:
# fold2

# Get the training and testing data for the current fold
A_train = np.concatenate((training_A[: 93], training_A[186: ]), axis=0)
A_test = training_A[93: 186]
print(A_train.shape)
print(A_test.shape)
# Get the initial embeddings for the training and testing data
X_train_initial = np.concatenate((training_initial_embeddings[: 93], training_initial_embeddings[186: ]), axis=0)
X_test_initial = training_initial_embeddings[93: 186]
print(X_train_initial.shape)
print(X_test_initial.shape)
# Get the truths for the current fold
GT_train = np.concatenate((training_truths[: 93], training_truths[186: ]), axis=0)
GT_test = training_truths[93: 186]
print(GT_train.shape)
print(GT_test.shape)

model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)
state_dict_path = 'checkpoint_fold2_random_all.pt'
model.load_state_dict(torch.load(state_dict_path))

preds_adjacencies, preds_vectors = predict(model, A_test, X_test_initial)
# mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc, avg_mae_cbc, avg_mae_dc = evaluate(preds_adjacencies, GT_test, full_eval=True)
# print(f"Fold 1 MAE: {mae}, PCC: {pcc}, JS Distance: {js_dis}, Avg MAE BC: {avg_mae_bc}, Avg MAE EC: {avg_mae_ec}, Avg MAE PC: {avg_mae_pc}, Avg MAE CBC: {avg_mae_cbc}, Avg MAE DC: {avg_mae_dc}")

evaluate_all(GT_test, preds_adjacencies, output_path='4-randomCV.csv')

In [None]:
# fold3

# Get the training and testing data for the current fold
A_train, A_test = training_A[: 186], training_A[186: ]
print(A_train.shape)
print(A_test.shape)
# Get the initial embeddings for the training and testing data
X_train_initial, X_test_initial = training_initial_embeddings[: 186], training_initial_embeddings[186: ]
print(X_train_initial.shape)
print(X_test_initial.shape)
# Get the truths for the current fold
GT_train, GT_test = training_truths[: 186], training_truths[186: ]
print(GT_train.shape)
print(GT_test.shape)

model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)
state_dict_path = 'checkpoint_fold3_random_all.pt'
model.load_state_dict(torch.load(state_dict_path))

preds_adjacencies, preds_vectors = predict(model, A_test, X_test_initial)
# mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc, avg_mae_cbc, avg_mae_dc = evaluate(preds_adjacencies, GT_test, full_eval=True)
# print(f"Fold 1 MAE: {mae}, PCC: {pcc}, JS Distance: {js_dis}, Avg MAE BC: {avg_mae_bc}, Avg MAE EC: {avg_mae_ec}, Avg MAE PC: {avg_mae_pc}, Avg MAE CBC: {avg_mae_cbc}, Avg MAE DC: {avg_mae_dc}")

evaluate_all(GT_test, preds_adjacencies, output_path='4-randomCV.csv')

### cluster

In [None]:
# fold 1

# Get the training and testing data for the current fold
A_train, A_test = training_A[102: ], training_A[: 102]
print(A_train.shape)
print(A_test.shape)
# Get the initial embeddings for the training and testing data
X_train_initial, X_test_initial = training_initial_embeddings[102: ], training_initial_embeddings[: 102]
print(X_train_initial.shape)
print(X_test_initial.shape)
# Get the truths for the current fold
GT_train, GT_test = training_truths[102: ], training_truths[: 102]
print(GT_train.shape)
print(GT_test.shape)

# Initialize and train the model
model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)
state_dict_path = 'checkpoint_fold1_cluster_all.pt'
model.load_state_dict(torch.load(state_dict_path))

preds_adjacencies, preds_vectors = predict(model, A_test, X_test_initial)
# mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc, avg_mae_cbc, avg_mae_dc = evaluate(preds_adjacencies, GT_test, full_eval=True)
# print(f"Fold 1 MAE: {mae}, PCC: {pcc}, JS Distance: {js_dis}, Avg MAE BC: {avg_mae_bc}, Avg MAE EC: {avg_mae_ec}, Avg MAE PC: {avg_mae_pc}, Avg MAE CBC: {avg_mae_cbc}, Avg MAE DC: {avg_mae_dc}")

evaluate_all(GT_test, preds_adjacencies, output_path='4-clusterCV.csv')

In [None]:
# fold2

# Get the training and testing data for the current fold
A_train = np.concatenate((training_A[: 102], training_A[204: ]), axis=0)
A_test = training_A[102: 204]
print(A_train.shape)
print(A_test.shape)
# Get the initial embeddings for the training and testing data
X_train_initial = np.concatenate((training_initial_embeddings[: 102], training_initial_embeddings[204: ]), axis=0)
X_test_initial = training_initial_embeddings[102: 204]
print(X_train_initial.shape)
print(X_test_initial.shape)
# Get the truths for the current fold
GT_train = np.concatenate((training_truths[: 102], training_truths[204: ]), axis=0)
GT_test = training_truths[102: 204]
print(GT_train.shape)
print(GT_test.shape)

# Initialize and train the model
model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)
state_dict_path = 'checkpoint_fold2_cluster_all.pt'
model.load_state_dict(torch.load(state_dict_path))

preds_adjacencies, preds_vectors = predict(model, A_test, X_test_initial)
# mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc, avg_mae_cbc, avg_mae_dc = evaluate(preds_adjacencies, GT_test, full_eval=True)
# print(f"Fold 1 MAE: {mae}, PCC: {pcc}, JS Distance: {js_dis}, Avg MAE BC: {avg_mae_bc}, Avg MAE EC: {avg_mae_ec}, Avg MAE PC: {avg_mae_pc}, Avg MAE CBC: {avg_mae_cbc}, Avg MAE DC: {avg_mae_dc}")

evaluate_all(GT_test, preds_adjacencies, output_path='4-clusterCV.csv')

In [None]:
# fold3

# Get the training and testing data for the current fold
A_train, A_test = training_A[: 204], training_A[204: ]
print(A_train.shape)
print(A_test.shape)
# Get the initial embeddings for the training and testing data
X_train_initial, X_test_initial = training_initial_embeddings[: 204], training_initial_embeddings[204: ]
print(X_train_initial.shape)
print(X_test_initial.shape)
# Get the truths for the current fold
GT_train, GT_test = training_truths[: 204], training_truths[204: ]
print(GT_train.shape)
print(GT_test.shape)

# Initialize and train the model
model = AGSRVec(hr_dim=268, hidden_dim=268*2).to(device)
state_dict_path = 'checkpoint_fold3_cluster_all.pt'
model.load_state_dict(torch.load(state_dict_path))

preds_adjacencies, preds_vectors = predict(model, A_test, X_test_initial)
# mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc, avg_mae_cbc, avg_mae_dc = evaluate(preds_adjacencies, GT_test, full_eval=True)
# print(f"Fold 1 MAE: {mae}, PCC: {pcc}, JS Distance: {js_dis}, Avg MAE BC: {avg_mae_bc}, Avg MAE EC: {avg_mae_ec}, Avg MAE PC: {avg_mae_pc}, Avg MAE CBC: {avg_mae_cbc}, Avg MAE DC: {avg_mae_dc}")

evaluate_all(GT_test, preds_adjacencies, output_path='4-clusterCV.csv')