| Description

Authors of notebook: Mateus Balda and Alessandro Bof

Reference paper and dataset: https://www.frontiersin.org/journals/psychiatry/articles/10.3389/fpsyt.2021.707581/full

| Notebook structure
0. Setup and Imports  
1. Utility Functions  
2. Data Loading and Preprocessing  
3. Data Preparation for Training  
4. Model Definition  
5. Training and Evaluation  
6. Results
7. Conclusions, Problems, and Limitations

| Training
1. Binary classification for `y` main.disorder (Disorder vs. Health control)
2. Balancing using BorderlineSMOTE for minority classes
3. Training across the 21 subsets

## | 0. Setup and Imports

In [1]:
!pip install -q imbalanced-learn==0.12.4;
!pip install -q pytorch_tabnet;

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.3/258.3 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.5/44.5 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

from IPython.display import FileLink, display, HTML
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.impute import KNNImputer

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau

from imblearn.over_sampling import BorderlineSMOTE

from warnings import filterwarnings
filterwarnings('ignore')

torch.__version__

'2.5.1+cu124'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
experiment_name = "binary_classification_main_disorder_tabnet"
device

device(type='cuda')

In [None]:
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x787ea41d41b0>

## | 1. Utility Functions

### 1. OUTLIERS

In [5]:
def detect_outliers_summary(df):
    outliers_summary = {}

    for col in df.select_dtypes(include=['float64', 'int64']).columns:
        Q1 = df[col].quantile(0.25)
        Q3 = df[col].quantile(0.75)
        IQR = Q3 - Q1
        lower_limit = Q1 - 1.5 * IQR
        upper_limit = Q3 + 1.5 * IQR

        outliers = df[(df[col] < lower_limit) | (df[col] > upper_limit)][col]
        
        outliers_summary[col] = {
            'num_outliers': len(outliers),
            'percent_outliers': len(outliers) / len(df) * 100,
            'outliers': outliers.tolist(),
            'lower_limit': lower_limit,
            'upper_limit': upper_limit
        }

    return pd.DataFrame(outliers_summary).T


def treat_all_outliers_iqr(df, factor=1.5):
    df_treated = df.copy()
    
    for column in df_treated.select_dtypes(include=[np.number]).columns:
        Q1 = np.percentile(df_treated[column], 25)
        Q3 = np.percentile(df_treated[column], 75)
        IQR = Q3 - Q1

        lower_bound = Q1 - factor * IQR
        upper_bound = Q3 + factor * IQR

        df_treated[column] = np.clip(df_treated[column], lower_bound, upper_bound)

    return df_treated

### 2. NANS

In [6]:
def remove_missing_columns(df, threshold=0.5):
    limit = int(threshold * len(df))
    df = df.dropna(thresh=limit, axis=1)
    return df
    
def find_most_null_column(df, threshold=0.5):
    null_ratios = df.isnull().mean()
    for col, ration in null_ratios.items():
        if ration > threshold:
            return col
    return None

def analyze_missing_values(df):
    missing_values = df.isnull().sum()
    missing_values = missing_values[missing_values > 0]
    total_number_nans = df.isnull().sum().sum()
    
    return missing_values, total_number_nans

def handle_nans(df):
    columns_with_nans = df.columns[df.isnull().any()].tolist()
    
    knn_imputer = KNNImputer(n_neighbors=5, weights='uniform', metric='nan_euclidean')
    
    df_imputed = pd.DataFrame(knn_imputer.fit_transform(df[columns_with_nans]),
                                columns=columns_with_nans)
    
    df[columns_with_nans] = df_imputed[columns_with_nans]
    
    return df

### 3. SETS & SUBSETS

