In [1]:
"""Data Files"""
import glob
from IPython.display import clear_output
import torch.nn.functional as F
import torch.nn as nn
import torch
import pandas as pd

files_data = pd.read_csv('../../../../../ds004504/participants.tsv', sep="\t")

files = []
for file in glob.glob("../../../../../ds004504/derivatives/*/*/*.set"):
    files.append(file)

In [2]:
"""build data features and connectivity bands, 
Here the duration of each epoch is 30 sec and
the calculated connectivities"""
import numpy as np
import pickle

from read_data import build_data

SAVED_DATA = True

if SAVED_DATA:
    #with open('saved_files/all_boxcox_30.pkl', 'rb') as fp:
    #    all_boxcox = pickle.load(fp)
    with open('saved_files/all_X_bands_30.pkl', 'rb') as fp:
        all_X = pickle.load(fp)
    with open('saved_files/feat_30.pkl', 'rb') as fp:
        all_feat = pickle.load(fp)
    with open('saved_files/all_graphs_30.pkl', 'rb') as fp:
        all_graphs= pickle.load(fp)
    with open('saved_files/all_y_30.pkl', 'rb') as fp:
        all_y = pickle.load(fp)
    with open('saved_files/MMSE_30.pkl', 'rb') as fp:
        MMSE = pickle.load(fp)

else:
    all_X, all_graphs, all_y, ch_names = build_data(files, size=30, 
                                                files_data=files_data,
                                                cal_conn="all", 
                                                raw_eeg=True, 
                                                bands=True, 
                                                data_used="dem")
    all_feat = cal_features(all_X)
    MMSE = {}
    for file in files:
        MMSE[file] = get_MMSE(file)
    
    #all_boxcox = boxcox_dict(all_X, files)
    for k, v in all_X.items():
        all_X[k] = np.float16(v)
    for k, v in all_graphs.items():
        all_graphs[k] = np.float16(v) 
    #for k, v in all_boxcox.items():
    #    all_boxcox[k] = np.float16(v) 
    with open('saved_files/feat_30.pkl', 'wb') as fp:
        pickle.dump(all_feat, fp)
    with open('saved_files/all_X_bands_30.pkl', 'wb') as fp:
        pickle.dump(all_X, fp)
    with open('saved_files/all_graphs_30.pkl', 'wb') as fp:
        pickle.dump(all_graphs, fp)
    with open('saved_files/all_y_30.pkl', 'wb') as fp:
        pickle.dump(all_y, fp)
    with open('saved_files/MMSE_30.pkl', 'wb') as fp:
        pickle.dump(MMSE, fp)
clear_output()

In [34]:
"""Ablation study for trying different connectivities alone (Coherence, PLI, PLV)"""   

import torch, gc
from torch_geometric.loader import DataLoader
import random
from sklearn.model_selection import KFold    
from sklearn.preprocessing import OneHotEncoder

from utils import build_pyg_dl
from read_data import stack_arrays
from train import trainer, train_test_split_subjects, predict
from evaluate import cal_accuracy, cal_accuracy_loso
from models.ADgraph import ADGraph
from evaluate import avg_accuracy

BATCH_SIZE = 100
NUM_EDGES = 5 # 5 frequency bands in coh or plv, pli
NUM_CONNS= 3
NUM_CHANNELS= 19 # number of used EEG channels
DEVICE = "cpu" # if has_mps else "cuda" if torch.cuda.is_available() else "cpu"
NUM_EPOCHS = 35 
TASK = "AD"
NUM_TIMEPOINTS = None
NUM_CLASSES = 2
NUM_OUT_FEAT = 64
SKIP_LABEL = 2 # 0:C, 1:AD, 2:FTD
SEED = 2025

results = {}
task_files = []
for file in files:
    if all_y[file][0] ==  1 or all_y[file][0] ==  0:
        task_files.append(file)


