In [None]:
import torch
import numpy as np
import torch.optim as optim
from sklearn.model_selection import KFold
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import scipy.io
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
from scipy.spatial.distance import jensenshannon
import networkx as nx
from evaluation import *
from utils import *
import math


from scipy.io import savemat


from scipy.io import loadmat
import seaborn as sns
import pandas as pd

In [None]:
class MatrixVectorizer:
    """
    A class for transforming between matrices and vector representations.

    This class provides methods to convert a symmetric matrix into a vector (vectorize)
    and to reconstruct the matrix from its vector form (anti_vectorize), focusing on
    vertical (column-based) traversal and handling of elements.
    """

    def __init__(self):
        """
        Initializes the MatrixVectorizer instance.

        The constructor currently does not perform any actions but is included for
        potential future extensions where initialization parameters might be required.
        """
        pass

    @staticmethod
    def vectorize(matrix, include_diagonal=False):
        """
        Converts a matrix into a vector by vertically extracting elements.

        This method traverses the matrix column by column, collecting elements from the
        upper triangle, and optionally includes the diagonal elements immediately below
        the main diagonal based on the include_diagonal flag.

        Parameters:
        - matrix (numpy.ndarray): The matrix to be vectorized.
        - include_diagonal (bool, optional): Flag to include diagonal elements in the vectorization.
          Defaults to False.

        Returns:
        - numpy.ndarray: The vectorized form of the matrix.
        """
        # Determine the size of the matrix based on its first dimension
        matrix_size = matrix.shape[0]

        # Initialize an empty list to accumulate vector elements
        vector_elements = []

        # Iterate over columns and then rows to collect the relevant elements
        for col in range(matrix_size):
            for row in range(matrix_size):
                # Skip diagonal elements if not including them
                if row != col:
                    if row < col:
                        # Collect upper triangle elements
                        vector_elements.append(matrix[row, col])
                    elif include_diagonal and row == col + 1:
                        # Optionally include the diagonal elements immediately below the diagonal
                        vector_elements.append(matrix[row, col])

        return np.array(vector_elements)

    @staticmethod
    def anti_vectorize(vector, matrix_size, include_diagonal=False):
        """
        Reconstructs a matrix from its vector form, filling it vertically.

        The method fills the matrix by reflecting vector elements into the upper triangle
        and optionally including the diagonal elements based on the include_diagonal flag.

        Parameters:
        - vector (numpy.ndarray): The vector to be transformed into a matrix.
        - matrix_size (int): The size of the square matrix to be reconstructed.
        - include_diagonal (bool, optional): Flag to include diagonal elements in the reconstruction.
          Defaults to False.

        Returns:
        - numpy.ndarray: The reconstructed square matrix.
        """
        # Initialize a square matrix of zeros with the specified size
        matrix = np.zeros((matrix_size, matrix_size))

        # Index to keep track of the current position in the vector
        vector_idx = 0

        # Fill the matrix by iterating over columns and then rows
        for col in range(matrix_size):
            for row in range(matrix_size):
                # Skip diagonal elements if not including them
                if row != col:
                    if row < col:
                        # Reflect vector elements into the upper triangle and its mirror in the lower triangle
                        matrix[row, col] = vector[vector_idx]
                        matrix[col, row] = vector[vector_idx]
                        vector_idx += 1
                    elif include_diagonal and row == col + 1:
                        # Optionally fill the diagonal elements after completing each column
                        matrix[row, col] = vector[vector_idx]
                        matrix[col, row] = vector[vector_idx]
                        vector_idx += 1

        return matrix

In [None]:
def pad_HR_adj(label, split):
    """
    Pads a high-resolution (HR) adjacency matrix with zeros on all sides.

    Parameters:
    - label (torch.Tensor): The HR adjacency matrix to be padded.
    - split (int): The number of zeros to add to each side of the matrix.

    Returns:
    torch.Tensor: The padded HR adjacency matrix.
    """
    padded_label = F.pad(label, pad=(split, split, split, split), mode="constant", value=0)
    return padded_label

def unpad(data, split):
    idx_0 = data.shape[0]-split
    idx_1 = data.shape[1]-split
    # print(idx_0,idx_1)
    train = data[split:idx_0, split:idx_1]
    return train

