In [72]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GENConv, GATv2Conv
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch_geometric.utils import dense_to_sparse
import torch_geometric.utils

from torch.distributions import normal, kl

from tqdm import tqdm

from sklearn.model_selection import KFold
import pandas as pd
from MatrixVectorizer import MatrixVectorizer


In [2]:
# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)

In [3]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

lr_X_dim1 = torch.load('model_autoencoder/final_embeddings/encode_lr.pt')
lr_X_dim3 = torch.load('model_autoencoder/final_embeddings/encode_lr_3.pt')
hr_X_dim1 = torch.load('model_autoencoder/final_embeddings/encode_hr.pt')
hr_X_dim3 = torch.load('model_autoencoder/final_embeddings/encode_hr_3.pt')
lr_X_dim1_test = torch.load('model_autoencoder/final_embeddings/encode_lr_test.pt')
hr_X_dim3_test = torch.load('model_autoencoder/final_embeddings/encode_lr_test_3.pt')

# Model Layers

In [17]:
def generate_steps(num_steps, low=N_LR_NODES, high=N_HR_NODES):
    step_size = (high - low) / (num_steps - 1)
    steps_list = [round(low + step_size * i) for i in range(num_steps)]
    return steps_list

In [20]:
x = lr_X_dim1[0].to(DEVICE)
a = lr_train[0].to(DEVICE)
num_steps = 10


dim_steps= generate_steps(num_steps)
up_sampler = AdjacencyDimChanger(dim_steps=dim_steps, f=32).to(DEVICE)

In [61]:
x = torch.randn((5, 2, 2))
f = torch.randn(2)
f[:, None] * x

tensor([[[ 1.7406e-01, -3.2386e-01],
         [ 5.6132e-02, -3.5660e-02]],

        [[-3.9412e-01,  3.2988e-01],
         [ 8.4503e-04, -1.6906e-02]],

        [[ 9.5981e-01, -3.3740e-01],
         [ 3.7067e-03,  7.7308e-02]],

        [[-1.7738e-01,  8.5273e-01],
         [-6.2215e-02,  6.0100e-02]],

        [[ 8.6247e-01,  6.7147e-01],
         [ 9.8556e-03, -7.8895e-02]]])

In [21]:
up_sampler(x, x, a)
sum(p.numel() for p in up_sampler.parameters())

400128

In [102]:
from torch import nn
import torch
import torch.nn.functional as F


