In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
import os
for dirname, _, filenames in os.walk('/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))


Below is Dr. Rekik's implementation from her paper in 2021:

https://www.sciencedirect.com/science/article/pii/S1361841521001304

In [None]:
import torch
import numpy as np
import os
import scipy.io
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from scipy.io import loadmat
import seaborn as sns
import pandas as pd
import torch.optim as optim
from evaluation import *
from utils import *
import math

In [None]:
# Set a fixed random seed for reproducibility across multiple libraries
import random
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

In [None]:
class MatrixVectorizer:
    """
    A class for transforming between matrices and vector representations.
    
    This class provides methods to convert a symmetric matrix into a vector (vectorize)
    and to reconstruct the matrix from its vector form (anti_vectorize), focusing on 
    vertical (column-based) traversal and handling of elements.
    """

    def __init__(self):
        """
        Initializes the MatrixVectorizer instance.
        
        The constructor currently does not perform any actions but is included for 
        potential future extensions where initialization parameters might be required.
        """
        pass

    @staticmethod
    def vectorize(matrix, include_diagonal=False):
        """
        Converts a matrix into a vector by vertically extracting elements.
        
        This method traverses the matrix column by column, collecting elements from the
        upper triangle, and optionally includes the diagonal elements immediately below
        the main diagonal based on the include_diagonal flag.
        
        Parameters:
        - matrix (numpy.ndarray): The matrix to be vectorized.
        - include_diagonal (bool, optional): Flag to include diagonal elements in the vectorization.
          Defaults to False.
        
        Returns:
        - numpy.ndarray: The vectorized form of the matrix.
        """
        # Determine the size of the matrix based on its first dimension
        matrix_size = matrix.shape[0]

        # Initialize an empty list to accumulate vector elements
        vector_elements = []

        # Iterate over columns and then rows to collect the relevant elements
        for col in range(matrix_size):
            for row in range(matrix_size):
                # Skip diagonal elements if not including them
                if row != col:  
                    if row < col:
                        # Collect upper triangle elements
                        vector_elements.append(matrix[row, col])
                    elif include_diagonal and row == col + 1:
                        # Optionally include the diagonal elements immediately below the diagonal
                        vector_elements.append(matrix[row, col])

        return np.array(vector_elements)

    @staticmethod
    def anti_vectorize(vector, matrix_size, include_diagonal=False):
        """
        Reconstructs a matrix from its vector form, filling it vertically.
        
        The method fills the matrix by reflecting vector elements into the upper triangle
        and optionally including the diagonal elements based on the include_diagonal flag.
        
        Parameters:
        - vector (numpy.ndarray): The vector to be transformed into a matrix.
        - matrix_size (int): The size of the square matrix to be reconstructed.
        - include_diagonal (bool, optional): Flag to include diagonal elements in the reconstruction.
          Defaults to False.
        
        Returns:
        - numpy.ndarray: The reconstructed square matrix.
        """
        # Initialize a square matrix of zeros with the specified size
        matrix = np.zeros((matrix_size, matrix_size))

        # Index to keep track of the current position in the vector
        vector_idx = 0

        # Fill the matrix by iterating over columns and then rows
        for col in range(matrix_size):
            for row in range(matrix_size):
                # Skip diagonal elements if not including them
                if row != col:  
                    if row < col:
                        # Reflect vector elements into the upper triangle and its mirror in the lower triangle
                        matrix[row, col] = vector[vector_idx]
                        matrix[col, row] = vector[vector_idx]
                        vector_idx += 1
                    elif include_diagonal and row == col + 1:
                        # Optionally fill the diagonal elements after completing each column
                        matrix[row, col] = vector[vector_idx]
                        matrix[col, row] = vector[vector_idx]
                        vector_idx += 1

        return matrix

In [None]:
# Path to your .mat files
lr_data_path = '../data/lr_train.csv'
hr_data_path = '../data/hr_train.csv'
lr_data_test_path = '../data/lr_test.csv'

# Load the data
lr_data = pd.read_csv(lr_data_path)
hr_data = pd.read_csv(hr_data_path)

# Print basic information about 'LR' and 'HR' variables
print("LR Data Shape:", lr_data.shape)
print("HR Data Shape:", hr_data.shape)

# If the data frames are not too large, you can print a small part of them
print("Sample from LR Data:")
print(lr_data.head())

print("Sample from HR Data:")
print(hr_data.head())

In [None]:
# Define a function to calculate statistics and return them in a dictionary
def calculate_statistics(data):
    statistics = {
        'Mean': np.mean(data),
        'Median': np.median(data),
        'Standard Deviation': np.std(data),
        'Min': np.min(data),
        'Max': np.max(data)
    }
    return statistics

# Extract the first line of LR and HR data as numpy arrays
lr_array = lr_data.iloc[0].to_numpy()
hr_array = hr_data.iloc[0].to_numpy()

# Calculate statistics for LR and HR data
lr_stats = calculate_statistics(lr_array)
hr_stats = calculate_statistics(hr_array)

# Create a DataFrame to hold the statistics for comparison
df_stats = pd.DataFrame({'LR Data': lr_stats, 'HR Data': hr_stats})

# Round the numbers to four decimal places for better readability
df_stats = df_stats.round(4)

df_stats

In [None]:
# Setting the Seaborn theme for nice aesthetics
sns.set_theme(style="whitegrid")

# Plotting histograms on the same figure for comparison
plt.figure(figsize=(10, 6))

# Making histograms semi-transparent with alpha and using a higher bin count for finer detail
sns.histplot(lr_array, bins=100, color='blue', alpha=0.5, label='LR Data')
sns.histplot(hr_array, bins=100, color='red', alpha=0.5, label='HR Data')

# Adding titles and labels
plt.title('Combined Distribution of LR and HR Data')
plt.xlabel('Value')
plt.ylabel('Frequency')

# Adding a legend to differentiate between LR and HR data
plt.legend()

plt.show()

In [None]:
# Setting the Seaborn theme for aesthetics
sns.set_theme(style="whitegrid")

# Plotting histograms on the same figure for comparison, excluding zeros
plt.figure(figsize=(10, 6))

# Making histograms semi-transparent with alpha and using a higher bin count for finer detail
# We filter out the zeros using lr_array[lr_array > 0].flatten() and hr_array[hr_array > 0].flatten()
sns.histplot(lr_array[lr_array > 0].flatten(), bins=100, color='blue', alpha=0.5, label='LR Data')
sns.histplot(hr_array[hr_array > 0].flatten(), bins=100, color='red', alpha=0.5, label='HR Data')

# Adding titles and labels
plt.title('Combined Distribution of LR and HR Data (Excluding Zeros)')
plt.xlabel('Value')
plt.ylabel('Frequency')

# Adding a legend to differentiate between LR and HR data
plt.legend()

# Display the plot
plt.show()

In [None]:
lr_matrix = MatrixVectorizer.anti_vectorize(lr_array, 160)
hr_matrix = MatrixVectorizer.anti_vectorize(hr_array, 268)

plt.figure(figsize=(20, 8))

# Heatmap of a subset of LR data
plt.subplot(1, 2, 1)
plt.imshow(lr_matrix, aspect='auto', cmap='viridis')  # Adjust subset size as needed
plt.colorbar()
plt.title('LR Data Heatmap')

# Heatmap of a subset of HR data
plt.subplot(1, 2, 2)
plt.imshow(hr_matrix, aspect='auto', cmap='viridis')  # Adjust subset size as needed
plt.colorbar()
plt.title('HR Data Heatmap')

plt.show()

In [None]:
def pad_HR_adj(label, split):
    """
    Pads a high-resolution (HR) adjacency matrix with zeros on all sides.

    Parameters:
    - label (torch.Tensor): The HR adjacency matrix to be padded.
    - split (int): The number of zeros to add to each side of the matrix.

    Returns:
    torch.Tensor: The padded HR adjacency matrix.
    """
    padded_label = F.pad(label, pad=(split, split, split, split), mode="constant", value=0)
    return padded_label


def normalize_adj_torch(mx):
    rowsum = mx.sum(1)
    r_inv_sqrt = torch.pow(rowsum, -0.5).flatten()
    r_inv_sqrt[torch.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = torch.diag(r_inv_sqrt)
    mx = torch.matmul(mx, r_mat_inv_sqrt)
    mx = torch.transpose(mx, 0, 1)
    mx = torch.matmul(mx, r_mat_inv_sqrt)
    return mx


def unpad(data, split):
    idx_0 = data.shape[0]-split
    idx_1 = data.shape[1]-split
    # print(idx_0,idx_1)
    train = data[split:idx_0, split:idx_1]
    return train


def extract_data(subject, session_str, parcellation_str, subjects_roi):
    folder_path = os.path.join(
        path, str(subject), session_str, parcellation_str)
    roi_data = scipy.io.loadmat(os.path.join(folder_path, roi_str))
    roi = roi_data['r']

    # Replacing NaN values
    col_mean = np.nanmean(roi, axis=0)
    inds = np.where(np.isnan(roi))
    roi[inds] = 1

    # Taking the absolute values of the matrix
    roi = np.absolute(roi, dtype=np.float32)

    if parcellation_str == 'shen_268':
        roi = np.reshape(roi, (1, 268, 268))
    else:
        roi = np.reshape(roi, (1, 160, 160))

    if subject == 25629:
        subjects_roi = roi
    else:
        subjects_roi = np.concatenate((subjects_roi, roi), axis=0)

    return subjects_roi


def load_data(start_value, end_value):

    subjects_label = np.zeros((1, 268, 268))
    subjects_adj = np.zeros((1, 160, 160))

    for subject in range(start_value, end_value):
        subject_path = os.path.join(path, str(subject))

        if 'session_1' in os.listdir(subject_path):

            subjects_label = extract_data(
                subject, 'session_1', 'shen_268', subjects_label)
            subjects_adj = extract_data(
                subject, 'session_1', 'Dosenbach_160', subjects_adj)

    return subjects_adj, subjects_label


def data():
    subjects_adj, subjects_labels = load_data(25629, 25830)
    test_adj_1, test_labels_1 = load_data(25831, 25863)
    test_adj_2, test_labels_2 = load_data(30701, 30757)
    test_adj = np.concatenate((test_adj_1, test_adj_2), axis=0)
    test_labels = np.concatenate((test_labels_1, test_labels_2), axis=0)
    return subjects_adj, subjects_labels, test_adj, test_labels

In [None]:
def weight_variable_glorot(output_dim):

    input_dim = output_dim
    init_range = np.sqrt(6.0 / (input_dim + output_dim))
    initial = np.random.uniform(-init_range, init_range,
                                (input_dim, output_dim))

    return initial

class GraphUnpool(nn.Module):

    def __init__(self):
        super(GraphUnpool, self).__init__()

    def forward(self, A, X, idx):
        new_X = torch.zeros([A.shape[0], X.shape[1]])
        new_X[idx] = X
        return A, new_X


class GraphPool(nn.Module):

    def __init__(self, k, in_dim):
        super(GraphPool, self).__init__()
        self.k = k
        self.proj = nn.Linear(in_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, A, X):
        scores = self.proj(X)
        # scores = torch.abs(scores)
        scores = torch.squeeze(scores)
        scores = self.sigmoid(scores/100)
        num_nodes = A.shape[0]
        values, idx = torch.topk(scores, int(self.k*num_nodes))
        new_X = X[idx, :]
        values = torch.unsqueeze(values, -1)
        new_X = torch.mul(new_X, values)
        A = A[idx, :]
        A = A[:, idx]
        return A, new_X, idx


class GCN(nn.Module):

    def __init__(self, in_dim, out_dim):
        super(GCN, self).__init__()
        self.proj = nn.Linear(in_dim, out_dim)
        self.drop = nn.Dropout(p=0)

    def forward(self, A, X):
        X = self.drop(X)
        X = torch.matmul(A, X)
        X = self.proj(X)
        return X
    
class GCNGraphSage(nn.Module):

    def __init__(self, in_dim, out_dim, sampling=None, k=50):
        super(GCNGraphSage, self).__init__()
        self.proj = nn.Linear(in_dim, out_dim)
        self.drop = nn.Dropout(p=0)
        self.sampling = sampling
        self.k = k
        
    def selectKRandom(self, adj):
        adj = adj.detach().numpy()
        rows, cols = adj.shape
        new_adj = np.zeros_like(adj)
        for i in range(rows):
            indices = np.random.choice(range(cols), size=self.k, replace=False)
            new_adj[i][indices] = row[indices]
        return torch.tensor(new_adj)
    
    def selectKDeterministic(self, adj):
        adj = adj.detach().numpy()
        rows, cols = adj.shape
        new_adj = np.zeros_like(adj)
        for i in range(rows):
            row = adj[i]
            indices = np.argsort(arr)[-self.k:][::-1]
            new_adj[i][indices] = row[indices]
        return torch.tensor(new_adj)
    
    def selectKProbabilistic(self, adj):
        adj = adj.detach().numpy()
        rows, cols = adj.shape
        new_adj = np.zeros_like(adj)
        for i in range(rows):
            row = adj[i]
            probabilities = row / np.sum(row) 
            indices = random.choices(range(cols), probabilities, k=self.k)
            new_adj[i][indices] = row[indices]
        return torch.tensor(new_adj)

    def forward(self, A, X):
        if self.sampling == "random":
            A = self.selectKRandom(A)
        elif self.sampling == "deterministic":
            A = self.selectKDeterministic(A)
        elif self.sampling == "probabilistic":
            A = self.selectKProbabilistic(A)
        X = self.drop(X)
        X = torch.matmul(A, X)
        X = self.proj(X)
        return X
    
class Dense(nn.Module):
    def __init__(self, n1, n2, args):
        super(Dense, self).__init__()
        self.weights = torch.nn.Parameter(
            torch.FloatTensor(n1, n2), requires_grad=True)
        nn.init.normal_(self.weights, mean=args.mean_dense, std=args.std_dense)

    def forward(self, x):
        np.random.seed(1)
        torch.manual_seed(1)

        out = torch.mm(x, self.weights)
        return out
    
class GIN(nn.Module):

    def __init__(self, in_dim, out_dim, hid_dim):
        super(GIN, self).__init__()
        self.dense_1 = nn.Linear(in_dim, hid_dim)
        self.relu_1 = nn.ReLU(inplace=False)
        self.dense_2 = nn.Linear(hid_dim, out_dim)

    def forward(self, A, X):
        X = torch.matmul(A, X)
        X = self.dense_1(X)
        X = self.relu_1(X)
        X = self.dense_2(X)
        return X

class GraphUnet(nn.Module):

    def __init__(self, ks, in_dim, out_dim, dim=320):
        super(GraphUnet, self).__init__()
        self.ks = ks

        self.start_gcn = GIN(in_dim, dim, dim)
        self.bottom_gcn = GIN(dim, dim, dim)
        self.end_gcn = GIN(2*dim, out_dim, dim)
        self.down_gcns = []
        self.up_gcns = []
        self.pools = []
        self.unpools = []
        self.l_n = len(ks)
        for i in range(self.l_n):
            self.down_gcns.append(GIN(dim, dim, dim))
            self.up_gcns.append(GIN(dim, dim, dim))
            self.pools.append(GraphPool(ks[i], dim))
            self.unpools.append(GraphUnpool())

    def forward(self, A, X):
        adj_ms = []
        indices_list = []
        down_outs = []
        X = self.start_gcn(A, X)
        start_gcn_outs = X
        org_X = X
        for i in range(self.l_n):

            X = self.down_gcns[i](A, X)
            adj_ms.append(A)
            down_outs.append(X)
            A, X, idx = self.pools[i](A, X)
            indices_list.append(idx)
        X = self.bottom_gcn(A, X)
        for i in range(self.l_n):
            up_idx = self.l_n - i - 1

            A, idx = adj_ms[up_idx], indices_list[up_idx]
            A, X = self.unpools[i](A, X, idx)
            X = self.up_gcns[i](A, X)
            X = X.add(down_outs[up_idx])
        X = torch.cat([X, org_X], 1)
        X = self.end_gcn(A, X)

        return X, start_gcn_outs

In [None]:
class GSRLayer(nn.Module):

    def __init__(self, hr_dim):
        super(GSRLayer, self).__init__()

        self.weights = torch.from_numpy(
            weight_variable_glorot(hr_dim)).type(torch.FloatTensor)
        self.weights = torch.nn.Parameter(
            data=self.weights, requires_grad=True)

    def forward(self, A, X):
        with torch.autograd.set_detect_anomaly(True):

            lr = A
            lr_dim = lr.shape[0]
            f = X
            #eig_val_lr, U_lr = torch.symeig(lr, eigenvectors=True, upper=True)
            eig_val_lr, U_lr = torch.linalg.eigh(lr, UPLO='U')

            # U_lr = torch.abs(U_lr)
            eye_mat = torch.eye(lr_dim).type(torch.FloatTensor)
            s_d = torch.cat((eye_mat, eye_mat), 0)

            a = torch.matmul(self.weights, s_d)
            b = torch.matmul(a, torch.t(U_lr))
            f_d = torch.matmul(b, f)
            f_d = torch.abs(f_d)
            f_d = f_d.fill_diagonal_(1)
            adj = f_d

            X = torch.mm(adj, adj.t())
            X = (X + X.t())/2
            X = X.fill_diagonal_(1)
        return adj, torch.abs(X)


class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, dropout, act=F.relu):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = torch.nn.Parameter(
            torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, input, adj):
        input = F.dropout(input, self.dropout, self.training)
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        output = self.act(output)
        return output
    
class GraphSageConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features, dropout, act=F.relu, sampling=None, k=50):
        super(GraphSageConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = torch.nn.Parameter(
            torch.FloatTensor(in_features, out_features))
        self.reset_parameters()
        self.sampling = sampling
        self.k = 50

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)
        
    def selectKRandom(self, adj):
        adj = adj.detach().numpy()
        rows, cols = adj.shape
        new_adj = np.zeros_like(adj)
        for i in range(rows):
            indices = np.random.choice(range(cols), size=self.k, replace=False)
            new_adj[i][indices] = row[indices]
        return torch.tensor(new_adj)
    
    def selectKDeterministic(self, adj):
        adj = adj.detach().numpy()
        rows, cols = adj.shape
        new_adj = np.zeros_like(adj)
        for i in range(rows):
            row = adj[i]
            indices = np.argsort(arr)[-self.k:][::-1]
            new_adj[i][indices] = row[indices]
        return torch.tensor(new_adj)
    
    def selectKProbabilistic(self, adj):
        adj = adj.detach().numpy()
        rows, cols = adj.shape
        new_adj = np.zeros_like(adj)
        for i in range(rows):
            row = adj[i]
            probabilities = row / np.sum(row) 
            indices = random.choices(range(cols), probabilities, k=self.k)
            new_adj[i][indices] = row[indices]
        return torch.tensor(new_adj)

    def forward(self, input, adj):
        if self.sampling == "random":
            adj = self.selectKRandom(adj)
        elif self.sampling == "deterministic":
            adj = self.selectKDeterministic(adj)
        elif self.sampling == "probabilistic":
            adj = self.selectKProbabilistic(adj)
        input = F.dropout(input, self.dropout, self.training)
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        output = self.act(output)
        return output