def normalize_adj_torch(mx):
    # mx = mx.to_dense()
    rowsum = mx.sum(1)
    r_inv_sqrt = torch.pow(rowsum, -0.5).flatten()
    r_inv_sqrt[torch.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = torch.diag(r_inv_sqrt)
    mx = torch.matmul(mx, r_mat_inv_sqrt)
    mx = torch.transpose(mx, 0, 1)
    mx = torch.matmul(mx, r_mat_inv_sqrt)
    return mx

## Layers

In [None]:

def weight_variable_glorot(output_dim):

    input_dim = output_dim
    init_range = np.sqrt(6.0 / (input_dim + output_dim))
    initial = np.random.uniform(-init_range, init_range,(input_dim, output_dim))

    return initial

# %% [markdown]
# ##2. Layers

# %%
class GSRLayer(nn.Module):

  def __init__(self, hr_dim):
    super(GSRLayer, self).__init__()

    self.weights = torch.from_numpy(weight_variable_glorot(hr_dim)).type(torch.FloatTensor)
    self.weights = torch.nn.Parameter(data=self.weights, requires_grad = True)

  def forward(self,A,X):
    lr = A
    lr_dim = lr.shape[0]
    f = X
    # eig_val_lr, U_lr = torch.symeig(lr, eigenvectors=True,upper=True)
    eig_val_lr, U_lr = torch.linalg.eigh(lr, UPLO='U')
    # U_lr = torch.abs(U_lr)
    eye_mat = torch.eye(lr_dim).type(torch.FloatTensor)
    s_d = torch.cat((eye_mat,eye_mat),0)

    a = torch.matmul(self.weights, s_d)
    b = torch.matmul(a ,torch.t(U_lr))
    f_d = torch.matmul(b ,f)
    f_d = torch.abs(f_d)
    self.f_d = f_d.fill_diagonal_(1)
    adj = normalize_adj_torch(self.f_d)
    X = torch.mm(adj, adj.t())
    X = (X + X.t())/2
    idx = torch.eye(320, dtype=bool)
    X[idx]=1
    return adj, torch.abs(X)



class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    #160x320 320x320 =  160x320
    def __init__(self, in_features, out_features, dropout=0., act=F.relu):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, input, adj):
        # input = F.dropout(input, self.dropout, self.training)
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        # output = self.act(output)

        return output

## Operations

In [None]:
class GraphUnpool(nn.Module):

    def __init__(self):
        super(GraphUnpool, self).__init__()

    def forward(self, A, X, idx):
        new_X = torch.zeros([A.shape[0], X.shape[1]])
        new_X[idx] = X
        return A, new_X


class GraphPool(nn.Module):

    def __init__(self, k, in_dim):
        super(GraphPool, self).__init__()
        self.k = k
        self.proj = nn.Linear(in_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, A, X):
        scores = self.proj(X)
        # scores = torch.abs(scores)
        scores = torch.squeeze(scores)
        scores = self.sigmoid(scores/100)
        num_nodes = A.shape[0]
        values, idx = torch.topk(scores, int(self.k*num_nodes))
        new_X = X[idx, :]
        values = torch.unsqueeze(values, -1)
        new_X = torch.mul(new_X, values)
        A = A[idx, :]
        A = A[:, idx]
        return A, new_X, idx


class GCN(nn.Module):

    def __init__(self, in_dim, out_dim):
        super(GCN, self).__init__()
        self.proj = nn.Linear(in_dim, out_dim)
        self.drop = nn.Dropout(p=0)

    def forward(self, A, X):

        X = self.drop(X)
        # X = torch.matmul(A, X)
        X = self.proj(X)
        return X

class GraphUnet(nn.Module):

    def __init__(self, ks, in_dim, out_dim, dim=320):
        super(GraphUnet, self).__init__()
        self.ks = ks

        self.start_gcn = GCN(in_dim, dim)
        self.bottom_gcn = GCN(dim, dim)
        self.end_gcn = GCN(2*dim, out_dim)
        self.down_gcns = []
        self.up_gcns = []
        self.pools = []
        self.unpools = []
        self.l_n = len(ks)
        for i in range(self.l_n):
            self.down_gcns.append(GCN(dim, dim))
            self.up_gcns.append(GCN(dim, dim))
            self.pools.append(GraphPool(ks[i], dim))
            self.unpools.append(GraphUnpool())

    def forward(self, A, X):
        adj_ms = []
        indices_list = []
        down_outs = []
        X = self.start_gcn(A, X)
        start_gcn_outs = X
        org_X = X
        for i in range(self.l_n):

            X = self.down_gcns[i](A, X)
            adj_ms.append(A)
            down_outs.append(X)
            A, X, idx = self.pools[i](A, X)
            indices_list.append(idx)
        X = self.bottom_gcn(A, X)
        for i in range(self.l_n):
            up_idx = self.l_n - i - 1

            A, idx = adj_ms[up_idx], indices_list[up_idx]
            A, X = self.unpools[i](A, X, idx)
            X = self.up_gcns[i](A, X)
            X = X.add(down_outs[up_idx])
        X = torch.cat([X, org_X], 1)
        X = self.end_gcn(A, X)

        return X, start_gcn_outs

