# Set up Environment

In [26]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import NNConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable

from torch.distributions import normal, kl


import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE, InnerProductDecoder, ARGVA
from sklearn.model_selection import KFold
import pandas as pd
from MatrixVectorizer import MatrixVectorizer


In [28]:
# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

EPOCHS = 10

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)


# Define Model Layers

## Model Layers

In [29]:
class SheafConvLayer(nn.Module):
    def __init__(self, n_nodes, d, f_in, f_out=None):
        super().__init__()
        self.d = d
        self.n_nodes = n_nodes
        self.f_in = f_in
        self.f_out = f_out
        # random init weight matrices
        if f_out is None:
            f_out = f_in 
        self.weight1 = nn.Parameter(torch.randn((d, d), device=DEVICE))
        self.weight2 = nn.Parameter(torch.randn((f_in, f_out), device=DEVICE))
        self.edge_weights = nn.Parameter(torch.randn((n_nodes, n_nodes, d, 2*d), device=DEVICE))


    def forward(self, X, adj):
        kron_prod = torch.kron(torch.eye(self.n_nodes).to(DEVICE), self.weight1)
        L = self.sheaf_laplacian(X, adj)
        if self.f_out is None:
            return X - F.elu(L @ kron_prod @ X @ self.weight2), L
        else:
            return F.elu(L @ kron_prod @ X @ self.weight2), L


    def sheaf_laplacian(self, X, adj, epsilon=1e-6):
        X_reshaped = X.reshape(self.n_nodes, self.d, -1)
        idx_pairs = torch.cartesian_prod(torch.arange(self.n_nodes), torch.arange(self.n_nodes))
        all_stacked_features = X_reshaped[idx_pairs].reshape(self.n_nodes, self.n_nodes, 2*self.d, -1).to(DEVICE)
        lin_trans = F.elu(torch.matmul(self.edge_weights, all_stacked_features))
        inner_transpose = torch.transpose(lin_trans, -1, -2)
        L_v = -1 * torch.matmul(lin_trans, torch.transpose(inner_transpose, 0, 1))
        row_cond = torch.isclose(torch.sum(adj, dim=1), torch.zeros_like(torch.sum(adj, dim=1)))
        col_cond = torch.isclose(torch.sum(adj, dim=0), torch.zeros_like(torch.sum(adj, dim=0)))
        adj_row_weights = adj / (torch.sum(adj, dim=1)[:, None] + epsilon)
        adj_col_weights = adj / (torch.sum(adj, dim=0)[:, None] + epsilon)
        # adj_col_weights = torch.where(col_cond[None, :], 0., adj / torch.sum(adj, dim=0)[None, :])
        adj_weights = torch.maximum(adj_row_weights * adj_col_weights, torch.zeros_like(adj_row_weights))

        adj_diag_weights = adj_row_weights ** 2
        diag_blocks = torch.sum(adj_diag_weights[:, :, None, None] * torch.matmul(lin_trans, inner_transpose), dim=1)
        L_v[range(self.n_nodes), range(self.n_nodes)] = diag_blocks
        return L_v.view(-1, self.n_nodes * self.d)
        ### NOTE IGNORE MATRIX NORMALISATION FOR NOW #####
        # inv_root_diag_blocks = torch.pow(diag_blocks+epsilon, -1/2)
        # normalise_mat = torch.block_diag(*inv_root_diag_blocks)

        # return normalise_mat @ L_v.view(-1, self.n_nodes * self.d) @ normalise_mat
        ################################################

        # laplacian_ls = []
        # for v in range(self.n_nodes):
        #     L_v = torch.zeros((self.d, self.d)).to(DEVICE)
        #     if torch.sum(adj[v]) > 0:
        #         for u in range(self.n_nodes):
        #             edge_weight = self.edge_weights[v, u]
        #             stacked_features = torch.concat((X[v*self.d:(v+1)*self.d], X[u*self.d:(u+1)*self.d]))
        #             lin_trans = F.elu(edge_weight @ stacked_features).to(DEVICE)
        #             L_v += adj[v, u] * lin_trans @ lin_trans.T
        #         laplacian_ls.append(L_v / torch.sum(adj[v]))
        #     else:
        #         laplacian_ls.append(L_v)

        # return torch.block_diag(*laplacian_ls)


