<a href="https://www.kaggle.com/code/mateusbaldamota/eeg-psychiatric-disorders-using-ann?scriptVersionId=208869244" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

## | Description

Authors of notebook: Mateus Balda and Alessandro Bof

Reference paper and dataset: https://www.frontiersin.org/journals/psychiatry/articles/10.3389/fpsyt.2021.707581/full

## | 0. Import from libraries

In [1]:
!pip install -q torchsummary;
!pip install -q torchviz;

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import csv
import os

from IPython.display import FileLink, display, HTML
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.impute import KNNImputer

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchsummary import summary
from torchviz import make_dot

from imblearn.over_sampling import BorderlineSMOTE

from warnings import filterwarnings
filterwarnings('ignore')

torch.__version__

'2.4.0'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x7b9f619d8a50>

## | [IMPORTANT] Useful functions

### 1. OUTLIERS

In [None]:
def detect_outliers_summary(df):
    outliers_summary = {}

    for col in df.select_dtypes(include=['float64', 'int64']).columns:
        Q1 = df[col].quantile(0.25)
        Q3 = df[col].quantile(0.75)
        IQR = Q3 - Q1
        lower_limit = Q1 - 1.5 * IQR
        upper_limit = Q3 + 1.5 * IQR

        outliers = df[(df[col] < lower_limit) | (df[col] > upper_limit)][col]
        
        outliers_summary[col] = {
            'num_outliers': len(outliers),
            'percent_outliers': len(outliers) / len(df) * 100,
            'outliers': outliers.tolist(),
            'lower_limit': lower_limit,
            'upper_limit': upper_limit
        }

    return pd.DataFrame(outliers_summary).T


def treat_outliers_iqr(df, column, factor=1.5):
    Q1 = np.percentile(df[column], 25)
    Q3 = np.percentile(df[column], 75)
    IQR = Q3 - Q1
    
    lower_bound = Q1 - factor * IQR
    upper_bound = Q3 + factor * IQR
    
    outliers = df[(df[column] < lower_bound) | (df[column] > upper_bound)]
    
    df[column] = np.clip(df[column], lower_bound, upper_bound)
        
    return _, outliers

### 2. NANS

In [6]:
def remove_missing_columns(df, threshold=0.5):
    limit = int(threshold * len(df))
    df = df.dropna(thresh=limit, axis=1)
    return df
    
def find_most_null_column(df, threshold=0.5):
    null_ratios = df.isnull().mean()
    for col, ration in null_ratios.items():
        if ration > threshold:
            return col
    return None

def analyze_missing_values(df):
    missing_values = df.isnull().sum()
    missing_values = missing_values[missing_values > 0]
    total_number_nans = df.isnull().sum().sum()
    
    return missing_values, total_number_nans

def handle_nans(df):
    columns_with_nans = df.columns[df.isnull().any()].tolist()
    
    knn_imputer = KNNImputer(n_neighbors=5, weights='uniform', metric='nan_euclidean')
    
    df_imputed = pd.DataFrame(knn_imputer.fit_transform(df[columns_with_nans]),
                                columns=columns_with_nans)
    
    df[columns_with_nans] = df_imputed[columns_with_nans]
    
    return df

### 3. PREPARER

In [7]:
def minority_class_balancer(df, target_column):
    X = df.drop(columns=[target_column])
    y = df[target_column]
    
    borderline_smote = BorderlineSMOTE(k_neighbors=3, random_state=42)
    X_resampled, y_resampled = borderline_smote.fit_resample(X, y)
    
    df = pd.DataFrame(X_resampled, columns=X.columns)
    df[target_column] = y_resampled
    
    return df

### 4. SETS & SUBSETS

In [8]:
import pandas as pd

class Sets:
    # AB = PSD (Power Spectral Density) 19 * 6
    # COH = FC (Functional Connectivity) 171 * 6
    
    def __init__(self, 
                 dataframe: pd.DataFrame,
                 quantitative_features: list[str],
                #  qualitative_features: list[str],
                 target_main: str,
                #  target_specific: str
                 ):
        
        if not isinstance(dataframe, pd.DataFrame):
            raise ValueError("The parameter must be a pandas DataFrame")
        
        self.df = dataframe.copy()
        self.quantitative_features = self.df[quantitative_features]
        # self.qualitative_features = self.df[qualitative_features]
        self.target_main = self.df[[target_main]]
        # self.target_specific = self.df[[target_specific]]
        self.df_ab_psd = None
        self.df_coh_fc = None
        self.df_ab_psd_coh_fc = None
        #self.dfs_bands = {}
        
        self.__create_df_psd()
        self.__create_df_fc()
        self.__create_union_psd_fc()

    # 19 (6 bands)
    def __create_df_psd(self):
        columns_AB = [col for col in self.df.columns if col.startswith('AB')]
        self.df_ab_psd = self.df[columns_AB]
    
    # 171 (6 bands)
    def __create_df_fc(self):
        columns_COH = [col for col in self.df.columns if col.startswith('COH')]
        self.df_coh_fc = self.df[columns_COH]
        
    def __create_union_psd_fc(self):
        if self.df_ab_psd is not None and self.df_coh_fc is not None:
            self.df_ab_psd_coh_fc = pd.concat([self.df_ab_psd, self.df_coh_fc], axis=1)
        else:
            raise ValueError("Subsets AB and COH were not created correctly")

    def create_dfs_bands(self, bands: list[str] = None, df: pd.DataFrame = None):
        dfs_bands = {}
        
        if bands is None:
            bands = ['delta', 'theta', 'alpha', 'beta', 'highbeta','gamma']
        
        for band in bands:
            columns_band = [col for col in df.columns if f'.{band}.' in col]
            if columns_band:
                dfs_bands[band] = df[columns_band]
        
        return dfs_bands


