# Set up Environment

In [1]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import NNConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable

from torch.distributions import normal, kl

from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE, InnerProductDecoder, ARGVA
from sklearn.model_selection import KFold
import pandas as pd
from MatrixVectorizer import MatrixVectorizer


In [2]:
# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)


# Define Model Layers

## Model Layers

In [3]:
class SheafConvLayer(nn.Module):
    def __init__(self, n_nodes, d, f_in, f_out=None):
        super().__init__()
        self.d = d
        self.n_nodes = n_nodes
        self.f_in = f_in
        self.f_out = f_out
        # random init weight matrices
        if f_out is None:
            f_out = f_in 
        self.weight1 = nn.Parameter(torch.randn((d, d), device=DEVICE))
        self.weight2 = nn.Parameter(torch.randn((f_in, f_out), device=DEVICE))
        self.edge_weights = nn.Parameter(torch.randn((n_nodes, n_nodes, d, 2*d), device=DEVICE))


    def forward(self, X, adj):
        kron_prod = torch.kron(torch.eye(self.n_nodes).to(DEVICE), self.weight1)
        L = self.sheaf_laplacian(X, adj)
        if self.f_out is None:
            return X - F.elu(L @ kron_prod @ X @ self.weight2), L
        else:
            return F.elu(L @ kron_prod @ X @ self.weight2), L


    def sheaf_laplacian(self, X, adj, epsilon=1e-6):
        X_reshaped = X.reshape(self.n_nodes, self.d, -1)
        idx_pairs = torch.cartesian_prod(torch.arange(self.n_nodes), torch.arange(self.n_nodes))
        all_stacked_features = X_reshaped[idx_pairs].reshape(self.n_nodes, self.n_nodes, 2*self.d, -1).to(DEVICE)
        lin_trans = F.elu(torch.matmul(self.edge_weights, all_stacked_features))
        inner_transpose = torch.transpose(lin_trans, -1, -2)
        L_v = -1 * torch.matmul(lin_trans, torch.transpose(inner_transpose, 0, 1))
        row_cond = torch.isclose(torch.sum(adj, dim=1), torch.zeros_like(torch.sum(adj, dim=1)))
        col_cond = torch.isclose(torch.sum(adj, dim=0), torch.zeros_like(torch.sum(adj, dim=0)))
        adj_row_weights = adj / (torch.sum(adj, dim=1)[:, None] + epsilon)
        adj_col_weights = adj / (torch.sum(adj, dim=0)[:, None] + epsilon)
        # adj_col_weights = torch.where(col_cond[None, :], 0., adj / torch.sum(adj, dim=0)[None, :])
        adj_weights = torch.maximum(adj_row_weights * adj_col_weights, torch.zeros_like(adj_row_weights))

        adj_diag_weights = adj_row_weights ** 2
        diag_blocks = torch.sum(adj_diag_weights[:, :, None, None] * torch.matmul(lin_trans, inner_transpose), dim=1)
        L_v[range(self.n_nodes), range(self.n_nodes)] = diag_blocks
        return L_v.view(-1, self.n_nodes * self.d)
        ### NOTE IGNORE MATRIX NORMALISATION FOR NOW #####
        # inv_root_diag_blocks = torch.pow(diag_blocks+epsilon, -1/2)
        # normalise_mat = torch.block_diag(*inv_root_diag_blocks)

        # return normalise_mat @ L_v.view(-1, self.n_nodes * self.d) @ normalise_mat
        ################################################

