# Set up Environment

In [59]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import NNConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable

from torch.distributions import normal, kl


import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE, InnerProductDecoder, ARGVA
from torch_geometric.utils import train_test_split_edges
from sklearn.model_selection import KFold
import pandas as pd
from MatrixVectorizer import MatrixVectorizer


In [2]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

EPOCHS = 10

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)


In [21]:
lr_X_dim1 = torch.load('model_autoencoder/encode_lr_1.pt')
lr_X_dim2 = torch.load('model_autoencoder/encode_lr_2.pt')
hr_X_dim1 = torch.load('model_autoencoder/encode_hr_1.pt')
hr_X_dim2 = torch.load('model_autoencoder/encode_hr_2.pt')

In [22]:
lr_X_dim1 = lr_X_dim1.detach()
lr_X_dim2 = lr_X_dim2.detach()
hr_X_dim1 = hr_X_dim1.detach()
hr_X_dim2 = hr_X_dim2.detach()

In [39]:
lr_X = torch.empty((167, 320, 32))
for i in range(len(lr_X_dim1)):
    a, b = lr_X_dim1[i], lr_X_dim2[i]
    lr_X[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

hr_X = torch.empty((167, 536, 32))
for i in range(len(hr_X_dim1)):
    a, b = hr_X_dim1[i], hr_X_dim2[i]
    hr_X[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

# Define Model Layers

## Model Layers

In [42]:
class SheafConvLayer(nn.Module):
    def __init__(self, n_nodes, d, f_in, f_out=None):
        super().__init__()
        self.d = d
        self.n_nodes = n_nodes
        self.f_out = f_out
        # random init weight matrices
        if f_out is None:
            f_out = f_in 
        self.weight1 = nn.Parameter(torch.randn((d, d))).to(DEVICE)
        self.weight2 = nn.Parameter(torch.randn((f_in, f_out))).to(DEVICE)
        self.edge_weights = nn.Parameter(torch.randn((d*n_nodes,2*d*n_nodes))).to(DEVICE)


    def forward(self, X, adj):
        kron_prod = torch.kron(torch.eye(self.n_nodes).to(DEVICE), self.weight1)
        L = self.sheaf_laplacian(X, adj)
        if self.f_out is None:
            return X - F.relu(L @ kron_prod @ X @ self.weight2) 
        else:
            return F.relu(L @ kron_prod @ X @ self.weight2) 


    def sheaf_laplacian(self, X, adj):
        laplacian_ls = []
        for v in range(self.n_nodes):
            L_v = torch.zeros((self.d, self.d)).to(DEVICE)
            for u in range(self.n_nodes):
                edge_weight = self.edge_weights[v*self.d:(v+1)*self.d, u*2*self.d:(u+1)*2*self.d]
                stacked_features = torch.concat((X[v*self.d:(v+1)*self.d], X[u*self.d:(u+1)*self.d]))
                lin_trans = F.relu(edge_weight @ stacked_features).to(DEVICE)
                L_v += adj[v, u] * lin_trans @ lin_trans.T
            laplacian_ls.append(L_v / torch.sum(adj[v]))
        return torch.block_diag(*laplacian_ls)


In [43]:
class SheafAligner(nn.Module):
    
    def __init__(self, d, f):
        super().__init__()

        self.sheafconv1 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm1 = BatchNorm(f)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm2 = BatchNorm(f)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm3 = BatchNorm(f)

    def forward(self, X, adj):

        x1 = self.sheafconv1(X, adj)
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, training=self.training)

        x2 = self.sheafconv2(x1, adj)
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, training=self.training)

        x3 = self.sheafconv3(x2, adj)
        x3 = F.sigmoid(self.batchnorm3(x3))
        # x3 = torch.cat([x3, x1], dim=1)

        # x4 = x3[:, 0:16]
        # x5 = x3[:, 16:2*16]


        return x3

        

In [44]:
class SheafGenerator(nn.Module):
    def __init__(self, d, f):
        super().__init__()
        
        self.sheafconv1 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm1 = BatchNorm(f)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, d, f)
        self.batchnorm2 = BatchNorm(f)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, d, f, N_HR_NODES)
        self.batchnorm3 = BatchNorm(N_HR_NODES)

        self.out_mat = nn.Parameter(torch.randn((N_LR_NODES, 2*N_LR_NODES))).to(DEVICE)
        
        


    def forward(self, X, adj):
        x1 = self.sheafconv1(X, adj) # returns (d*lr_n) * 16
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, p=0.1, training=self.training)

        x2 = self.sheafconv2(x1, adj) # returns (d*lr_n) * 16
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, p=0.1, training=self.training)

        x3 = self.sheafconv3(x2, adj) # returns (d*lr_n) * hr_n
        x3 = F.sigmoid(self.batchnorm3(x3))
        x3 = F.sigmoid(x3.T @  self.out_mat.T @ adj @ self.out_mat @ x3)

        return (x3 + x3.T) / 2 # to ensure the matrix is symmetric
 