## | 1. Load dataset

In [9]:
df = pd.read_csv('../input/eeg-psychiatric-disorders-dataset/EEG.machinelearing_data_BRMH.csv')
df.shape

(945, 1149)

In [10]:
# Checking / imputing Nan

output1 = analyze_missing_values(df)
output2 = find_most_null_column(df)
df = remove_missing_columns(df)
output3 = analyze_missing_values(df)
df = handle_nans(df)

display(output1)
display(HTML('<hr>'))
display(output2)
display(HTML('<hr>'))
display(output3)
display(df.shape)
display(HTML('<hr>'))
display(df.isna().sum().sum())
display(df)

(education        15
 IQ               13
 Unnamed: 122    945
 dtype: int64,
 973)

'Unnamed: 122'

(education    15
 IQ           13
 dtype: int64,
 28)

(945, 1148)

0

Unnamed: 0,no.,sex,age,eeg.date,education,IQ,main.disorder,specific.disorder,AB.A.delta.a.FP1,AB.A.delta.b.FP2,...,COH.F.gamma.o.Pz.p.P4,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2
0,1,M,57.0,2012.8.30,13.43871,101.580472,Addictive disorder,Alcohol use disorder,35.998557,21.717375,...,55.989192,16.739679,23.452271,45.678820,30.167520,16.918761,48.850427,9.422630,34.507082,28.613029
1,2,M,37.0,2012.9.6,6.00000,120.000000,Addictive disorder,Alcohol use disorder,13.425118,11.002916,...,45.595619,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261
2,3,M,32.0,2012.9.10,16.00000,113.000000,Addictive disorder,Alcohol use disorder,29.941780,27.544684,...,99.475453,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799
3,4,M,35.0,2012.10.8,18.00000,126.000000,Addictive disorder,Alcohol use disorder,21.496226,21.846832,...,59.986561,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873
4,5,M,36.0,2012.10.18,16.00000,112.000000,Addictive disorder,Alcohol use disorder,37.775667,33.607679,...,61.462720,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
940,941,M,22.0,2014.8.28,13.00000,116.000000,Healthy control,Healthy control,41.851823,36.771496,...,82.905657,34.850706,63.970519,63.982003,51.244725,62.203684,62.062237,31.013031,31.183413,98.325230
941,942,M,26.0,2014.9.19,13.00000,118.000000,Healthy control,Healthy control,18.986856,19.401387,...,65.917918,66.700117,44.756285,49.787513,98.905995,54.021304,93.902401,52.740396,92.807331,56.320868
942,943,M,26.0,2014.9.27,16.00000,113.000000,Healthy control,Healthy control,28.781317,32.369230,...,61.040959,27.632209,45.552852,33.638817,46.690983,19.382928,41.050717,7.045821,41.962451,19.092111
943,944,M,24.0,2014.9.20,13.00000,107.000000,Healthy control,Healthy control,19.929100,25.196375,...,99.113664,48.328934,41.248470,28.192238,48.665743,42.007147,28.735945,27.176500,27.529522,20.028446


In [11]:
# detecting outliers

outliers_summary = detect_outliers_summary(df)
outliers_summary.to_csv('outliers_summary.csv', index=True)
outliers_summary['num_outliers'].sum()

26452

In [12]:
X = df.iloc[:,8:]
target_main = df['main.disorder']
target_specific = df['specific.disorder']
quantitative_features = df.loc[:, ['age', 'education', 'IQ']]
sex = df['sex']

X.shape, target_main.shape, target_specific.shape, quantitative_features.shape, sex.shape

((945, 1140), (945,), (945,), (945, 3), (945,))

In [13]:
column_names = np.unique(target_main).tolist()
pd.DataFrame(np.unique(target_main))

Unnamed: 0,0
0,Addictive disorder
1,Anxiety disorder
2,Healthy control
3,Mood disorder
4,Obsessive compulsive disorder
5,Schizophrenia
6,Trauma and stress related disorder


In [14]:
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(target_main)
y_encoded = y_encoded.astype(np.float32)
(pd.DataFrame(y_encoded)).value_counts()