In [4]:
class SheafAligner(nn.Module):
    
    def __init__(self, d, f):
        super().__init__()

        self.d = d
        self.f = f

        self.sheafconv1 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm1 = BatchNorm(f)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm2 = BatchNorm(f)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm3 = BatchNorm(f)

    def forward(self, X, adj):

        x1, L1 = self.sheafconv1(X, adj)
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, training=self.training)

        mean_x1 = x1.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj1 = torch.matmul(mean_x1[:,None, None, :], L1.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj1 = torch.matmul(adj1, mean_x1[None, :, :, None])
        adj1 = F.sigmoid(adj1.squeeze())

        x2, L2 = self.sheafconv2(x1, adj1)
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, training=self.training)
        
        mean_x2 = x2.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj2 = torch.matmul(mean_x2[:,None, None, :], L2.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj2 = torch.matmul(adj2, mean_x2[None, :, :, None])
        adj2 = F.sigmoid(adj2.squeeze())

        x3, L3 = self.sheafconv3(x2, adj2)
        x3 = F.sigmoid(self.batchnorm3(x3))

        mean_x3 = x3.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj3 = torch.matmul(mean_x3[:,None, None, :], L3.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj3 = torch.matmul(adj3, mean_x3[None, :, :, None])
        adj3 = (adj3 + torch.t(adj3)) / 2 # to ensure the matrix is symmetric
        adj3 = F.tanh(F.relu(adj3)) # to ensure there is sparsity
        
        return x3, adj3

        

In [5]:
class SheafGenerator(nn.Module):
    def __init__(self, d, f):
        super().__init__()

        self.d = d
        self.f = f
        
        self.sheafconv1 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm1 = BatchNorm(f)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm2 = BatchNorm(f)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, d, f, N_HR_NODES)
        self.batchnorm3 = BatchNorm(N_HR_NODES)

        self.out_mat = nn.Parameter(torch.randn((N_LR_NODES, 2*N_LR_NODES), device=DEVICE))

        self.out_sigmoid = nn.Sigmoid()
        
        


    def forward(self, X, adj):
        x1, L1 = self.sheafconv1(X, adj) # returns (d*lr_n) * f
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, p=0.1, training=self.training)

        mean_x1 = x1.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj1 = torch.matmul(mean_x1[:,None, None, :], L1.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj1 = torch.matmul(adj1, mean_x1[None, :, :, None])
        adj1 = F.sigmoid(adj1.squeeze())

        x2, L2 = self.sheafconv2(x1, adj1) # returns (d*lr_n) * f
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, p=0.1, training=self.training)

        mean_x2 = x2.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj2 = torch.matmul(mean_x2[:,None, None, :], L2.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj2 = torch.matmul(adj2, mean_x2[None, :, :, None])
        adj2 = F.sigmoid(adj2.squeeze())

        x3, L3 = self.sheafconv3(x2, adj2) # returns (d*lr_n) * hr_n
        x3 = F.sigmoid(self.batchnorm3(x3))

        x3 = torch.matmul(self.out_mat, x3)
        adj3 = torch.t(x3) @ adj2 @ x3
        adj3 = (adj3 + torch.t(adj3)) / 2 # to ensure the matrix is symmetric
        adj3 = F.tanh(F.relu(adj3)) # to ensure there is sparsity

        return adj3
 

In [6]:
class SheafDiscriminator(nn.Module):
    def __init__(self, d, f):
        super().__init__()

        self.d = d
        self.f = f

        self.sheafconv1 = SheafConvLayer(N_HR_NODES, d, f)
        self.sheafconv2 = SheafConvLayer(N_HR_NODES, d, f, 1)
        self.out = torch.nn.Linear(2*N_HR_NODES, 1)

    def forward(self, X, adj):
        x1, L1 = self.sheafconv1(X, adj)
        x1 = F.sigmoid(x1)
        x1 = F.dropout(x1, p=0.1, training=self.training)

        mean_x1 = x1.reshape(N_HR_NODES, self.d, self.f).mean(dim=-1)
        adj1 = torch.matmul(mean_x1[:,None, None, :], L1.reshape(N_HR_NODES, N_HR_NODES, self.d, self.d))
        adj1 = torch.matmul(adj1, mean_x1[None, :, :, None])
        adj1 = F.sigmoid(adj1.squeeze())


        x2, L2 = self.sheafconv2(x1, adj)
        x2 = F.sigmoid(x2).flatten()
        x3 = F.sigmoid(self.out(x2))
        return x3

## Helper Functions

In [7]:
def pearson_coor(input, target, epsilon=1e-7):
    vx = input - torch.mean(input, dim=(1, 2))[:, None, None]
    vy = target - torch.mean(target, dim=(1, 2))[:, None, None]
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)+epsilon) * torch.sqrt(torch.sum(vy ** 2)+epsilon)+epsilon)
    return cost

def GT_loss(target, predicted):

    # l1_loss
    l1_loss = torch.nn.L1Loss()
    loss_pix2pix = l1_loss(target, predicted)

    # topological_loss
    target_n = target.detach().cpu().clone().numpy()
    predicted_n = predicted.detach().cpu().clone().numpy()
    torch.cuda.empty_cache()

    topo_loss = []
    

    for i in range(len(target_n)):

        cur_target = target_n[i]
        cur_predicted = predicted_n[i]

        target_t = eigen_centrality(cur_target)
        real_topology = torch.tensor(target_t[0])
        predicted_t = eigen_centrality(cur_predicted)
        fake_topology = torch.tensor(predicted_t[0])
        topo_loss.append(l1_loss(real_topology, fake_topology))

    topo_loss = torch.sum(torch.stack(topo_loss))

    pc_loss = pearson_coor(target, predicted).to(DEVICE)
    torch.cuda.empty_cache()

    G_loss = loss_pix2pix + (1 - pc_loss) + topo_loss

    return G_loss


In [8]:
import numpy as np
import networkx as nx


# put it back into a 2D symmetric array