In [None]:
class Sets:
    # AB = PSD (Power Spectral Density) 19 * 6
    # COH = FC (Functional Connectivity) 171 * 6
    
    def __init__(self, dataframe: pd.DataFrame):
        
        if not isinstance(dataframe, pd.DataFrame):
            raise ValueError("The parameter must be a pandas DataFrame")
        
        self.df = dataframe.copy()
        self.df_ab_psd = None
        self.df_coh_fc = None
        self.df_ab_psd_coh_fc = None
        
        self.__create_df_psd()
        self.__create_df_fc()
        self.__create_union_psd_fc()

    # 19 (6 bands)
    def __create_df_psd(self):
        columns_AB = [col for col in self.df.columns if col.startswith('AB')]
        self.df_ab_psd = self.df[columns_AB]
    
    # 171 (6 bands)
    def __create_df_fc(self):
        columns_COH = [col for col in self.df.columns if col.startswith('COH')]
        self.df_coh_fc = self.df[columns_COH]
        
    def __create_union_psd_fc(self):
        if self.df_ab_psd is not None and self.df_coh_fc is not None:
            self.df_ab_psd_coh_fc = pd.concat([self.df_ab_psd, self.df_coh_fc], axis=1)
        else:
            raise ValueError("Subsets AB and COH were not created correctly")

    def create_dfs_bands(self, bands: list[str] = None, df: pd.DataFrame = None):
        dfs_bands = {}
        
        if bands is None:
            bands = ['delta', 'theta', 'alpha', 'beta', 'highbeta','gamma']
        
        for band in bands:
            columns_band = [col for col in df.columns if f'.{band}.' in col]
            if columns_band:
                dfs_bands[band] = df[columns_band]
        
        return dfs_bands

## | 2. Data Loading and Preprocessing

In [8]:
df = pd.read_csv('../input/eeg-psychiatric-disorders-dataset/EEG.machinelearing_data_BRMH.csv')
df.shape

(945, 1149)

In [9]:
# Checking / imputing Nan

output1 = analyze_missing_values(df)
output2 = find_most_null_column(df)
df = remove_missing_columns(df)
output3 = analyze_missing_values(df)
df = handle_nans(df)

display(output1)
display(HTML('<hr>'))
display(output2)
display(HTML('<hr>'))
display(output3)
display(df.shape)
display(HTML('<hr>'))
display(df.isna().sum().sum())
display(df)

(education        15
 IQ               13
 Unnamed: 122    945
 dtype: int64,
 973)

'Unnamed: 122'

(education    15
 IQ           13
 dtype: int64,
 28)

(945, 1148)

0

Unnamed: 0,no.,sex,age,eeg.date,education,IQ,main.disorder,specific.disorder,AB.A.delta.a.FP1,AB.A.delta.b.FP2,...,COH.F.gamma.o.Pz.p.P4,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2
0,1,M,57.0,2012.8.30,13.43871,101.580472,Addictive disorder,Alcohol use disorder,35.998557,21.717375,...,55.989192,16.739679,23.452271,45.678820,30.167520,16.918761,48.850427,9.422630,34.507082,28.613029
1,2,M,37.0,2012.9.6,6.00000,120.000000,Addictive disorder,Alcohol use disorder,13.425118,11.002916,...,45.595619,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261
2,3,M,32.0,2012.9.10,16.00000,113.000000,Addictive disorder,Alcohol use disorder,29.941780,27.544684,...,99.475453,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799
3,4,M,35.0,2012.10.8,18.00000,126.000000,Addictive disorder,Alcohol use disorder,21.496226,21.846832,...,59.986561,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873
4,5,M,36.0,2012.10.18,16.00000,112.000000,Addictive disorder,Alcohol use disorder,37.775667,33.607679,...,61.462720,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
940,941,M,22.0,2014.8.28,13.00000,116.000000,Healthy control,Healthy control,41.851823,36.771496,...,82.905657,34.850706,63.970519,63.982003,51.244725,62.203684,62.062237,31.013031,31.183413,98.325230
941,942,M,26.0,2014.9.19,13.00000,118.000000,Healthy control,Healthy control,18.986856,19.401387,...,65.917918,66.700117,44.756285,49.787513,98.905995,54.021304,93.902401,52.740396,92.807331,56.320868
942,943,M,26.0,2014.9.27,16.00000,113.000000,Healthy control,Healthy control,28.781317,32.369230,...,61.040959,27.632209,45.552852,33.638817,46.690983,19.382928,41.050717,7.045821,41.962451,19.092111
943,944,M,24.0,2014.9.20,13.00000,107.000000,Healthy control,Healthy control,19.929100,25.196375,...,99.113664,48.328934,41.248470,28.192238,48.665743,42.007147,28.735945,27.176500,27.529522,20.028446