In [None]:
class AGSRNet(nn.Module):

    def __init__(self, ks, args):
        super(AGSRNet, self).__init__()

        self.lr_dim = args.lr_dim
        self.hr_dim = args.hr_dim
        self.hidden_dim = args.hidden_dim
        self.layer = GSRLayer(self.hr_dim)
        self.net = GraphUnet(ks, self.lr_dim, self.hr_dim)
#         self.gc1 = GraphConvolution(
#             self.hr_dim, self.hidden_dim, 0, act=F.relu)
#         self.gc2 = GraphConvolution(
#             self.hidden_dim, self.hr_dim, 0, act=F.relu)
        self.gc1 = GraphSageConvolution(
            self.hr_dim, self.hidden_dim, 0, act=F.relu,
            sampling="probabilistic", k=args.k)
        self.gc2 = GraphSageConvolution(
            self.hidden_dim, self.hr_dim, 0, act=F.relu,
            sampling="probabilistic", k=args.k)

    def forward(self, lr, lr_dim, hr_dim):
        with torch.autograd.set_detect_anomaly(True):

            I = torch.eye(self.lr_dim).type(torch.FloatTensor)
            A = normalize_adj_torch(lr).type(torch.FloatTensor)

            self.net_outs, self.start_gcn_outs = self.net(A, I)

            self.outputs, self.Z = self.layer(A, self.net_outs)

            self.hidden1 = self.gc1(self.Z, self.outputs)
            self.hidden2 = self.gc2(self.hidden1, self.outputs)
            z = self.hidden2

            z = (z + z.t())/2
            z = z.fill_diagonal_(1)

        return torch.abs(z), self.net_outs, self.start_gcn_outs, self.outputs