def topological_measures(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology = []



    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph from similarity matrix
    G = nx.from_numpy_matrix(np.absolute(data))
    U = G.to_undirected()

    # Centrality #

    # compute closeness centrality and transform the output to vector
    cc = nx.closeness_centrality(U, distance="weight")
    closeness_centrality = np.array([cc[g] for g in U])
    # compute betweeness centrality and transform the output to vector
    # bc = nx.betweenness_centrality(U, weight='weight')
    # bc = (nx.betweenness_centrality(U))
    betweenness_centrality = np.array([cc[g] for g in U])
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    eigenvector_centrality = np.array([ec[g] for g in U])


    topology.append(closeness_centrality)  # 0
    topology.append(betweenness_centrality)  # 1
    topology.append(eigenvector_centrality)  # 2

    return topology
# put it back into a 2D symmetric array

def eigen_centrality(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology_eigen = []

    G = nx.from_numpy_array(np.absolute(data))
    U = G.to_undirected()

    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph frL2
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    
    eigenvector_centrality = np.array([ec[g] for g in U])



    topology_eigen.append(eigenvector_centrality)  # 2

    return topology_eigen


In [9]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

# Attempt Model Training

## Load in Data

In [10]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

In [11]:
lr_X_dim1 = torch.load('model_autoencoder/encode_lr_1.pt')
lr_X_dim2 = torch.load('model_autoencoder/encode_lr_2.pt')
hr_X_dim1 = torch.load('model_autoencoder/encode_hr_1.pt')
hr_X_dim2 = torch.load('model_autoencoder/encode_hr_2.pt')

In [12]:
lr_X_all = torch.empty((167, 320, 32))
for i in range(len(lr_X_dim1)):
    a, b = lr_X_dim1[i], lr_X_dim2[i]
    lr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

hr_X_all = torch.empty((167, 536, 32))
for i in range(len(hr_X_dim1)):
    a, b = hr_X_dim1[i], hr_X_dim2[i]
    hr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

## Begin Training

In [13]:
d = 2 # number of dimensions in each node
f = 32 # length of node encoding
BATCHSIZE = 1
N_TRAIN_SAMPLES = len(lr_train)
EPOCHS = 30


aligner = SheafAligner(d, f).to(DEVICE)
generator = SheafGenerator(d, f).to(DEVICE)
discriminator = SheafDiscriminator(d, f).to(DEVICE)

In [27]:
sum(p.numel() for p in aligner.parameters()) + sum(p.numel() for p in generator.parameters()) + sum(p.numel() for p in discriminator.parameters()) 

2445361

In [28]:
# aligner_optimizer = torch.optim.AdamW(aligner.parameters(), lr=0.0001, betas=(0.5, 0.999))
generator_optimizer = torch.optim.AdamW(list(aligner.parameters()) + list(generator.parameters()), lr=0.001, betas=(0.5, 0.999))
# generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))

adversarial_loss = torch.nn.BCELoss()

In [29]:
aligner.train()
generator.train()
discriminator.train()
    
for epoch in range(EPOCHS):

    alignment_loss_ls = []
    generator_loss_ls = []
    disciriminator_loss_ls = []

    for i, sample in tqdm(enumerate(zip(lr_X_all, lr_train, hr_X_all, hr_train))):

        # aligner_optimizer.zero_grad()
        generator_optimizer.zero_grad()
        discriminator_optimizer.zero_grad()

        X_lr, adj_lr, X_hr, adj_hr = sample

        aligned_X_lr, aligned_adj_lr = aligner(X_lr.to(DEVICE), adj_lr.to(DEVICE))
        torch.cuda.empty_cache()

        hr_mean = torch.mean(X_hr)
        hr_std = torch.std(X_hr)

        adj_hr_sampled = torch.normal(hr_mean, hr_std, size=(N_LR_NODES, N_LR_NODES)).to(DEVICE)
        # hr_X_sampled = torch.Tensor(MatrixVectorizer().anti_vectorize(hr_X_sampled, N_HR_NODES))


        alignment_loss = torch.abs(F.kl_div(F.softmax(adj_hr_sampled, dim=-1), F.softmax(aligned_adj_lr, dim=-1), None, None, 'sum'))
        alignment_loss = alignment_loss / 1000

        alignment_loss_ls.append(alignment_loss.detach().item())

        # generate hr adjacency
        generated_adj_hr = generator(aligned_X_lr.to(DEVICE), aligned_adj_lr.to(DEVICE))
        torch.cuda.empty_cache()

        freeze_model(generator)
        freeze_model(aligner)
        unfreeze_model(discriminator)

        d_real = discriminator(X_hr.to(DEVICE), adj_hr.to(DEVICE))
        torch.cuda.empty_cache()

        d_fake = discriminator(X_hr.to(DEVICE), generated_adj_hr.to(DEVICE))
        torch.cuda.empty_cache()

        d_real_loss = adversarial_loss(d_real, torch.ones_like(d_real, requires_grad=False))
        torch.cuda.empty_cache()
        d_fake_loss = adversarial_loss(d_fake, torch.zeros_like(d_fake, requires_grad=False))
        d_loss = (d_real_loss + d_fake_loss) / 2
        torch.cuda.empty_cache()

        d_loss.backward(retain_graph=True)
        discriminator_optimizer.step()

        disciriminator_loss_ls.append(d_loss.detach().item())


        unfreeze_model(generator)
        unfreeze_model(aligner)
        freeze_model(discriminator)

        ### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
        temp_adj_hr = adj_hr.reshape(1, *adj_hr.shape)
        temp_generated_adj_hr = generated_adj_hr.reshape(1, *generated_adj_hr.shape)
        ##########################################################

        g_topology_loss = GT_loss(temp_adj_hr.to(DEVICE), temp_generated_adj_hr.to(DEVICE))
        torch.cuda.empty_cache()

        d_fake = discriminator(X_hr.to(DEVICE), generated_adj_hr.to(DEVICE))

        g_adversarial_loss = adversarial_loss(d_fake, (torch.ones_like(d_fake)))
        g_loss = g_adversarial_loss + g_topology_loss + alignment_loss
      
        g_loss.backward()
        generator_optimizer.step()

        generator_loss_ls.append(g_loss.detach().item())


    avg_align_loss = np.mean(alignment_loss_ls)
    avg_generator_loss = np.mean(generator_loss_ls)
    avg_discriminator_loss = np.mean(disciriminator_loss_ls)

    print(f'sample {i}: align_loss = {avg_align_loss}, generator_loss = {avg_generator_loss}, discriminator = {avg_discriminator_loss}')


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7965999516898287, generator_loss = 3.248340706285146, discriminator = 0.7017677319263984


167it [02:29,  1.12it/s]


sample 166: align_loss = 0.795399045159003, generator_loss = 3.2420367845093234, discriminator = 0.6954997767231421


167it [02:30,  1.11it/s]


sample 166: align_loss = 0.7946944914892048, generator_loss = 3.2320239408576374, discriminator = 0.7060076954835903


167it [02:27,  1.13it/s]


sample 166: align_loss = 0.7943154891094047, generator_loss = 3.2412356868407497, discriminator = 0.6986124486980324


167it [02:27,  1.13it/s]


sample 166: align_loss = 0.7940994900857618, generator_loss = 3.236818812950267, discriminator = 0.6986538696431828


167it [02:27,  1.13it/s]


sample 166: align_loss = 0.7939687406945372, generator_loss = 3.2583425432554076, discriminator = 0.6936831602793254


167it [02:27,  1.14it/s]


sample 166: align_loss = 0.7938886256275063, generator_loss = 3.2553153082928072, discriminator = 0.707389393609441


167it [02:27,  1.14it/s]


sample 166: align_loss = 0.7938321603986318, generator_loss = 3.2445529628231666, discriminator = 0.7056346560666661


167it [02:27,  1.13it/s]


sample 166: align_loss = 0.7937915210952302, generator_loss = 3.2354528117254637, discriminator = 0.7104832041049431


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7937619768216938, generator_loss = 3.2482992944535196, discriminator = 0.6995087708541733


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7937398613569979, generator_loss = 3.240412697849498, discriminator = 0.7070539797137597


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7937230867540052, generator_loss = 3.245352510592519, discriminator = 0.7069193089079714


167it [02:29,  1.12it/s]


sample 166: align_loss = 0.7937108210460868, generator_loss = 3.2279776563476474, discriminator = 0.6955548592670235


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7937004180725463, generator_loss = 3.247601592462558, discriminator = 0.7114887594462869


167it [02:29,  1.12it/s]


sample 166: align_loss = 0.7936909819791417, generator_loss = 3.2283712816741468, discriminator = 0.6942692263397628


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7936854676572148, generator_loss = 3.2243179143834784, discriminator = 0.717961979126502


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7936796774407346, generator_loss = 3.245733324489532, discriminator = 0.7009783854741536


167it [02:29,  1.12it/s]


sample 166: align_loss = 0.7936732883224944, generator_loss = 3.23987036874942, discriminator = 0.7022154052814323


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.79366909743783, generator_loss = 3.2322180630136206, discriminator = 0.7037931345180123


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7936656378700347, generator_loss = 3.252268266028611, discriminator = 0.7014898832686647


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7936611828689804, generator_loss = 3.2624085676102395, discriminator = 0.6998482285859342


167it [02:29,  1.12it/s]


sample 166: align_loss = 0.7936598683545689, generator_loss = 3.258113860755962, discriminator = 0.6992874559528099


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.793655875200283, generator_loss = 3.2425749328967868, discriminator = 0.7007992210502396


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7936541812862464, generator_loss = 3.2397737062551832, discriminator = 0.7021619845293239


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7936508352171161, generator_loss = 3.237967924829353, discriminator = 0.7061702416328612


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7936492901362345, generator_loss = 3.2300402563421127, discriminator = 0.70190732314915


167it [02:28,  1.12it/s]


sample 166: align_loss = 0.7936463331034084, generator_loss = 3.237903223630535, discriminator = 0.7085740248600166


167it [02:29,  1.11it/s]


sample 166: align_loss = 0.7936441776995173, generator_loss = 3.241238585875044, discriminator = 0.7045286287090735


167it [02:29,  1.12it/s]


sample 166: align_loss = 0.7936414005513677, generator_loss = 3.2406666709209446, discriminator = 0.6925859258560363


167it [02:29,  1.12it/s]

sample 166: align_loss = 0.7936385648693153, generator_loss = 3.248549182822229, discriminator = 0.6992024600862743





# Cross Validation

In [18]:
def train(aligner, generator, discriminator, train_X_lr, train_adj_lr, train_X_hr, train_adj_hr, generator_optimizer, discriminator_optimizer, adversarial_loss=torch.nn.BCELoss()):
    aligner.train()
    generator.train()
    discriminator.train()

    train_data = (train_X_lr, train_adj_lr, train_X_hr, train_adj_hr)
        
    for epoch in range(EPOCHS):

        alignment_loss_ls = []
        generator_loss_ls = []
        disciriminator_loss_ls = []



        for i, sample in tqdm(enumerate(zip(*train_data))):

            generator_optimizer.zero_grad()
            discriminator_optimizer.zero_grad()

            X_lr, adj_lr, X_hr, adj_hr = sample

            aligned_X_lr, aligned_adj_lr = aligner(X_lr.to(DEVICE), adj_lr.to(DEVICE))
            torch.cuda.empty_cache()

            hr_mean = torch.mean(X_hr)
            hr_std = torch.std(X_hr)

            adj_hr_sampled = torch.normal(hr_mean, hr_std, size=(N_LR_NODES, N_LR_NODES)).to(DEVICE)


            alignment_loss = torch.abs(F.kl_div(F.softmax(adj_hr_sampled, dim=-1), F.softmax(aligned_adj_lr, dim=-1), None, None, 'sum'))
            alignment_loss = alignment_loss / 1000

            alignment_loss_ls.append(alignment_loss.detach().item())

            # generate hr adjacency
            generated_adj_hr = generator(aligned_X_lr.to(DEVICE), aligned_adj_lr.to(DEVICE))
            torch.cuda.empty_cache()

            freeze_model(generator)
            freeze_model(aligner)
            unfreeze_model(discriminator)

            d_real = discriminator(X_hr.to(DEVICE), adj_hr.to(DEVICE))
            torch.cuda.empty_cache()

            d_fake = discriminator(X_hr.to(DEVICE), generated_adj_hr.to(DEVICE))
            torch.cuda.empty_cache()

            d_real_loss = adversarial_loss(d_real, torch.ones_like(d_real, requires_grad=False))
            torch.cuda.empty_cache()
            d_fake_loss = adversarial_loss(d_fake, torch.zeros_like(d_fake, requires_grad=False))
            d_loss = (d_real_loss + d_fake_loss) / 2
            torch.cuda.empty_cache()

            d_loss.backward(retain_graph=True)
            discriminator_optimizer.step()

            disciriminator_loss_ls.append(d_loss.detach().item())


            unfreeze_model(generator)
            unfreeze_model(aligner)
            freeze_model(discriminator)

            ### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
            temp_adj_hr = adj_hr.reshape(1, *adj_hr.shape)
            temp_generated_adj_hr = generated_adj_hr.reshape(1, *generated_adj_hr.shape)
            ##########################################################

            g_topology_loss = GT_loss(temp_adj_hr.to(DEVICE), temp_generated_adj_hr.to(DEVICE))
            torch.cuda.empty_cache()

            d_fake = discriminator(X_hr.to(DEVICE), generated_adj_hr.to(DEVICE))

            g_adversarial_loss = adversarial_loss(d_fake, (torch.ones_like(d_fake)))
            g_loss = g_adversarial_loss + g_topology_loss + alignment_loss
        
            g_loss.backward(retain_graph=True)
            generator_optimizer.step()
        
            generator_loss_ls.append(g_loss.detach().item())


        avg_align_loss = np.mean(alignment_loss_ls)
        avg_generator_loss = np.mean(generator_loss_ls)
        avg_discriminator_loss = np.mean(disciriminator_loss_ls)

        print(f'sample {i}: align_loss = {avg_align_loss}, generator_loss = {avg_generator_loss}, discriminator = {avg_discriminator_loss}')

    return aligner, generator, discriminator



In [19]:
from evaluation_fn import evaluate_predictions


def validation(aligner, generator, val_X_lr, val_adj_lr, val_adj_hr):
    print('begin validation')
    aligner.eval()
    generator.eval()

    all_predictions = torch.empty((len(val_X_lr), N_HR_NODES, N_HR_NODES), requires_grad=False).cpu()

    for i in range(len(val_X_lr)):
                    
        aligned_X_lr, aligned_adj_lr = aligner(val_X_lr[i].to(DEVICE), val_adj_lr[i].to(DEVICE))
        torch.cuda.empty_cache()

        generated_adj_hr = generator(aligned_X_lr.detach(), aligned_adj_lr.detach()).cpu()

        all_predictions[i] = generated_adj_hr

    return evaluate_predictions(all_predictions, val_adj_hr)


In [29]:
validation(aligner, generator, lr_X_all[:20], lr_train[:20], hr_train[:20])

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 15.70 GiB of which 30.75 MiB is free. Process 4068805 has 2.50 GiB memory in use. Including non-PyTorch memory, this process has 12.71 GiB memory in use. Of the allocated memory 11.51 GiB is allocated by PyTorch, and 933.96 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [20]:
def cross_validate(n_fold, X_lr, adj_lr, X_hr, adj_hr, d=2, f=32):
    kf = KFold(n_fold, shuffle=True, random_state=99)
    runs_results = []
    for train_idx, val_idx in kf.split(X_lr):
        train_X_lr, val_X_lr = X_lr[train_idx], X_lr[val_idx]
        train_adj_lr, val_adj_lr = adj_lr[train_idx], adj_lr[val_idx]
        train_X_hr = X_hr[train_idx]
        train_adj_hr, val_adj_hr = adj_hr[train_idx], adj_hr[val_idx]

        aligner = SheafAligner(d, f).to(DEVICE)
        generator = SheafGenerator(d, f).to(DEVICE)
        discriminator = SheafDiscriminator(d, f).to(DEVICE)

        generator_optimizer = torch.optim.AdamW(list(aligner.parameters()) + list(generator.parameters()), lr=0.001, betas=(0.5, 0.999))
        discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))        
        
        aligner, generator, discriminator = train(aligner, generator, discriminator, train_X_lr, train_adj_lr, train_X_hr, train_adj_hr, generator_optimizer, discriminator_optimizer)
        val_metrics = validation(aligner, generator, val_X_lr, val_adj_lr, val_adj_hr)
        runs_results.append(val_metrics)

    return runs_results