In [10]:
# detecting outliers

outliers_summary = detect_outliers_summary(df)
outliers_summary.to_csv('outliers_summary.csv', index=True)
outliers_summary['num_outliers'].sum()

26452

In [None]:
label_encoder = LabelEncoder()
target_name = 'main.disorder'

X = df.iloc[:,8:]
target = df[target_name]

quantitative_features = df.loc[:, ['age', 'education', 'IQ']]

sex = df['sex']
sex_encoded = label_encoder.fit_transform(sex)
sex_encoded = pd.Series(sex_encoded, name='sex')

X = pd.concat([quantitative_features, sex_encoded, X], axis=1)

X.shape, target.shape, quantitative_features.shape, sex.shape

((945, 1144), (945,), (945, 3), (945,))

In [None]:
disorders = np.unique(target).tolist()
disorders.remove('Healthy control')
disorders

['Addictive disorder',
 'Anxiety disorder',
 'Mood disorder',
 'Obsessive compulsive disorder',
 'Schizophrenia',
 'Trauma and stress related disorder']

In [13]:
X_concated = pd.concat([X, target], axis=1)

display(X.head())
display(X_concated.head())
display(X.isna().sum().sum())
display(X_concated.isna().sum().sum())

Unnamed: 0,age,education,IQ,sex,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,...,COH.F.gamma.o.Pz.p.P4,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2
0,57.0,13.43871,101.580472,1,35.998557,21.717375,21.51828,26.825048,26.611516,25.732649,...,55.989192,16.739679,23.452271,45.67882,30.16752,16.918761,48.850427,9.42263,34.507082,28.613029
1,37.0,6.0,120.0,1,13.425118,11.002916,11.942516,15.272216,14.15157,12.456034,...,45.595619,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261
2,32.0,16.0,113.0,1,29.94178,27.544684,17.150159,23.60896,27.087811,13.541237,...,99.475453,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799
3,35.0,18.0,126.0,1,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,...,59.986561,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873
4,36.0,16.0,112.0,1,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,...,61.46272,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662


Unnamed: 0,age,education,IQ,sex,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,...,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2,main.disorder
0,57.0,13.43871,101.580472,1,35.998557,21.717375,21.51828,26.825048,26.611516,25.732649,...,16.739679,23.452271,45.67882,30.16752,16.918761,48.850427,9.42263,34.507082,28.613029,Addictive disorder
1,37.0,6.0,120.0,1,13.425118,11.002916,11.942516,15.272216,14.15157,12.456034,...,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261,Addictive disorder
2,32.0,16.0,113.0,1,29.94178,27.544684,17.150159,23.60896,27.087811,13.541237,...,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799,Addictive disorder
3,35.0,18.0,126.0,1,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,...,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873,Addictive disorder
4,36.0,16.0,112.0,1,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,...,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662,Addictive disorder


0

0

In [14]:
# treat outliers X_concated

X = treat_all_outliers_iqr(X, factor=1.5)
outliers_summary = detect_outliers_summary(X)
outliers_summary['num_outliers'].sum()

0

In [15]:
binary_datasets = {}

for disorder in disorders:
    
    mask = X_concated[target_name].isin(['Healthy control', disorder])
    
    df_filtered = X_concated[mask]
    df_filtered_X = X.loc[mask]
    
    y_binary = df_filtered[target_name].apply(lambda x: 0 if x == 'Healthy control' else 1)
    
    borderline_smote = BorderlineSMOTE(kind='borderline-2', random_state=42, k_neighbors=10, m_neighbors=15)
    X_resampled, y_resampled = borderline_smote.fit_resample(df_filtered_X, y_binary)
    
    dataset = pd.concat([X_resampled, y_resampled], axis=1)
    
    binary_datasets[disorder] = {
        'data': dataset,
        'labels': ['Healthy control', disorder]
    }

In [16]:
binary_datasets[disorders[0]]['data']