0  
3.0    266
0.0    186
6.0    128
5.0    117
1.0    107
2.0     95
4.0     46
Name: count, dtype: int64

In [15]:
X_concated = pd.concat([quantitative_features, X], axis=1)
X_concated

Unnamed: 0,age,education,IQ,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,AB.A.delta.g.F8,...,COH.F.gamma.o.Pz.p.P4,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2
0,57.0,13.43871,101.580472,35.998557,21.717375,21.518280,26.825048,26.611516,25.732649,16.563408,...,55.989192,16.739679,23.452271,45.678820,30.167520,16.918761,48.850427,9.422630,34.507082,28.613029
1,37.0,6.00000,120.000000,13.425118,11.002916,11.942516,15.272216,14.151570,12.456034,8.436832,...,45.595619,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261
2,32.0,16.00000,113.000000,29.941780,27.544684,17.150159,23.608960,27.087811,13.541237,16.523963,...,99.475453,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799
3,35.0,18.00000,126.000000,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,14.613650,...,59.986561,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873
4,36.0,16.00000,112.000000,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,15.969042,...,61.462720,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
940,22.0,13.00000,116.000000,41.851823,36.771496,43.671792,36.860889,24.732236,23.607823,23.288260,...,82.905657,34.850706,63.970519,63.982003,51.244725,62.203684,62.062237,31.013031,31.183413,98.325230
941,26.0,13.00000,118.000000,18.986856,19.401387,27.586436,20.194732,19.407491,20.216570,16.465027,...,65.917918,66.700117,44.756285,49.787513,98.905995,54.021304,93.902401,52.740396,92.807331,56.320868
942,26.0,16.00000,113.000000,28.781317,32.369230,11.717778,23.134370,26.209302,25.484497,22.586688,...,61.040959,27.632209,45.552852,33.638817,46.690983,19.382928,41.050717,7.045821,41.962451,19.092111
943,24.0,13.00000,107.000000,19.929100,25.196375,14.445391,16.453456,16.590649,16.007279,18.909188,...,99.113664,48.328934,41.248470,28.192238,48.665743,42.007147,28.735945,27.176500,27.529522,20.028446


In [16]:
# class balancing 

borderline_smote = BorderlineSMOTE(k_neighbors=3, random_state=42)
X_resampled, y_resampled = borderline_smote.fit_resample(X_concated, y_encoded)

display(HTML('<hr>'))
display(X_resampled)
display((pd.DataFrame(y_resampled)).value_counts())

Unnamed: 0,age,education,IQ,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,AB.A.delta.g.F8,...,COH.F.gamma.o.Pz.p.P4,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2
0,57.000000,13.438710,101.580472,35.998557,21.717375,21.518280,26.825048,26.611516,25.732649,16.563408,...,55.989192,16.739679,23.452271,45.678820,30.167520,16.918761,48.850427,9.422630,34.507082,28.613029
1,37.000000,6.000000,120.000000,13.425118,11.002916,11.942516,15.272216,14.151570,12.456034,8.436832,...,45.595619,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261
2,32.000000,16.000000,113.000000,29.941780,27.544684,17.150159,23.608960,27.087811,13.541237,16.523963,...,99.475453,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799
3,35.000000,18.000000,126.000000,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,14.613650,...,59.986561,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873
4,36.000000,16.000000,112.000000,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,15.969042,...,61.462720,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1857,25.316734,16.000000,122.590560,18.848215,17.953192,18.124769,17.798791,17.155828,19.501348,18.809150,...,90.649213,59.274382,57.996010,65.906688,74.424361,51.600495,72.481835,37.669746,70.393744,59.026147
1858,37.910967,15.206766,119.717294,22.475810,37.214201,22.666290,23.451906,25.555333,24.532122,22.737343,...,77.966235,51.807610,58.152599,56.323073,74.984977,50.960234,69.986686,38.314046,70.768974,52.070437
1859,48.281079,12.254164,90.254164,28.445129,18.053016,19.203409,18.843402,20.442576,19.893870,16.028124,...,75.444990,42.442354,52.766711,44.048382,69.067053,42.758237,52.427951,28.734816,47.995477,59.294021
1860,42.111736,8.476453,94.771744,15.903714,14.886987,12.303060,13.781029,15.076520,17.184864,12.620410,...,79.225446,42.568994,42.391061,45.600630,67.811362,35.011289,54.835830,23.835465,50.603209,52.778712


0  
0.0    266
1.0    266
2.0    266
3.0    266
4.0    266
5.0    266
6.0    266
Name: count, dtype: int64

In [17]:
y_resampled = pd.Series(y_resampled, name='main.disorder')
display(y_resampled)
display(pd.DataFrame(y_resampled).value_counts())


0       0.0
1       0.0
2       0.0
3       0.0
4       0.0
       ... 
1857    6.0
1858    6.0
1859    6.0
1860    6.0
1861    6.0
Name: main.disorder, Length: 1862, dtype: float32