In [21]:
EPOCHS = 15
cross_validate(3, lr_X_all, lr_train, hr_X_all, hr_train)

0it [00:00, ?it/s]

111it [01:36,  1.14it/s]


sample 110: align_loss = 0.7963261384147782, generator_loss = 3.2505935683239473, discriminator = 0.7009271485311491


111it [01:36,  1.14it/s]


sample 110: align_loss = 0.794862018512176, generator_loss = 3.2396781119822613, discriminator = 0.6988674165966274


111it [01:37,  1.14it/s]


sample 110: align_loss = 0.794393320341368, generator_loss = 3.2476265712387633, discriminator = 0.7013840122265859


111it [01:36,  1.15it/s]


sample 110: align_loss = 0.7941779794993701, generator_loss = 3.2512970190887756, discriminator = 0.7027396054955216


111it [01:37,  1.14it/s]


sample 110: align_loss = 0.7940458429826273, generator_loss = 3.2383339633299126, discriminator = 0.7025974129771327


111it [01:37,  1.14it/s]


sample 110: align_loss = 0.7939558593002526, generator_loss = 3.2416893662656454, discriminator = 0.703324428549758


111it [01:36,  1.14it/s]


sample 110: align_loss = 0.7938910298519306, generator_loss = 3.2330873981812758, discriminator = 0.7070250446731979


