# GSAT-Net Structure
```
GSATNet(
  (layer): GSRLayer()
  (net): GraphUnet(
    (start_gcnpro): GCNPro(
      (proj): Linear(in_features=160, out_features=320, bias=True)
      (drop): Dropout(p=0.1, inplace=False)
    )
    (bottom_gcnpro): GCNPro(
      (proj): Linear(in_features=320, out_features=320, bias=True)
      (drop): Dropout(p=0.1, inplace=False)
    )
    (end_gcnpro): GCNPro(
      (proj): Linear(in_features=640, out_features=320, bias=True)
      (drop): Dropout(p=0.1, inplace=False)
    )
  )
  (gc1): GraphConvolution()
  (gc2): GraphConvolution()
)
Discriminator(
  (dense_1): Dense()
  (relu_1): ReLU()
  (dense_2): Dense()
  (relu_2): ReLU()
  (dense_3): Dense()
  (sigmoid): Sigmoid()
)
```

# Libraries and Reproducibility

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

# Set up random seed
random_seed = 42
np.random.seed(random_seed)
torch.manual_seed(random_seed)

# Check for CUDA (GPU support) and set device accordingly
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # For multi-GPU setups
    # Additional settings for ensuring reproducibility on CUDA
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_default_device(device)
else:
    device = torch.device("cpu")
    print("CUDA not available. Using CPU.")

# Datasets

In [None]:
from MatrixVectorizer import MatrixVectorizer

# Loading dataset
lr_train_path = './data/lr_train.csv'
lr_test_path = './data/lr_test.csv'
hr_train_path = './data/hr_train.csv'

lr_train_pd = pd.read_csv(lr_train_path)
hr_train_pd = pd.read_csv(hr_train_path)
lr_test_pd = pd.read_csv(lr_test_path)

# Converting each row into a numpy.ndarray explicitly
lr_train = np.array(lr_train_pd.apply(lambda x: np.array(x), axis=1).tolist())
hr_train = np.array(hr_train_pd.apply(lambda x: np.array(x), axis=1).tolist())
lr_test = np.array(lr_test_pd.apply(lambda x: np.array(x), axis=1).tolist())

matrix_vectorizer = MatrixVectorizer()
lr_train_A = []
hr_train_A = []
lr_test_A = []

for g in lr_train:
    lr_train_A.append(matrix_vectorizer.anti_vectorize(g, 160))
for g in hr_train:
    hr_train_A.append(matrix_vectorizer.anti_vectorize(g, 268))
for g in lr_test:
    lr_test_A.append(matrix_vectorizer.anti_vectorize(g, 160))

# Layers

In [None]:
def weight_variable_glorot(output_dim):

    input_dim = output_dim
    init_range = np.sqrt(6.0 / (input_dim + output_dim))
    initial = np.random.uniform(-init_range, init_range,
                                (input_dim, output_dim))

    return initial


class GSRLayer(nn.Module):
    """
    GSR layer, from https://github.com/basiralab/AGSR-Net.
    """

    def __init__(self, hr_dim):
        super(GSRLayer, self).__init__()

        self.weights = torch.from_numpy(
            weight_variable_glorot(hr_dim)).type(torch.FloatTensor).to(device)
        self.weights = torch.nn.Parameter(
            data=self.weights, requires_grad=True)

    def forward(self, A, X):
        with torch.autograd.set_detect_anomaly(True):

            lr = A
            lr_dim = lr.shape[0]
            f = X
            eig_val_lr, U_lr = torch.linalg.eigh(lr, UPLO='U')

            # U_lr = torch.abs(U_lr)
            eye_mat = torch.eye(lr_dim).type(torch.FloatTensor).to(device)
            s_d = torch.cat((eye_mat, eye_mat), 0)

            a = torch.matmul(self.weights, s_d)
            b = torch.matmul(a, torch.t(U_lr))
            f_d = torch.matmul(b, f)
            f_d = torch.abs(f_d)
            f_d = f_d.fill_diagonal_(1)
            adj = f_d

            X = torch.mm(adj, adj.t())
            X = (X + X.t())/2
            X = X.fill_diagonal_(1)
        return adj, torch.abs(X)


class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, dropout, act=F.relu):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = torch.nn.Parameter(
            torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, input, adj):
        input = F.dropout(input, self.dropout, self.training)
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        output = self.act(output)
        return output


class GraphUnpool(nn.Module):
    """
    Unpooling, from https://github.com/basiralab/AGSR-Net.
    """

    def __init__(self):
        super(GraphUnpool, self).__init__()

    def forward(self, A, X, idx):
        new_X = torch.zeros([A.shape[0], X.shape[1]], device=device)
        new_X[idx] = X
        return A, new_X