main.disorder
0.0              266
1.0              266
2.0              266
3.0              266
4.0              266
5.0              266
6.0              266
Name: count, dtype: int64

In [18]:
df_union = pd.concat([X_resampled, pd.DataFrame(y_resampled)], axis=1)

display(df_union.shape)
display(df_union.head())
display(df.isna().sum().sum())

(1862, 1144)

Unnamed: 0,age,education,IQ,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,AB.A.delta.g.F8,...,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2,main.disorder
0,57.0,13.43871,101.580472,35.998557,21.717375,21.51828,26.825048,26.611516,25.732649,16.563408,...,16.739679,23.452271,45.67882,30.16752,16.918761,48.850427,9.42263,34.507082,28.613029,0.0
1,37.0,6.0,120.0,13.425118,11.002916,11.942516,15.272216,14.15157,12.456034,8.436832,...,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261,0.0
2,32.0,16.0,113.0,29.94178,27.544684,17.150159,23.60896,27.087811,13.541237,16.523963,...,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799,0.0
3,35.0,18.0,126.0,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,14.61365,...,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873,0.0
4,36.0,16.0,112.0,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,15.969042,...,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662,0.0


0

In [19]:
# set and subset separation

sets = Sets(dataframe=df_union, 
            quantitative_features=['age', 'education', 'IQ'], 
            target_main='main.disorder',
            )


In [38]:
df_psd_all_bands = sets.df_ab_psd
df_fc_all_bands = sets.df_coh_fc
df_psd_fc_all_bands = sets.df_ab_psd_coh_fc
quantitative_features = sets.quantitative_features
target_main = sets.target_main

dfs_bands_psd = sets.create_dfs_bands(df=df_psd_all_bands)
dfs_bands_fc = sets.create_dfs_bands(df=df_fc_all_bands)
dfs_bands_psd_fc = sets.create_dfs_bands(df=df_psd_fc_all_bands)

In [22]:
df_dict_psd_all_band = {
    'psd_all_bands': df_psd_all_bands,
}

df_dict_fc_all_band = {
    'fc_all_bands': df_fc_all_bands,
}

df_dict_psd_fc_all_band = {
    'psd_fc_all_bands': df_psd_fc_all_bands,
}

df_dict_psd_band = {
    'psd_delta': dfs_bands_psd['delta'],
    'psd_theta': dfs_bands_psd['theta'],
    'psd_alpha': dfs_bands_psd['alpha'],
    'psd_beta': dfs_bands_psd['beta'],
    'psd_highbeta': dfs_bands_psd['highbeta'],
    'psd_gamma': dfs_bands_psd['gamma'],
}

df_dict_fc_band = {
    'fc_delta': dfs_bands_fc['delta'],
    'fc_theta': dfs_bands_fc['theta'],
    'fc_alpha': dfs_bands_fc['alpha'],
    'fc_beta': dfs_bands_fc['beta'],
    'fc_highbeta': dfs_bands_fc['highbeta'],
    'fc_gamma': dfs_bands_fc['gamma'],
}

df_dict_psd_fc_band = {
    'psd_fc_delta': dfs_bands_psd_fc['delta'],
    'psd_fc_theta': dfs_bands_psd_fc['theta'],
    'psd_fc_alpha': dfs_bands_psd_fc['alpha'],
    'psd_fc_beta': dfs_bands_psd_fc['beta'],
    'psd_fc_highbeta': dfs_bands_psd_fc['highbeta'],
    'psd_fc_gamma': dfs_bands_psd_fc['gamma'],
}

In [23]:
for dataset_name, df in df_dict_psd_all_band.items():
    print(f"{dataset_name}: {df.shape} shape")
    display(df.head())
    print("\n")

psd_all_bands: (1862, 114) shape


Unnamed: 0,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,AB.A.delta.g.F8,AB.A.delta.h.T3,AB.A.delta.i.C3,AB.A.delta.j.Cz,...,AB.F.gamma.j.Cz,AB.F.gamma.k.C4,AB.F.gamma.l.T4,AB.F.gamma.m.T5,AB.F.gamma.n.P3,AB.F.gamma.o.Pz,AB.F.gamma.p.P4,AB.F.gamma.q.T6,AB.F.gamma.r.O1,AB.F.gamma.s.O2
0,35.998557,21.717375,21.51828,26.825048,26.611516,25.732649,16.563408,29.891368,22.402246,22.582176,...,1.993727,1.765493,1.464281,1.501948,1.707307,1.553448,1.552658,1.388662,1.592717,1.806598
1,13.425118,11.002916,11.942516,15.272216,14.15157,12.456034,8.436832,9.975238,14.83474,10.950564,...,0.903383,0.931967,0.437117,0.930843,1.234874,1.373268,1.411808,1.140695,1.118041,3.162143
2,29.94178,27.544684,17.150159,23.60896,27.087811,13.541237,16.523963,12.775574,21.686306,18.367666,...,1.096713,1.691152,1.505663,1.133891,1.661768,1.403429,1.349457,1.270525,1.408471,1.454618
3,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,14.61365,8.063191,11.015078,11.63956,...,1.11504,1.122776,2.128138,1.648217,1.147666,1.049152,1.131654,1.415856,1.391048,1.527403
4,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,15.969042,9.434306,15.244523,17.041979,...,1.193191,2.320845,3.56282,1.441662,1.018804,1.274009,2.350806,2.30773,2.129431,3.76686