## Model

In [None]:
class GSRNet(nn.Module):

  def __init__(self, ks, lr_dim, hr_dim, hidden_dim):
    super(GSRNet, self).__init__()

    self.lr_dim = lr_dim
    self.hr_dim = hr_dim
    self.hidden_dim = hidden_dim
    self.layer = GSRLayer(self.hr_dim)
    self.net = GraphUnet(ks, self.lr_dim, self.hr_dim)
    self.gc1 = GraphConvolution(self.hr_dim, self.hidden_dim, 0, act=F.relu)
    self.gc2 = GraphConvolution(self.hidden_dim, self.hr_dim, 0, act=F.relu)

  def forward(self,lr):

    I = torch.eye(self.lr_dim).type(torch.FloatTensor)
    A = normalize_adj_torch(lr).type(torch.FloatTensor)

    self.net_outs, self.start_gcn_outs = self.net(A, I)
    self.outputs, self.Z = self.layer(A, self.net_outs)

    self.hidden1 = self.gc1(self.Z, self.outputs)
    self.hidden2 = self.gc2(self.hidden1, self.outputs)

    z = self.hidden2
    z = (z + z.t())/2
    idx = torch.eye(self.hr_dim, dtype=bool)
    z[idx]=1

    return torch.abs(z), self.net_outs, self.start_gcn_outs, self.outputs
  

In [None]:

class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(in_dim, in_dim)
        self.key = nn.Linear(in_dim, in_dim)
        self.value = nn.Linear(in_dim, in_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        attention_scores = torch.matmul(query, key.transpose(-2, -1))
        attention_scores = self.softmax(attention_scores)
        attended_values = torch.matmul(attention_scores, value)
        return attended_values

        
class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = in_dim // num_heads
        
        self.query = nn.Linear(in_dim, in_dim)
        self.key = nn.Linear(in_dim, in_dim)
        self.value = nn.Linear(in_dim, in_dim)
        
        self.fc = nn.Linear(in_dim, in_dim)
        
    def forward(self, x):
        seq_len, _ = x.size()
        
        query = self.query(x).view(seq_len, self.num_heads, self.head_dim).transpose(0, 1)
        key = self.key(x).view(seq_len, self.num_heads, self.head_dim).transpose(0, 1)
        value = self.value(x).view(seq_len, self.num_heads, self.head_dim).transpose(0, 1)
        
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_probs = F.softmax(attention_scores, dim=-1)
        
        context = torch.matmul(attention_probs, value).transpose(0, 1).contiguous().view(seq_len, -1)
        output = self.fc(context)
        
        return output
    
import torch_geometric as tg

class SimpleMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
         
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super(ResidualBlock, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, in_dim)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out += residual
        out = self.relu(out)
        return out

class AttentionResidualBlock(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_heads):
        super(AttentionResidualBlock, self).__init__()
        self.self_attention = MultiHeadAttention(in_dim, num_heads)
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, in_dim)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.self_attention(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out += residual
        out = self.relu(out)
        return out

class AGSRNet(nn.Module):
    def __init__(self, ks, args):
        super(AGSRNet, self).__init__()
        self.lr_dim = args.lr_dim
        self.hr_dim = args.hr_dim
        self.hidden_dim = args.hidden_dim
        
        self.layer = GSRLayer(self.hr_dim)
        self.net = GraphUnet(ks, self.lr_dim, self.hr_dim)
        
        self.self_attention = MultiHeadAttention(self.hr_dim, num_heads=8)
        self.attention_residual_block1 = AttentionResidualBlock(self.hr_dim, self.hidden_dim, num_heads=8)
        self.attention_residual_block2 = AttentionResidualBlock(self.hr_dim, self.hidden_dim, num_heads=8)
        
        self.gin1 = tg.nn.dense.DenseGINConv(SimpleMLP(self.hr_dim, self.hidden_dim, self.hidden_dim))
        self.gin2 = tg.nn.dense.DenseGINConv(SimpleMLP(self.hidden_dim, self.hr_dim, self.hr_dim))
        
        self.skip_connection = nn.Linear(self.hr_dim, self.hr_dim)

    def forward(self, lr):
        I = torch.eye(self.lr_dim).type(torch.FloatTensor)
        A = normalize_adj_torch(lr).type(torch.FloatTensor)
        
        self.net_outs, self.start_gcn_outs = self.net(A, I)
        self.outputs, self.Z = self.layer(A, self.net_outs)
        
        self.hidden1 = F.relu(self.gin1(self.Z, self.outputs))
        self.hidden2 = F.relu(self.gin2(self.hidden1, self.outputs).squeeze(0))
        
        z = self.hidden2
        z = self.self_attention(z)
        z = F.dropout(z, p=0.5, training=self.training)
        
        z = self.attention_residual_block1(z)
        z = self.attention_residual_block2(z)
        
        skip_z = self.skip_connection(self.Z)
        z = z + skip_z
        
        z = (z + z.t()) / 2
        z = z.fill_diagonal_(1)
        
        return torch.abs(z), self.net_outs, self.start_gcn_outs, self.outputs

class Discriminator(nn.Module):
    def __init__(self, args):
        super(Discriminator, self).__init__()
        self.dense_1 = Dense(args.hr_dim, args.hr_dim, args)
        self.bn1 = nn.BatchNorm1d(args.hr_dim)
        self.relu_1 = nn.ReLU(inplace=False)
        
        self.residual_block1 = ResidualBlock(args.hr_dim, args.hr_dim)
        self.residual_block2 = ResidualBlock(args.hr_dim, args.hr_dim)
        
        self.dense_2 = Dense(args.hr_dim, args.hr_dim, args)
        self.bn2 = nn.BatchNorm1d(args.hr_dim)
        self.relu_2 = nn.ReLU(inplace=False)
        
        self.dense_3 = Dense(args.hr_dim, 1, args)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        dc_den1 = self.relu_1(self.bn1(self.dense_1(inputs)))
        dc_den1 = self.residual_block1(dc_den1)
        
        dc_den2 = self.relu_2(self.bn2(self.dense_2(dc_den1)))
        dc_den2 = self.residual_block2(dc_den2)
        
        output = self.dense_3(dc_den2)
        output = self.sigmoid(output)
        
        return torch.abs(output)
class Dense(nn.Module):
    def __init__(self, n1, n2, args):
        super(Dense, self).__init__()
        self.weights = torch.nn.Parameter(
            torch.FloatTensor(n1, n2), requires_grad=True)
        nn.init.normal_(self.weights, mean=args.mean_dense, std=args.std_dense)

    def forward(self, x):
        np.random.seed(1)
        torch.manual_seed(1)

        out = torch.mm(x, self.weights)
        return out


  
def gaussian_noise_layer(input_layer, args):
    z = torch.empty_like(input_layer)
    noise = z.normal_(mean=args.mean_gaussian, std=args.std_gaussian)
    z = torch.abs(input_layer + noise)

    z = (z + z.t())/2
    z = z.fill_diagonal_(1)
    return z


# Train

In [None]:

def perturb_graph(adj, drop_rate=0.1):
    mask = torch.rand(adj.shape) > drop_rate
    perturbed_adj = adj * mask
    return perturbed_adj

criterion = nn.L1Loss()  # Changed from MSELoss to L1Loss for MAE

def train_asgr(model, subjects_adj, subjects_labels, args):

    bce_loss = nn.BCELoss()
    netD = Discriminator(args)
    optimizerG = optim.Adam(model.parameters(), lr=args.lr)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr)

    all_epochs_loss = []
    for epoch in range(args.epochs):
        with torch.autograd.set_detect_anomaly(True):
            epoch_loss = []
            epoch_error = []
            for lr, hr in zip(subjects_adj, subjects_labels):
                optimizerD.zero_grad()
                optimizerG.zero_grad()
                
                # Perturb the LR graph
                #lr = torch.from_numpy(lr).type(torch.FloatTensor)
                lr = perturb_graph(lr, drop_rate=0.1)  # Apply edge dropout
                
                # Pad HR adjacency matrix and ensure it is a PyTorch tensor
                hr_padded = pad_HR_adj(hr, args.padding)
                padded_hr = hr_padded.type(torch.FloatTensor)
    

                # Use torch.linalg.eigh() instead of deprecated torch.symeig()
                eig_val_hr, U_hr = torch.linalg.eigh(padded_hr, UPLO='U')

                model_outputs, net_outs, start_gcn_outs, layer_outs = model(
                    lr)

                mse_loss = args.lmbda * criterion(net_outs, start_gcn_outs) + criterion(
                    model.layer.weights, U_hr) + criterion(model_outputs, padded_hr)

                error = criterion(model_outputs, padded_hr)
                real_data = model_outputs.detach()
                fake_data = gaussian_noise_layer(padded_hr, args)

                d_real = netD(real_data)
                d_fake = netD(fake_data)

                dc_loss_real = bce_loss(d_real, torch.ones(args.hr_dim, 1))
                dc_loss_fake = bce_loss(d_fake, torch.zeros(args.hr_dim, 1))
                dc_loss = dc_loss_real + dc_loss_fake

                dc_loss.backward()
                optimizerD.step()

                d_fake = netD(gaussian_noise_layer(padded_hr, args))

                gen_loss = bce_loss(d_fake, torch.ones(args.hr_dim, 1))
                generator_loss = gen_loss + mse_loss
                generator_loss.backward()
                optimizerG.step()

                epoch_loss.append(generator_loss.item())
                epoch_error.append(error.item())

            print("Epoch: ", epoch, "Loss: ", np.mean(epoch_loss),
                  "Error: ", np.mean(epoch_error)*100, "%", flush=True)  # Error now represents MAE
            all_epochs_loss.append(np.mean(epoch_loss))
    return all_epochs_loss