In [31]:
class SheafAligner(nn.Module):
    
    def __init__(self, d, f):
        super().__init__()

        self.d = d
        self.f = f

        self.sheafconv1 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm1 = BatchNorm(f)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm2 = BatchNorm(f)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm3 = BatchNorm(f)

    def forward(self, X, adj):

        x1, L1 = self.sheafconv1(X, adj)
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, training=self.training)

        mean_x1 = x1.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj1 = torch.matmul(mean_x1[:,None, None, :], L1.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj1 = torch.matmul(adj1, mean_x1[None, :, :, None])
        adj1 = F.sigmoid(adj1.squeeze())

        x2, L2 = self.sheafconv2(x1, adj1)
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, training=self.training)
        
        mean_x2 = x2.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj2 = torch.matmul(mean_x2[:,None, None, :], L2.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj2 = torch.matmul(adj2, mean_x2[None, :, :, None])
        adj2 = F.sigmoid(adj2.squeeze())

        x3, L3 = self.sheafconv3(x2, adj2)
        x3 = F.sigmoid(self.batchnorm3(x3))

        mean_x3 = x3.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj3 = torch.matmul(mean_x3[:,None, None, :], L3.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj3 = torch.matmul(adj3, mean_x3[None, :, :, None])
        adj3 = F.sigmoid(adj3.squeeze())

        # x3 = torch.cat([x3, x1], dim=1)

        # x4 = x3[:, 0:16]
        # x5 = x3[:, 16:2*16]


        return x3, adj3

        

In [95]:
class SheafGenerator(nn.Module):
    def __init__(self, d, f):
        super().__init__()

        self.d = d
        self.f = f
        
        self.sheafconv1 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm1 = BatchNorm(f)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm2 = BatchNorm(f)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, d, f, N_HR_NODES)
        self.batchnorm3 = BatchNorm(N_HR_NODES)

        self.out_mat = nn.Parameter(torch.randn((N_LR_NODES, 2*N_LR_NODES), device=DEVICE))

        self.out_sigmoid = nn.Sigmoid()
        
        


    def forward(self, X, adj):
        x1, L1 = self.sheafconv1(X, adj) # returns (d*lr_n) * f
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, p=0.1, training=self.training)

        mean_x1 = x1.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj1 = torch.matmul(mean_x1[:,None, None, :], L1.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj1 = torch.matmul(adj1, mean_x1[None, :, :, None])
        adj1 = F.sigmoid(adj1.squeeze())

        x2, L2 = self.sheafconv2(x1, adj1) # returns (d*lr_n) * f
        x2 = self.batchnorm2(x2)
        x2 = F.sigmoid(x2)
        x2 = F.dropout(x2, p=0.1, training=self.training)

        mean_x2 = x2.reshape(N_LR_NODES, self.d, self.f).mean(dim=-1)
        adj2 = torch.matmul(mean_x2[:,None, None, :], L2.reshape(N_LR_NODES, N_LR_NODES, self.d, self.d))
        adj2 = torch.matmul(adj2, mean_x2[None, :, :, None])
        adj2 = F.sigmoid(adj2.squeeze())

        x3, L3 = self.sheafconv3(x2, adj2) # returns (d*lr_n) * hr_n
        x3 = F.sigmoid(self.batchnorm3(x3))

        # mean_x3 = x3.reshape(N_LR_NODES, self.d, N_HR_NODES).mean(dim=0).T
        # adj3 = torch.matmul(mean_x3[:,None, None, :], L3.reshape(N_HR_NODES, , self.d, self.d))
        # adj3 = torch.matmul(adj3, mean_x3[None, :, :, None])
        # adj3 = F.sigmoid(adj3.squeeze())

        x3 = torch.matmul(self.out_mat, x3)
        adj3 = torch.sigmoid(torch.t(x3) @ adj2 @ x3)

        return (adj3 + torch.t(adj3)) / 2 # to ensure the matrix is symmetric
 

In [96]:
class SheafDiscriminator(nn.Module):
    def __init__(self, d, f):
        super().__init__()

        self.d = d
        self.f = f

        self.sheafconv1 = SheafConvLayer(N_HR_NODES, d, f)
        self.sheafconv2 = SheafConvLayer(N_HR_NODES, d, f, 1)
        self.out = torch.nn.Linear(2*N_HR_NODES, 1)

    def forward(self, X, adj):
        x1, L1 = self.sheafconv1(X, adj)
        x1 = F.sigmoid(x1)
        x1 = F.dropout(x1, p=0.1, training=self.training)

        mean_x1 = x1.reshape(N_HR_NODES, self.d, self.f).mean(dim=-1)
        adj1 = torch.matmul(mean_x1[:,None, None, :], L1.reshape(N_HR_NODES, N_HR_NODES, self.d, self.d))
        adj1 = torch.matmul(adj1, mean_x1[None, :, :, None])
        adj1 = F.sigmoid(adj1.squeeze())
        x2 = self.batchnorm2(x2)


        x2, L2 = self.sheafconv2(x1, adj)
        x2 = F.sigmoid(x2).flatten()
        x3 = F.sigmoid(self.out(x2))
        return x3

## Helper Functions

In [34]:
def pearson_coor(input, target, epsilon=1e-7):
    vx = input - torch.mean(input, dim=(1, 2))[:, None, None]
    vy = target - torch.mean(target, dim=(1, 2))[:, None, None]
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)+epsilon) * torch.sqrt(torch.sum(vy ** 2)+epsilon)+epsilon)
    return cost