class Discriminator(nn.Module):
    def __init__(self, args):
        super(Discriminator, self).__init__()
        self.dense_1 = Dense(args.hr_dim, args.hr_dim, args)
        self.relu_1 = nn.ReLU(inplace=False)
        self.dense_2 = Dense(args.hr_dim, args.hr_dim, args)
        self.relu_2 = nn.ReLU(inplace=False)
        self.dense_3 = Dense(args.hr_dim, 1, args)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        np.random.seed(1)
        torch.manual_seed(1)
        dc_den1 = self.relu_1(self.dense_1(inputs))
        dc_den2 = self.relu_2(self.dense_2(dc_den1))
        output = dc_den2
        output = self.dense_3(dc_den2)
        output = self.sigmoid(output)
        return torch.abs(output)

def gaussian_noise_layer(input_layer, args):
    z = torch.empty_like(input_layer)
    noise = z.normal_(mean=args.mean_gaussian, std=args.std_gaussian)
    z = torch.abs(input_layer + noise)

    z = (z + z.t())/2
    z = z.fill_diagonal_(1)
    return z

In [None]:
from sklearn.metrics import mean_absolute_error
criterion = nn.MSELoss()
# criterion = mean_absolute_error

def train(model, subjects_adj, subjects_labels, args):

    bce_loss = nn.BCELoss()
    netD = Discriminator(args)
    print(netD)
    optimizerG = optim.Adam(model.parameters(), lr=args.lr)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr)

    all_epochs_loss = []
    for epoch in range(args.epochs):
        with torch.autograd.set_detect_anomaly(True):
            epoch_loss = []
            epoch_error = []
            for lr, hr in zip(subjects_adj, subjects_labels):
                optimizerD.zero_grad()
                optimizerG.zero_grad()

                hr = pad_HR_adj(hr, args.padding)
                lr = torch.from_numpy(lr).type(torch.FloatTensor)
                padded_hr = torch.from_numpy(hr).type(torch.FloatTensor)

                #eig_val_hr, U_hr = torch.symeig(padded_hr, eigenvectors=True, upper=True)
                eig_val_hr, U_hr = torch.linalg.eigh(padded_hr, UPLO='U')

                model_outputs, net_outs, start_gcn_outs, layer_outs = model(
                    lr, args.lr_dim, args.hr_dim)

                mse_loss = args.lmbda * criterion(net_outs, start_gcn_outs) + criterion(
                    model.layer.weights, U_hr) + criterion(model_outputs, padded_hr)

                error = criterion(model_outputs, padded_hr)
                real_data = model_outputs.detach()
                fake_data = gaussian_noise_layer(padded_hr, args)

                d_real = netD(real_data)
                d_fake = netD(fake_data)

                dc_loss_real = bce_loss(d_real, torch.ones(args.hr_dim, 1))
                dc_loss_fake = bce_loss(d_fake, torch.zeros(args.hr_dim, 1))
                dc_loss = dc_loss_real + dc_loss_fake

                dc_loss.backward()
                optimizerD.step()

                d_fake = netD(gaussian_noise_layer(padded_hr, args))

                gen_loss = bce_loss(d_fake, torch.ones(args.hr_dim, 1))
                generator_loss = gen_loss + mse_loss
                generator_loss.backward()
                optimizerG.step()

                epoch_loss.append(generator_loss.item())
                epoch_error.append(error.item())

            print("Epoch: ", epoch, "Loss: ", np.mean(epoch_loss),
                  "Error: ", np.mean(epoch_error)*100, "%")
            all_epochs_loss.append(np.mean(epoch_loss))
            # test(model, lr_val_matrices, hr_val_matrices, args)