111it [01:36,  1.15it/s]


sample 110: align_loss = 0.7938442316141214, generator_loss = 3.2418197925049683, discriminator = 0.6990944141740197


111it [01:36,  1.15it/s]


sample 110: align_loss = 0.7938065373145782, generator_loss = 3.250571236956412, discriminator = 0.6999446985957859


111it [01:36,  1.15it/s]


sample 110: align_loss = 0.7937809057063885, generator_loss = 3.229642957182803, discriminator = 0.6936066757451307


111it [01:36,  1.15it/s]


sample 110: align_loss = 0.7937571180833353, generator_loss = 3.2338481250133304, discriminator = 0.7096576572538497


111it [01:37,  1.14it/s]


sample 110: align_loss = 0.7937413647368148, generator_loss = 3.245629262819132, discriminator = 0.7080226433169734


111it [01:37,  1.14it/s]


sample 110: align_loss = 0.7937282141264494, generator_loss = 3.2437831064088627, discriminator = 0.7040401144070668


111it [01:38,  1.13it/s]


sample 110: align_loss = 0.7937172398910867, generator_loss = 3.2378315581955057, discriminator = 0.6923432956944715


111it [01:38,  1.12it/s]


sample 110: align_loss = 0.7937090767396463, generator_loss = 3.2495518933568692, discriminator = 0.7002628338229548
begin validation


  pcc = pearsonr(pred_1d, gt_1d)[0]