class AdjacencyStep(nn.Module):
    def __init__(self, old_dim, new_dim, channels, dt=1., alpha=1., gamma=1., dropout=0.2):
        super().__init__()
        self.dt = dt
        self.alpha = alpha
        self.gamma = gamma
        # self.gnn = GENConv(channels, channels, aggr='powermean', learn_p=True)
        self.gnn = GATv2Conv(channels, channels, edge_dim=1).to(DEVICE)
        self.dropout = dropout
        self.dim_changer1 = nn.Parameter(torch.randn((new_dim, 1), device=DEVICE))
        self.dim_changer2 = nn.Parameter(torch.randn((1, old_dim), device=DEVICE))
        self.A_dim_changer1 = nn.Parameter(torch.randn((new_dim, 1), device=DEVICE))   
        self.A_dim_changer2 = nn.Parameter(torch.randn((1, old_dim), device=DEVICE))
        self.Z_dim_changer1 = nn.Parameter(torch.randn((channels, 1), device=DEVICE))   
        self.Z_dim_changer2 = nn.Parameter(torch.randn((1, new_dim), device=DEVICE))

        self.forget_gate = nn.Parameter(torch.randn(new_dim, device=DEVICE))
        self.input_gate = nn.Parameter(torch.randn(new_dim, device=DEVICE))


    def forward(self, X, Y, A):
        # solve ODEs using simple IMEX scheme
        dim_changer = self.dim_changer1 @ self.dim_changer2 
        A_dim_changer = self.A_dim_changer1 @ self.A_dim_changer2
        Z_dim_changer = self.Z_dim_changer1 @ self.Z_dim_changer2

        # forget gate from previous adjacency
        f = F.sigmoid(self.forget_gate)
        i = F.sigmoid(self.input_gate)
        forget_A = f[:, None] * F.relu(A_dim_changer @ A @ A_dim_changer.T)
    
        # update node features with gcn
        # edge_index, edge_weights = dense_to_sparse(A)
        data_list = []
        for x, adj in zip(X, A):
            edge_index = adj.nonzero().t()
            edge_weights = adj[edge_index[0], edge_index[1]]
            data = Data(x=x, edge_index=edge_index, edge_attr=edge_weights.view(-1, 1))
            data_list.append(data)
        graph_batch = torch_geometric.data.Batch().from_data_list(data_list)
        Z = self.gnn(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr).reshape(X.shape)
        input_Z = i[:, None] * F.relu(dim_changer @ Z @ Z_dim_changer)

        # get new adjacency matrix
        new_A = forget_A + input_Z
        new_A = (new_A + torch.transpose(new_A, -1, -2)) / 2
        new_A = F.tanh(F.relu(new_A))
        # new_A = new_A + torch.diag_embed(torch.diagonal(new_A, 0)).to(DEVICE) - torch.eye(new_A.shape[0]).to(DEVICE)  # remove self connections


        # update feature embeiddings
        Y_temp = Y
        Y = dim_changer @ (Y + self.dt * (Z - self.alpha * Y - self.gamma * X))
        X = dim_changer @ (X + self.dt * Y_temp)        

        if (self.dropout is not None):
            Y = F.dropout(Y, self.dropout, training=self.training)
            X = F.dropout(X, self.dropout, training=self.training)

        return X, Y, new_A

    

In [103]:
class AdjacencyDimChanger(nn.Module):

    def __init__(self, dim_steps, f):
        super().__init__()
        
        self.layers = nn.ModuleList([AdjacencyStep(dim_steps[i], dim_steps[i+1], f) for i in range(len(dim_steps)-1)])
        
    def forward(self, X, Y, A):
        adj_ls = [A]
        x, y, adj = X, Y, A
        
        for layer in self.layers:
            x, y, adj = layer(x, y, adj)
            adj_ls.append(adj)


        return adj_ls
        

In [104]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

In [110]:
import numpy as np
import networkx as nx

def eigen_centrality(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology_eigen = []

    G = nx.from_numpy_array(np.absolute(data))
    U = G.to_undirected()

    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph frL2
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    
    eigenvector_centrality = np.array([ec[g] for g in U])



    topology_eigen.append(eigenvector_centrality)  # 2

    return topology_eigen

def pearson_coor(input, target, epsilon=1e-7):
    vx = input - torch.mean(input, dim=(1, 2))[:, None, None]
    vy = target - torch.mean(target, dim=(1, 2))[:, None, None]
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)+epsilon) * torch.sqrt(torch.sum(vy ** 2)+epsilon)+epsilon)
    return cost

def GT_loss(target, predicted):

    # l1_loss
    l1_loss = torch.nn.L1Loss()
    # loss_pix2pix = l1_loss(target, predicted)

    # topological_loss
    target_n = target.detach().cpu().clone().numpy()
    predicted_n = predicted.detach().cpu().clone().numpy()
    torch.cuda.empty_cache()

    topo_loss = []
    

    for i in range(len(target_n)):

        cur_target = target_n[i]
        cur_predicted = predicted_n[i]

        target_t = eigen_centrality(cur_target)
        real_topology = torch.tensor(target_t[0])
        predicted_t = eigen_centrality(cur_predicted)
        fake_topology = torch.tensor(predicted_t[0])
        topo_loss.append(l1_loss(real_topology, fake_topology))

    topo_loss = torch.sum(torch.stack(topo_loss))

    pc_loss = pearson_coor(target, predicted).to(DEVICE)
    torch.cuda.empty_cache()

    # G_loss = loss_pix2pix + (1 - pc_loss) + topo_loss
    G_loss = (1 - pc_loss) + topo_loss


    return G_loss

