In [1]:
import numpy as np
import csv
import pandas as pd
import numpy as np
import os

import numpy as np
import torch
import torch.nn as nn
import random
from torch.nn import TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader

import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

from scipy.signal import hilbert

In [2]:
data = np.load('../AL-NEGAT/Data/ABIDE2.npy',allow_pickle=True)
filename = '../AL-NEGAT/Data/ABIDEII_Composite_Phenotypic.csv'
csv2dict = pd.read_csv(filename,encoding='windows-1252').to_dict()
ID_list = np.array(list(csv2dict['SUB_ID'].values()))
FC_mat = []
T1_mat = []
lbl_arr = []
for id in ID_list:
    try:
        if data[id]['FC'].shape[0]==data[id]['T1'].shape[0]:
            FC_mat.append(data[id]['FC'])
            T1_mat.append(data[id]['T1'])
            lbl_arr.append(data[id]['label'])
    except:
        pass

In [3]:
FC_mat = np.array(FC_mat)
T1_mat = np.array(T1_mat)
lbl_arr = np.array(lbl_arr)

In [4]:
FC_mat.shape, T1_mat.shape, lbl_arr.shape

((546, 400, 400), (546, 400, 4), (546,))

In [5]:
lbl_arr[lbl_arr==2] = 0
lbl_arr.mean()

0.46886446886446886

In [6]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        z, x, y = self.data[idx]
        z = torch.from_numpy(z.astype(np.float32))
        x = torch.from_numpy(x.astype(np.float32))
        # edge_indx = torch.tensor(np.indices((400,400)).reshape(2,-1),dtype=torch.long).t().contiguous()
        # iden_mat = torch.eye(400,dtype=torch.float32)
        y = torch.tensor(y)
        return z,x,y

In [7]:
#random number generator for train test split
rng = np.random.RandomState(42)
indices = np.arange(len(FC_mat))
rng.shuffle(indices)
FC_mat = FC_mat[indices]
T1_mat = T1_mat[indices]
lbl_arr = lbl_arr[indices]

data_dict = {}
for i in range(len(FC_mat)):
    data_dict[i] = [FC_mat[i], T1_mat[i], lbl_arr[i]]

train_data = CustomDataset(list(data_dict.values())[:int(0.8*len(data_dict))])
test_data = CustomDataset(list(data_dict.values())[int(0.8*len(data_dict)):])

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=True)

In [8]:
class NEGA_block(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.Wqk = nn.Parameter(torch.randn(400,4))
        self.Wv = nn.Parameter(torch.randn(400,400))
        self.softmax = nn.Softmax(dim=1)
        #batchnorm with relu
        self.batchnorm_relu = nn.Sequential(
            nn.BatchNorm1d(400),
            nn.ReLU()
        )
        self.W_z = nn.Parameter(torch.randn(400,400))
        self.W_x = nn.Parameter(torch.randn(4,4))

    def forward(self, z,x):
        
        v = z*self.Wv
        qk = x*self.Wqk
        # print(qk.shape, v.shape)
        A = self.softmax(qk@torch.transpose(qk,1,2))*v

        Z_lp1 = self.batchnorm_relu(A@z@self.W_z)+z
        X_lp1 = self.batchnorm_relu(A@x@self.W_x)+x
        return Z_lp1, X_lp1

In [11]:
#create 5 layers of NEGA blocks and flatten the output and apply mlp individually on node and edge features and concatenate them and apply mlp on the concatenated features for classification
class NEGA(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.nega1 = NEGA_block()
        self.nega2 = NEGA_block()
        self.nega3 = NEGA_block()
        self.nega4 = NEGA_block()
        self.nega5 = NEGA_block()
        self.mlp_node = nn.Sequential(
            nn.Linear(4*400,128),
            nn.ReLU(),
            nn.Linear(128,16),
            nn.Dropout(0.5)
        )
        self.mlp_edge = nn.Sequential(
            nn.Linear(160000,512),
            nn.ReLU(),
            nn.Linear(512,16),
            nn.Dropout(0.5)
        )
        self.mlp_concat = nn.Sequential(
            nn.Linear(32,16),
            nn.ReLU(),
            nn.Linear(16,1)
        )
    def forward(self, z,x):
        z1,x1 = self.nega1(z,x)
        z2,x2 = self.nega2(z1,x1)
        z3,x3 = self.nega3(z2,x2)
        z4,x4 = self.nega4(z3,x3)
        z5,x5 = self.nega5(z4,x4)
        # print(x5.flatten(1,2).shape, z5.flatten(1,2).shape)

        node_out = self.mlp_node(x5.flatten(1,2))
        edge_out = self.mlp_edge(z5.flatten(1,2))
        concat_out = self.mlp_concat(torch.cat((node_out,edge_out),dim=1))

        return concat_out

In [12]:
#train
import wandb
torch.random.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NEGA().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)
model.train()
# wandb.init(project='NEGA')
for epoch in range(20):
    for i, (z,x,y) in enumerate(train_loader):
        z = z.to(device)
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        out = model(z,x)
        train_loss = criterion(out,y.unsqueeze(1).float())
        train_loss.backward()
        optimizer.step()
    scheduler.step(train_loss.item())
    # wandb.log({'train_loss':train_loss.item()})

    #test
    model.eval()
    for i, (z,x,y) in enumerate(test_loader):
        z = z.to(device)
        x = x.to(device)
        y = y.to(device)
        out = model(z,x)
        test_loss = criterion(out,y.unsqueeze(1).float())
    # wandb.log({'test_loss':test_loss.item()})
    print('Epoch: {} Train Loss: {} Test Loss: {}'.format(epoch, train_loss.item(), test_loss.item()))

Epoch: 0 Train Loss: 35.74465560913086 Test Loss: 8.053668975830078
Epoch: 1 Train Loss: 2.3253366947174072 Test Loss: 0.9752073287963867
Epoch: 2 Train Loss: 1.3835289478302002 Test Loss: 2.8403632640838623
Epoch: 3 Train Loss: 1.217729926109314 Test Loss: 1.4212478399276733
Epoch: 4 Train Loss: 1.3130813837051392 Test Loss: 1.469170331954956
Epoch: 5 Train Loss: 0.376838743686676 Test Loss: 1.9218261241912842
Epoch: 6 Train Loss: 0.06437404453754425 Test Loss: 2.784975290298462
Epoch: 7 Train Loss: 0.00044554146006703377 Test Loss: 3.0164759159088135
Epoch: 8 Train Loss: 0.0021474198438227177 Test Loss: 0.9099162220954895
Epoch: 9 Train Loss: 0.0009143541101366282 Test Loss: 1.444753885269165
Epoch: 10 Train Loss: 0.0008626551134511828 Test Loss: 0.9214375615119934
Epoch: 11 Train Loss: 0.0009268997237086296 Test Loss: 1.401872992515564
Epoch: 12 Train Loss: 0.004754604771733284 Test Loss: 161.9132843017578
Epoch 00014: reducing learning rate of group 0 to 5.0000e-04.
Epoch: 13 Train