MAE:  0.73091316
PCC:  nan
Jensen-Shannon Distance:  0.34707408553065305
Average MAE betweenness centrality: 0.01531607274276826
Average MAE eigenvector centrality: 0.01719175730917245
Average MAE PageRank centrality: 0.0007483549436318534


111it [01:05,  1.69it/s]


sample 110: align_loss = 0.7975988425650038, generator_loss = 2.9283703311348703, discriminator = 0.7027852776888255


111it [01:00,  1.84it/s]


sample 110: align_loss = 0.7975799860181035, generator_loss = 2.856796785376231, discriminator = 0.6976501087884646


111it [01:01,  1.80it/s]


sample 110: align_loss = 0.797562643214389, generator_loss = 2.871366634334486, discriminator = 0.6970356122867482


111it [00:56,  1.97it/s]


sample 110: align_loss = 0.7975556082553692, generator_loss = 2.8213571084144835, discriminator = 0.6879967510163247


111it [00:56,  1.97it/s]


sample 110: align_loss = 0.79754989533811, generator_loss = 2.816919800746125, discriminator = 0.7061627985120894


111it [00:56,  1.96it/s]


sample 110: align_loss = 0.7975462288469881, generator_loss = 2.8275692238801526, discriminator = 0.7042101111497965


111it [00:56,  1.96it/s]