for conn_idx in range(3):
    num_kfolds = int(len(task_files))
    num_runs = num_kfolds
    
    kf = KFold(n_splits=num_kfolds, shuffle=True, random_state=SEED)
    
    all_cm = {}
    AD_saved_models = []
    AD_saved_iters = []
    
    for train_idx, val_idx in kf.split(task_files):
        train_files = [task_files[i] for i in train_idx]
        val_files = [task_files[i] for i in val_idx]
        train_X2 = [all_feat[i] for i in train_files]
        val_X2 = [all_feat[i] for i in val_files]
        train_graphs = [all_graphs[i] for i in train_files]
        val_graphs = [all_graphs[i] for i in val_files]
        train_y = [all_y[i] for i in train_files]
        val_y = [all_y[i] for i in val_files]
    
        print("LOSO Subject:", val_files)
        
        # Stacks window arrays
        train_X, train_graphs, train_y = stack_arrays(train_X2, train_graphs, train_y, task=TASK)
        val_X, val_graphs, val_y = stack_arrays(val_X2, val_graphs, val_y, task=TASK)
        
        # take only one connectivity with all frequency bands
        val_graphs = val_graphs.reshape(val_graphs.shape[0], 19, 19, 5, 3)[:, :, :, :, conn_idx].reshape(val_graphs.shape[0], 19, 19, 5)
        train_graphs = train_graphs.reshape(train_graphs.shape[0], 19, 19, 5, 3)[:, :, :, :, conn_idx].reshape(train_graphs.shape[0], 19, 19, 5)
        
        #ohe
        ohe = OneHotEncoder()
        train_y = ohe.fit_transform(train_y).toarray()
        val_y = ohe.transform(val_y).toarray()
        
        # build pyg dataloader
        train_dataset = [build_pyg_dl(x, g, y, NUM_EDGES, DEVICE) for x, g, y in zip(train_X, train_graphs, train_y)]
        val_dataset = [build_pyg_dl(x, g, y, NUM_EDGES, DEVICE) for x, g, y in zip(val_X, val_graphs, val_y)]
        train_iter = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        val_iter = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

        # create model
        torch.manual_seed(SEED)
        model = ADGraph(num_nodes=NUM_CHANNELS, 
                        num_out_feat=NUM_OUT_FEAT, 
                        num_classes=NUM_CLASSES, 
                        device=DEVICE, 
                        num_edges=NUM_EDGES, 
                        num_conns= NUM_CONNS,
                        timepoints=NUM_TIMEPOINTS, 
                        operator = "TransformerConv",
                        num_signals=NUM_SIGNALS).to(DEVICE)
        
        # training
        model, _, _ = trainer(NUM_EPOCHS, model, train_iter, val_iter, lr=0.00001)
        ytrue, ypreds = predict(model, val_iter)
        val_acc = cal_accuracy_loso(ytrue, ypreds)
        all_cm[val_files[0]] = val_acc[1]
        print("LOSO Accuracy: ", val_acc)
        AD_saved_models.append(model)
        AD_saved_iters.append(val_iter)
        del train_X
        del val_X

    results[conn_idx] = avg_accuracy(all_cm, all_y)

clear_output()

print("AD Prediction")
conns = ["Coherence", "PLI", "PLV"]
idx = 0
for k, v in results.items():
    acc, sen, spec, prec, f1 = v
    print(conns[idx])
    print("Accuracy:", acc)
    print("Sensitivity:", sen)
    print("Specificity:", spec)
    print("Precision:", prec)
    print("F1:", f1)
    print("==================")
    idx += 1

AD Prediction
Coherence
Accuracy: 0.7660956881275842
Sensitivity: 0.7665995975855131
Specificity: 0.765379113018598
Precision: 0.8228941684665226
F1: 0.7937500000000001
PLI
Accuracy: 0.7359716479621973
Sensitivity: 0.7257304429783223
Specificity: 0.7531645569620253
Precision: 0.8315334773218143
F1: 0.7750377453447409
PLV
Accuracy: 0.7371529828706438
Sensitivity: 0.7836084905660378
Specificity: 0.6905325443786983
Precision: 0.7176025917926566
F1: 0.7491544532130778