def GT_loss(target, predicted):

    # l1_loss
    l1_loss = torch.nn.L1Loss()
    loss_pix2pix = l1_loss(target, predicted)

    # topological_loss
    target_n = target.detach().cpu().clone().numpy()
    predicted_n = predicted.detach().cpu().clone().numpy()
    torch.cuda.empty_cache()

    topo_loss = []
    

    for i in range(len(target_n)):

        cur_target = target_n[i]
        cur_predicted = predicted_n[i]

        target_t = eigen_centrality(cur_target)
        real_topology = torch.tensor(target_t[0])
        predicted_t = eigen_centrality(cur_predicted)
        fake_topology = torch.tensor(predicted_t[0])
        topo_loss.append(l1_loss(real_topology, fake_topology))

    topo_loss = torch.sum(torch.stack(topo_loss))

    pc_loss = pearson_coor(target, predicted).to(DEVICE)
    torch.cuda.empty_cache()

    G_loss = loss_pix2pix + (1 - pc_loss) + topo_loss

    return G_loss
        x2 = self.batchnorm2(x2)


In [59]:
import numpy as np
import networkx as nx


# put it back into a 2D symmetric array


def topological_measures(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology = []



    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph from similarity matrix
    G = nx.from_numpy_matrix(np.absolute(data))
    U = G.to_undirected()

    # Centrality #

    # compute closeness centrality and transform the output to vector
    cc = nx.closeness_centrality(U, distance="weight")
    closeness_centrality = np.array([cc[g] for g in U])
    # compute betweeness centrality and transform the output to vector
    # bc = nx.betweenness_centrality(U, weight='weight')
    # bc = (nx.betweenness_centrality(U))
    betweenness_centrality = np.array([cc[g] for g in U])
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    eigenvector_centrality = np.array([ec[g] for g in U])


    topology.append(closeness_centrality)  # 0
    topology.append(betweenness_centrality)  # 1
    topology.append(eigenvector_centrality)  # 2

    return topology
# put it back into a 2D symmetric array

def eigen_centrality(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology_eigen = []

    G = nx.from_numpy_array(np.absolute(data))
    U = G.to_undirected()

    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph frL2
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    
    eigenvector_centrality = np.array([ec[g] for g in U])



    topology_eigen.append(eigenvector_centrality)  # 2

    return topology_eigen


In [36]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

## Testing layers

In [30]:
X = torch.randn((2*N_LR_NODES, 16))
adj = lr_train[0]

In [31]:
aligner = SheafAligner(2).to('cuda')
aligned = aligner(X.to(DEVICE), adj.to(DEVICE)) # should return (n_lr * d) * f matrix

TypeError: SheafAligner.__init__() missing 1 required positional argument: 'f'

In [32]:
aligned.shape

NameError: name 'aligned' is not defined

In [None]:
generator = SheafGenerator(2).to('cuda')
generated = generator(aligned, adj.to(DEVICE))

In [None]:
generated.shape

torch.Size([268, 268])

In [None]:
Y = torch.randn((N_HR_NODES*2, 16))
discriminator = SheafDiscriminator(2).to('cuda')
dis_decision = discriminator(Y.to(DEVICE), generated.to(DEVICE))

# Attempt Model Training

## Load in Data

In [10]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

In [11]:
lr_X_dim1 = torch.load('model_autoencoder/encode_lr_1.pt')
lr_X_dim2 = torch.load('model_autoencoder/encode_lr_2.pt')
hr_X_dim1 = torch.load('model_autoencoder/encode_hr_1.pt')
hr_X_dim2 = torch.load('model_autoencoder/encode_hr_2.pt')

In [12]:
lr_X_dim1 = lr_X_dim1.detach()
lr_X_dim2 = lr_X_dim2.detach()
hr_X_dim1 = hr_X_dim1.detach()
hr_X_dim2 = hr_X_dim2.detach()

In [13]:
lr_X_all = torch.empty((167, 320, 32))
for i in range(len(lr_X_dim1)):
    a, b = lr_X_dim1[i], lr_X_dim2[i]
    lr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

hr_X_all = torch.empty((167, 536, 32))
for i in range(len(hr_X_dim1)):
    a, b = hr_X_dim1[i], hr_X_dim2[i]
    hr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

## Begin Training

In [92]:
d = 2 # number of dimensions in each node
f = 32 # length of node encoding

aligner = SheafAligner(d, f).to(DEVICE)
generator = SheafGenerator(d, f).to(DEVICE)
discriminator = SheafDiscriminator(d, f).to(DEVICE)

In [88]:
sum(p.numel() for p in aligner.parameters()) + sum(p.numel() for p in generator.parameters()) + sum(p.numel() for p in discriminator.parameters()) 

2445361

In [93]:
aligner_optimizer = torch.optim.AdamW(aligner.parameters(), lr=0.025, betas=(0.5, 0.999))
generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=0.025, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=0.025, betas=(0.5, 0.999))

adversarial_loss = torch.nn.BCELoss()

In [101]:
aligner.train()
generator.train()
discriminator.train()
    
with torch.autograd.set_detect_anomaly(True):
    for i, sample in enumerate(zip(lr_X_all, lr_train, hr_X_all, hr_train)):
        X_lr, adj_lr, X_hr, adj_hr = sample
        aligner_optimizer.zero_grad()
        generator_optimizer.zero_grad()
        discriminator_optimizer.zero_grad()

        aligned_X_lr, aligned_adj_lr = aligner(X_lr.to(DEVICE), adj_lr.to(DEVICE))
        torch.cuda.empty_cache()

        hr_mean = torch.mean(X_hr)
        hr_std = torch.std(X_hr)

        adj_hr_sampled = torch.normal(hr_mean, hr_std, size=(N_LR_NODES, N_LR_NODES)).to(DEVICE)
        # hr_X_sampled = torch.Tensor(MatrixVectorizer().anti_vectorize(hr_X_sampled, N_HR_NODES))


        alignment_loss = torch.abs(F.kl_div(F.softmax(adj_hr_sampled, dim=-1), F.softmax(aligned_adj_lr, dim=-1), None, None, 'sum'))



        # generate hr adjacency
        generated_adj_hr = generator(aligned_X_lr.to(DEVICE), adj_lr.to(DEVICE))
        torch.cuda.empty_cache()


        # calculate losses for generator

        #### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
        temp_adj_hr = adj_hr.reshape(1, *adj_hr.shape)
        temp_generated_adj_hr = generated_adj_hr.reshape(1, *generated_adj_hr.shape)
        ##########################################################

        g_topology_loss = GT_loss(temp_adj_hr.to(DEVICE), temp_generated_adj_hr.to(DEVICE))
        torch.cuda.empty_cache()

        d_real = discriminator(X_hr.to(DEVICE), adj_hr.to(DEVICE))
        torch.cuda.empty_cache()

        d_fake = discriminator(X_hr.to(DEVICE), generated_adj_hr.to(DEVICE))
        torch.cuda.empty_cache()
        
        g_adversarial_loss = adversarial_loss(d_fake, (torch.ones_like(d_fake, requires_grad=False)))
        g_loss = g_adversarial_loss + g_topology_loss



        d_real_loss = adversarial_loss(d_real, torch.ones_like(d_real, requires_grad=False))
        torch.cuda.empty_cache()
        d_fake_loss = adversarial_loss(d_fake, torch.zeros_like(d_fake, requires_grad=False))
        d_loss = (d_real_loss + d_fake_loss) / 2
        torch.cuda.empty_cache()



        # g_loss.backward(retain_graph=True)
        g_loss.backward()
        generator_optimizer.step()

        alignment_loss.backward(retain_graph=True)
        # alignment_loss.backward()
        aligner_optimizer.step()

        # d_loss.backward(retain_graph=True)
        d_loss.backward()
        discriminator_optimizer.step()

        print(f'sample {i}: align_loss = {alignment_loss}, generator_loss = {g_loss}, discriminator = {d_loss}')




  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/vol/bitbucket/se223/DGL24-Group-Project/venv/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/vol/bitbucket/se223/DGL24-Group-Project/venv/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/vol/bitbucket/se223/DGL24-Group-Project/venv/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/vol/bitbucket/se223/DGL24-Group-Project/venv/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/lib/python3.10/asyncio/base_events.py",

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.