sample 110: align_loss = 0.7975517994648701, generator_loss = 2.817022312757365, discriminator = 0.6997884076994818


111it [00:56,  1.97it/s]


sample 110: align_loss = 0.7975492515005507, generator_loss = 2.833361236328846, discriminator = 0.6977764931884972


111it [00:56,  1.98it/s]


sample 110: align_loss = 0.7975660197369687, generator_loss = 2.8247117884353186, discriminator = 0.6931395149445748


111it [00:56,  1.97it/s]


sample 110: align_loss = 0.7975582501909755, generator_loss = 2.8283984846019474, discriminator = 0.698121477891733


111it [00:56,  1.97it/s]


sample 110: align_loss = 0.7975570785032736, generator_loss = 2.830449732224929, discriminator = 0.694651693374187


111it [00:55,  1.99it/s]


sample 110: align_loss = 0.7975729311908688, generator_loss = 2.819719614164132, discriminator = 0.6989114295254957


111it [00:55,  1.99it/s]


sample 110: align_loss = 0.7975733532561912, generator_loss = 2.844377701696241, discriminator = 0.6973009114866858


111it [00:56,  1.98it/s]


sample 110: align_loss = 0.7975785077155173, generator_loss = 2.7946865476071863, discriminator = 0.7028823038479229


111it [00:56,  1.97it/s]


sample 110: align_loss = 0.7975819422318055, generator_loss = 2.8289459368331467, discriminator = 0.6984498098089889
begin validation


  pcc = pearsonr(pred_1d, gt_1d)[0]
  p = p / np.sum(p, axis=axis, keepdims=True)


MAE:  0.25090793
PCC:  nan
Jensen-Shannon Distance:  nan
Average MAE betweenness centrality: 0.015412674840715746
Average MAE eigenvector centrality: 0.01734299046230713
Average MAE PageRank centrality: 0.0007402635545339637


112it [01:03,  1.75it/s]


sample 111: align_loss = 0.7976914133344378, generator_loss = 2.894211779893698, discriminator = 0.7025277258030006


112it [00:56,  1.98it/s]


sample 111: align_loss = 0.7976698407105037, generator_loss = 2.808177985624804, discriminator = 0.6933394920613084


112it [00:56,  1.97it/s]


sample 111: align_loss = 0.7976588542972293, generator_loss = 2.821406282014415, discriminator = 0.6941365775253091


112it [00:56,  1.99it/s]


sample 111: align_loss = 0.797634192343269, generator_loss = 2.8176566381403854, discriminator = 0.6997857716466699


112it [00:56,  1.98it/s]


sample 111: align_loss = 0.7976143594299044, generator_loss = 2.8093899675968754, discriminator = 0.6922769205910819


112it [00:56,  1.98it/s]


sample 111: align_loss = 0.7975747452250549, generator_loss = 2.829363536366324, discriminator = 0.6857746876776218


112it [00:56,  1.99it/s]


sample 111: align_loss = 0.797548172729356, generator_loss = 2.82663897312177, discriminator = 0.7000684988285814


112it [00:56,  1.98it/s]


sample 111: align_loss = 0.7974959622536387, generator_loss = 2.8181051241059625, discriminator = 0.7024704196623394


112it [00:56,  1.99it/s]


sample 111: align_loss = 0.7974584209067481, generator_loss = 2.8227147002557293, discriminator = 0.7060143351554871


112it [00:56,  1.99it/s]