## Custom dataset PyTorch

In [24]:
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

### Split for train, val and test

In [25]:
train_ratio = 0.7
val_ratio = 0.15

output_dim = len(np.unique(y_encoded))

# params 
neurons_psd_all_bands = [
    [456, 228, 114, 57, 28],
    [256, 128, 64, 32, 16],
    [114, 57, 28, 14, 7], #
    [96, 48, 24, 12, 6],
]
neurons_fc_all_bands = [
    [2048, 1024, 512, 256, 128],
    [1536, 768, 384, 192, 96],
    [1026, 513, 256, 128, 64], #
    [896, 448, 224, 112, 56],
]
neurons_psd_fc_all_bands = [
    [2048, 1024, 512, 256, 128],
    [1536, 768, 384, 192, 96],
    [1140, 570, 285, 142, 71], #
    [896, 448, 224, 112, 56],
]
neurons_psd_per_band = [
    [95, 64, 32, 16, 8],
    [64, 32, 16, 8, 4],
    [48, 24, 12, 6, 3],
    [19, 16, 8, 4, 2], #
]
neurons_fc_per_band = [
    [684, 342, 171, 85, 42],
    [171, 114, 76, 38, 19], #
    [128, 85, 57, 28, 14],
]
neurons_psd_fc_per_band = [
    [760, 380, 190, 95, 48],
    [256, 128, 64, 32, 16],
    [190, 127, 85, 42, 21], #
    [128, 85, 57, 28, 14],
]

dropout_rates = [0.10, 0.20, 0.30]
learning_rates = [0.0001, 0.001, 0.01, 0.1]
batch_sizes = [32, 64, 128]
weight_decays = [0.0001, 0.001, 0.01]  

num_epochs = 500

In [26]:
def prepare_datasets(
    X, y, 
    train_ratio = train_ratio,
    val_ratio = val_ratio,
    batch_size = batch_sizes[0], 
    scaler = None
):
    
    if scaler is None:
        scaler = StandardScaler() 
    
    X_scaled = scaler.fit_transform(X)

    dataset = EEGDataset(X_scaled, y)
    
    total_size = len(dataset)
    train_size = int(train_ratio * total_size) # 70% train
    val_size = int(val_ratio * total_size) # 15% val
    test_size = total_size - train_size - val_size # 15% test

    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return {
        'datasets': {
            'train': train_dataset,
            'val': val_dataset,
            'test': test_dataset,
        },
        'loaders': {
            'train': train_loader,
            'val': val_loader,
            'test': test_loader,
        },
        'X_scaled': X_scaled
    }

## Neural network architecture

In [27]:
class EEGANNClassifier(nn.Module):
    def __init__(self, input_dim, output_dim, neurons, dropout=0.10):
        super(EEGANNClassifier, self).__init__()
        
        self.layer1 = nn.Linear(input_dim, neurons[0])
        self.layer2 = nn.Linear(neurons[0], neurons[1])
        self.layer3 = nn.Linear(neurons[1], neurons[2])
        self.layer4 = nn.Linear(neurons[2], neurons[3])
        self.layer5 = nn.Linear(neurons[3], neurons[4])
        self.output_layer = nn.Linear(neurons[4], output_dim)

        self.bn1 = nn.BatchNorm1d(neurons[0])
        self.bn2 = nn.BatchNorm1d(neurons[1])
        self.bn3 = nn.BatchNorm1d(neurons[2])
        self.bn4 = nn.BatchNorm1d(neurons[3])
        self.bn5 = nn.BatchNorm1d(neurons[4])

        self.gelu = nn.GELU()

        self.dropout = nn.Dropout(p=dropout)
        
        init.xavier_uniform_(self.layer1.weight)
        init.xavier_uniform_(self.layer2.weight)
        init.xavier_uniform_(self.layer3.weight)
        init.xavier_uniform_(self.layer4.weight)
        init.xavier_uniform_(self.layer5.weight)
        init.xavier_uniform_(self.output_layer.weight)

    def forward(self, x):
        x = self.gelu(self.bn1(self.layer1(x))) 
        x = self.dropout(x)
        
        x = self.gelu(self.bn2(self.layer2(x)))
        x = self.dropout(x)
        
        x = self.gelu(self.bn3(self.layer3(x)))
        x = self.dropout(x)
        
        x = self.gelu(self.bn4(self.layer4(x)))
        x = self.dropout(x)
        
        x = self.gelu(self.bn5(self.layer5(x)))
        
        x = self.output_layer(x)
        
        return x