def test(model, test_adj, test_labels, args):

    g_t = []
    test_error = []
    preds_list = []

    # i = 0

    for lr, hr in zip(test_adj, test_labels):
        all_zeros_lr = not np.any(lr)
        all_zeros_hr = not np.any(hr)
        if all_zeros_lr == False and all_zeros_hr == False:
            lr = torch.from_numpy(lr).type(torch.FloatTensor)
            np.fill_diagonal(hr, 1)
#             hr = pad_HR_adj(hr, args.padding)
            hr = torch.from_numpy(hr).type(torch.FloatTensor)
            preds, a, b, c = model(lr, args.lr_dim, args.hr_dim)
            preds = unpad(preds, args.padding)

            preds_list.append(preds.flatten().detach().numpy())
#             error = criterion(preds, hr)
            error = mean_absolute_error(preds.flatten().detach().numpy(), hr.flatten().detach().numpy())
            g_t.append(hr.flatten())
#             print(error.item())
            test_error.append(error.item())
            # i += 1

    print("Test error MAE: ", np.mean(test_error))
    return np.mean(test_error)

In [None]:
import re


criterion = nn.MSELoss()

def train(model, subjects_adj, subjects_labels, args):

    bce_loss = nn.BCELoss()
    netD = Discriminator(args)
    optimizerG = optim.Adam(model.parameters(), lr=args.lr)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr)

    all_epochs_loss = []
    for epoch in range(args.epochs):
        with torch.autograd.set_detect_anomaly(True):
            epoch_loss = []
            epoch_error = []
            for lr, hr in zip(subjects_adj, subjects_labels):
                optimizerD.zero_grad()
                optimizerG.zero_grad()
                
                # Pad HR adjacency matrix and ensure it is a PyTorch tensor
                hr_padded = pad_HR_adj(hr, args.padding)
                padded_hr = hr_padded.type(torch.FloatTensor)
    

                # Use torch.linalg.eigh() instead of deprecated torch.symeig()
                eig_val_hr, U_hr = torch.linalg.eigh(padded_hr, UPLO='U')

                model_outputs, net_outs, start_gcn_outs, layer_outs = model(
                    lr, args.lr_dim, args.hr_dim)

                mse_loss = args.lmbda * criterion(net_outs, start_gcn_outs) + criterion(
                    model.layer.weights, U_hr) + criterion(model_outputs, padded_hr)

                error = criterion(model_outputs, padded_hr)
                real_data = model_outputs.detach()
                fake_data = gaussian_noise_layer(padded_hr, args)

                d_real = netD(real_data)
                d_fake = netD(fake_data)

                dc_loss_real = bce_loss(d_real, torch.ones(args.hr_dim, 1))
                dc_loss_fake = bce_loss(d_fake, torch.zeros(args.hr_dim, 1))
                dc_loss = dc_loss_real + dc_loss_fake

                dc_loss.backward()
                optimizerD.step()

                d_fake = netD(gaussian_noise_layer(padded_hr, args))

                gen_loss = bce_loss(d_fake, torch.ones(args.hr_dim, 1))
                generator_loss = gen_loss + mse_loss
                generator_loss.backward()
                optimizerG.step()

                epoch_loss.append(generator_loss.item())
                epoch_error.append(error.item())

            print("Epoch: ", epoch, "Loss: ", np.mean(epoch_loss),
                  "Error: ", np.mean(epoch_error)*100, "%", flush=True) 
            all_epochs_loss.append(np.mean(epoch_loss))
    return all_epochs_loss