Unnamed: 0,age,education,IQ,sex,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,...,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2,main.disorder
0,56.030000,13.438710,101.580472,1,35.998557,21.717375,21.518280,26.825048,26.611516,25.732649,...,16.739679,23.452271,45.678820,30.167520,16.918761,48.850427,9.422630,34.507082,28.613029,1
1,37.000000,6.000000,120.000000,1,13.425118,11.002916,11.942516,15.272216,14.151570,12.456034,...,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261,1
2,32.000000,16.000000,113.000000,1,29.941780,27.544684,17.150159,23.608960,27.087811,13.541237,...,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799,1
3,35.000000,18.000000,126.000000,1,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,...,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873,1
4,36.000000,16.000000,112.000000,1,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,...,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
367,23.936684,12.914847,106.786020,0,14.437635,12.248253,12.008394,15.753527,16.710368,12.349961,...,62.841966,75.293952,64.040256,95.871374,46.718100,79.288627,44.872529,78.717620,54.905455,0
368,23.931990,15.948992,102.282243,1,13.138764,21.004126,12.924976,13.116845,17.862602,17.925295,...,16.902791,39.786575,38.205377,32.522101,29.408334,52.904032,10.820888,34.237195,32.547049,0
369,22.000000,13.000000,133.026971,1,15.690757,14.745960,11.360117,15.665005,20.809406,17.723404,...,39.895060,36.969294,35.186558,50.849725,31.069971,42.293424,21.050751,35.997421,47.988328,0
370,25.983926,16.190356,121.454457,1,17.235063,18.973345,13.134278,18.516893,21.669362,22.479819,...,64.383890,72.705952,77.612578,74.304228,61.461376,81.776479,46.188054,78.515296,67.269585,0


In [None]:
sets_binaries = {}

for disorder, bin_data in binary_datasets.items():
    df_bin = bin_data['data'].copy()
    
    X_bin = df_bin.drop(columns=[target_name,'age', 'education', 'IQ', 'sex'], errors='ignore') 
    X_quantitative = df_bin.loc[:, ['age', 'education', 'IQ', 'sex']]
    y_bin = df_bin[target_name] if target_name in df_bin.columns else df_bin.iloc[:, -1]

    sets = Sets(dataframe=X_bin)

    sets_binaries[disorder] = {
        'sets': sets,
        'labels': bin_data['labels'],
        'X_quantitative': X_quantitative,
        'y': y_bin
    }

In [18]:
for disorder in sets_binaries:
    sets = sets_binaries[disorder]['sets']

    dfs_bands_psd = sets.create_dfs_bands(df=sets.df_ab_psd)
    dfs_bands_fc = sets.create_dfs_bands(df=sets.df_coh_fc)
    dfs_bands_psd_fc = sets.create_dfs_bands(df=sets.df_ab_psd_coh_fc)

    sets_binaries[disorder].update({
        'dfs_bands_psd': dfs_bands_psd,
        'dfs_bands_fc': dfs_bands_fc,
        'dfs_bands_psd_fc': dfs_bands_psd_fc,
    })

In [19]:
for disorder, data in sets_binaries.items():
    print(f"Disorder: {disorder}")
    print(f"Labels: {data['labels']}")
    print(f"Number of samples: {len(data['y'])}")
    print(f"Quantitative features shape: {data['X_quantitative'].shape}")
    print(f"PSD bands shape: {data['sets'].df_ab_psd.shape}")
    print(f"FC bands shape: {data['sets'].df_coh_fc.shape}")
    print(f"PSD + FC bands shape: {data['sets'].df_ab_psd_coh_fc.shape}\n")

Disorder: Addictive disorder
Labels: ['Healthy control', 'Addictive disorder']
Number of samples: 372
Quantitative features shape: (372, 4)
PSD bands shape: (372, 114)
FC bands shape: (372, 1026)
PSD + FC bands shape: (372, 1140)

Disorder: Anxiety disorder
Labels: ['Healthy control', 'Anxiety disorder']
Number of samples: 214
Quantitative features shape: (214, 4)
PSD bands shape: (214, 114)
FC bands shape: (214, 1026)
PSD + FC bands shape: (214, 1140)

Disorder: Mood disorder
Labels: ['Healthy control', 'Mood disorder']
Number of samples: 532
Quantitative features shape: (532, 4)
PSD bands shape: (532, 114)
FC bands shape: (532, 1026)
PSD + FC bands shape: (532, 1140)

