In [1]:
###DATASET CONSTRUCTION
import torch
import scipy 
from torch_geometric.data import Data
import numpy as np

##SNET DATA PATH
X = scipy.io.loadmat('C:\\Users\\Gurur\\capa_snet_destrieux_normalizedweighted.mat')
X = X["capa_snet_destrieux_normalizedweighted"]

index = np.hstack(  (np.where( X[:,1]=="MCI" ) ,np.where( X[:,1]=="Erken AD" )) ).T
X=np.squeeze(X[index,:])


def convert_to_tensor(data,d_type=None):
    data = np.asarray(data,dtype=np.float32)
    data = torch.from_numpy(data)
    if d_type=="long" :
        data = data.long()
    return data


dataset = []

for i in range(0, X.shape[0]):
    item = X[i,0]
    edge_index = convert_to_tensor(np.nonzero(item),d_type="long")
    label = X[i,1]
    #SOURCE 0: TARGET 1
    new_index= torch.zeros([2,edge_index.size()[1]],dtype=torch.long)
    new_index[0] = edge_index[1]
    new_index[1] = edge_index[0]
    
    item[item > 0] = 1
    x = convert_to_tensor(item + np.identity(148))
    if   label=="Erken AD":
        y = torch.tensor([1], dtype=torch.long)
    else :
        y = torch.tensor([0], dtype=torch.long)
    data = Data(x=x, edge_index=new_index, y=y)
    dataset.append(data)

In [2]:
### LOAD SAVED CROSS VALIDATION INDEXES FOR FAIR COMPARISON OF MODELS
load = True
if load == False :
    fold_num = 8
    fold_arr = np.arange(len(dataset))
    np.random.shuffle(fold_arr)
    fold_arr = np.split(fold_arr, fold_num)
else :
    fold_num = 8
    fold_arr = np.load("8fold_v1.npz")["arr_0"]

y = []
for i in range(len(dataset)):
    y.append(dataset[i].y)

In [3]:
###ATTENTION BASED GRAPH POOLING LAYER
from torch_geometric.nn.inits import reset
from torch_geometric.utils import softmax
from typing import Optional
from torch import Tensor
from torch_geometric.nn.aggr import Aggregation

class AttentionalAggregation(Aggregation):

    def __init__(self, gate_nn: torch.nn.Module,
                 nn: Optional[torch.nn.Module] = None):
        super().__init__()
        self.gate_nn = gate_nn
        self.nn = nn
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.gate_nn)
        reset(self.nn)

    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:

        self.assert_two_dimensional_input(x, dim)
        gate = self.gate_nn(x)
        x = self.nn(x) if self.nn is not None else x
        gate = softmax(gate, index, ptr, dim_size, dim)
        return self.reduce(gate * x, index, ptr, dim_size, dim), gate

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(gate_nn={self.gate_nn}, '
                f'nn={self.nn})')    

class GlobalAttention(AttentionalAggregation):
    def __call__(self, x, batch=None, size=None):
        return super().__call__(x, batch, dim_size=size)

In [None]:
###DAGNN
from torch_geometric.data import DataLoader
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from torch_geometric.nn import  global_add_pool, global_mean_pool, global_max_pool, Set2Set, GINConv,GATConv,GCNConv,TransformerConv
import math
from torch_geometric.utils import to_dense_batch, to_dense_adj
import warnings
from sklearn.metrics import f1_score, roc_auc_score
from scipy.spatial import distance
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")


def f1_calculation(y,patient_results):
    y_pred = np.zeros(len(y))
    for i,label in enumerate(patient_results) :
        if label == 1 :
            y_pred[i]=y[i]
        else :
            y_pred[i] = 1-y[i]
    return f1_score(y,y_pred)


class DAGNN(torch.nn.Module):
    def __init__(self, num_heads):
        super(DAGNN, self).__init__()

        num_features = 148
        dim1 = 128
        dim_att = 32
        dim = dim_att*num_heads
        self.heads = num_heads
        self.nn1 = Sequential(Linear(num_features, dim1), ReLU(), Linear(dim1, dim1))         
        self.conv2 = GATConv(dim1, dim_att , heads = num_heads)            
        self.fc1 = Linear(dim, 2)
        
    def forward(self, inx, edge_index, batch):
        x = F.relu(self.nn1(inx))
        x, (ind, weight) = self.conv2(x, edge_index, return_attention_weights=True) 
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.fc1(x)        
        return F.log_softmax(x, dim=-1), ind, weight
    

def disentanglement_loss(ind, weight, batch) :
    c = to_dense_adj(ind, batch, weight)
    B,N,_,M = c.shape
    columns = c.permute(0, 1, 3, 2)
    dists = torch.cdist(columns, columns, p=1)
    avg_dists = torch.mean(dists, 1)    
    return (2*torch.triu(avg_dists, diagonal=1).sum(dim=(1, 2)) /(M*(M-1))).mean()

    

def train(epoch):
    model.train()

    ce_loss_all = 0
    l1_loss_all = 0
    loss_all = 0
    for data in data_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, ind, weight = model(data.x, data.edge_index,  data.batch)
        ce_loss = F.nll_loss(output, data.y, weight=torch.cuda.FloatTensor([1,1]))
        l1_loss = F.relu(2 - disentanglement_loss(ind, weight, data.batch))
        loss = ce_loss + 0.1 * l1_loss
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        ce_loss_all += ce_loss.item() * data.num_graphs
        l1_loss_all += l1_loss.item() * data.num_graphs
        optimizer.step()
    return ce_loss_all / len(dataset), l1_loss_all / len(dataset), loss_all / len(dataset)


def test(loader):
    model.eval()

    correct = 0
    for data in loader:
        data = data.to(device)
        output,_,_ = model(data.x, data.edge_index,  data.batch)
        pred = output.max(dim=1)[1]
        correct_results = pred.eq(data.y).detach().cpu().numpy()
        correct += pred.eq(data.y).sum().item()
    return correct, torch.exp(output)[:,1], correct_results 

batch_size = 6
num_heads = 4
test_num = 1
epoch_num = 100
results = np.zeros([test_num])
f1_arr = np.zeros([test_num])
patient_results = np.zeros([len(dataset),test_num])
patient_scores = np.zeros([len(dataset),test_num])
auc_arr = np.zeros([test_num])


for t in range(0,test_num):
    acc = 0
    for i, fold in enumerate(fold_arr) :
        print(i)
        train_indices = np.delete(np.arange(len(dataset)),fold).tolist() 
        test_indices = fold
        dataset_train =  torch.utils.data.Subset(dataset, train_indices)
        dataset_test = torch.utils.data.Subset(dataset, test_indices)       
        data_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
        data_loader_test = DataLoader(dataset_test, batch_size=len(fold))

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = DAGNN(num_heads).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)


        for epoch in range(1, epoch_num+1):
            train_loss = train(epoch)
            train_acc,_,_ = test(data_loader)
            train_acc = train_acc/(len(train_indices))
            test_acc,logits,test_results = test(data_loader_test)
            #print('Epoch: {:03d}, Train Loss: {:.7f}, '
                  #'Train Acc: {:.7f}'.format(epoch, train_loss,train_acc))
        patient_scores[fold, t] = logits.detach().cpu().numpy()                 
        patient_results[fold,t] = test_results 
        acc = acc + test_acc
        del model 
    f1_arr[t] = f1_calculation(torch.cat(y),  patient_results[:,t])
    auc_arr[t] = roc_auc_score(torch.cat(y),  patient_scores[:,t])
    acc = acc/len(dataset) 
    print(acc)
    results[t] = acc