def train_validation(model, subjects_adj, subjects_labels, val_adj, val_labels, args):

    bce_loss = nn.BCELoss()
    netD = Discriminator(args)
    optimizerG = optim.Adam(model.parameters(), lr=args.lr)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr)

    all_epochs_loss = []
    all_epochs_val_loss = []

    # Early stopping parameters
    best_val_mae = float('inf')
    patience = 10
    counter = 0
    for epoch in range(args.epochs):
        with torch.autograd.set_detect_anomaly(True):
            epoch_loss = []
            epoch_error = []
            for lr, hr in zip(subjects_adj, subjects_labels):
                optimizerD.zero_grad()
                optimizerG.zero_grad()
                
                # Pad HR adjacency matrix and ensure it is a PyTorch tensor
                hr_padded = pad_HR_adj(hr, args.padding)
                padded_hr = hr_padded.type(torch.FloatTensor)
    

                # Use torch.linalg.eigh() instead of deprecated torch.symeig()
                eig_val_hr, U_hr = torch.linalg.eigh(padded_hr, UPLO='U')

                model_outputs, net_outs, start_gcn_outs, layer_outs = model(
                    lr, args.lr_dim, args.hr_dim)

                mse_loss = args.lmbda * criterion(net_outs, start_gcn_outs) + criterion(
                    model.layer.weights, U_hr) + criterion(model_outputs, padded_hr)

                error = criterion(model_outputs, padded_hr)
                real_data = model_outputs.detach()
                fake_data = gaussian_noise_layer(padded_hr, args)

                d_real = netD(real_data)
                d_fake = netD(fake_data)

                dc_loss_real = bce_loss(d_real, torch.ones(args.hr_dim, 1))
                dc_loss_fake = bce_loss(d_fake, torch.zeros(args.hr_dim, 1))
                dc_loss = dc_loss_real + dc_loss_fake

                dc_loss.backward()
                optimizerD.step()

                d_fake = netD(gaussian_noise_layer(padded_hr, args))

                gen_loss = bce_loss(d_fake, torch.ones(args.hr_dim, 1))
                generator_loss = gen_loss + mse_loss
                generator_loss.backward()
                optimizerG.step()

                epoch_loss.append(generator_loss.item())
                epoch_error.append(error.item())

            all_epochs_loss.append(np.mean(epoch_loss))

            val_mae = test(model, val_adj, val_labels, args) # Test on validation set
            all_epochs_val_loss.append(val_mae)

            print("Epoch: ", epoch, "Loss: ", np.mean(epoch_loss),
                  "Error: ", np.mean(epoch_error)*100, "%", "Validation MAE: ", val_mae, flush=True)  # Error now represents MAE
            if val_mae < best_val_mae:
                best_val_mae = val_mae
                counter = 0
            else:
                counter += 1
                if counter >= patience and epoch > 100:
                    print("Early stopping at epoch ", epoch)
                    return all_epochs_loss, all_epochs_val_loss
    return all_epochs_loss, all_epochs_val_loss


