In [1]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import BatchNorm, GCNConv
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torch.distributions import normal, kl

from tqdm import tqdm

from sklearn.model_selection import KFold
import pandas as pd
from MatrixVectorizer import MatrixVectorizer


In [2]:
# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)

# Model Layers

In [3]:
# class SheafConvLayer(nn.Module):
#     def __init__(self, n_nodes, d, f_in, f_out=None):
#         super().__init__()
#         self.d = d
#         self.n_nodes = n_nodes
#         self.f_in = f_in
#         self.f_out = f_out
#         # random init weight matrices
#         if f_out is None:
#             f_out = f_in 
#         self.weight1 = nn.Parameter(torch.randn((d, d), device=DEVICE))
#         self.weight2 = nn.Parameter(torch.randn((f_in, f_out), device=DEVICE))
#         self.edge_weights = nn.Parameter(torch.randn((n_nodes, n_nodes, d, 2*d), device=DEVICE))


#     def forward(self, X, adj):
#         kron_prod = torch.kron(torch.eye(self.n_nodes).to(DEVICE), self.weight1)
#         L = self.sheaf_laplacian(X, adj)
#         if self.f_out is None:
#             return X - F.elu(L @ kron_prod @ X @ self.weight2), L
#         else:
#             return F.elu(L @ kron_prod @ X @ self.weight2), L


#     def sheaf_laplacian(self, X, adj, epsilon=1e-6):
#         X_reshaped = X.reshape(self.n_nodes, self.d, -1)
#         idx_pairs = torch.cartesian_prod(torch.arange(self.n_nodes), torch.arange(self.n_nodes))
#         all_stacked_features = X_reshaped[idx_pairs].reshape(self.n_nodes, self.n_nodes, 2*self.d, -1).to(DEVICE)
#         lin_trans = F.elu(torch.matmul(self.edge_weights, all_stacked_features))
#         inner_transpose = torch.transpose(lin_trans, -1, -2)
#         L_v = -1 * torch.matmul(lin_trans, torch.transpose(inner_transpose, 0, 1))
#         # row_cond = torch.isclose(torch.sum(adj, dim=1), torch.zeros_like(torch.sum(adj, dim=1)))
#         # col_cond = torch.isclose(torch.sum(adj, dim=0), torch.zeros_like(torch.sum(adj, dim=0)))
#         adj_row_weights = adj / (torch.sum(adj, dim=1)[:, None] + epsilon)
#         adj_col_weights = adj / (torch.sum(adj, dim=0)[:, None] + epsilon)
#         # adj_col_weights = torch.where(col_cond[None, :], 0., adj / torch.sum(adj, dim=0)[None, :])
#         adj_weights = torch.maximum(adj_row_weights * adj_col_weights, torch.zeros_like(adj_row_weights))

#         adj_diag_weights = adj_row_weights ** 2
#         diag_blocks = torch.sum(adj_diag_weights[:, :, None, None] * torch.matmul(lin_trans, inner_transpose), dim=1)
#         L_v[range(self.n_nodes), range(self.n_nodes)] = diag_blocks
#         return L_v.reshape(-1, self.n_nodes * self.d)
#         ### NOTE IGNORE MATRIX NORMALISATION FOR NOW #####
#         # inv_root_diag_blocks = torch.pow(diag_blocks+epsilon, -1/2)
#         # normalise_mat = torch.block_diag(*inv_root_diag_blocks)

#         # return normalise_mat @ L_v.view(-1, self.n_nodes * self.d) @ normalise_mat
#         ################################################

In [4]:
class AdjacencyDimChanger(nn.Module):
    def __init__(self, new_n, old_n, old_f, d):
        super().__init__()
        self.new_n = new_n
        self.old_n = old_n
        self.d = d
        #self.sheafconv = SheafConvLayer(old_n, d, old_f, new_n)  # Using Sheaf conv
        # TODO: ADD ANOTHER LAYER TYPE FROM TORCH_GEOMETRIC INSTEAD OF SHEAFCONV
        self.gcnconv = GCNConv(old_f, d)  # GCN
        self.layernorm = nn.LayerNorm([d, old_n]).to(DEVICE)

    def forward(self, X, adj):

        adj = adj - torch.diag_embed(torch.diagonal(adj, 0)).to(DEVICE) + torch.eye(adj.shape[0]).to(DEVICE)  # add self connections
        #x, L = self.sheafconv(X, adj)  # Using Sheaf conv
        # TODO: USE THE OTHER LAYER TYPE
        print(f'X: {X.shape}, adj: {adj.shape}')
        x = self.gcnconv(X, adj)  # Using GCNConv
        print(f'x: {x.shape}')

        x = x.reshape(self.old_n, self.d, self.new_n)
        x = torch.transpose(x, 0, -1)
        x = self.layernorm(x)
        
        x_mean = x.mean(dim=-1)

        #L_mean = L.reshape(self.old_n, self.old_n, self.d, self.d).max(dim=0)[0].mean(dim=0) # aggregate by eigenvalues of each n by n mat? # Using Sheaf conv
        # TODO: for the reshaping above, might be better to hard-code it with some values instead?
        # Using GCNConv: calculate L_mean based on the adjacency matrix adj
        L_mean = adj.sum(dim=1, keepdim=True)  # Sum of each row
        L_mean = L_mean.repeat(1, self.d)  # Repeat for each feature dimension
        L_mean = L_mean.unsqueeze(2)  # Add dimension for element-wise multiplication
        L_mean = L_mean * torch.eye(self.old_n).to(DEVICE)  # Keep only diagonal elements
        L_mean = L_mean.transpose(0, 1)  # Transpose to match the dimensions
        print(f'L_mean: {L_mean.shape}, self.old_n: {self.old_n},  self.d:{ self.d}')
        
        
        adj_new = torch.matmul(x_mean, L_mean)
        adj_new = torch.matmul(adj_new, x_mean.T)
        adj_new_T = torch.t(adj_new)
        adj_new = F.tanh(F.relu(((adj_new + adj_new_T) / 2))) # becomes a new f by new f adj1


        return x.reshape(self.new_n*self.d, -1), adj_new