In [48]:
class SheafDiscriminator(nn.Module):
    def __init__(self, d, f):
        super().__init__()
        self.sheafconv1 = SheafConvLayer(N_HR_NODES, d, f)

        self.sheafconv2 = SheafConvLayer(N_HR_NODES, d, f, 1)
        self.out = torch.nn.Linear(2*N_HR_NODES, 1)

    def forward(self, X, adj):
        x1 = F.sigmoid(self.sheafconv1(X, adj))
        x1 = F.dropout(x1, p=0.1, training=self.training)
        x2 = F.sigmoid(self.sheafconv2(X, adj))
        x3 = F.sigmoid(self.out(x2.flatten()))
        return x3

## Helper Functions

In [89]:
def pearson_coor(input, target):
    vx = input - torch.mean(input, dim=(1, 2))[:, None, None]
    vy = target - torch.mean(target, dim=(1, 2))[:, None, None]
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
    return cost

def GT_loss(target, predicted):

    # l1_loss
    l1_loss = torch.nn.L1Loss()
    loss_pix2pix = l1_loss(target, predicted)

    # topological_loss
    target_n = target.detach().cpu().clone().numpy()
    predicted_n = predicted.detach().cpu().clone().numpy()
    torch.cuda.empty_cache()

    topo_loss = []
    

    for i in range(len(target_n)):

        cur_target = target_n[i]
        cur_predicted = predicted_n[i]

        target_t = eigen_centrality(cur_target)
        real_topology = torch.tensor(target_t)
        predicted_t = eigen_centrality(cur_predicted)
        fake_topology = torch.tensor(predicted_t)
        topo_loss.append(l1_loss(real_topology, fake_topology))

    topo_loss = torch.sum(torch.stack(topo_loss))

    pc_loss = pearson_coor(target, predicted).to(DEVICE)
    torch.cuda.empty_cache()

    G_loss = loss_pix2pix + (1 - pc_loss) + topo_loss

    return G_loss


In [87]:
import numpy as np
import networkx as nx


# put it back into a 2D symmetric array