def test(model, test_adj, test_labels, args):

    g_t = []
    test_error = []
    preds_list = []

    for lr, hr in zip(test_adj, test_labels):
        #all_zeros_lr = not np.any(lr)
        all_zeros_lr = not torch.any(lr)
        #all_zeros_hr = not np.any(hr)
        all_zeros_hr = not torch.any(hr)
        if all_zeros_lr == False and all_zeros_hr == False:
            #lr = torch.from_numpy(lr).type(torch.FloatTensor)
            #np.fill_diagonal(hr, 1)
            hr = hr.fill_diagonal_(1)

            hr = pad_HR_adj(hr, args.padding)
            preds, a, b, c = model(lr, args.lr_dim, args.hr_dim)

            preds_list.append(preds.flatten().detach().numpy())
            error = criterion(preds, hr)
            g_t.append(hr.flatten())
            test_error.append(error.item())

    print("Test error MAE: ", np.mean(test_error), flush=True)  # Changed MSE to MAE in print statement
    return np.mean(test_error)

In [None]:
def save_predictions(vectorized_predictions, fold_num=None):
    melted = np.array(vectorized_predictions).flatten()
    predictions_df = pd.DataFrame({'ID': np.arange(1, len(melted) + 1), 'Predicted': melted})
    if not fold_num:
        predictions_df.to_csv(f"predictions_test.csv", index=False)
    else:
        predictions_df.to_csv(f"predictions_fold_{fold_num}.csv", index=False)

def evaluate_predictions(pred_matrices, gt_matrices):
    """
    Evaluate predictions using various metrics.

    Args:
    pred_matrices (numpy array): Predicted high-resolution matrices.
    gt_matrices (numpy array): Ground truth high-resolution matrices.
    Both arguements of the form:
     [[ Vectorized form of high-resolution matrix 1 ],
      [ Vectorized form of high-resolution matrix 2 ],
      ...
      [ Vectorized form of high-resolution matrix n ]]

    Returns:
    dict: Dictionary containing evaluation measures.
    """
    # Initialize lists to store MAEs for each centrality measure
    mae_bc = []
    mae_ec = []
    mae_pc = []

    num_test_samples = len(pred_matrices)
    
    for i in range(num_test_samples):
        print(i)
        pred = matrix_vectorizer.anti_vectorize(pred_matrices[i], 268)
        gt = matrix_vectorizer.anti_vectorize(gt_matrices[i], 268)

       # Convert adjacency matrices to NetworkX graphs
        pred_graph = nx.from_numpy_array(pred, edge_attr="weight")
        gt_graph = nx.from_numpy_array(gt, edge_attr="weight")

        # Compute centrality measures
        pred_bc = nx.betweenness_centrality(pred_graph, weight="weight")
        pred_ec = nx.eigenvector_centrality(pred_graph, weight="weight")
        pred_pc = nx.pagerank(pred_graph, weight="weight")

        gt_bc = nx.betweenness_centrality(gt_graph, weight="weight")
        gt_ec = nx.eigenvector_centrality(gt_graph, weight="weight")
        gt_pc = nx.pagerank(gt_graph, weight="weight")

        # Convert centrality dictionaries to lists
        pred_bc_values = list(pred_bc.values())
        pred_ec_values = list(pred_ec.values())
        pred_pc_values = list(pred_pc.values())

        gt_bc_values = list(gt_bc.values())
        gt_ec_values = list(gt_ec.values())
        gt_pc_values = list(gt_pc.values())

        # Compute MAEs
        mae_bc.append(mean_absolute_error(pred_bc_values, gt_bc_values))
        mae_ec.append(mean_absolute_error(pred_ec_values, gt_ec_values))
        mae_pc.append(mean_absolute_error(pred_pc_values, gt_pc_values))

    # Compute average MAEs
    avg_mae_bc = sum(mae_bc) / len(mae_bc)
    avg_mae_ec = sum(mae_ec) / len(mae_ec)
    avg_mae_pc = sum(mae_pc) / len(mae_pc)

    # Concatenate flattened matrices
    pred_1d = np.concatenate(pred_matrices)
    gt_1d = np.concatenate(gt_matrices)

    # Compute metrics
    mae = mean_absolute_error(pred_1d, gt_1d)
    pcc = pearsonr(pred_1d, gt_1d)[0]
    js_dis = jensenshannon(pred_1d, gt_1d)

    # Construct dictionary of evaluation measures
    measures = {
        "MAE": mae,
        "PCC": pcc,
        "JSD": js_dis,
        "MAE (PC)": avg_mae_pc,
        "MAE (BC)": avg_mae_bc,
        "MAE (EC)": avg_mae_ec
    }

    return measures
    