def train_asgr_validation(model, subjects_adj, subjects_labels, val_adj, val_labels, args):

    bce_loss = nn.BCELoss()
    netD = Discriminator(args)
    optimizerG = optim.Adam(model.parameters(), lr=args.lr)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr)

    all_epochs_loss = []
    all_epochs_val_loss = []

    # Early stopping parameters
    best_val_mae = float('inf')
    patience = 10
    counter = 0
    for epoch in range(args.epochs):
        with torch.autograd.set_detect_anomaly(True):
            epoch_loss = []
            epoch_error = []
            for lr, hr in zip(subjects_adj, subjects_labels):
                optimizerD.zero_grad()
                optimizerG.zero_grad()
                
                # Perturb the LR graph
                #lr = torch.from_numpy(lr).type(torch.FloatTensor)
                lr = perturb_graph(lr, drop_rate=0.1)  # Apply edge dropout
                
                # Pad HR adjacency matrix and ensure it is a PyTorch tensor
                hr_padded = pad_HR_adj(hr, args.padding)
                padded_hr = hr_padded.type(torch.FloatTensor)
    

                # Use torch.linalg.eigh() instead of deprecated torch.symeig()
                eig_val_hr, U_hr = torch.linalg.eigh(padded_hr, UPLO='U')

                model_outputs, net_outs, start_gcn_outs, layer_outs = model(
                    lr)

                mse_loss = args.lmbda * criterion(net_outs, start_gcn_outs) + criterion(
                    model.layer.weights, U_hr) + criterion(model_outputs, padded_hr)

                error = criterion(model_outputs, padded_hr)
                real_data = model_outputs.detach()
                fake_data = gaussian_noise_layer(padded_hr, args)

                d_real = netD(real_data)
                d_fake = netD(fake_data)

                dc_loss_real = bce_loss(d_real, torch.ones(args.hr_dim, 1))
                dc_loss_fake = bce_loss(d_fake, torch.zeros(args.hr_dim, 1))
                dc_loss = dc_loss_real + dc_loss_fake

                dc_loss.backward()
                optimizerD.step()

                d_fake = netD(gaussian_noise_layer(padded_hr, args))

                gen_loss = bce_loss(d_fake, torch.ones(args.hr_dim, 1))
                generator_loss = gen_loss + mse_loss
                generator_loss.backward()
                optimizerG.step()

                epoch_loss.append(generator_loss.item())
                epoch_error.append(error.item())

            all_epochs_loss.append(np.mean(epoch_loss))

            val_mae = test_asgr(model, val_adj, val_labels, args) # Test on validation set
            all_epochs_val_loss.append(val_mae)

            print("Epoch: ", epoch, "Loss: ", np.mean(epoch_loss),
                  "Error: ", np.mean(epoch_error)*100, "%", "Validation MAE: ", val_mae, flush=True)  # Error now represents MAE
            if val_mae < best_val_mae:
                best_val_mae = val_mae
                counter = 0
            else:
                counter += 1
                if counter >= patience and epoch > 100:
                    print("Early stopping at epoch ", epoch)
                    return all_epochs_loss, all_epochs_val_loss
    return all_epochs_loss, all_epochs_val_loss