class GraphPool(nn.Module):
    """
    Pooling, from https://github.com/basiralab/AGSR-Net.
    """

    def __init__(self, k, in_dim):
        super(GraphPool, self).__init__()
        self.k = k
        self.proj = nn.Linear(in_dim, 1).to(device)
        self.sigmoid = nn.Sigmoid().to(device)

    def forward(self, A, X):
        scores = self.proj(X)
        # scores = torch.abs(scores)
        scores = torch.squeeze(scores)
        scores = self.sigmoid(scores/100)
        num_nodes = A.shape[0]
        values, idx = torch.topk(scores, int(self.k*num_nodes))
        new_X = X[idx, :]
        values = torch.unsqueeze(values, -1)
        new_X = torch.mul(new_X, values)
        A = A[idx, :]
        A = A[:, idx]
        return A, new_X, idx


class GCNPro(nn.Module):
    """
    GCN Pro layer.
    It aggregates information from a node's neighbors using mean aggregation, processed by activation function and drop 10% nodes to prevent overfitting.
    """

    def __init__(self, in_features, out_features, activation=F.relu, p = 0):
        super(GCNPro, self).__init__()
        self.proj = nn.Linear(in_features, out_features).to(device)
        self.activation = activation
        self.drop = nn.Dropout(p=p).to(device)
    
    def mean_normalization(self, A):
        e = 1e-7
        D = torch.diag(A.sum(1))
        D_inv = torch.diag(1.0 / (D.diag()+e))
        A_tilde = A @ D_inv + torch.eye(A.size(0)).to(device)
        return A_tilde.float()
    
    def forward(self, A, X):
        X = self.drop(X)
        X = self.proj(X)
        A_norm = self.mean_normalization(A)
        X = torch.matmul(A_norm, X)
        X = self.activation(X)
        return X
    
    
class GAT(nn.Module):
    """
    GAT layer.
    It applies an attention mechanism in the graph convolution process, allowing the model to focus on different parts of the neighborhood of each node.
    """

    def __init__(self, in_features, out_features, activation = F.relu, p = 0):
        super(GAT, self).__init__()
        # Initialize the weights, bias, and attention parameters as
        # trainable parameters
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
        self.phi = nn.Parameter(torch.FloatTensor(2 * out_features, 1))
        self.activation = activation
        self.drop = nn.Dropout(p=p).to(device)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / np.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

        stdv = 1. / np.sqrt(self.phi.size(1))
        self.phi.data.uniform_(-stdv, stdv)

    def forward(self, adj, input):
        # Apply linear transformation and add bias
        ones_tensor = torch.ones(self.bias.shape[0])
        input = self.drop(input)
        X_new = self.bias + torch.matmul(input, self.weight)
        N = X_new.size(0)
        # Compute the attention scores
        expanded = X_new.unsqueeze(0).expand(N, -1, -1)
        combined = torch.cat((expanded, expanded.transpose(0, 1)), dim=2)
        S = torch.matmul(combined, self.phi).squeeze(2)
        S = nn.LeakyReLU()(S)
        # Compute mask based on adjacency matrix and apply to the pre-attention matrix
        mask = (adj + torch.eye(N)).bool()
        S_masked = torch.where(mask, S, torch.tensor(float('-inf')))
        # Compute attention weights using softmax
        attention_activated = torch.softmax(S_masked, dim=1)
        # Aggregate features based on attention weights
        h = torch.matmul(attention_activated, X_new)
    
        return self.activation(h) if self.activation else h
    

class GraphUnet(nn.Module):
    """
    Graph Unet layer, consisting of 3 GCN Pro layers and 2 GAT layers.
    """

    def __init__(self, ks, in_dim, out_dim, dim=320):
        super(GraphUnet, self).__init__()
        self.ks = ks

        self.start_gcnpro = GCNPro(in_dim, dim, p=0.1).to(device)
        self.bottom_gcnpro = GCNPro(dim, dim, p=0.1).to(device)
        self.end_gcnpro = GCNPro(2*dim, out_dim, p=0.1).to(device)
        self.down_gcns = []
        self.up_gcns = []
        self.pools = []
        self.unpools = []
        self.l_n = len(ks)
        for i in range(self.l_n):
            self.down_gcns.append(GAT(dim, dim).to(device))
            self.up_gcns.append(GAT(dim, dim).to(device))
            self.pools.append(GraphPool(ks[i], dim).to(device))
            self.unpools.append(GraphUnpool().to(device))

    def forward(self, A, X):
        adj_ms = []
        indices_list = []
        down_outs = []
        X = self.start_gcnpro(A, X)
        start_gcn_outs = X
        org_X = X
        for i in range(self.l_n):

            X = self.down_gcns[i](A, X)
            adj_ms.append(A)
            down_outs.append(X)
            A, X, idx = self.pools[i](A, X)
            indices_list.append(idx)
        X = self.bottom_gcnpro(A, X)
        for i in range(self.l_n):
            up_idx = self.l_n - i - 1

            A, idx = adj_ms[up_idx], indices_list[up_idx]
            A, X = self.unpools[i](A, X, idx)
            X = self.up_gcns[i](A, X)
            X = X.add(down_outs[up_idx])
        X = torch.cat([X, org_X], 1)
        X = self.end_gcnpro(A, X)

        return X, start_gcn_outs