In [5]:
class AdjacencyChangerUp(nn.Module):

    def __init__(self, d, f_in):
        super().__init__()
        self.d = d

        self.adjdim_changer1 = AdjacencyDimChanger(200, N_LR_NODES, f_in, d) # from 160 (initial size) to 200
        self.adjdim_changer2 = AdjacencyDimChanger(220, 200, N_LR_NODES, d) # from 200 to 220
        self.adjdim_changer3 = AdjacencyDimChanger(N_HR_NODES, 220, 200, d) # from 220 to 268 (final size) 

        
    def forward(self, X, adj):
        x1, adj1 = self.adjdim_changer1(X, adj)
        x2, adj2 = self.adjdim_changer2(x1, adj1)
        x3, adj3 = self.adjdim_changer3(x2, adj2)
        return [adj, adj1, adj2, adj3]
        

In [6]:
class AdjacencyChangerDown(nn.Module):

    def __init__(self, d, f_in):
        super().__init__()
        self.d = d

        self.adjdim_changer1 = AdjacencyDimChanger(220, N_HR_NODES, f_in, d).to(DEVICE) # from 268 (final size) to 220
        self.adjdim_changer2  = AdjacencyDimChanger(200, 220, N_HR_NODES, d).to(DEVICE) # from 220 to 200
        self.adjdim_changer3 = AdjacencyDimChanger(N_LR_NODES, 200, 220, d).to(DEVICE) # from 200 to 168

        
    def forward(self, X, adj):
        x1, adj1 = self.adjdim_changer1(X, adj)
        x2, adj2 = self.adjdim_changer2(x1, adj1)
        x3, adj3 = self.adjdim_changer3(x2, adj2)
        return [adj, adj1, adj2, adj3]

In [7]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

In [8]:
import numpy as np
import networkx as nx

def eigen_centrality(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology_eigen = []

    G = nx.from_numpy_array(np.absolute(data))
    U = G.to_undirected()

    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph frL2
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    
    eigenvector_centrality = np.array([ec[g] for g in U])



    topology_eigen.append(eigenvector_centrality)  # 2

    return topology_eigen

def pearson_coor(input, target, epsilon=1e-7):
    vx = input - torch.mean(input, dim=(1, 2))[:, None, None]
    vy = target - torch.mean(target, dim=(1, 2))[:, None, None]
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)+epsilon) * torch.sqrt(torch.sum(vy ** 2)+epsilon)+epsilon)
    return cost

def GT_loss(target, predicted):

    # l1_loss
    l1_loss = torch.nn.L1Loss()
    # loss_pix2pix = l1_loss(target, predicted)

    # topological_loss
    target_n = target.detach().cpu().clone().numpy()
    predicted_n = predicted.detach().cpu().clone().numpy()
    torch.cuda.empty_cache()

    topo_loss = []
    

    for i in range(len(target_n)):

        cur_target = target_n[i]
        cur_predicted = predicted_n[i]

        target_t = eigen_centrality(cur_target)
        real_topology = torch.tensor(target_t[0])
        predicted_t = eigen_centrality(cur_predicted)
        fake_topology = torch.tensor(predicted_t[0])
        topo_loss.append(l1_loss(real_topology, fake_topology))

    topo_loss = torch.sum(torch.stack(topo_loss))

    pc_loss = pearson_coor(target, predicted).to(DEVICE)
    torch.cuda.empty_cache()

    # G_loss = loss_pix2pix + (1 - pc_loss) + topo_loss
    G_loss = (1 - pc_loss) + topo_loss


    return G_loss