def topological_measures(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology = []



    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph from similarity matrix
    G = nx.from_numpy_matrix(np.absolute(data))
    U = G.to_undirected()

    # Centrality #

    # compute closeness centrality and transform the output to vector
    cc = nx.closeness_centrality(U, distance="weight")
    closeness_centrality = np.array([cc[g] for g in U])
    # compute betweeness centrality and transform the output to vector
    # bc = nx.betweenness_centrality(U, weight='weight')
    # bc = (nx.betweenness_centrality(U))
    betweenness_centrality = np.array([cc[g] for g in U])
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    eigenvector_centrality = np.array([ec[g] for g in U])


    topology.append(closeness_centrality)  # 0
    topology.append(betweenness_centrality)  # 1
    topology.append(eigenvector_centrality)  # 2

    return topology
# put it back into a 2D symmetric array

def eigen_centrality(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology_eigen = []



    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph from similarity matrix
    G = nx.from_numpy_array(np.absolute(data))
    U = G.to_undirected()

    # Centrality #


    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    eigenvector_centrality = np.array([ec[g] for g in U])



    topology_eigen.append(eigenvector_centrality)  # 2

    return topology_eigen


In [30]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

## Testing layers

In [9]:
X = torch.randn((2*N_LR_NODES, 16))
adj = lr_train[0]

In [10]:
aligner = SheafAligner(2).to('cuda')
aligned = aligner(X.to(DEVICE), adj.to(DEVICE)) # should return (n_lr * d) * f matrix

In [11]:
aligned.shape

torch.Size([320, 16])

In [12]:
generator = SheafGenerator(2).to('cuda')
generated = generator(aligned, adj.to(DEVICE))

In [13]:
generated.shape

torch.Size([268, 268])

In [None]:
Y = torch.randn((N_HR_NODES*2, 16))
discriminator = SheafDiscriminator(2).to('cuda')
dis_decision = discriminator(Y.to(DEVICE), generated.to(DEVICE))

# Attempt Model Training

## Load in Data

In [None]:
lr_X_dim1 = torch.load('model_autoencoder/encode_lr_1.pt')
lr_X_dim2 = torch.load('model_autoencoder/encode_lr_2.pt')
hr_X_dim1 = torch.load('model_autoencoder/encode_hr_1.pt')
hr_X_dim2 = torch.load('model_autoencoder/encode_hr_2.pt')

In [None]:
lr_X_dim1 = lr_X_dim1.detach()
lr_X_dim2 = lr_X_dim2.detach()
hr_X_dim1 = hr_X_dim1.detach()
hr_X_dim2 = hr_X_dim2.detach()

In [40]:
lr_X_all = torch.empty((167, 320, 32))
for i in range(len(lr_X_dim1)):
    a, b = lr_X_dim1[i], lr_X_dim2[i]
    lr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

hr_X_all = torch.empty((167, 536, 32))
for i in range(len(hr_X_dim1)):
    a, b = hr_X_dim1[i], hr_X_dim2[i]
    hr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

## Begin Training

In [70]:
d = 2 # number of dimensions in each node
f = 32 # length of node encoding
aligner = SheafAligner(d, f).to(DEVICE)
generator = SheafGenerator(d, f).to(DEVICE)
discriminator = SheafDiscriminator(d, f).to(DEVICE)

In [74]:
aligner_optimizer = torch.optim.AdamW(aligner.parameters(), lr=0.025, betas=(0.5, 0.999))
generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=0.025, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=0.025, betas=(0.5, 0.999))

adversarial_loss = torch.nn.BCELoss()

In [91]:
for i, sample in enumerate(zip(lr_X_all, lr_train, hr_X_all, hr_train)):
    X_lr, adj_lr, X_hr, adj_hr = sample
    aligner_optimizer.zero_grad()
    generator_optimizer.zero_grad()
    discriminator_optimizer.zero_grad()


    aligned_X_lr = aligner(X_lr.to(DEVICE), adj_lr.to(DEVICE))
    torch.cuda.empty_cache()

    hr_mean = torch.mean(X_hr)
    hr_std = torch.std(X_hr)

    hr_X_sampled = torch.normal(hr_mean, hr_std, size=(N_LR_NODES*d, f))
    # hr_X_sampled = torch.Tensor(MatrixVectorizer().anti_vectorize(hr_X_sampled, N_HR_NODES))


    alignment_loss = torch.abs(F.kl_div(F.softmax(hr_X_sampled), F.softmax(X_lr), None, None, 'sum'))



    # generate hr adjacency
    generated_adj_hr = generator(aligned_X_lr.to(DEVICE), adj_lr.to(DEVICE))
    torch.cuda.empty_cache()


    # calculate losses for generator

    #### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
    temp_adj_hr = adj_hr.reshape(1, *adj_hr.shape)
    temp_generated_adj_hr = generated_adj_hr.reshape(1, *generated_adj_hr.shape)
    ##########################################################

    g_topology_loss = GT_loss(temp_adj_hr.to(DEVICE), temp_generated_adj_hr.to(DEVICE))
    torch.cuda.empty_cache()

    d_real = discriminator(X_hr.to(DEVICE), adj_hr.to(DEVICE))
    torch.cuda.empty_cache()

    d_fake = discriminator(X_hr.to(DEVICE), generated_adj_hr.to(DEVICE))
    torch.cuda.empty_cache()
    
    g_adversarial_loss = adversarial_loss(d_fake, (torch.ones_like(d_fake, requires_grad=False)))
    g_loss = g_adversarial_loss + g_topology_loss



    d_real_loss = adversarial_loss(d_real, (torch.ones_like(d_real, requires_grad=False)))
    torch.cuda.empty_cache()
    d_fake_loss = adversarial_loss(d_fake.detach(), torch.zeros_like(d_fake))
    d_loss = (d_real_loss + d_fake_loss) / 2
    torch.cuda.empty_cache()



    g_loss.backward(retain_graph=True)
    generator_optimizer.step()

    alignment_loss.backward(retain_graph=True)
    aligner_optimizer.step()

    d_loss.backward(retain_graph=True)
    discriminator_optimizer.step()

    print(f'sample {i}: align_loss = {alignment_loss}, generator_loss = {g_loss}, discriminator = {d_loss}')




  alignment_loss = torch.abs(F.kl_div(F.softmax(hr_X_sampled), F.softmax(X_lr), None, None, 'sum'))
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [0,0,0], thread: [0,0,0] Assertion `input_val >= zero && input_val <= one` failed.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