### Network summary

In [28]:
# model = EEGANNClassifier(input_dim, output_dim, selected_neurons, dropout)
# print(f'{model}\n{input_dim}\n{output_dim}')

In [29]:
# summary(model, input_size=(input_dim, ))

### Network graph

In [30]:
# model.eval()

# sample_input = torch.randn(1, input_dim)  

# output = model(sample_input)
# graph = make_dot(output, params=dict(model.named_parameters()))

# graph.render("eeg_classifier", format="png")

# file_path = "eeg_classifier.png"

# FileLink(file_path)

## Training

In [31]:
# Method for train

def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer, 
    scheduler, 
    num_epochs = num_epochs,
    patience = 20,
    min_delta = 0.01
):
    
    best_loss = float('inf')
    early_stop_counter = 0
    train_losses = []
    val_losses = []
    val_accs = []
    final_epoch = 0

    for epoch in range(num_epochs):
        
        # ------------------- Train -------------------
        model.train()
        running_loss = 0.0

        for data in train_loader:
            inputs, labels = data
            
            optimizer.zero_grad()  # zero gradients

            # Forward pass
            outputs = model(inputs)
            labels = labels.long()

            # Calculate a loss
            loss = criterion(outputs, labels)
            loss.backward()  # Calculate gradients
            optimizer.step()  # Update weights

            running_loss += loss.item()

        # Average training loss
        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")
        train_losses.append(epoch_loss)

        # ------------------- Val -------------------
        model.eval() 
        val_loss = 0.0
        val_accuracy = 0.0

        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data
                
                outputs = model(inputs)
                
                labels = labels.long()

                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                accuracy = (predicted == labels).float().mean()
                val_accuracy += accuracy.item()

        val_loss /= len(val_loader)
        val_accuracy /= len(val_loader)

        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n")
        val_losses.append(val_loss)
        val_accs.append(val_accuracy)

        # ------------------- Early Stopping -------------------
        if val_loss < best_loss - min_delta:
            best_loss = val_loss
            early_stop_counter = 0
        else:
            early_stop_counter += 1

        if early_stop_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}.")
            final_epoch = epoch + 1
            break 

        # update scheduler
        scheduler.step(val_loss)

    return train_losses, val_losses, val_accs, final_epoch

In [32]:
criterion = nn.CrossEntropyLoss()

In [34]:
# Method for evaluation