def plot_evaluation_measures(measures):
    """
    Measures is a list of dictionary where the dictionary holds each measure
    And each dictionary corresponds to 1 fold of cross validation
    """
    folds = len(measures) # should be 3
    keys = list(measures[0].keys())
    #TODO: feel free to change these colours
    colours = ["red","orange","yellow","cyan","blue","purple"]
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    for i in range(folds):
        ax = axes[i // 2, i % 2]
        ax.bar(range(len(keys)), list(measures[i].values()), color=colours)
        ax.set_xticks(range(len(keys)), keys)
        ax.set_title(f"Fold {i+1}")

    # Calculate mean and standard deviation over folds
    values = np.array([[d[key] for key in keys] for d in measures])
    mean = np.mean(values, axis=0)
    std_dev = np.std(values, axis=0)

    ax = axes[1, 1]
    ax.bar(range(len(keys)), mean, color=colours)
    # Plot the error bars
    ax.errorbar(range(len(keys)), mean, yerr=std_dev, fmt='none', capsize=5, elinewidth=1, markeredgewidth=1, ecolor='k')
    ax.set_xticks(range(len(keys)), keys)
    ax.set_title("Average across folds")

    plt.tight_layout()
#     plt.savefig("metrics_plot1.png")
    plt.show()

In [None]:
def plot_evaluation_measures2(measures):
    """
    Measures is a list of dictionaries where each dictionary holds each measure.
    Each dictionary corresponds to 1 fold of cross-validation.
    """

    keys = list(measures[0].keys())
    folds = len(measures)  # should be 3

    # TODO: feel free to change these colors
    colors = ["red", "orange", "yellow", "cyan", "blue", "purple"]

    fig, axes = plt.subplots(2, 3, figsize=(12, 8))

    for metric_idx, metric in enumerate(keys):
        
        ax = axes[metric_idx // 3, metric_idx % 3]
        values = [d[metric] for d in measures]

        # Plot the bars for each fold
        x = np.arange(folds)
        ax.bar(x, values, color=colors[metric_idx])
        
        # Calculate mean and std.dev
        mean = np.mean(values)
        std_dev = np.std(values)
        ax.bar(folds, mean, color=colors[metric_idx])
        ax.errorbar(folds, mean, yerr=std_dev, fmt='none', capsize=5, elinewidth=1, markeredgewidth=1, ecolor='k')

        ax.set_xticks(np.arange(folds + 1))
        tick_labels = [f"Fold {i + 1}" for i in range(folds)]
        tick_labels.append("Avg")
        ax.set_xticklabels(tick_labels)
        ax.set_xlabel('Folds')
        ax.set_title(metric)

    plt.tight_layout()
#     plt.savefig("metrics_plot2.png")
    plt.show()

# RandomCV

In [None]:
SPLIT_1_LR_PATH = 'RandomCV/Train/Fold1/lr_split_1.csv'
SPLIT_1_HR_PATH = 'RandomCV/Train/Fold1/hr_split_1.csv'
SPLIT_2_LR_PATH = 'RandomCV/Train/Fold2/lr_split_2.csv'
SPLIT_2_HR_PATH = 'RandomCV/Train/Fold2/hr_split_2.csv'
SPLIT_3_LR_PATH = 'RandomCV/Train/Fold3/lr_split_3.csv'
SPLIT_3_HR_PATH = 'RandomCV/Train/Fold3/hr_split_3.csv'

In [None]:
def compute_output(args, test_adj, model):
    outputs = []
    model.eval()

    for lr_graph in test_adj:
            #output = model(lr_graph)
            model_outputs, net_outs, start_gcn_outs, layer_outs = model(lr_graph, args.lr_dim, args.hr_dim)
            #model_outputs = unpad(model_outputs, padding)
            #unpad and refactorize this
            idx_0 = model_outputs.shape[0]-26
            idx_1 = model_outputs.shape[1]-26
            model_outputs = model_outputs[26:idx_0, 26:idx_1]
            # append clipped outputs clipped between 0 and 1
            outputs.append(model_outputs.detach().numpy())

    outputs = np.array(outputs)

    return outputs

In [None]:
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)
# from paper
args = Args(
    epochs=400, #down from 200
    lr=0.0001,
    lmbda=0.1,
    lr_dim=160,
    hr_dim=320,
    hidden_dim=320,
    padding=26,
    mean_dense=0,
    std_dense=0.01,
    mean_gaussian=0,
    std_gaussian=0.1,
    k = 50
)

ks = [0.9, 0.7, 0.6, 0.5]
matrix_vectorizer = MatrixVectorizer()

In [None]:
# Set seeds
SEED = 42
GET_METRICS = True
torch.manual_seed(SEED)
np.random.seed(SEED)

# Load Data
split_1_adj, split_1_ground_truth = load_matrix_data(SPLIT_1_LR_PATH, SPLIT_1_HR_PATH, 93)
split_2_adj, split_2_ground_truth = load_matrix_data(SPLIT_2_LR_PATH, SPLIT_2_HR_PATH, 93)
split_3_adj, split_3_ground_truth = load_matrix_data(SPLIT_3_LR_PATH, SPLIT_3_HR_PATH, 93)

train_losses_all_with_val = []
val_losses_all = []
train_losses_all_no_val = []
fold_results = []

# Run 3-fold CV
for i in range(3):
    print(f"Fold {i+1}:")
    
    # Determine train, validation, and test splits
    if i == 0:
        train_adj = torch.cat((split_2_adj[:-20], split_3_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_2_ground_truth[:-20], split_3_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_2_adj[-20:], split_3_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_2_ground_truth[-20:], split_3_ground_truth[-20:]), dim=0)
        test_adj = split_1_adj
        test_ground_truth = split_1_ground_truth
    elif i == 1:
        train_adj = torch.cat((split_1_adj[:-20], split_3_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_1_ground_truth[:-20], split_3_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_1_adj[-20:], split_3_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_1_ground_truth[-20:], split_3_ground_truth[-20:]), dim=0)
        test_adj = split_2_adj
        test_ground_truth = split_2_ground_truth
    else:
        train_adj = torch.cat((split_1_adj[:-20], split_2_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_1_ground_truth[:-20], split_2_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_1_adj[-20:], split_2_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_1_ground_truth[-20:], split_2_ground_truth[-20:]), dim=0)
        test_adj = split_3_adj
        test_ground_truth = split_3_ground_truth
    
    # Find early stopping epoch
    model = AGSRNet(ks, args)
    lr = 0.0001
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_losses, val_losses = train_validation(model, train_adj, train_ground_truth, val_adj, val_ground_truth, args)
    train_losses_all_with_val.append(train_losses)
    val_losses_all.append(val_losses)
    num_epochs = len(train_losses)
    
    # Retrain model on full training set (without validation)
    full_train_adj = torch.cat((train_adj, val_adj), dim=0)
    full_train_ground_truth = torch.cat((train_ground_truth, val_ground_truth), dim=0)
    
    model = AGSRNet(ks, args)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    args.epochs = num_epochs
    train_losses = train(model, full_train_adj, full_train_ground_truth, args)
    train_losses_all_no_val.append(train_losses)
    
    # Get metrics for the left-out fold
    test_outputs = compute_output(args, test_adj, model)
    metrics = evaluate_all(test_ground_truth.detach().numpy(), test_outputs)

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    #plt.plot(train_losses_all[i], label='Training Loss')
    plt.plot(val_losses_all[i], label='Validation MAE')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Validation MAE')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_all_with_val[i], label='Training Loss (with validation)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Training Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_all_no_val[i], label='Training Loss (No Validation)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Training Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
identity_df = pd.read_csv('29-randomCV_old.csv', index_col=0, header=0)

In [None]:
identity_df

In [None]:
#add row averaging and std columns except for the first column and top row
identity_df.loc['mean'] = identity_df.mean()
identity_df.loc['std'] = identity_df.std()

In [None]:
identity_df

In [None]:
# save the dataframe to a csv file
identity_df.to_csv('29-randomCV.csv')

# ClusterCV

In [None]:
SPLIT_1_LR_PATH = 'Cluster-CV/Fold1/lr_clusterA.csv'
SPLIT_1_HR_PATH = 'Cluster-CV/Fold1/hr_clusterA.csv'
SPLIT_2_LR_PATH = 'Cluster-CV/Fold2/lr_clusterB.csv'
SPLIT_2_HR_PATH = 'Cluster-CV/Fold2/hr_clusterB.csv'
SPLIT_3_LR_PATH = 'Cluster-CV/Fold3/lr_clusterC.csv'
SPLIT_3_HR_PATH = 'Cluster-CV/Fold3/hr_clusterC.csv'


In [None]:
def compute_output(args, test_adj, model):
    outputs = []
    model.eval()

    for lr_graph in test_adj:
            #output = model(lr_graph)
            model_outputs, net_outs, start_gcn_outs, layer_outs = model(lr_graph, args.lr_dim, args.hr_dim)
            #model_outputs = unpad(model_outputs, padding)
            #unpad and refactorize this
            idx_0 = model_outputs.shape[0]-26
            idx_1 = model_outputs.shape[1]-26
            model_outputs = model_outputs[26:idx_0, 26:idx_1]
            # append clipped outputs clipped between 0 and 1
            outputs.append(model_outputs.detach().numpy())

    outputs = np.array(outputs)

    return outputs

In [None]:
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)
# from paper
args = Args(
    epochs=400, #down from 200
    lr=0.0001,
    lmbda=0.1,
    lr_dim=160,
    hr_dim=320,
    hidden_dim=320,
    padding=26,
    mean_dense=0,
    std_dense=0.01,
    mean_gaussian=0,
    std_gaussian=0.1,
    k = 50
)