def test_asgr(model, test_adj, test_labels, args):

    g_t = []
    test_error = []
    preds_list = []

    for lr, hr in zip(test_adj, test_labels):
        #all_zeros_lr = not np.any(lr)
        all_zeros_lr = not torch.any(lr)
        #all_zeros_hr = not np.any(hr)
        all_zeros_hr = not torch.any(hr)
        if all_zeros_lr == False and all_zeros_hr == False:
            #lr = torch.from_numpy(lr).type(torch.FloatTensor)
            #np.fill_diagonal(hr, 1)
            hr = hr.fill_diagonal_(1)

            hr = pad_HR_adj(hr, args.padding)
            preds, a, b, c = model(lr)    

            preds_list.append(preds.flatten().detach().numpy())
            error = criterion(preds, hr)
            g_t.append(hr.flatten())
            test_error.append(error.item())

    print("Test error MAE: ", np.mean(test_error), flush=True)  # Changed MSE to MAE in print statement
    return np.mean(test_error)

# Args

In [None]:
epochs = 400 # bring to 200
lr = 0.0001
splits = 10
lmbda = 16
lr_dim = 160
hr_dim = 320
hidden_dim = 320
padding = 26

class Args:
    epochs = epochs
    lr = 0.0001
    lmbda = 0.1
    lr_dim = lr_dim
    hr_dim = hr_dim
    hidden_dim = hidden_dim
    padding = 26
    mean_dense = 0
    std_dense = 0.01
    mean_gaussian = 0
    std_gaussian = 0.1

args = Args()
ks = [0.9, 0.7, 0.6, 0.5]

In [None]:
SPLIT_1_LR_PATH = 'RandomCV/Train/Fold1/lr_split_1.csv'
SPLIT_1_HR_PATH = 'RandomCV/Train/Fold1/hr_split_1.csv'
SPLIT_2_LR_PATH = 'RandomCV/Train/Fold2/lr_split_2.csv'
SPLIT_2_HR_PATH = 'RandomCV/Train/Fold2/hr_split_2.csv'
SPLIT_3_LR_PATH = 'RandomCV/Train/Fold3/lr_split_3.csv'
SPLIT_3_HR_PATH = 'RandomCV/Train/Fold3/hr_split_3.csv'

In [None]:
def compute_output(test_adj, model):
    outputs = []
    model.eval()

    for lr_graph in test_adj:
            #output = model(lr_graph)
            model_outputs, net_outs, start_gcn_outs, layer_outs = model(lr_graph)
            #model_outputs = unpad(model_outputs, padding)
            #unpad and refactorize this
            idx_0 = model_outputs.shape[0]-26
            idx_1 = model_outputs.shape[1]-26
            model_outputs = model_outputs[26:idx_0, 26:idx_1]
            # append clipped outputs clipped between 0 and 1
            outputs.append(model_outputs.detach().numpy())

    outputs = np.array(outputs)

    return outputs

In [None]:
# Set seeds
SEED = 42
GET_METRICS = True
torch.manual_seed(SEED)
np.random.seed(SEED)

# Load Data
split_1_adj, split_1_ground_truth = load_matrix_data(SPLIT_1_LR_PATH, SPLIT_1_HR_PATH, 93)
split_2_adj, split_2_ground_truth = load_matrix_data(SPLIT_2_LR_PATH, SPLIT_2_HR_PATH, 93)
split_3_adj, split_3_ground_truth = load_matrix_data(SPLIT_3_LR_PATH, SPLIT_3_HR_PATH, 93)