def evaluate_model(
    model,
    test_loader,
    class_names,
    train_losses=None,
    val_losses=None,
    val_accs=None,
    num_classes=None,
    plot_curves=True,
    plot_conf_matrix=True,
    plot_roc=True,
    compute_auc=True,
    verbose=True
):
    model.eval()

    # 1. Plot training and validation loss/accuracy
    if plot_curves and train_losses is not None and val_losses is not None and val_accs is not None:
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Training Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(val_accs, label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.legend()

        plt.tight_layout()
        plt.savefig('training_validation_plot.png')
        plt.show()

    # 2. Get predictions and labels
    all_preds, all_labels = [], []
    correct, total = 0, 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(probs, 1)

            all_preds.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # 3. Accuracy
    accuracy = 100 * correct / total
    if verbose:
        print(f"\n✅ Accuracy on test set: {accuracy:.2f}%")

    # 4. Classification Report
    pred_classes = np.argmax(all_preds, axis=1)
    report = classification_report(all_labels, pred_classes, target_names=class_names, digits=4)
    if verbose:
        print("\n📋 Classification Report:")
        print(report)

    # 5. Confusion Matrix
    if plot_conf_matrix:
        cm = confusion_matrix(all_labels, pred_classes, labels=np.arange(len(class_names)))
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
        plt.xlabel('Predicted Class')
        plt.ylabel('True Class')
        plt.title('Confusion Matrix')
        plt.tight_layout()
        plt.savefig('confusion_matrix.png')
        plt.show()

    # 6. AUC Score
    aucs, mean_auc = None, None
    if compute_auc and num_classes is not None:
        aucs = []
        for i in range(num_classes):
            auc = roc_auc_score(all_labels == i, all_preds[:, i])
            aucs.append(auc)
        mean_auc = np.mean(aucs)
        if verbose:
            print(f"\n📈 AUC per class: {aucs}")
            print(f"📊 Mean AUC: {mean_auc:.4f}")

    # 7. ROC Curves
    if plot_roc:
        plt.figure(figsize=(10, 8))
        for i, class_name in enumerate(class_names):
            fpr, tpr, _ = roc_curve(all_labels == i, all_preds[:, i])
            auc = roc_auc_score(all_labels == i, all_preds[:, i])
            plt.plot(fpr, tpr, label=f'{class_name} (AUC = {auc:.4f})')
        plt.plot([0, 1], [0, 1], 'k--', label='Chance (AUC = 0.5)')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curves by Class')
        plt.legend(loc='lower right')
        plt.savefig('roc_curves_class.png')
        plt.show()

    return {
        "accuracy": accuracy,
        "classification_report": report,
        "aucs": aucs,
        "mean_auc": mean_auc
    }


In [54]:
# Method for random search

import random

def random_search(
    num_trials=10,
    neurons=None,
    dropout_rates=None,
    learning_rates=None,
    batch_sizes=None,
    weight_decays=None,
    train_ratio=None,
    val_ratio=None,
    X=None,
    y_encoded=None,
    input_dim=None,
    output_dim=None,
    num_epochs=None,
    device=None,
    criterion=None,
    ):
    
    best_accuracy = 0
    best_params = {}
    best_results = {}
    
    for _ in range(num_trials):
        # Randomly select parameters
        neurons_choice = random.choice(neurons)
        dropout_choice = random.choice(dropout_rates)
        learning_rate_choice = random.choice(learning_rates)
        batch_size_choice = random.choice(batch_sizes)
        weight_decay_choice = random.choice(weight_decays)

        # Prepare datasets
        datasets = prepare_datasets(
            X, y_encoded, 
            train_ratio=train_ratio,
            val_ratio=val_ratio,
            batch_size=batch_size_choice
        )

        # Create model
        model = EEGANNClassifier(input_dim, output_dim, neurons_choice, dropout_choice).to(device)

        optimizer = optim.Adam(
            model.parameters(), 
            lr=learning_rate_choice,
            weight_decay=weight_decay_choice
        )
        
        scheduler = ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor = 0.1,
            patience = 3,
            min_lr = 1e-6, 
            verbose = True
        )

        # Train model
        train_losses, val_losses, val_accs, final_epoch = train_model(
            model,
            datasets['loaders']['train'],
            datasets['loaders']['val'],
            criterion,
            optimizer,
            scheduler,
            num_epochs=num_epochs
        )
        
        print(f'{neurons_choice} neurons, {dropout_choice} dropout, {learning_rate_choice} learning rate, {batch_size_choice} batch size, {weight_decay_choice} weight decay')

        # Evaluate model on test set
        test_results = evaluate_model(
                                      model=model, 
                                      test_loader=datasets['loaders']['test'],
                                      class_names=column_names,
                                      train_losses=train_losses,
                                      val_losses=val_losses,
                                      val_accs=val_accs,
                                      num_classes=output_dim,
                                      )
        
        display(HTML('<hr>'))
        display(test_results['mean_auc'])
        display(HTML('<hr>'))
        display(datasets['X_scaled'].shape)
        
        if test_results['accuracy'] > best_accuracy:
            best_accuracy = test_results['accuracy']
            
            best_params = {
                'neurons': neurons_choice,
                'dropout': dropout_choice,
                'learning_rate': learning_rate_choice,
                'batch_size': batch_size_choice,
                'weight_decay': weight_decay_choice
            }
            
            best_results = {
                'final_epoch': final_epoch,
                'best_accuracy': test_results['accuracy'],
                'mean_auc': test_results['mean_auc'],
                'classification_report': test_results['classification_report'],
                'aucs': test_results['aucs'],
                'train_losses': train_losses,
                'val_losses': val_losses,
                'val_accs': val_accs
                
            }
    
    return best_params, best_results

In [55]:
from sklearn.preprocessing import RobustScaler

scaler2 = RobustScaler()

quantitative_features_scaled = scaler2.fit_transform(quantitative_features)

quantitative_features = pd.DataFrame(quantitative_features_scaled, columns=quantitative_features.columns)

quantitative_features

Unnamed: 0,age,education,IQ
0,2.556680,0.060827,-0.118573
1,0.873993,-1.950991,0.784104
2,0.453321,0.753534,0.441059
3,0.705724,1.294439,1.078143
4,0.789859,0.753534,0.392052
...,...,...,...
1857,-0.108971,0.753534,0.911058
1858,0.950637,0.539002,0.770250
1859,1.823119,-0.259537,-0.673635
1860,1.304066,-1.281228,-0.452245


In [56]:
# Method for training all subsets

def train_all_subsets(
    df_dict: dict,
    num_trials: int = 10,
    train_ratio = 0.7,
    val_ratio = 0.15,
    y=None,
    quantitative_features=None,
    neurons=None,
    dropout_rates=None,
    learning_rates=None,
    batch_sizes=None,
    weight_decays=None,
    num_epochs=500,
    device=None,
    criterion=nn.CrossEntropyLoss(),
):
    all_results = {}
    
    for dataset_name, df in df_dict.items():
        print(f"Training on {dataset_name}...")
        print(f"Dataset shape: {df.shape}")
        
        display(HTML('<hr>'))
        display(df.head())
        display(HTML('<hr>'))
        
        df_quantitative_features = pd.concat([quantitative_features, df], axis=1)
        
        best_params, best_results = random_search(
            num_trials=num_trials,
            neurons=neurons,
            dropout_rates=dropout_rates,
            learning_rates=learning_rates,
            batch_sizes=batch_sizes,
            weight_decays=weight_decays,
            train_ratio=train_ratio,
            val_ratio=val_ratio,
            X=df_quantitative_features,
            y_encoded=y,
            input_dim=df_quantitative_features.shape[1],
            output_dim=len(np.unique(y)),
            num_epochs=num_epochs,
            criterion=criterion
        )
        
        all_results[dataset_name] = {
            'best_params': best_params,
            'best_results': best_results
        }
        
    return all_results