ks = [0.9, 0.7, 0.6, 0.5]
matrix_vectorizer = MatrixVectorizer()

In [None]:
# Set seeds
SEED = 42
GET_METRICS = True
torch.manual_seed(SEED)
np.random.seed(SEED)

# Load Data
split_1_adj, split_1_ground_truth = load_matrix_data(SPLIT_1_LR_PATH, SPLIT_1_HR_PATH, 103)
split_2_adj, split_2_ground_truth = load_matrix_data(SPLIT_2_LR_PATH, SPLIT_2_HR_PATH, 103)
split_3_adj, split_3_ground_truth = load_matrix_data(SPLIT_3_LR_PATH, SPLIT_3_HR_PATH, 76)

train_losses_all_with_val = []
val_losses_all = []
train_losses_all_no_val = []
fold_results = []

# Run 3-fold CV
for i in range(3):
    print(f"Fold {i+1}:")
    
    # Determine train, validation, and test splits
    if i == 0:
        train_adj = torch.cat((split_2_adj[:-20], split_3_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_2_ground_truth[:-20], split_3_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_2_adj[-20:], split_3_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_2_ground_truth[-20:], split_3_ground_truth[-20:]), dim=0)
        test_adj = split_1_adj
        test_ground_truth = split_1_ground_truth
    elif i == 1:
        train_adj = torch.cat((split_1_adj[:-20], split_3_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_1_ground_truth[:-20], split_3_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_1_adj[-20:], split_3_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_1_ground_truth[-20:], split_3_ground_truth[-20:]), dim=0)
        test_adj = split_2_adj
        test_ground_truth = split_2_ground_truth
    else:
        train_adj = torch.cat((split_1_adj[:-20], split_2_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_1_ground_truth[:-20], split_2_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_1_adj[-20:], split_2_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_1_ground_truth[-20:], split_2_ground_truth[-20:]), dim=0)
        test_adj = split_3_adj
        test_ground_truth = split_3_ground_truth
    
    # Find early stopping epoch
    model = AGSRNet(ks, args)
    lr = 0.0001
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_losses, val_losses = train_validation(model, train_adj, train_ground_truth, val_adj, val_ground_truth, args)
    train_losses_all_with_val.append(train_losses)
    val_losses_all.append(val_losses)
    num_epochs = len(train_losses)
    
    # Retrain model on full training set (without validation)
    full_train_adj = torch.cat((train_adj, val_adj), dim=0)
    full_train_ground_truth = torch.cat((train_ground_truth, val_ground_truth), dim=0)
    
    model = AGSRNet(ks, args)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    args.epochs = num_epochs
    train_losses = train(model, full_train_adj, full_train_ground_truth, args)
    train_losses_all_no_val.append(train_losses)
    
    # Get metrics for the left-out fold
    test_outputs = compute_output(args, test_adj, model)
    metrics = evaluate_all(test_ground_truth.detach().numpy(), test_outputs)

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    #plt.plot(train_losses_all[i], label='Training Loss')
    plt.plot(val_losses_all[i], label='Validation MAE')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Validation MAE')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_all_with_val[i], label='Training Loss (with validation)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Training Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_all_no_val[i], label='Training Loss (No Validation)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Training Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
identity_df = pd.read_csv('ID-randomCV.csv', index_col=0, header=0)

In [None]:
identity_df

In [None]:
#add row averaging and std columns except for the first column and top row
identity_df.loc['mean'] = identity_df.mean()
identity_df.loc['std'] = identity_df.std()

In [None]:
identity_df

In [None]:
# save the dataframe to a csv file
identity_df.to_csv('clusterCV.csv')