In [116]:
def loss_calc(adj_ls, opp_adj_ls):
    total_loss = torch.Tensor([0]).to(DEVICE)
    mse_loss_fn = torch.nn.MSELoss()
    n = len(adj_ls)

    for i, (adj, opp_adj) in enumerate(zip(adj_ls, opp_adj_ls[::-1])):

        ### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
        # temp_adj = adj.reshape(1, *adj.shape)
        # temp_opp_adj = opp_adj.reshape(1, *opp_adj.shape)
        ##########################################################
        gt_loss = GT_loss(adj, opp_adj)
        
        mse_loss = mse_loss_fn(adj, opp_adj)
        total_loss = total_loss + (gt_loss + mse_loss) * n / (i + 1)

    # gt_loss = torch.Tensor([0]).to(DEVICE)
    # for i, (adj, opp_adj) in enumerate(zip(adj_ls, opp_adj_ls[::-1])):

    #     ### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
    #     temp_adj = adj.reshape(1, *adj.shape)
    #     temp_opp_adj = opp_adj.reshape(1, *opp_adj.shape)
    #     ##########################################################
    #     gt_loss = gt_loss + GT_loss(temp_adj, temp_opp_adj)

    # gt_loss = gt_loss / n
        
    return total_loss


def end_adj_loss_calc(adj, opp_adj):
    mae_loss_fn = torch.nn.L1Loss()
    mae_loss = mae_loss_fn(adj, opp_adj).detach()
    temp_adj = adj.reshape(1, *adj.shape)
    temp_opp_adj = opp_adj.reshape(1, *opp_adj.shape)
    gt_loss = GT_loss(temp_adj, temp_opp_adj)
    return mae_loss.detach().item(), gt_loss.detach().item()

# Training

In [122]:
trainloader = DataLoader(list(zip(lr_X_dim1, lr_X_dim3, lr_train, hr_X_dim1, hr_X_dim3, hr_train)), shuffle=True, batch_size=32)

dim_steps = generate_steps(num_steps=10)

up_changer = AdjacencyDimChanger(dim_steps, f=32).to(DEVICE)
down_changer = AdjacencyDimChanger(dim_steps[::-1],f=32).to(DEVICE)

up_optimizer = torch.optim.AdamW(up_changer.parameters(), lr=0.001, betas=(0.5, 0.999))
down_optimizer = torch.optim.AdamW(down_changer.parameters(), lr=0.001, betas=(0.5, 0.999))

sum(p.numel() for model in [up_changer, down_changer] for p in model.parameters())


67284