In [9]:
def loss_calc(adj_ls, opp_adj_ls):
    total_loss = torch.Tensor([0]).to(DEVICE)
    mse_loss_fn = torch.nn.MSELoss()
    for i, (adj, opp_adj) in enumerate(zip(adj_ls[::-1], opp_adj_ls)):

        ### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
        temp_adj = adj.reshape(1, *adj.shape)
        temp_opp_adj = opp_adj.reshape(1, *opp_adj.shape)
        ##########################################################
        gt_loss = GT_loss(temp_adj, temp_opp_adj) / len(adj_ls)
        mse_loss = torch.pow(mse_loss_fn(adj, opp_adj), 1/(i+1)) 
        total_loss = total_loss + mse_loss + gt_loss.to(DEVICE)
    return total_loss

# Training

In [10]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor('data')

In [11]:
lr_X_dim1 = torch.load('model_autoencoder/encode_lr_1.pt')
lr_X_dim2 = torch.load('model_autoencoder/encode_lr_2.pt')
hr_X_dim1 = torch.load('model_autoencoder/encode_hr_1.pt')
hr_X_dim2 = torch.load('model_autoencoder/encode_hr_2.pt')


lr_X_all = torch.empty((167, 320, 32))
for i in range(len(lr_X_dim1)):
    a, b = lr_X_dim1[i], lr_X_dim2[i]
    lr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

hr_X_all = torch.empty((167, 536, 32))
for i in range(len(hr_X_dim1)):
    a, b = hr_X_dim1[i], hr_X_dim2[i]
    hr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

In [12]:
trainloader = DataLoader(list(zip(lr_X_all, lr_train, hr_X_all, hr_train)), shuffle=True, batch_size=8)


up_changer = AdjacencyChangerUp(d=2,f_in=32).to(DEVICE)
down_changer = AdjacencyChangerDown(d=2,f_in=32).to(DEVICE)

up_optimizer = torch.optim.AdamW(up_changer.parameters(), lr=0.001, betas=(0.5, 0.999))
down_optimizer = torch.optim.AdamW(down_changer.parameters(), lr=0.001, betas=(0.5, 0.999))

sum(p.numel() for model in [up_changer, down_changer] for p in model.parameters())


6908

In [13]:
def train(epochs, up_changer, down_changer, trainloader, up_optimizer, down_optimizer):

    up_changer.train()
    down_changer.train()
    with torch.autograd.set_detect_anomaly(True):
        for epoch in range(epochs):

            up_losses = []
            down_losses = []

            for X_lr, adj_lr, X_hr, adj_hr in tqdm(trainloader):

                freeze_model(up_changer)
                unfreeze_model(down_changer)
            
                down_optimizer.zero_grad()
                up_optimizer.zero_grad()

                down_batch_loss = []

                for i in range(len(X_lr)):

                    up_adj_ls = up_changer(X_lr[i].to(DEVICE), adj_lr[i].to(DEVICE))
                    torch.cuda.empty_cache()
                    down_adj_ls = down_changer(X_hr[i].to(DEVICE), adj_hr[i].to(DEVICE))
                    torch.cuda.empty_cache()

                    down_batch_loss.append(loss_calc(down_adj_ls, up_adj_ls))

                down_loss = torch.mean(torch.stack(down_batch_loss))
                down_loss.backward()
                down_optimizer.step()

                down_losses.append(down_loss.detach().item())
                del down_loss
                del down_batch_loss
                torch.cuda.empty_cache()

                unfreeze_model(up_changer)
                freeze_model(down_changer)
            
                down_optimizer.zero_grad()
                up_optimizer.zero_grad()

                up_batch_loss = []


                for i in range(len(X_lr)):

                    up_adj_ls = up_changer(X_lr[i].to(DEVICE), adj_lr[i].to(DEVICE))
                    torch.cuda.empty_cache()
                    down_adj_ls = down_changer(X_hr[i].to(DEVICE), adj_hr[i].to(DEVICE))
                    torch.cuda.empty_cache()
                
                    up_batch_loss.append(loss_calc(up_adj_ls, down_adj_ls))

                up_loss = torch.mean(torch.stack(up_batch_loss))
                up_loss.backward()
                up_optimizer.step()

                up_losses.append(up_loss.detach().item())
                del up_loss
                del up_batch_loss
                torch.cuda.empty_cache()



            epoch_up_loss = np.mean(up_losses)
            epoch_down_loss = np.mean(down_losses)

            print(f'epoch {epoch}: down loss = {epoch_down_loss}, up loss = {epoch_up_loss}')

        return up_changer, down_changer


In [14]:
up_changer, down_changer = train(20, up_changer, down_changer, trainloader, up_optimizer, down_optimizer)

  0%|          | 0/21 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]


X: torch.Size([320, 32]), adj: torch.Size([160, 160])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 160 but got size 2 for tensor number 1 in the list.