In [8]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GENConv, GATv2Conv
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch_geometric.utils import dense_to_sparse

from torch.distributions import normal, kl

from tqdm import tqdm

from sklearn.model_selection import KFold
import pandas as pd
from MatrixVectorizer import MatrixVectorizer


In [4]:
# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)

# Model Layers

In [9]:
import torch
import torch_geometric
from torch_geometric.utils import to_undirected

# Example adjacency matrix tensor
adjacency_matrix = lr_train[0]

# Find non-zero indices (edges) and their corresponding weights
edge_index = adjacency_matrix.nonzero(as_tuple=False).t()
edge_weight = adjacency_matrix[edge_index[0], edge_index[1]]

# Convert to undirected graph if needed
edge_index = to_undirected(edge_index)

print("Edge Index:", edge_index)
print("Edge Weight:", edge_weight)

Edge Index: tensor([[  0,   0,   0,  ..., 159, 159, 159],
        [  1,   2,   3,  ..., 156, 157, 158]])
Edge Weight: tensor([0.3388, 0.2025, 0.6895,  ..., 0.4202, 0.2553, 0.1834])


In [15]:
x = lr_X_dim1[0]
a = lr_train[0]

m = GraphCON(x.shape[0], 180, 32)
m(x, x, a)

(tensor([[ -55.6210,   17.0611,   31.0198,  ...,   41.2022,    1.5587,
            36.1190],
         [  38.4514,  -11.7945,  -21.4443,  ...,  -28.4836,   -1.0776,
           -24.9695],
         [ -19.8355,    6.0843,   11.0623,  ...,   14.6935,    0.5559,
            12.8808],
         ...,
         [  38.1203,   -0.0000,  -21.2597,  ...,  -28.2383,   -1.0683,
           -24.7544],
         [-115.0808,   35.2997,    0.0000,  ...,   85.2483,    3.2250,
            74.7309],
         [ -10.7544,    3.2988,    5.9977,  ...,    7.9665,    0.3014,
             6.9836]], grad_fn=<MulBackward0>),
 tensor([[ 22.0319,  -4.9978,  -0.0000,  ...,  -0.0000,  -8.8853,  -0.0000],
         [ -0.0000,   3.4550,  13.7855,  ...,   0.0000,   6.1425,  10.1823],
         [  0.0000,  -1.7823,  -7.1114,  ...,  -4.9164,  -3.1687,  -0.0000],
         ...,
         [ -0.0000,   0.0000,   0.0000,  ...,   9.4485,   6.0896,  10.0946],
         [ 45.5844, -10.3406, -41.2586,  ..., -28.5239, -18.3839,  -0.0000],
   

In [13]:
from torch import nn
import torch
import torch.nn.functional as F


class GraphCON(nn.Module):
    def __init__(self, old_dim, new_dim, channels, dt=1., alpha=1., gamma=1., dropout=0.2):
        super(GraphCON, self).__init__()
        self.dt = dt
        self.alpha = alpha
        self.gamma = gamma
        # self.gnn = GENConv(channels, channels, aggr='powermean', learn_p=True)
        self.gnn = GATv2Conv(channels, channels, edge_dim=1)
        self.dropout = dropout
        self.dim_changer1 = nn.Parameter(torch.randn((new_dim, 1), device=DEVICE))
        self.dim_changer2 = nn.Parameter(torch.randn((1, old_dim), device=DEVICE))
        self.A_dim_changer1 = nn.Parameter(torch.randn((new_dim, 1), device=DEVICE))   
        self.A_dim_changer2 = nn.Parameter(torch.randn((1, old_dim), device=DEVICE))
        self.Z_dim_changer1 = nn.Parameter(torch.randn((channels, 1), device=DEVICE))   
        self.Z_dim_changer2 = nn.Parameter(torch.randn((1, new_dim), device=DEVICE))

        self.forget_gate = nn.Parameter(torch.randn(new_dim, device=DEVICE))
        self.input_gate = nn.Parameter(torch.randn(new_dim, device=DEVICE))


    def forward(self, X, Y, A):
        # solve ODEs using simple IMEX scheme
        dim_changer = self.dim_changer1 @ self.dim_changer2 
        A_dim_changer = self.A_dim_changer1 @ self.A_dim_changer2
        Z_dim_changer = self.Z_dim_changer1 @ self.Z_dim_changer2

        # forget gate from previous adjacency
        f = F.sigmoid(self.forget_gate)
        i = F.sigmoid(self.input_gate)
        forget_A = f[:, None] * F.relu(A_dim_changer @ A @ A_dim_changer.T)
    
        # update node features with gcn
        edge_index, edge_weights = dense_to_sparse(A)
        Z = self.gnn(X, edge_index, edge_weights)
        input_Z = i[:, None] * F.relu(dim_changer @ Z @ Z_dim_changer)

        # get new adjacency matrix
        new_A = forget_A + input_Z
        new_A = (new_A + new_A.T) / 2
        new_A = F.tanh(F.relu(new_A))
        # new_A = new_A + torch.diag_embed(torch.diagonal(new_A, 0)).to(DEVICE) - torch.eye(new_A.shape[0]).to(DEVICE)  # remove self connections


        # update feature vectorss
        Y_temp = Y
        Y = dim_changer @ (Y + self.dt * (Z - self.alpha * Y - self.gamma * X))
        X = dim_changer @ (X + self.dt * Y_temp)        

        if (self.dropout is not None):
            Y = F.dropout(Y, self.dropout, training=self.training)
            X = F.dropout(X, self.dropout, training=self.training)

        return X, Y, A

    

In [4]:
class AdjacencyDimChanger(nn.Module):
    def __init__(self, new_n, old_n, old_f, d):
        super().__init__()
        self.new_n = new_n
        self.old_n = old_n
        self.d = d
        self.sheafconv = SheafConvLayer(old_n, d, old_f, new_n)
        self.layernorm = nn.LayerNorm([d, old_n]).to(DEVICE)

    def forward(self, X, adj):

        adj = adj - torch.diag_embed(torch.diagonal(adj, 0)).to(DEVICE) + torch.eye(adj.shape[0]).to(DEVICE)  # add self connections
        x, L = self.sheafconv(X, adj)

        x = x.reshape(self.old_n, self.d, self.new_n)
        x = torch.transpose(x, 0, -1)
        x = self.layernorm(x)
        
        x_mean = x.mean(dim=-1)

        L_mean = L.reshape(self.old_n, self.old_n, self.d, self.d).max(dim=0)[0].mean(dim=0) # aggregate by eigenvalues of each n by n mat?
        adj_new = torch.matmul(x_mean, L_mean)
        adj_new = torch.matmul(adj_new, x_mean.T)
        adj_new_T = torch.t(adj_new)
        adj_new = F.tanh(F.relu(((adj_new + adj_new_T) / 2))) # becomes a new f by new f adj1

        return x.reshape(self.new_n*self.d, -1), adj_new

In [5]:
class AdjacencyChangerUp(nn.Module):

    def __init__(self, d, f_in):
        super().__init__()
        self.d = d

        self.adjdim_changer1 = AdjacencyDimChanger(200, N_LR_NODES, f_in, d)
        self.adjdim_changer2 = AdjacencyDimChanger(220, 200, N_LR_NODES, d)
        self.adjdim_changer3 = AdjacencyDimChanger(N_HR_NODES, 220, 200, d)

        
    def forward(self, X, adj):
        x1, adj1 = self.adjdim_changer1(X, adj)
        x2, adj2 = self.adjdim_changer2(x1, adj1)
        x3, adj3 = self.adjdim_changer3(x2, adj2)
        return [adj, adj1, adj2, adj3]
        

In [6]:
class AdjacencyChangerDown(nn.Module):

    def __init__(self, d, f_in):
        super().__init__()
        self.d = d

        self.adjdim_changer1 = AdjacencyDimChanger(220, N_HR_NODES, f_in, d).to(DEVICE)
        self.adjdim_changer2  = AdjacencyDimChanger(200, 220, N_HR_NODES, d).to(DEVICE)
        self.adjdim_changer3 = AdjacencyDimChanger(N_LR_NODES, 200, 220, d).to(DEVICE)

        
    def forward(self, X, adj):
        x1, adj1 = self.adjdim_changer1(X, adj)
        x2, adj2 = self.adjdim_changer2(x1, adj1)
        x3, adj3 = self.adjdim_changer3(x2, adj2)
        return [adj, adj1, adj2, adj3]

In [7]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

In [8]:
import numpy as np
import networkx as nx

def eigen_centrality(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology_eigen = []

    G = nx.from_numpy_array(np.absolute(data))
    U = G.to_undirected()

    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph frL2
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    
    eigenvector_centrality = np.array([ec[g] for g in U])



    topology_eigen.append(eigenvector_centrality)  # 2

    return topology_eigen

def pearson_coor(input, target, epsilon=1e-7):
    vx = input - torch.mean(input, dim=(1, 2))[:, None, None]
    vy = target - torch.mean(target, dim=(1, 2))[:, None, None]
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)+epsilon) * torch.sqrt(torch.sum(vy ** 2)+epsilon)+epsilon)
    return cost

def GT_loss(target, predicted):

    # l1_loss
    l1_loss = torch.nn.L1Loss()
    # loss_pix2pix = l1_loss(target, predicted)

    # topological_loss
    target_n = target.detach().cpu().clone().numpy()
    predicted_n = predicted.detach().cpu().clone().numpy()
    torch.cuda.empty_cache()

    topo_loss = []
    

    for i in range(len(target_n)):

        cur_target = target_n[i]
        cur_predicted = predicted_n[i]

        target_t = eigen_centrality(cur_target)
        real_topology = torch.tensor(target_t[0])
        predicted_t = eigen_centrality(cur_predicted)
        fake_topology = torch.tensor(predicted_t[0])
        topo_loss.append(l1_loss(real_topology, fake_topology))

    topo_loss = torch.sum(torch.stack(topo_loss))

    pc_loss = pearson_coor(target, predicted).to(DEVICE)
    torch.cuda.empty_cache()

    # G_loss = loss_pix2pix + (1 - pc_loss) + topo_loss
    G_loss = (1 - pc_loss) + topo_loss


    return G_loss

In [9]:
def loss_calc(adj_ls, opp_adj_ls):
    total_loss = torch.Tensor([0]).to(DEVICE)
    mse_loss_fn = torch.nn.MSELoss()
    for i, (adj, opp_adj) in enumerate(zip(adj_ls[::-1], opp_adj_ls)):

        ### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
        temp_adj = adj.reshape(1, *adj.shape)
        temp_opp_adj = opp_adj.reshape(1, *opp_adj.shape)
        ##########################################################
        gt_loss = GT_loss(temp_adj, temp_opp_adj) / len(adj_ls)
        mse_loss = torch.pow(mse_loss_fn(adj, opp_adj), 1/(i+1)) 
        total_loss = total_loss + mse_loss + gt_loss.to(DEVICE)
    return total_loss

# Training

In [1]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

In [6]:
lr_X_dim1 = torch.load('model_autoencoder/encode_lr_1.pt')
lr_X_dim2 = torch.load('model_autoencoder/encode_lr_2.pt')
hr_X_dim1 = torch.load('model_autoencoder/encode_hr_1.pt')
hr_X_dim2 = torch.load('model_autoencoder/encode_hr_2.pt')


lr_X_all = torch.empty((167, 320, 32))
for i in range(len(lr_X_dim1)):
    a, b = lr_X_dim1[i], lr_X_dim2[i]
    lr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

hr_X_all = torch.empty((167, 536, 32))
for i in range(len(hr_X_dim1)):
    a, b = hr_X_dim1[i], hr_X_dim2[i]
    hr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

RuntimeError: The expanded size of the tensor (320) must match the existing size (240) at non-singleton dimension 0.  Target sizes: [320, 32].  Tensor sizes: [240, 32]

In [12]:
trainloader = DataLoader(list(zip(lr_X_all, lr_train, hr_X_all, hr_train)), shuffle=True, batch_size=8)


up_changer = AdjacencyChangerUp(d=2,f_in=32).to(DEVICE)
down_changer = AdjacencyChangerDown(d=2,f_in=32).to(DEVICE)

up_optimizer = torch.optim.AdamW(up_changer.parameters(), lr=0.001, betas=(0.5, 0.999))
down_optimizer = torch.optim.AdamW(down_changer.parameters(), lr=0.001, betas=(0.5, 0.999))

sum(p.numel() for model in [up_changer, down_changer] for p in model.parameters())


2389928

In [14]:
def train(epochs, up_changer, down_changer, trainloader, up_optimizer, down_optimizer):

    up_changer.train()
    down_changer.train()
    with torch.autograd.set_detect_anomaly(True):
        for epoch in range(epochs):

            up_losses = []
            down_losses = []

            for X_lr, adj_lr, X_hr, adj_hr in tqdm(trainloader):

                freeze_model(up_changer)
                unfreeze_model(down_changer)
            
                down_optimizer.zero_grad()
                up_optimizer.zero_grad()

                down_batch_loss = []

                for i in range(len(X_lr)):

                    up_adj_ls = up_changer(X_lr[i].to(DEVICE), adj_lr[i].to(DEVICE))
                    torch.cuda.empty_cache()
                    down_adj_ls = down_changer(X_hr[i].to(DEVICE), adj_hr[i].to(DEVICE))
                    torch.cuda.empty_cache()

                    down_batch_loss.append(loss_calc(down_adj_ls, up_adj_ls))

                down_loss = torch.mean(torch.stack(down_batch_loss))
                down_loss.backward()
                down_optimizer.step()

                down_losses.append(down_loss.detach().item())
                del down_loss
                del down_batch_loss
                torch.cuda.empty_cache()

                unfreeze_model(up_changer)
                freeze_model(down_changer)
            
                down_optimizer.zero_grad()
                up_optimizer.zero_grad()

                up_batch_loss = []


                for i in range(len(X_lr)):

                    up_adj_ls = up_changer(X_lr[i].to(DEVICE), adj_lr[i].to(DEVICE))
                    torch.cuda.empty_cache()
                    down_adj_ls = down_changer(X_hr[i].to(DEVICE), adj_hr[i].to(DEVICE))
                    torch.cuda.empty_cache()
                
                    up_batch_loss.append(loss_calc(up_adj_ls, down_adj_ls))

                up_loss = torch.mean(torch.stack(up_batch_loss))
                up_loss.backward()
                up_optimizer.step()

                up_losses.append(up_loss.detach().item())
                del up_loss
                del up_batch_loss
                torch.cuda.empty_cache()



            epoch_up_loss = np.mean(up_losses)
            epoch_down_loss = np.mean(down_losses)

            print(f'epoch {epoch}: down loss = {epoch_down_loss}, up loss = {epoch_up_loss}')

        return up_changer, down_changer


In [15]:
up_changer, down_changer = train(20, up_changer, down_changer, trainloader, up_optimizer, down_optimizer)

 33%|███▎      | 7/21 [05:51<11:35, 49.71s/it]