sample 111: align_loss = 0.7973908952304295, generator_loss = 2.816436887253166, discriminator = 0.6976693235337734


112it [00:56,  1.98it/s]


sample 111: align_loss = 0.7973480810012136, generator_loss = 2.831261266968121, discriminator = 0.7017229438892433


112it [00:56,  1.99it/s]


sample 111: align_loss = 0.7972805249903884, generator_loss = 2.8391023763928627, discriminator = 0.7093329557350704


112it [00:56,  1.98it/s]


sample 111: align_loss = 0.7972132680671555, generator_loss = 2.8190940282649706, discriminator = 0.7068851121834346


112it [00:56,  1.99it/s]


sample 111: align_loss = 0.797135978937149, generator_loss = 2.8211140230765595, discriminator = 0.7014998947935445


112it [00:56,  1.99it/s]


sample 111: align_loss = 0.7970623964709895, generator_loss = 2.815761498750056, discriminator = 0.6986792923084327
begin validation
MAE:  0.25977245
PCC:  nan
Jensen-Shannon Distance:  nan
Average MAE betweenness centrality: 0.015483741598290088
Average MAE eigenvector centrality: 0.01720620378716123
Average MAE PageRank centrality: 0.0007260023412142038


  pcc = pearsonr(pred_1d, gt_1d)[0]
  p = p / np.sum(p, axis=axis, keepdims=True)


[{'mae': 0.73091316,
  'pcc': nan,
  'js_dis': 0.34707408553065305,
  'avg_mae_bc': 0.01531607274276826,
  'avg_mae_ec': 0.01719175730917245,
  'avg_mae_pc': 0.0007483549436318534},
 {'mae': 0.25090793,
  'pcc': nan,
  'js_dis': nan,
  'avg_mae_bc': 0.015412674840715746,
  'avg_mae_ec': 0.01734299046230713,
  'avg_mae_pc': 0.0007402635545339637},
 {'mae': 0.25977245,
  'pcc': nan,
  'js_dis': nan,
  'avg_mae_bc': 0.015483741598290088,
  'avg_mae_ec': 0.01720620378716123,
  'avg_mae_pc': 0.0007260023412142038}]

In [15]:
torch.save(aligner.state_dict(), 'tim_files/aligner.pth')
torch.save(generator.state_dict(), 'tim_files/generator.pth')

In [14]:
aligner.load_state_dict(torch.load('tim_files/aligner.pth'))
generator.load_state_dict(torch.load('tim_files/generator.pth'))

RuntimeError: Error(s) in loading state_dict for SheafGenerator:
	Missing key(s) in state_dict: "out_mat". 
	size mismatch for sheafconv3.weight2: copying a param with shape torch.Size([32, 32]) from checkpoint, the shape in current model is torch.Size([32, 268]).
	size mismatch for batchnorm3.module.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([268]).
	size mismatch for batchnorm3.module.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([268]).
	size mismatch for batchnorm3.module.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([268]).
	size mismatch for batchnorm3.module.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([268]).

In [31]:
# predictions = torch.empty((len(lr_test), N_HR_NODES, N_HR_NODES))
# for i in range(len(lr_test)):
#     X_lr = lr_X_all[i]
#     adj_lr = lr_test[i]
#     aligned aligner

all_predictions = torch.empty((len(lr_test), N_HR_NODES, N_HR_NODES), requires_grad=False).cpu()

for i in range(len(lr_test)):
                
    aligned_X_lr, aligned_adj_lr = aligner(lr_X_all[i].to(DEVICE), lr_test[i].to(DEVICE))
    torch.cuda.empty_cache()

    generated_adj_hr = generator(aligned_X_lr.detach(), aligned_adj_lr.detach()).cpu()

    all_predictions[i] = generated_adj_hr



OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 11.69 GiB of which 7.94 MiB is free. Process 2720755 has 178.00 MiB memory in use. Including non-PyTorch memory, this process has 11.47 GiB memory in use. Of the allocated memory 10.59 GiB is allocated by PyTorch, and 733.61 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
from data_preparation import generate_submission_file

generate_submission_file(all_predictions, 'average_IMAN_submission.csv')

In [22]:
validation(aligner, generator, lr_X_all[:20], lr_train[:20], hr_train[:20])

begin validation
MAE:  0.7193633
PCC:  0.007882498734272114
Jensen-Shannon Distance:  0.36142108855438787
Average MAE betweenness centrality: 0.015426310551565453
Average MAE eigenvector centrality: 0.019239364692123003
Average MAE PageRank centrality: 0.0008014847652272805


{'mae': 0.7193633,
 'pcc': 0.007882498734272114,
 'js_dis': 0.36142108855438787,
 'avg_mae_bc': 0.015426310551565453,
 'avg_mae_ec': 0.019239364692123003,
 'avg_mae_pc': 0.0008014847652272805}