Disorder: Obsessive compulsive disorder
Labels: ['Healthy control', 'Obsessive compulsive disorder']
Number of samples: 190
Quantitative features shape: (190, 4)
PSD bands shape: (190, 114)
FC bands shape: (190, 1026)
PSD + FC bands shape: (190, 1140)

Disorder: Schizophrenia
Labels: ['Healthy control', 'Schizophrenia']

In [None]:
dfs_dicts_binaries = {}

for disorder, data in sets_binaries.items():
    dfs_psd = data['dfs_bands_psd']
    dfs_fc = data['dfs_bands_fc']
    dfs_psd_fc = data['dfs_bands_psd_fc']
    sets = data['sets']

    df_dict_psd_all_bands = {
        'psd_all_bands': sets.df_ab_psd,
    }

    df_dict_fc_all_bands = {
        'fc_all_bands': sets.df_coh_fc,
    }

    df_dict_psd_fc_all_bands = {
        'psd_fc_all_bands': sets.df_ab_psd_coh_fc,
    }

    df_dict_psd_band = {
        f'psd_{band}': dfs_psd[band] for band in dfs_psd
    }

    df_dict_fc_band = {
        f'fc_{band}': dfs_fc[band] for band in dfs_fc
    }

    df_dict_psd_fc_band = {
        f'psd_fc_{band}': dfs_psd_fc[band] for band in dfs_psd_fc
    }

    dfs_dicts_binaries[disorder] = {
        'df_dict_psd_all_bands': df_dict_psd_all_bands,
        'df_dict_fc_all_bands': df_dict_fc_all_bands,
        'df_dict_psd_fc_all_bands': df_dict_psd_fc_all_bands,
        'df_dict_psd_band': df_dict_psd_band,
        'df_dict_fc_band': df_dict_fc_band,
        'df_dict_psd_fc_band': df_dict_psd_fc_band,
    }

In [21]:
from pytorch_tabnet.tab_model import TabNetClassifier

In [22]:
clf = TabNetClassifier(
    optimizer_fn=torch.optim.Adam,
    scheduler_params={"step_size":30, 
                        "gamma":0.9},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    verbose=10,
)

In [23]:
from sklearn.preprocessing import RobustScaler, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score
from collections import defaultdict

all_results = {}

for disorder, dicts in dfs_dicts_binaries.items():
    print(f"\n=== Disorder: {disorder} ===")
    
    y_binary = sets_binaries[disorder]['y']
    X_quantitative = sets_binaries[disorder]['X_quantitative']
    labels = sets_binaries[disorder]['labels']
    
    all_results[disorder] = {}
    
    for dict_name, df_dict in dicts.items():
        if not dict_name.startswith('df_dict_'):
            continue
        
        for df_name, df in df_dict.items():
            print(f"\n--- Processing: {df_name} ---")
            print("Shape:", df.shape)
            print("Missing values:", df.isna().sum().sum())
            
            df_full = pd.concat([X_quantitative, df], axis=1)
            assert all(df_full.index == y_binary.index), 'Misaligned indexes'
            
            X_train, X_test, y_train, y_test = train_test_split(
                df_full,
                y_binary,
                train_size=0.80,
                stratify=y_binary,
                random_state=42
            )
            
            X_cols = X_train.drop(columns=['sex']).columns
            
            scaler = RobustScaler()
            X_train_scaled = scaler.fit_transform(X_train[X_cols])
            X_test_scaled = scaler.transform(X_test[X_cols])
            
            X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X_cols, index=X_train.index)
            X_test_scaled_df = pd.DataFrame(X_test_scaled, columns=X_cols, index=X_test.index)

            X_train_final = pd.concat([X_train_scaled_df, X_train['sex']], axis=1)
            X_test_final = pd.concat([X_test_scaled_df, X_test['sex']], axis=1)

            X_train_array = X_train_final.to_numpy()
            X_test_array = X_test_final.to_numpy()
            y_train_array = y_train.to_numpy()
            y_test_array = y_test.to_numpy()
            
            clf.fit(
                X_train_array, y_train_array,
                eval_set=[(X_train_array, y_train_array), (X_test_array, y_test_array)],
                eval_name=['train', 'test'],
                eval_metric=['auc','accuracy'],
                max_epochs=200, patience=60,
                batch_size=48, virtual_batch_size=24,
                num_workers=0,
                weights=1,
                drop_last=True
            )
            
            y_pred = clf.predict(X_test_array)
            y_proba = clf.predict_proba(X_test_array)[:, 1]

            acc = accuracy_score(y_test_array, y_pred)
            auc = roc_auc_score(y_test_array, y_proba)
            report = classification_report(y_test_array, y_pred)
            
            all_results.setdefault(disorder, {}).setdefault(dict_name, {})[df_name] = {
                'accuracy': acc,
                'auc': auc,
                'classification_report': report,
                'features': list(X_test_final.columns),
                'X_train_shape': X_train_array.shape,
                'X_val_shape': X_test_array.shape,
                'clf': clf
            }