In [38]:
"""performances of using all combined connectivities: Coherence, PLI, PLV averaged over frequency bands"""
  
NUM_EDGES = 3 # coh, plv, pli averaged over all frequency bands

num_kfolds = int(len(task_files))
num_runs = num_kfolds

kf = KFold(n_splits=num_kfolds, shuffle=True, random_state=SEED)

all_cm = {}
AD_saved_models = []
AD_saved_iters = []


for train_idx, val_idx in kf.split(task_files):
    train_files = [task_files[i] for i in train_idx]
    val_files = [task_files[i] for i in val_idx]
    train_X2 = [all_feat[i] for i in train_files]
    val_X2 = [all_feat[i] for i in val_files]
    train_graphs = [all_graphs[i] for i in train_files]
    val_graphs = [all_graphs[i] for i in val_files]
    train_y = [all_y[i] for i in train_files]
    val_y = [all_y[i] for i in val_files]

    print("LOSO Subject:", val_files)
    
    # Stacks window arrays
    train_X, train_graphs, train_y = stack_arrays(train_X2, train_graphs, train_y, task=TASK)
    val_X, val_graphs, val_y = stack_arrays(val_X2, val_graphs, val_y, task=TASK)
    
    # take only one connectivity with all frequency bands
    val_graphs = np.mean(val_graphs.reshape(val_graphs.shape[0], 19, 19, 5, 3), 
                         -2).reshape(val_graphs.shape[0], 19, 19, 3)
    train_graphs = np.mean(train_graphs.reshape(train_graphs.shape[0], 19, 19, 5, 3), 
                           -2).reshape(train_graphs.shape[0], 19, 19, 3)
    
    #ohe
    ohe = OneHotEncoder()
    train_y = ohe.fit_transform(train_y).toarray()
    val_y = ohe.transform(val_y).toarray()
    
    # build pyg dataloader
    train_dataset = [build_pyg_dl(x, g, y, NUM_EDGES, DEVICE) for x, g, y in zip(train_X, train_graphs, train_y)]
    val_dataset = [build_pyg_dl(x, g, y, NUM_EDGES, DEVICE) for x, g, y in zip(val_X, val_graphs, val_y)]
    train_iter = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_iter = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # create model
    torch.manual_seed(SEED)
    model = ADGraph(num_nodes=NUM_CHANNELS, 
                    num_out_feat=NUM_OUT_FEAT, 
                    num_classes=NUM_CLASSES, 
                    device=DEVICE, 
                    num_edges=NUM_EDGES, 
                    num_conns= NUM_CONNS,
                    timepoints=NUM_TIMEPOINTS,
                    operator = "TransformerConv",
                    num_signals=NUM_SIGNALS).to(DEVICE)
    
    # training
    model, _, _ = trainer(NUM_EPOCHS, model, train_iter, val_iter, lr=0.00001)
    ytrue, ypreds = predict(model, val_iter)
    val_acc = cal_accuracy_loso(ytrue, ypreds)
    all_cm[val_files[0]] = val_acc[1]
    print("LOSO Accuracy: ", val_acc)
    AD_saved_models.append(model)
    AD_saved_iters.append(val_iter)
    del train_X
    del val_X

clear_output()
results = {}

results["average"] = avg_accuracy(all_cm, all_y)

for k, v in results.items():
    acc, sen, spec, prec, f1 = v
    print("Using Coherence, PLV, PLI")
    print("Accuracy:", acc)
    print("Sensitivity:", sen)
    print("Specificity:", spec)
    print("Precision:", prec)
    print("F1:", f1)

Using Coherence, PLV, PLI
Accuracy: 0.7141169521559362
Sensitivity: 0.7299687825182102
Specificity: 0.6933060109289617
Precision: 0.7575593952483801
F1: 0.7435082140964494