### .

In [None]:
all_results = train_all_subsets(
    df_dict=df_dict_psd_all_band,
    num_trials=10,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    y=y_resampled,
    neurons=neurons_psd_all_bands,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    batch_sizes=batch_sizes,
    weight_decays=weight_decays,
    num_epochs=num_epochs
)

all_results2 =  train_all_subsets(
    df_dict=df_dict_fc_all_band,
    num_trials=10,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    y=y_resampled,
    neurons=neurons_fc_all_bands,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    batch_sizes=batch_sizes,
    weight_decays=weight_decays,
    num_epochs=num_epochs
)

all_results3 =  train_all_subsets(
    df_dict=df_dict_psd_fc_all_band,
    num_trials=10,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    y=y_resampled,
    neurons=neurons_psd_fc_all_bands,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    batch_sizes=batch_sizes,
    weight_decays=weight_decays,
    num_epochs=num_epochs
)

all_results4 =  train_all_subsets(
    df_dict=df_dict_psd_band,
    num_trials=10,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    y=y_resampled,
    neurons=neurons_psd_per_band,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    batch_sizes=batch_sizes,
    weight_decays=weight_decays,
    num_epochs=num_epochs
)

all_results5 =  train_all_subsets(
    df_dict=df_dict_fc_band,
    num_trials=10,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    y=y_resampled,
    neurons=neurons_fc_per_band,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    batch_sizes=batch_sizes,
    weight_decays=weight_decays,
    num_epochs=num_epochs
)

all_results6 =  train_all_subsets(
    df_dict=df_dict_psd_fc_band,
    num_trials=10,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    y=y_resampled,
    neurons=neurons_psd_fc_per_band,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    batch_sizes=batch_sizes,
    weight_decays=weight_decays,
    num_epochs=num_epochs
)

## | Results
Training/Validation Loss and Validation Accuracy

In [None]:
all_result_sets = [
    ("psd_all_band", all_results),
    ("fc_all_band", all_results2),
    ("psd_fc_all_band", all_results3),
    ("psd_per_band", all_results4),
    ("fc_per_band", all_results5),
    ("psd_fc_per_band", all_results6),
]

compiled_results = []

for group_name, result_dict in all_result_sets:
    for subset_name, data in result_dict.items():
        best_params = data["best_params"]
        best_results = data["best_results"]

        compiled_results.append({
            "group": group_name,
            "subset": subset_name,
            "final_epoch": best_results["final_epoch"],
            "best_accuracy": best_results["best_accuracy"],
            "mean_auc": best_results["mean_auc"],
            "train_losses": best_results["train_losses"],
            "val_losses": best_results["val_losses"],
            "val_accs": best_results["val_accs"],
            "classification_report": best_results["classification_report"],
            "aucs": best_results["aucs"],
            "neurons": best_params["neurons"],
            "dropout": best_params["dropout"],
            "learning_rate": best_params["learning_rate"],
            "batch_size": best_params["batch_size"],
            "weight_decay": best_params["weight_decay"]
        })

df_summary = pd.DataFrame(compiled_results)
df_summary

Unnamed: 0,group,subset,final_epoch,best_accuracy,mean_auc,train_losses,val_losses,val_accs,classification_report,aucs,neurons,dropout,learning_rate,batch_size,weight_decay
0,psd_all_band,psd_all_bands,143,57.5,0.887556,"[2.1971133095877513, 2.130477093514942, 2.1083...","[1.9643330812454223, 1.9422193050384522, 1.914...","[0.14864130467176437, 0.18301630467176438, 0.2...",precision ...,"[0.8062058714232628, 0.9505208333333333, 0.957...","[456, 228, 114, 57, 28]",0.2,0.0001,64,0.0001


## Conclusions, Problems, and Limitations

[need to update]

* The class imbalance is evident for the multiclasses of the main.disorder variable. The SMOTE technique was applied to balance the classes and perform data augmentation.

* The data pre-processed with PSD and FC generated 1140 features. PCA with 99% variance was applied to the data, and thus 223 principal components were extracted.

* To avoid overfitting, techniques such as dropout, batch normalization, and early stopping were applied.

* The weights were initialized in the network with xavier_uniform.

* The network has convergence issues due to the high variability of the data. Despite using PCA, there were no significant improvements in performance compared to using all features from PSD and FC.

* Due to the enrichment of the data, the minority class obsessive compulsive disorder shows a tendency to overfitting, as the original data was 46 and increased to 266 with Euclidean distance.