=== Disorder: Addictive disorder ===

--- Processing: psd_all_bands ---
Shape: (372, 114)
Missing values: 0
epoch 0  | loss: 0.75306 | train_auc: 0.52206 | train_accuracy: 0.51852 | test_auc: 0.44523 | test_accuracy: 0.46667 |  0:00:01s
epoch 10 | loss: 0.69248 | train_auc: 0.58271 | train_accuracy: 0.56566 | test_auc: 0.5825  | test_accuracy: 0.56    |  0:00:03s
epoch 20 | loss: 0.67129 | train_auc: 0.60017 | train_accuracy: 0.57576 | test_auc: 0.51636 | test_accuracy: 0.57333 |  0:00:05s
epoch 30 | loss: 0.68976 | train_auc: 0.59931 | train_accuracy: 0.54209 | test_auc: 0.633   | test_accuracy: 0.54667 |  0:00:08s
epoch 40 | loss: 0.6478  | train_auc: 0.69883 | train_accuracy: 0.62626 | test_auc: 0.64083 | test_accuracy: 0.6     |  0:00:10s
epoch 50 | loss: 0.62572 | train_auc: 0.74134 | train_accuracy: 0.63973 | test_auc: 0.55974 | test_accuracy: 0.58667 |  0:00:12s
epoch 60 | loss: 0.6033  | train_auc: 0.7481  | train_accuracy: 0.6633  | test_auc: 0.64509 | test_accuracy: 0.6     

## | 6. Results

In [24]:
rows = []

for disorder, dicts in all_results.items():
    for dict_name, df_results in dicts.items():
        for df_name, metrics in df_results.items():
            row = {
                'Disorder': disorder,
                'Dict': dict_name,
                'DataFrame': df_name,
                'Accuracy': metrics['accuracy'],
                'AUC': metrics['auc'],
                'Classification Report': metrics['classification_report']
            }
            
            rows.append(row)

results_df = pd.DataFrame(rows)
results_df.to_csv(f'{experiment_name}.csv', index=False)
results_df

Unnamed: 0,Disorder,Dict,DataFrame,Accuracy,AUC,Classification Report
0,Addictive disorder,df_dict_psd_all_bands,psd_all_bands,0.920000,0.937411,precision recall f1-score ...
1,Addictive disorder,df_dict_fc_all_bands,fc_all_bands,0.893333,0.923186,precision recall f1-score ...
2,Addictive disorder,df_dict_psd_fc_all_bands,psd_fc_all_bands,0.786667,0.842105,precision recall f1-score ...
3,Addictive disorder,df_dict_psd_band,psd_delta,0.893333,0.894737,precision recall f1-score ...
4,Addictive disorder,df_dict_psd_band,psd_theta,0.946667,0.951636,precision recall f1-score ...
...,...,...,...,...,...,...
121,Trauma and stress related disorder,df_dict_psd_fc_band,psd_fc_theta,0.903846,0.889053,precision recall f1-score ...
122,Trauma and stress related disorder,df_dict_psd_fc_band,psd_fc_alpha,0.750000,0.781065,precision recall f1-score ...
123,Trauma and stress related disorder,df_dict_psd_fc_band,psd_fc_beta,0.903846,0.933432,precision recall f1-score ...
124,Trauma and stress related disorder,df_dict_psd_fc_band,psd_fc_highbeta,0.826923,0.868343,precision recall f1-score ...


## | 7. Conclusions, Problems, and Limitations

[need to update]