train_losses_all_with_val = []
val_losses_all = []
train_losses_all_no_val = []
fold_results = []

# Run 3-fold CV
for i in range(3):
    print(f"Fold {i+1}:")
    
    # Determine train, validation, and test splits
    if i == 0:
        train_adj = torch.cat((split_2_adj[:-20], split_3_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_2_ground_truth[:-20], split_3_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_2_adj[-20:], split_3_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_2_ground_truth[-20:], split_3_ground_truth[-20:]), dim=0)
        test_adj = split_1_adj
        test_ground_truth = split_1_ground_truth
    elif i == 1:
        train_adj = torch.cat((split_1_adj[:-20], split_3_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_1_ground_truth[:-20], split_3_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_1_adj[-20:], split_3_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_1_ground_truth[-20:], split_3_ground_truth[-20:]), dim=0)
        test_adj = split_2_adj
        test_ground_truth = split_2_ground_truth
    else:
        train_adj = torch.cat((split_1_adj[:-20], split_2_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_1_ground_truth[:-20], split_2_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_1_adj[-20:], split_2_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_1_ground_truth[-20:], split_2_ground_truth[-20:]), dim=0)
        test_adj = split_3_adj
        test_ground_truth = split_3_ground_truth
    
    # Find early stopping epoch
    model = AGSRNet(ks, args)
    lr = 0.0001
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_losses, val_losses = train_asgr_validation(model, train_adj, train_ground_truth, val_adj, val_ground_truth, args)
    train_losses_all_with_val.append(train_losses)
    val_losses_all.append(val_losses)
    num_epochs = len(train_losses)
    
    # Retrain model on full training set (without validation)
    full_train_adj = torch.cat((train_adj, val_adj), dim=0)
    full_train_ground_truth = torch.cat((train_ground_truth, val_ground_truth), dim=0)
    
    model = AGSRNet(ks, args)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    args.epochs = num_epochs
    train_losses = train_asgr(model, full_train_adj, full_train_ground_truth, args)
    train_losses_all_no_val.append(train_losses)
    
    # Get metrics for the left-out fold
    test_outputs = compute_output(test_adj, model)
    metrics = evaluate_all(test_ground_truth.detach().numpy(), test_outputs)

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    #plt.plot(train_losses_all[i], label='Training Loss')
    plt.plot(val_losses_all[i], label='Validation MAE')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Validation MAE')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_all_with_val[i], label='Training Loss (with validation)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Training Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_all_no_val[i], label='Training Loss (No Validation)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Training Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
identity_df = pd.read_csv('15-randomCV_old.csv', index_col=0, header=0)

In [None]:
identity_df

In [None]:
#add row averaging and std columns except for the first column and top row
identity_df.loc['mean'] = identity_df.mean()
identity_df.loc['std'] = identity_df.std()

In [None]:
identity_df

In [None]:
# save the dataframe to a csv file
identity_df.to_csv('randomCV.csv')

# ClusterCV

In [None]:
epochs = 400 # bring to 200
lr = 0.0001
splits = 10
lmbda = 16
lr_dim = 160
hr_dim = 320
hidden_dim = 320
padding = 26

class Args:
    epochs = epochs
    lr = 0.0001
    lmbda = 0.1
    lr_dim = lr_dim
    hr_dim = hr_dim
    hidden_dim = hidden_dim
    padding = 26
    mean_dense = 0
    std_dense = 0.01
    mean_gaussian = 0
    std_gaussian = 0.1

args = Args()
ks = [0.9, 0.7, 0.6, 0.5]

In [None]:
SPLIT_1_LR_PATH = 'Cluster-CV/Fold1/lr_clusterA.csv'
SPLIT_1_HR_PATH = 'Cluster-CV/Fold1/hr_clusterA.csv'
SPLIT_2_LR_PATH = 'Cluster-CV/Fold2/lr_clusterB.csv'
SPLIT_2_HR_PATH = 'Cluster-CV/Fold2/hr_clusterB.csv'
SPLIT_3_LR_PATH = 'Cluster-CV/Fold3/lr_clusterC.csv'
SPLIT_3_HR_PATH = 'Cluster-CV/Fold3/hr_clusterC.csv'