In [123]:
def train(epochs, up_changer, down_changer, trainloader, up_optimizer, down_optimizer):

    up_changer.train()
    down_changer.train()
    with torch.autograd.set_detect_anomaly(True):
        for epoch in range(epochs):

            up_losses = []
            down_losses = []

            for X_lr, Y_lr, adj_lr, X_hr, Y_hr, adj_hr in tqdm(trainloader):

                freeze_model(up_changer)
                unfreeze_model(down_changer)
            
                down_optimizer.zero_grad()
                up_optimizer.zero_grad()

                down_batch_loss = []
                
                down_end_adj_mse_loss = []
                down_end_adj_gt_loss = []

                    

                up_adj_ls = up_changer(X_lr.to(DEVICE), Y_lr.to(DEVICE), adj_lr.to(DEVICE))
                torch.cuda.empty_cache()
                down_adj_ls = down_changer(X_hr.to(DEVICE), Y_hr.to(DEVICE), adj_hr.to(DEVICE))
                torch.cuda.empty_cache()

                down_batch_loss.append(loss_calc(down_adj_ls, up_adj_ls))
            
                # for printing loss only
                down_end_adj_mse_loss.append(end_adj_loss_calc(down_adj_ls[-1].detach(), up_adj_ls[0].detach())[0])
                down_end_adj_gt_loss.append(end_adj_loss_calc(down_adj_ls[-1].detach(), up_adj_ls[0].detach())[1])
                torch.cuda.empty_cache()

                # for i in range(len(X_lr)):
                

                #     up_adj_ls = up_changer(X_lr[i].to(DEVICE), Y_lr[i].to(DEVICE), adj_lr[i].to(DEVICE))
                #     torch.cuda.empty_cache()
                #     down_adj_ls = down_changer(X_hr[i].to(DEVICE), Y_hr[i].to(DEVICE), adj_hr[i].to(DEVICE))
                #     torch.cuda.empty_cache()

                #     down_batch_loss.append(loss_calc(down_adj_ls, up_adj_ls))
                
                #     # for printing loss only
                #     down_end_adj_mse_loss.append(end_adj_loss_calc(down_adj_ls[-1].detach(), up_adj_ls[0].detach())[0])
                #     down_end_adj_gt_loss.append(end_adj_loss_calc(down_adj_ls[-1].detach(), up_adj_ls[0].detach())[1])
                #     torch.cuda.empty_cache()

                
                print(f'Down end adj mse {np.mean(down_end_adj_mse_loss)}, gt loss {np.mean(down_end_adj_gt_loss)}')
                del down_end_adj_mse_loss
                del down_end_adj_gt_loss
                down_loss = torch.mean(torch.stack(down_batch_loss))
                down_loss.backward()
                down_optimizer.step()

                down_losses.append(down_loss.detach().item())
                del down_loss
                del down_batch_loss
                torch.cuda.empty_cache()

                unfreeze_model(up_changer)
                freeze_model(down_changer)
            
                down_optimizer.zero_grad()
                up_optimizer.zero_grad()

                up_batch_loss = []
                    
                up_end_adj_mse_loss = []
                up_end_adj_gt_loss = []


                for i in range(len(X_lr)):

                    up_adj_ls = up_changer(X_lr[i].to(DEVICE), Y_lr[i].to(DEVICE), adj_lr[i].to(DEVICE))
                    torch.cuda.empty_cache()
                    down_adj_ls = down_changer(X_hr[i].to(DEVICE), Y_hr[i].to(DEVICE), adj_hr[i].to(DEVICE))
                    torch.cuda.empty_cache()
                
                    up_batch_loss.append(loss_calc(up_adj_ls, down_adj_ls))
                    
                    # for printing loss only
                    up_end_adj_mse_loss.append(end_adj_loss_calc(up_adj_ls[-1].detach(), down_adj_ls[0].detach())[0])
                    up_end_adj_gt_loss.append(end_adj_loss_calc(up_adj_ls[-1].detach(), down_adj_ls[0].detach())[1])
                    torch.cuda.empty_cache()

                
                print(f'Up end adj mse {np.mean(up_end_adj_mse_loss)}, gt loss {np.mean(up_end_adj_gt_loss)}')
                del up_end_adj_mse_loss
                del up_end_adj_gt_loss

                up_loss = torch.mean(torch.stack(up_batch_loss))
                up_loss.backward()
                up_optimizer.step()

                up_losses.append(up_loss.detach().item())
                del up_loss
                del up_batch_loss
                torch.cuda.empty_cache()



            epoch_up_loss = np.mean(up_losses)
            epoch_down_loss = np.mean(down_losses)

            print(f'epoch {epoch}: down loss = {epoch_down_loss}, up loss = {epoch_up_loss}')

        return up_changer, down_changer


In [124]:
up_changer, down_changer = train(10, up_changer, down_changer, trainloader, up_optimizer, down_optimizer)

  0%|          | 0/6 [00:00<?, ?it/s]