def pad_HR_adj(label, split):

    label = np.pad(label, ((split, split), (split, split)), mode="constant")
    np.fill_diagonal(label, 1)
    return label


def normalize_adj_torch(mx):
    rowsum = mx.sum(1)
    r_inv_sqrt = torch.pow(rowsum, -0.5).flatten()
    r_inv_sqrt[torch.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = torch.diag(r_inv_sqrt)
    mx = torch.matmul(mx, r_mat_inv_sqrt)
    mx = torch.transpose(mx, 0, 1)
    mx = torch.matmul(mx, r_mat_inv_sqrt)
    return mx


def unpad(data, split):

    idx_0 = data.shape[0]-split
    idx_1 = data.shape[1]-split
    train = data[split:idx_0, split:idx_1]
    return train

class GSATNet(nn.Module):
    """
    GSAT Net layer, consisting of 1 GSR layer, 1 Graph Unet layer, and 2 graph convolution layers. 
    """

    def __init__(self, ks, lr_dim=160, hr_dim=320, hidden_dim=320):
        super(GSATNet, self).__init__()

        self.lr_dim = lr_dim
        self.hr_dim = hr_dim
        self.hidden_dim = hidden_dim
        self.layer = GSRLayer(self.hr_dim).to(device)
        self.net = GraphUnet(ks, self.lr_dim, self.hr_dim).to(device)
        self.gc1 = GraphConvolution(
            self.hr_dim, self.hidden_dim, 0, act=F.relu).to(device)
        self.gc2 = GraphConvolution(
            self.hidden_dim, self.hr_dim, 0, act=F.relu).to(device)

    def forward(self, lr, lr_dim, hr_dim):
        with torch.autograd.set_detect_anomaly(True):

            I = torch.eye(self.lr_dim).type(torch.FloatTensor).to(device)
            A = normalize_adj_torch(lr).type(torch.FloatTensor).to(device)
            

            self.net_outs, self.start_gcn_outs = self.net(A, I)

            self.outputs, self.Z = self.layer(A, self.net_outs)

            self.hidden1 = self.gc1(self.Z, self.outputs)
            self.hidden2 = self.gc2(self.hidden1, self.outputs)
            z = self.hidden2

            z = (z + z.t())/2
            z = z.fill_diagonal_(1)

        return torch.abs(z), self.net_outs, self.start_gcn_outs, self.outputs


class Dense(nn.Module):
    """
    Dense layer, from https://github.com/basiralab/AGSR-Net.
    """

    def __init__(self, n1, n2):
        super(Dense, self).__init__()
        self.weights = torch.nn.Parameter(
            torch.FloatTensor(n1, n2), requires_grad=True)
        nn.init.normal_(self.weights, mean=0, std=0.01)

    def forward(self, x):
        np.random.seed(1)
        torch.manual_seed(1)

        out = torch.mm(x, self.weights)
        return out


class Discriminator(nn.Module):
    """
    Discriminator layer, from https://github.com/basiralab/AGSR-Net.
    """

    def __init__(self, hr_dim):
        super(Discriminator, self).__init__()
        self.dense_1 = Dense(hr_dim, hr_dim).to(device)
        self.relu_1 = nn.ReLU(inplace=False)
        self.dense_2 = Dense(hr_dim, hr_dim).to(device)
        self.relu_2 = nn.ReLU(inplace=False)
        self.dense_3 = Dense(hr_dim, 1).to(device)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        np.random.seed(1)
        torch.manual_seed(1)
        dc_den1 = self.relu_1(self.dense_1(inputs))
        dc_den2 = self.relu_2(self.dense_2(dc_den1))
        output = dc_den2
        output = self.dense_3(dc_den2)
        output = self.sigmoid(output)
        return torch.abs(output)


def gaussian_noise_layer(input_layer):
    z = torch.empty_like(input_layer)
    noise = z.normal_(mean=0, std=0.1)
    z = torch.abs(input_layer + noise)

    z = (z + z.t())/2
    z = z.fill_diagonal_(1)
    return z


# Prediction

In [None]:
def predict(model, lr_test_A, out_path):
    """
    Predict model results and save them as csv file.
    """

    model.eval()
    
    with torch.no_grad():
        pred_matrices = []
    
        for lr in lr_test_A:
            lr = torch.from_numpy(lr).type(torch.FloatTensor).to(device)
            model_outputs, net_outs, start_gcn_outs, layer_outs = model(lr, 160, 320)
            pred_matrices.append(unpad(model_outputs, 26))
    
        # vectorize and flatten
        pred_1d = torch.cat([torch.tensor(MatrixVectorizer.vectorize(m.cpu().detach()).flatten()) for m in pred_matrices])
        pred_1d_np = pred_1d.cpu().numpy()
    
        print(pred_1d_np.shape)

    # Create an ID array starting from 1 to the length of pred_1d_np
    ids = np.arange(1, len(pred_1d_np) + 1)

    # Create a DataFrame with two columns: ID and Predicted
    df = pd.DataFrame({
        'ID': ids,
        'Predicted': pred_1d_np
    })

    # Output the DataFrame to a CSV file
    df.to_csv(out_path, index=False)

    return pred_1d_np

# Train

In [None]:
# Set up loss function
criterion = nn.MSELoss()


def train(subjects_adj, subjects_labels, epochs=50, ks=[0.9, 0.7, 0.6, 0.5]):
    """
    Train model and print model structure.
    """

    # Build up model structure
    hr_dim=320
    lr_dim=160
    model = GSATNet(ks).to(device)
    print(model)
    bce_loss = nn.BCELoss()
    netD = Discriminator(320).to(device)
    print(netD)
    optimizerG = optim.Adam(model.parameters(), lr=0.0001)
    optimizerD = optim.Adam(netD.parameters(), lr=0.0001)

    # Start training
    all_epochs_loss = []
    for epoch in tqdm(range(epochs)):
        with torch.autograd.set_detect_anomaly(True):
            epoch_loss = []
            epoch_error = []
            for lr, hr in zip(subjects_adj, subjects_labels):
                optimizerD.zero_grad()
                optimizerG.zero_grad()

                hr = pad_HR_adj(hr, 26)
                lr = torch.from_numpy(lr).type(torch.FloatTensor).to(device)
                padded_hr = torch.from_numpy(hr).type(torch.FloatTensor).to(device)

                eig_val_hr, U_hr = torch.linalg.eigh(
                    padded_hr, UPLO='U')

                model_outputs, net_outs, start_gcn_outs, layer_outs = model(
                    lr, 160, 320)

                mse_loss = 0.1 * criterion(net_outs, start_gcn_outs) + 0.8 * criterion(
                    model.layer.weights, U_hr) + criterion(model_outputs, padded_hr)

                error = criterion(model_outputs, padded_hr)
                real_data = model_outputs.detach()
                fake_data = gaussian_noise_layer(padded_hr)

                d_real = netD(real_data)
                d_fake = netD(fake_data)

                dc_loss_real = bce_loss(d_real, torch.ones(hr_dim, 1).to(device))
                dc_loss_fake = bce_loss(d_fake, torch.zeros(hr_dim, 1).to(device))
                dc_loss = dc_loss_real + dc_loss_fake

                dc_loss.backward()
                optimizerD.step()

                d_fake = netD(gaussian_noise_layer(padded_hr))

                gen_loss = bce_loss(d_fake, torch.ones(hr_dim, 1).to(device))
                generator_loss = gen_loss + mse_loss
                generator_loss.backward()
                optimizerG.step()

                epoch_loss.append(generator_loss.item())
                epoch_error.append(error.item())

            # Print training results in every epoch
            print("Epoch: ", epoch, "Loss: ", np.mean(epoch_loss),
                  "Error: ", np.mean(epoch_error)*100, "%")
            all_epochs_loss.append(np.mean(epoch_loss))

    return model

# Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def bar_plots(eval_data):
    """
    Plot bar charts for evaluation metrics including error bar.
    """

    eval_tags = ["MAE", "PCC", "JSD", "MAE(PC)", "MAE(EC)", "MAE(BC)"]
    eval_color = ["#ff5860", "#47b45d", "#6666ff", "#ffc650", "#00ffff", "#00ff48"]

    eval_avg = np.mean(eval_data, axis=0)
    eval_std = np.std(eval_data, axis=0)
    print(f"Average: {eval_avg}")
    print(f"Standard Deviation: {eval_std}")

    subplots_num = len(eval_data) + 1
    col_num = 2
    row_num = subplots_num // col_num if subplots_num % col_num == 0 else subplots_num // col_num + 1

    fig, axs = plt.subplots(row_num, col_num, figsize=(10, row_num*4))

    for i in range(subplots_num):
        row, col = divmod(i, col_num)
        if i == subplots_num - 1:
            axs[row, col].bar(eval_tags, eval_avg, yerr=eval_std, capsize=5, color=eval_color)
            axs[row, col].set_title(f"Avg. Across Folds")
        else:
            axs[row, col].bar(eval_tags, eval_data[i], color=eval_color)
            axs[row, col].set_title(f"Fold {i + 1}")

    if subplots_num % col_num != 0:
        fig.delaxes(axs[-1, -1])

    plt.tight_layout()
    plt.show()

# Evaluation

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
from scipy.spatial.distance import jensenshannon
import networkx as nx
from sklearn.model_selection import KFold

padding_top = (320 - 268) // 2
padding_bottom = 320 - 268 - padding_top
padding_left = padding_top
padding_right = padding_bottom


# Unpadding matrix
def unpad_matrix(matrix, padding_top, padding_bottom, padding_left, padding_right):
    new_height = matrix.shape[0] - (padding_top + padding_bottom)
    new_width = matrix.shape[1] - (padding_left + padding_right)

    unpadded_matrix = matrix[padding_top:new_height+padding_top, padding_left:new_width+padding_left]
    return unpadded_matrix


def evaluation(model, lr_val_fold, hr_val_fold):
    """
    Evaluate model performance in MAE, PCC, Jensen-Shannon Distance, Average MAE betweenness/eigenvector/PageRank centrality.
    """

    model.eval()
    
    with torch.no_grad():
        pred_matrices = []
        gt_matrices = []
    
        for lr, hr in zip(lr_val_fold, hr_val_fold):
            model_outputs, net_outs, start_gcn_outs, layer_outs = model(torch.tensor(lr, dtype=torch.float32).to(device), 160, 320)
            pred_matrices.append(unpad_matrix(model_outputs, padding_top, padding_bottom, padding_left, padding_right))
            gt_matrices.append(torch.tensor(hr, dtype=torch.float32))

    num_val = len(pred_matrices)
    mae_bc = []
    mae_ec = []
    mae_pc = []
    pred_1d_list = []
    gt_1d_list = []

    # Iterate over each test sample
    for i in tqdm(range(num_val)):
        # Convert adjacency matrices to NetworkX graphs
        pred_graph = nx.from_numpy_array(pred_matrices[i].cpu().detach().numpy(), edge_attr="weight")
        gt_graph = nx.from_numpy_array(gt_matrices[i].cpu().detach().numpy(), edge_attr="weight")

        # Compute centrality measures
        pred_bc = nx.betweenness_centrality(pred_graph, weight="weight")
        pred_ec = nx.eigenvector_centrality(pred_graph, weight="weight")
        pred_pc = nx.pagerank(pred_graph, weight="weight")

        gt_bc = nx.betweenness_centrality(gt_graph, weight="weight")
        gt_ec = nx.eigenvector_centrality(gt_graph, weight="weight")
        gt_pc = nx.pagerank(gt_graph, weight="weight")

        # Convert centrality dictionaries to lists
        pred_bc_values = list(pred_bc.values())
        pred_ec_values = list(pred_ec.values())
        pred_pc_values = list(pred_pc.values())

        gt_bc_values = list(gt_bc.values())
        gt_ec_values = list(gt_ec.values())
        gt_pc_values = list(gt_pc.values())

        # Compute MAEs
        mae_bc.append(mean_absolute_error(pred_bc_values, gt_bc_values))
        mae_ec.append(mean_absolute_error(pred_ec_values, gt_ec_values))
        mae_pc.append(mean_absolute_error(pred_pc_values, gt_pc_values))

        # Vectorize matrices
        pred_1d_list.append(MatrixVectorizer.vectorize(pred_matrices[i].cpu().detach().numpy()))
        gt_1d_list.append(MatrixVectorizer.vectorize(gt_matrices[i].cpu().detach().numpy()))

    # Compute average MAEs
    avg_mae_bc = sum(mae_bc) / len(mae_bc)
    avg_mae_ec = sum(mae_ec) / len(mae_ec)
    avg_mae_pc = sum(mae_pc) / len(mae_pc)

    # Concatenate flattened matrices
    pred_1d = np.concatenate(pred_1d_list)
    gt_1d = np.concatenate(gt_1d_list)

    # Compute metrics
    mae = mean_absolute_error(pred_1d, gt_1d)
    pcc = pearsonr(pred_1d, gt_1d)[0]
    js_dis = jensenshannon(pred_1d, gt_1d)

    print("MAE: ", mae)
    print("PCC: ", pcc)
    print("Jensen-Shannon Distance: ", js_dis)
    print("Average MAE betweenness centrality:", avg_mae_bc)
    print("Average MAE eigenvector centrality:", avg_mae_ec)
    print("Average MAE PageRank centrality:", avg_mae_pc)
    
    return mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc

# Three Fold Validation

In [None]:
def train_with_kfold(lr_train_A, hr_train_A, train_func, train_paras, fold=3):
    """
    Train model with k(3 in this case) fold validation.
    """

    kfold = KFold(n_splits=fold, shuffle=True, random_state=random_seed)
    total_mae = 0
    fold = 0
    eval_data = []
    models = []

    lr_train_A_np = np.array(lr_train_A)
    hr_train_A_np = np.array(hr_train_A)
    for train_idx, val_idx in kfold.split(lr_train_A_np):
        fold += 1
        print(f"Fold {fold}")

        # Training and validation split for the current fold
        lr_train_fold, hr_train_fold = lr_train_A_np[train_idx], hr_train_A_np[train_idx]
        lr_val_fold, hr_val_fold = lr_train_A_np[val_idx], hr_train_A_np[val_idx]

        model = train_func(lr_train_fold, hr_train_fold, train_paras[0], train_paras[1])
        models.append(model)
        mae, pcc, js_dis, avg_mae_bc, avg_mae_ec, avg_mae_pc = evaluation(model, lr_val_fold, hr_val_fold)
        eval_data.append([mae, pcc, js_dis, avg_mae_pc, avg_mae_ec, avg_mae_bc])
        total_mae += mae

    bar_plots(eval_data)

    return total_mae / fold, models

# Main Loop

In [None]:
import pickle

TEST = True     # Test mode
KFOLD = True    # K fold validation
PREDICT = True  # Output prediction results

# Testing mode
if TEST:
    # Divide datasets
    subjects_adj = lr_train_A[:132]
    subjects_ground_truth = hr_train_A[:132]
    adj_test = lr_train_A[132:]
    ground_truth = hr_train_A[132:]

    # Not use k fold validation
    if not KFOLD:
        print(f"TEST TRAINING ON {device}")
        model = train(subjects_adj, subjects_ground_truth, 150, ks = [0.9, 0.7, 0.6, 0.5])
        # Save model
        with open('models_CUDA_GAT_TEST.pkl', 'wb') as file:
            pickle.dump(model, file)
        # Save prediction results
        if PREDICT:
            predicts = predict(model, lr_test_A, 'predicts_agsr_GAT_test.csv')
    
    # Use k fold validation
    else:
        print(f"USING KFOLD ON {device}")
        # Set up training parameters
        train_paras = (100, [0.9, 0.7, 0.6, 0.5])
        # Train with k fold (3 fold in this case) validation
        avg_mae, models = train_with_kfold(np.array(lr_train_A), np.array(hr_train_A), train, train_paras, fold=3)
        print(f'Average MAE over folds: {avg_mae}')
        # Save model
        with open('models_CUDA_GAT_KFOLDS.pkl', 'wb') as file:
            pickle.dump(models, file)
        # Save prediction results for every fold
        if PREDICT:
            for i in range(len(models)):
                predicts = predict(models[i], lr_test_A, f'predictions_fold_{i+1}.csv')

# Training mode
else:
    print(f"TRAINING KFOLD ON {device}")
    # Training with the whole dataset
    best_model = train(np.array(lr_train_A), np.array(hr_train_A), 150, ks = [0.9, 0.7, 0.6, 0.5])
    # Save model
    with open('model_CUDA_GAT.pkl', 'wb') as file:
        pickle.dump(best_model, file)
    # Save prediction results
    predicts = predict(best_model, lr_test_A, 'predicts_agsr_best.csv')

## Training Method

In [None]:
# Set up loss function
criterion = nn.MSELoss()


def paper_train(subjects_adj, subjects_labels, validation_lr, validation_hr, epochs=50, ks=[0.9, 0.7, 0.6, 0.5], search_es = False):
    """
    Train model and print model structure.
    """
    best_mae = float('inf')
    ctr = 0
    # Early stopping counter for fair comparison
    es = -1

    # Build up model structure
    hr_dim=320
    lr_dim=160
    model = GSATNet(ks).to(device)
    print(model)
    bce_loss = nn.BCELoss()
    netD = Discriminator(320).to(device)
    print(netD)
    optimizerG = optim.Adam(model.parameters(), lr=0.0001)
    optimizerD = optim.Adam(netD.parameters(), lr=0.0001)

    # Start training
    all_epochs_loss = []
    for epoch in tqdm(range(epochs)):
        with torch.autograd.set_detect_anomaly(True):
            epoch_loss = []
            epoch_error = []
            model.train()
            for lr, hr in zip(subjects_adj, subjects_labels):
                optimizerD.zero_grad()
                optimizerG.zero_grad()

                hr = pad_HR_adj(hr, 26)
                lr = torch.from_numpy(lr).type(torch.FloatTensor).to(device)
                padded_hr = torch.from_numpy(hr).type(torch.FloatTensor).to(device)

                eig_val_hr, U_hr = torch.linalg.eigh(
                    padded_hr, UPLO='U')

                model_outputs, net_outs, start_gcn_outs, layer_outs = model(
                    lr, 160, 320)

                mse_loss = 0.1 * criterion(net_outs, start_gcn_outs) + 0.8 * criterion(
                    model.layer.weights, U_hr) + criterion(model_outputs, padded_hr)

                error = criterion(model_outputs, padded_hr)
                real_data = model_outputs.detach()
                fake_data = gaussian_noise_layer(padded_hr)

                d_real = netD(real_data)
                d_fake = netD(fake_data)

                dc_loss_real = bce_loss(d_real, torch.ones(hr_dim, 1).to(device))
                dc_loss_fake = bce_loss(d_fake, torch.zeros(hr_dim, 1).to(device))
                dc_loss = dc_loss_real + dc_loss_fake

                dc_loss.backward()
                optimizerD.step()

                d_fake = netD(gaussian_noise_layer(padded_hr))

                gen_loss = bce_loss(d_fake, torch.ones(hr_dim, 1).to(device))
                generator_loss = gen_loss + mse_loss
                generator_loss.backward()
                optimizerG.step()

                epoch_loss.append(generator_loss.item())
                epoch_error.append(error.item())

            # Print training results in every epoch
            print("Epoch: ", epoch, "Loss: ", np.mean(epoch_loss),
                  "Error: ", np.mean(epoch_error)*100, "%")
            all_epochs_loss.append(np.mean(epoch_loss))
            
            if search_es:
                print(f"Now check for es.")
                model.eval()
                with torch.no_grad():
                    pred_matrices = []
                    gt_matrices = []
    
                    for lr, hr in zip(validation_lr, validation_hr):
                        model_outputs, net_outs, start_gcn_outs, layer_outs = model(torch.tensor(lr, dtype=torch.float32).to(device), 160, 320)
                        pred_matrices.append(unpad_matrix(model_outputs, padding_top, padding_bottom, padding_left, padding_right))
                        gt_matrices.append(torch.tensor(hr, dtype=torch.float32))
                        
                    num_val = len(pred_matrices)
                    pred_1d_list = []
                    gt_1d_list = []
                
                    # Iterate over each test sample
                    for i in range(num_val):
                        # Vectorize matrices
                        pred_1d_list.append(MatrixVectorizer.vectorize(pred_matrices[i].cpu().detach().numpy()))
                        gt_1d_list.append(MatrixVectorizer.vectorize(gt_matrices[i].cpu().detach().numpy()))
                
                
                    # Concatenate flattened matrices
                    pred_1d = np.concatenate(pred_1d_list)
                    gt_1d = np.concatenate(gt_1d_list)
                
                    # Compute metrics
                    mae = mean_absolute_error(pred_1d, gt_1d)
                    print(f"MAE_validate: {mae}")
                    
                    if mae < best_mae:
                        if best_mae - mae > 0.0001:
                            best_mae = mae
                            print(f"Best MAE: {best_mae}, PASS")
                            ctr = 0
                        else:
                            ctr += 1
                            print(f"Best MAE: {best_mae}, BUT UNDER THRESHOLD")
                    else:
                        ctr += 1
                        if ctr >= 10:
                            if epoch >= 99:
                                print(f"Early stopping at epoch {epoch}!!!!!!!")
                                es = epoch
                                break
                            else:
                                print(f"Worse MAE for {ctr} times, continue training.")
                                continue
            
                    

    return model, all_epochs_loss, es

## Evaluation Method

In [None]:
from evaluation import evaluate_all


def paper_eval(model, lr_val_fold, hr_val_fold):
    """
    Evaluate model performance in MAE, PCC, Jensen-Shannon Distance, Average MAE betweenness/eigenvector/PageRank centrality.
    """

    model.eval()
    
    with torch.no_grad():
        pred_matrices = []
        gt_matrices = []
    
        for lr, hr in zip(lr_val_fold, hr_val_fold):
            model_outputs, net_outs, start_gcn_outs, layer_outs = model(torch.tensor(lr, dtype=torch.float32).to(device), 160, 320)
            pred_matrices.append(unpad_matrix(model_outputs, padding_top, padding_bottom, padding_left, padding_right).cpu().detach().numpy())
            gt_matrices.append(np.array(hr))
            
    return pred_matrices, gt_matrices

## Kfold Method Searching Early Stopping (Fair comparison)

In [None]:
from MatrixVectorizer import MatrixVectorizer
import evaluation
def paper_kfold_es(train_func, train_params):
    
    # Load the data
    global lr_test_fold, hr_test_fold, lr_train_fold, hr_train_fold, lr_validate_fold, hr_validate_fold
    lrs = []
    hrs = []
    vlrs = []
    vhrs = []
        
    for i in range(1,4):
        lr_train_path = f'../../Cluster-CV2/Fold{i}/lr_train_split_{i}.csv'
        hr_train_path = f'../../Cluster-CV2/Fold{i}/hr_train_split_{i}.csv'
        lr_train_pd = pd.read_csv(lr_train_path)
        hr_train_pd = pd.read_csv(hr_train_path)
        lr_train = np.array(lr_train_pd.iloc[:-10].apply(lambda x: np.array(x), axis=1).tolist())
        hr_train = np.array(hr_train_pd.iloc[:-10].apply(lambda x: np.array(x), axis=1).tolist())
        
        lr_validate = np.array(lr_train_pd.tail(10).apply(lambda x: np.array(x), axis=1).tolist())
        hr_validate = np.array(hr_train_pd.tail(10).apply(lambda x: np.array(x), axis=1).tolist())
        
        lr_train_A = []
        hr_train_A = []
        lr_validate_A = []
        hr_validate_A = []
        
        matrix_vectorizer = MatrixVectorizer()
        
        for g in lr_train:
            lr_train_A.append(matrix_vectorizer.anti_vectorize(g, 160))
        for g in hr_train:
            hr_train_A.append(matrix_vectorizer.anti_vectorize(g, 268))
            
        for g in lr_validate:
            lr_validate_A.append(matrix_vectorizer.anti_vectorize(g, 160))
        for g in hr_validate:
            hr_validate_A.append(matrix_vectorizer.anti_vectorize(g, 268))
        
        lrs.append(np.array(lr_train_A))
        hrs.append(np.array(hr_train_A))
        vhrs.append(np.array(hr_validate_A))
        vlrs.append(np.array(lr_validate_A))
        
    # Train with kfold
    models = []
    losses = []
    ess = []
    for fold in range(0, 3):
        print(f"Fold {fold+1}")
        if fold == 0:
            lr_train_fold = np.concatenate((lrs[1], lrs[2]))
            hr_train_fold = np.concatenate((hrs[1], hrs[2]))
            lr_validate_fold = np.concatenate((vlrs[1], vlrs[2]))
            hr_validate_fold = np.concatenate((vhrs[1], vhrs[2]))
            lr_test_fold = lrs[0]
            hr_test_fold = hrs[0]
        elif fold == 1:
            lr_train_fold = np.concatenate((lrs[0], lrs[2]))
            hr_train_fold = np.concatenate((hrs[0], hrs[2]))
            lr_validate_fold = np.concatenate((vlrs[0], vlrs[2]))
            hr_validate_fold = np.concatenate((vhrs[0], vhrs[2]))
            lr_test_fold = lrs[1]
            hr_test_fold = hrs[1]
        elif fold == 2:
            lr_train_fold = np.concatenate((lrs[0], lrs[1]))
            hr_train_fold = np.concatenate((hrs[0], hrs[1]))
            lr_validate_fold = np.concatenate((vlrs[0], vlrs[1]))
            hr_validate_fold = np.concatenate((vhrs[0], vhrs[1]))
            lr_test_fold = lrs[2]
            hr_test_fold = hrs[2]
        model, loss, es = train_func(lr_train_fold, hr_train_fold, lr_validate_fold, hr_validate_fold, train_params[0], train_params[1], search_es = True)
        ess.append(es)
        models.append(model)
        losses.append(loss)
    
    return models, losses, ess

## Kfold Method for Model Performance Evaluation

In [None]:
from MatrixVectorizer import MatrixVectorizer
import evaluation
def paper_kfold(train_func, train_params, ess):
    
    # Load the data
    global lr_test_fold, hr_test_fold, lr_train_fold, hr_train_fold
    lrs = []
    hrs = []

        
    for i in range(1,4):
        # !!!!!!!Please Change the path to the correct path for Cluster CV and Random CV.
        lr_train_path = f'../../Cluster-CV2/Fold{i}/lr_train_split_{i}.csv'
        hr_train_path = f'../../Cluster-CV2/Fold{i}/hr_train_split_{i}.csv'
        lr_train_pd = pd.read_csv(lr_train_path)
        hr_train_pd = pd.read_csv(hr_train_path)
        lr_train = np.array(lr_train_pd.apply(lambda x: np.array(x), axis=1).tolist())
        hr_train = np.array(hr_train_pd.apply(lambda x: np.array(x), axis=1).tolist())

        
        lr_train_A = []
        hr_train_A = []

        
        matrix_vectorizer = MatrixVectorizer()
        
        for g in lr_train:
            lr_train_A.append(matrix_vectorizer.anti_vectorize(g, 160))
        for g in hr_train:
            hr_train_A.append(matrix_vectorizer.anti_vectorize(g, 268))
            

        
        lrs.append(np.array(lr_train_A))
        hrs.append(np.array(hr_train_A))

        
    # Train with kfold
    models = []
    losses = []
    for fold in range(0, 3):
        print(f"Fold {fold+1}")
        if fold == 0:
            lr_train_fold = np.concatenate((lrs[1], lrs[2]))
            hr_train_fold = np.concatenate((hrs[1], hrs[2]))
            lr_test_fold = lrs[0]
            hr_test_fold = hrs[0]
        elif fold == 1:
            lr_train_fold = np.concatenate((lrs[0], lrs[2]))
            hr_train_fold = np.concatenate((hrs[0], hrs[2]))
            lr_test_fold = lrs[1]
            hr_test_fold = hrs[1]
        elif fold == 2:
            lr_train_fold = np.concatenate((lrs[0], lrs[1]))
            hr_train_fold = np.concatenate((hrs[0], hrs[1]))
            lr_test_fold = lrs[2]
            hr_test_fold = hrs[2]
        model, loss, es = train_func(lr_train_fold, hr_train_fold, None, None, ess[fold]+1, train_params[1], search_es = False)
        models.append(model)
        losses.append(loss)
        pred_matrices, gt_matrices = paper_eval(model, lr_test_fold, hr_test_fold)
        evaluate_all(np.array(pred_matrices), np.array(gt_matrices), output_path=f'ID-randomCluster2-{fold}-fixed.csv')
    return models, losses


## Main Method for Paper Evaluation

In [None]:
train_params = (300, [0.9, 0.7, 0.6, 0.5])
model, losses, ess = paper_kfold_es(paper_train, train_params)
print(f"Losses: {losses}")
print(f"Early Stoppings: {ess}")