In [None]:
def compute_output(test_adj, model):
    outputs = []
    model.eval()

    for lr_graph in test_adj:
            #output = model(lr_graph)
            model_outputs, net_outs, start_gcn_outs, layer_outs = model(lr_graph)
            #model_outputs = unpad(model_outputs, padding)
            #unpad and refactorize this
            idx_0 = model_outputs.shape[0]-26
            idx_1 = model_outputs.shape[1]-26
            model_outputs = model_outputs[26:idx_0, 26:idx_1]
            # append clipped outputs clipped between 0 and 1
            outputs.append(model_outputs.detach().numpy())

    outputs = np.array(outputs)

    return outputs

In [None]:
# Set seeds
SEED = 42
GET_METRICS = True
torch.manual_seed(SEED)
np.random.seed(SEED)

# Load Data
split_1_adj, split_1_ground_truth = load_matrix_data(SPLIT_1_LR_PATH, SPLIT_1_HR_PATH, 103)
split_2_adj, split_2_ground_truth = load_matrix_data(SPLIT_2_LR_PATH, SPLIT_2_HR_PATH, 103)
split_3_adj, split_3_ground_truth = load_matrix_data(SPLIT_3_LR_PATH, SPLIT_3_HR_PATH, 76)

train_losses_all_with_val = []
val_losses_all = []
train_losses_all_no_val = []
fold_results = []

# Run 3-fold CV
for i in range(3):
    print(f"Fold {i+1}:")
    
    # Determine train, validation, and test splits
    if i == 0:
        train_adj = torch.cat((split_2_adj[:-20], split_3_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_2_ground_truth[:-20], split_3_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_2_adj[-20:], split_3_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_2_ground_truth[-20:], split_3_ground_truth[-20:]), dim=0)
        test_adj = split_1_adj
        test_ground_truth = split_1_ground_truth
    elif i == 1:
        train_adj = torch.cat((split_1_adj[:-20], split_3_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_1_ground_truth[:-20], split_3_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_1_adj[-20:], split_3_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_1_ground_truth[-20:], split_3_ground_truth[-20:]), dim=0)
        test_adj = split_2_adj
        test_ground_truth = split_2_ground_truth
    else:
        train_adj = torch.cat((split_1_adj[:-20], split_2_adj[:-20]), dim=0)
        train_ground_truth = torch.cat((split_1_ground_truth[:-20], split_2_ground_truth[:-20]), dim=0)
        val_adj = torch.cat((split_1_adj[-20:], split_2_adj[-20:]), dim=0)
        val_ground_truth = torch.cat((split_1_ground_truth[-20:], split_2_ground_truth[-20:]), dim=0)
        test_adj = split_3_adj
        test_ground_truth = split_3_ground_truth
    
    # Find early stopping epoch
    model = AGSRNet(ks, args)
    lr = 0.0001
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_losses, val_losses = train_asgr_validation(model, train_adj, train_ground_truth, val_adj, val_ground_truth, args)
    train_losses_all_with_val.append(train_losses)
    val_losses_all.append(val_losses)
    num_epochs = len(train_losses)
    
    # Retrain model on full training set (without validation)
    full_train_adj = torch.cat((train_adj, val_adj), dim=0)
    full_train_ground_truth = torch.cat((train_ground_truth, val_ground_truth), dim=0)
    
    model = AGSRNet(ks, args)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    args.epochs = num_epochs
    train_losses = train_asgr(model, full_train_adj, full_train_ground_truth, args)
    train_losses_all_no_val.append(train_losses)
    
    # Get metrics for the left-out fold
    test_outputs = compute_output(test_adj, model)
    metrics = evaluate_all(test_ground_truth.detach().numpy(), test_outputs)

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    #plt.plot(train_losses_all[i], label='Training Loss')
    plt.plot(val_losses_all[i], label='Validation MAE')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Validation MAE')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_all_with_val[i], label='Training Loss (with validation)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Training Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# Create plots for each fold
for i in range(3):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_all_no_val[i], label='Training Loss (No Validation)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {i+1} - Training Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
identity_df = pd.read_csv('ID-randomCV.csv', index_col=0, header=0)

In [None]:
identity_df

In [None]:
#add row averaging and std columns except for the first column and top row
identity_df.loc['mean'] = identity_df.mean()
identity_df.loc['std'] = identity_df.std()

In [None]:
identity_df

In [None]:
# save the dataframe to a csv file
identity_df.to_csv('clusterCV.csv')