| Description

Authors of notebook: Mateus Balda and Alessandro Bof

Reference paper and dataset: https://www.frontiersin.org/journals/psychiatry/articles/10.3389/fpsyt.2021.707581/full

| Notebook structure
0. Setup and Imports  
1. Utility Functions  
2. Data Loading and Preprocessing  
3. Data Preparation for Training  
4. Model Definition  
5. Training and Evaluation  
6. Results
7. Conclusions, Problems, and Limitations

| Training
1. Binary classification for `y` main.disorder (Disorder vs. Health control)
2. Balancing using BorderlineSMOTE for minority classes
3. Training across the 21 subsets

## | 0. Setup and Imports

In [1]:
!pip install -q imbalanced-learn==0.12.4;
!pip install rtdl;
!pip install libzero==0.0.4;

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.3/258.3 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hCollecting rtdl
  Downloading rtdl-0.0.13-py3-none-any.whl.metadata (1.0 kB)
Collecting torch<2,>=1.7 (from rtdl)
  Downloading torch-1.13.1-cp311-cp311-manylinux1_x86_64.whl.metadata (24 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch<2,>=1.7->rtdl)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch<2,>=1.7->rtdl)
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu11==11.10.3.66 (from torch<2,>=1.7->rtdl)
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch<2,>=1.7->rtdl)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Downloading rtdl-0.

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

from IPython.display import FileLink, display, HTML
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.impute import KNNImputer

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau

from imblearn.over_sampling import BorderlineSMOTE

from warnings import filterwarnings
filterwarnings('ignore')

torch.__version__

'1.13.1+cu117'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
experiment_name = "binary_classification_main_disorder_fttransformer"
device

device(type='cuda')

In [4]:
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x7ba0f8467330>

## | 1. Utility Functions

### 1. OUTLIERS

In [5]:
def detect_outliers_summary(df):
    outliers_summary = {}

    for col in df.select_dtypes(include=['float64', 'int64']).columns:
        Q1 = df[col].quantile(0.25)
        Q3 = df[col].quantile(0.75)
        IQR = Q3 - Q1
        lower_limit = Q1 - 1.5 * IQR
        upper_limit = Q3 + 1.5 * IQR

        outliers = df[(df[col] < lower_limit) | (df[col] > upper_limit)][col]
        
        outliers_summary[col] = {
            'num_outliers': len(outliers),
            'percent_outliers': len(outliers) / len(df) * 100,
            'outliers': outliers.tolist(),
            'lower_limit': lower_limit,
            'upper_limit': upper_limit
        }

    return pd.DataFrame(outliers_summary).T


def treat_all_outliers_iqr(df, factor=1.5):
    df_treated = df.copy()
    
    for column in df_treated.select_dtypes(include=[np.number]).columns:
        Q1 = np.percentile(df_treated[column], 25)
        Q3 = np.percentile(df_treated[column], 75)
        IQR = Q3 - Q1

        lower_bound = Q1 - factor * IQR
        upper_bound = Q3 + factor * IQR

        df_treated[column] = np.clip(df_treated[column], lower_bound, upper_bound)

    return df_treated

### 2. NANS

In [6]:
def remove_missing_columns(df, threshold=0.5):
    limit = int(threshold * len(df))
    df = df.dropna(thresh=limit, axis=1)
    return df
    
def find_most_null_column(df, threshold=0.5):
    null_ratios = df.isnull().mean()
    for col, ration in null_ratios.items():
        if ration > threshold:
            return col
    return None

def analyze_missing_values(df):
    missing_values = df.isnull().sum()
    missing_values = missing_values[missing_values > 0]
    total_number_nans = df.isnull().sum().sum()
    
    return missing_values, total_number_nans

def handle_nans(df):
    columns_with_nans = df.columns[df.isnull().any()].tolist()
    
    knn_imputer = KNNImputer(n_neighbors=5, weights='uniform', metric='nan_euclidean')
    
    df_imputed = pd.DataFrame(knn_imputer.fit_transform(df[columns_with_nans]),
                                columns=columns_with_nans)
    
    df[columns_with_nans] = df_imputed[columns_with_nans]
    
    return df

### 3. SETS & SUBSETS

In [7]:
class Sets:
    # AB = PSD (Power Spectral Density) 19 * 6
    # COH = FC (Functional Connectivity) 171 * 6
    
    def __init__(self, dataframe: pd.DataFrame):
        
        if not isinstance(dataframe, pd.DataFrame):
            raise ValueError("The parameter must be a pandas DataFrame")
        
        self.df = dataframe.copy()
        self.df_ab_psd = None
        self.df_coh_fc = None
        self.df_ab_psd_coh_fc = None
        
        self.__create_df_psd()
        self.__create_df_fc()
        self.__create_union_psd_fc()

    # 19 (6 bands)
    def __create_df_psd(self):
        columns_AB = [col for col in self.df.columns if col.startswith('AB')]
        self.df_ab_psd = self.df[columns_AB]
    
    # 171 (6 bands)
    def __create_df_fc(self):
        columns_COH = [col for col in self.df.columns if col.startswith('COH')]
        self.df_coh_fc = self.df[columns_COH]
        
    def __create_union_psd_fc(self):
        if self.df_ab_psd is not None and self.df_coh_fc is not None:
            self.df_ab_psd_coh_fc = pd.concat([self.df_ab_psd, self.df_coh_fc], axis=1)
        else:
            raise ValueError("Subsets AB and COH were not created correctly")

    def create_dfs_bands(self, bands: list[str] = None, df: pd.DataFrame = None):
        dfs_bands = {}
        
        if bands is None:
            bands = ['delta', 'theta', 'alpha', 'beta', 'highbeta','gamma']
        
        for band in bands:
            columns_band = [col for col in df.columns if f'.{band}.' in col]
            if columns_band:
                dfs_bands[band] = df[columns_band]
        
        return dfs_bands

## | 2. Data Loading and Preprocessing

In [8]:
df = pd.read_csv('../input/eeg-psychiatric-disorders-dataset/EEG.machinelearing_data_BRMH.csv')
df.shape

(945, 1149)

In [9]:
# Checking / imputing Nan

output1 = analyze_missing_values(df)
output2 = find_most_null_column(df)
df = remove_missing_columns(df)
output3 = analyze_missing_values(df)
df = handle_nans(df)

display(output1)
display(HTML('<hr>'))
display(output2)
display(HTML('<hr>'))
display(output3)
display(df.shape)
display(HTML('<hr>'))
display(df.isna().sum().sum())
display(df)

(education        15
 IQ               13
 Unnamed: 122    945
 dtype: int64,
 973)

'Unnamed: 122'

(education    15
 IQ           13
 dtype: int64,
 28)

(945, 1148)

0

Unnamed: 0,no.,sex,age,eeg.date,education,IQ,main.disorder,specific.disorder,AB.A.delta.a.FP1,AB.A.delta.b.FP2,...,COH.F.gamma.o.Pz.p.P4,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2
0,1,M,57.0,2012.8.30,13.43871,101.580472,Addictive disorder,Alcohol use disorder,35.998557,21.717375,...,55.989192,16.739679,23.452271,45.678820,30.167520,16.918761,48.850427,9.422630,34.507082,28.613029
1,2,M,37.0,2012.9.6,6.00000,120.000000,Addictive disorder,Alcohol use disorder,13.425118,11.002916,...,45.595619,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261
2,3,M,32.0,2012.9.10,16.00000,113.000000,Addictive disorder,Alcohol use disorder,29.941780,27.544684,...,99.475453,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799
3,4,M,35.0,2012.10.8,18.00000,126.000000,Addictive disorder,Alcohol use disorder,21.496226,21.846832,...,59.986561,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873
4,5,M,36.0,2012.10.18,16.00000,112.000000,Addictive disorder,Alcohol use disorder,37.775667,33.607679,...,61.462720,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
940,941,M,22.0,2014.8.28,13.00000,116.000000,Healthy control,Healthy control,41.851823,36.771496,...,82.905657,34.850706,63.970519,63.982003,51.244725,62.203684,62.062237,31.013031,31.183413,98.325230
941,942,M,26.0,2014.9.19,13.00000,118.000000,Healthy control,Healthy control,18.986856,19.401387,...,65.917918,66.700117,44.756285,49.787513,98.905995,54.021304,93.902401,52.740396,92.807331,56.320868
942,943,M,26.0,2014.9.27,16.00000,113.000000,Healthy control,Healthy control,28.781317,32.369230,...,61.040959,27.632209,45.552852,33.638817,46.690983,19.382928,41.050717,7.045821,41.962451,19.092111
943,944,M,24.0,2014.9.20,13.00000,107.000000,Healthy control,Healthy control,19.929100,25.196375,...,99.113664,48.328934,41.248470,28.192238,48.665743,42.007147,28.735945,27.176500,27.529522,20.028446


In [10]:
# detecting outliers

outliers_summary = detect_outliers_summary(df)
outliers_summary.to_csv('outliers_summary.csv', index=True)
outliers_summary['num_outliers'].sum()

26452

In [11]:
label_encoder = LabelEncoder()
target_name = 'main.disorder'

X = df.iloc[:,8:]
target_main = df[target_name]

quantitative_features = df.loc[:, ['age', 'education', 'IQ']]

sex = df['sex']
sex_encoded = label_encoder.fit_transform(sex)
sex_encoded = pd.Series(sex_encoded, name='sex')

X = pd.concat([quantitative_features, sex_encoded, X], axis=1)

X.shape, target_main.shape, quantitative_features.shape, sex.shape

((945, 1144), (945,), (945, 3), (945,))

In [12]:
main_disorders = np.unique(target_main).tolist()
main_disorders.remove('Healthy control')
main_disorders

['Addictive disorder',
 'Anxiety disorder',
 'Mood disorder',
 'Obsessive compulsive disorder',
 'Schizophrenia',
 'Trauma and stress related disorder']

In [13]:
X_concated = pd.concat([X, target_main], axis=1)

display(X.head())
display(X_concated.head())
display(X.isna().sum().sum())
display(X_concated.isna().sum().sum())

Unnamed: 0,age,education,IQ,sex,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,...,COH.F.gamma.o.Pz.p.P4,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2
0,57.0,13.43871,101.580472,1,35.998557,21.717375,21.51828,26.825048,26.611516,25.732649,...,55.989192,16.739679,23.452271,45.67882,30.16752,16.918761,48.850427,9.42263,34.507082,28.613029
1,37.0,6.0,120.0,1,13.425118,11.002916,11.942516,15.272216,14.15157,12.456034,...,45.595619,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261
2,32.0,16.0,113.0,1,29.94178,27.544684,17.150159,23.60896,27.087811,13.541237,...,99.475453,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799
3,35.0,18.0,126.0,1,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,...,59.986561,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873
4,36.0,16.0,112.0,1,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,...,61.46272,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662


Unnamed: 0,age,education,IQ,sex,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,...,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2,main.disorder
0,57.0,13.43871,101.580472,1,35.998557,21.717375,21.51828,26.825048,26.611516,25.732649,...,16.739679,23.452271,45.67882,30.16752,16.918761,48.850427,9.42263,34.507082,28.613029,Addictive disorder
1,37.0,6.0,120.0,1,13.425118,11.002916,11.942516,15.272216,14.15157,12.456034,...,17.510824,26.777368,28.201062,57.108861,32.375401,60.351749,13.900981,57.831848,43.463261,Addictive disorder
2,32.0,16.0,113.0,1,29.94178,27.544684,17.150159,23.60896,27.087811,13.541237,...,70.654171,39.131547,69.920996,71.063644,38.534505,69.908764,27.180532,64.803155,31.485799,Addictive disorder
3,35.0,18.0,126.0,1,21.496226,21.846832,17.364316,13.833701,14.100954,13.100939,...,63.822201,36.478254,47.117006,84.658376,24.724096,50.299349,35.319695,79.822944,41.141873,Addictive disorder
4,36.0,16.0,112.0,1,37.775667,33.607679,21.865556,21.771413,22.854536,21.456377,...,59.166097,51.465531,58.635415,80.685608,62.138436,75.888749,61.003944,87.455509,70.531662,Addictive disorder


0

0

In [14]:
# treat outliers X_concated

X = treat_all_outliers_iqr(X, factor=1.5)
outliers_summary = detect_outliers_summary(X)
outliers_summary['num_outliers'].sum()

0

In [15]:
binary_datasets = {}

for disorder in main_disorders:
    
    mask = X_concated[target_name].isin(['Healthy control', disorder])
    
    df_filtered = X_concated[mask]
    df_filtered_X = X.loc[mask]
    
    y_binary = df_filtered[target_name].apply(lambda x: 0 if x == 'Healthy control' else 1)
    
    borderline_smote = BorderlineSMOTE(kind='borderline-2', random_state=42, k_neighbors=10, m_neighbors=15)
    X_resampled, y_resampled = borderline_smote.fit_resample(df_filtered_X, y_binary)
    
    dataset = pd.concat([X_resampled, y_resampled], axis=1)
    
    binary_datasets[disorder] = {
        'data': dataset,
        'labels': ['Healthy control', disorder]
    }

In [16]:
binary_datasets['Mood disorder']['data']

Unnamed: 0,age,education,IQ,sex,AB.A.delta.a.FP1,AB.A.delta.b.FP2,AB.A.delta.c.F7,AB.A.delta.d.F3,AB.A.delta.e.Fz,AB.A.delta.f.F4,...,COH.F.gamma.o.Pz.q.T6,COH.F.gamma.o.Pz.r.O1,COH.F.gamma.o.Pz.s.O2,COH.F.gamma.p.P4.q.T6,COH.F.gamma.p.P4.r.O1,COH.F.gamma.p.P4.s.O2,COH.F.gamma.q.T6.r.O1,COH.F.gamma.q.T6.s.O2,COH.F.gamma.r.O1.s.O2,main.disorder
0,32.870000,16.000000,108.000000,0,12.159137,13.113503,9.031007,14.879389,15.834830,19.595759,...,59.590594,77.310851,75.280467,57.311188,48.015594,59.579033,68.503920,82.885151,86.986191,1
1,20.240000,12.000000,127.000000,0,12.404484,9.737819,13.925651,12.325169,15.130696,10.292518,...,53.402639,60.535899,73.838548,72.298636,43.330485,71.298037,23.109295,62.098128,67.749204,1
2,19.890000,13.000000,113.000000,0,16.573145,15.586708,14.094928,12.660197,12.979617,10.611797,...,20.062716,36.041763,36.436509,47.491240,20.875426,53.747615,8.009907,50.202188,33.647889,1
3,39.180000,16.000000,112.600000,0,26.650019,22.823161,17.942133,17.148214,15.338128,14.571966,...,83.045735,85.752150,85.875624,86.286859,79.520327,84.541455,82.452224,86.950905,87.384296,1
4,28.420000,9.000000,98.800000,0,14.624474,14.277301,11.099375,16.047926,17.851932,17.165379,...,58.218838,61.211830,65.038486,76.899037,49.849540,71.791429,39.910653,82.739819,62.285969,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
527,29.859179,16.056555,129.547562,1,14.332130,23.574633,11.045143,18.351897,16.785585,19.659401,...,31.533846,55.828495,45.141456,50.233665,45.187841,55.899282,19.821091,39.923730,47.636457,0
528,29.536491,12.767765,122.175057,0,16.018351,17.909770,13.825752,14.791642,18.013545,20.968432,...,86.092666,53.775009,92.172863,90.911438,45.939741,88.578716,42.796943,85.889060,55.994294,0
529,27.580821,16.000000,114.744465,1,23.887229,23.936216,16.964908,23.727482,27.984030,22.996285,...,65.154648,48.855452,65.079410,69.760110,47.340168,68.406285,31.144459,94.580466,36.281022,0
530,21.262096,12.877715,104.755431,1,41.877695,44.607081,34.711693,28.247273,35.248240,36.466592,...,68.056056,70.364430,73.989504,74.344268,53.351700,72.497074,48.287194,79.983862,64.618663,0


In [17]:
sets_binaries = {}

for disorder, bin_data in binary_datasets.items():
    df_bin = bin_data['data'].copy()
    
    X_bin = df_bin.drop(columns=[target_name,'age', 'education', 'IQ', 'sex'], errors='ignore') 
    X_quantitative = df_bin.loc[:, ['age', 'education', 'IQ', 'sex']]
    y_bin = df_bin[target_name] if target_name in df_bin.columns else df_bin.iloc[:, -1]

    sets = Sets(dataframe=X_bin)

    sets_binaries[disorder] = {
        'sets': sets,
        'labels': bin_data['labels'],
        'X_quantitative': X_quantitative,
        'y': y_bin
    }

In [18]:
for disorder in sets_binaries:
    sets = sets_binaries[disorder]['sets']

    dfs_bands_psd = sets.create_dfs_bands(df=sets.df_ab_psd)
    dfs_bands_fc = sets.create_dfs_bands(df=sets.df_coh_fc)
    dfs_bands_psd_fc = sets.create_dfs_bands(df=sets.df_ab_psd_coh_fc)

    sets_binaries[disorder].update({
        'dfs_bands_psd': dfs_bands_psd,
        'dfs_bands_fc': dfs_bands_fc,
        'dfs_bands_psd_fc': dfs_bands_psd_fc,
    })

In [19]:
for disorder, data in sets_binaries.items():
    print(f"Disorder: {disorder}")
    print(f"Labels: {data['labels']}")
    print(f"Number of samples: {len(data['y'])}")
    print(f"Quantitative features shape: {data['X_quantitative'].shape}")
    print(f"PSD bands shape: {data['sets'].df_ab_psd.shape}")
    print(f"FC bands shape: {data['sets'].df_coh_fc.shape}")
    print(f"PSD + FC bands shape: {data['sets'].df_ab_psd_coh_fc.shape}\n")

Disorder: Addictive disorder
Labels: ['Healthy control', 'Addictive disorder']
Number of samples: 372
Quantitative features shape: (372, 4)
PSD bands shape: (372, 114)
FC bands shape: (372, 1026)
PSD + FC bands shape: (372, 1140)

Disorder: Anxiety disorder
Labels: ['Healthy control', 'Anxiety disorder']
Number of samples: 214
Quantitative features shape: (214, 4)
PSD bands shape: (214, 114)
FC bands shape: (214, 1026)
PSD + FC bands shape: (214, 1140)

Disorder: Mood disorder
Labels: ['Healthy control', 'Mood disorder']
Number of samples: 532
Quantitative features shape: (532, 4)
PSD bands shape: (532, 114)
FC bands shape: (532, 1026)
PSD + FC bands shape: (532, 1140)

Disorder: Obsessive compulsive disorder
Labels: ['Healthy control', 'Obsessive compulsive disorder']
Number of samples: 190
Quantitative features shape: (190, 4)
PSD bands shape: (190, 114)
FC bands shape: (190, 1026)
PSD + FC bands shape: (190, 1140)

Disorder: Schizophrenia
Labels: ['Healthy control', 'Schizophrenia']

In [20]:
dfs_dicts_binaries = {}

for disorder, data in sets_binaries.items():
    dfs_psd = data['dfs_bands_psd']
    dfs_fc = data['dfs_bands_fc']
    dfs_psd_fc = data['dfs_bands_psd_fc']
    sets = data['sets']

    df_dict_psd_all_bands = {
        'psd_all_bands': sets.df_ab_psd,
    }

    df_dict_fc_all_bands = {
        'fc_all_bands': sets.df_coh_fc,
    }

    df_dict_psd_fc_all_bands = {
        'psd_fc_all_bands': sets.df_ab_psd_coh_fc,
    }

    df_dict_psd_band = {
        f'psd_{band}': dfs_psd[band] for band in dfs_psd
    }

    df_dict_fc_band = {
        f'fc_{band}': dfs_fc[band] for band in dfs_fc
    }

    df_dict_psd_fc_band = {
        f'psd_fc_{band}': dfs_psd_fc[band] for band in dfs_psd_fc
    }

    dfs_dicts_binaries[disorder] = {
        'df_dict_psd_all_bands': df_dict_psd_all_bands,
        'df_dict_fc_all_bands': df_dict_fc_all_bands,
        'df_dict_psd_fc_all_bands': df_dict_psd_fc_all_bands,
        'df_dict_psd_band': df_dict_psd_band,
        'df_dict_fc_band': df_dict_fc_band,
        'df_dict_psd_fc_band': df_dict_psd_fc_band,
    }

In [21]:
import torch.nn as nn
import torch
import rtdl
from rtdl import FTTransformer
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import RobustScaler, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score

In [22]:
def fit(model, train_loader, epochs=200, lr=1e-3, weight_decay=1e-4, device='cpu', patience=20):
        model.to(device)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

        best_acc = 0
        patience_counter = 0
        history = {'train_loss': [], 'train_acc': []}
        
        for epoch in range(epochs):
            model.train()
            running_loss = 0.0
            correct, total = 0, 0
            
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                
                assert set(targets.cpu().tolist()).issubset({0.0, 1.0}), "Targets must be binary (0 or 1)"
                
                optimizer.zero_grad()
                outputs = model(inputs, None)
                
                # shape
                outputs = outputs.view(-1).float()
                targets = targets.view(-1).float() 
                
                assert outputs.shape == targets.shape, "Outputs and targets must have the same shape"
                assert outputs.dtype == torch.float32 and targets.dtype == torch.float32, "Outputs and targets must be float32"
                assert outputs.dim() == 1 and targets.dim() == 1, "Outputs and targets must be 1-dimensional"
                
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                
                preds = (torch.sigmoid(outputs) > 0.5).long()
                correct += (preds == targets.long()).sum().item()
                total += targets.size(0)
                running_loss += loss.item() * inputs.size(0)
                
            train_acc = correct / total
            epoch_loss = running_loss / len(train_loader.dataset)
            history['train_acc'].append(train_acc)
            history['train_loss'].append(epoch_loss)
            
            if train_acc > best_acc:
                best_acc = train_acc
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print("Early stopping triggered.")
                break
            
            scheduler.step(epoch_loss)
            
            print(
                f"Epoch [{epoch+1}/{epochs}] | "
                f"Train Loss: {epoch_loss:.4f}, Train Acc: {train_acc:.4f} | "
            )
        
        return history


def test(model, test_loader, device='cpu'):
    model.to(device)
    model.eval()
    
    all_preds, all_probas, all_targets = [], [], []
    correct, total = 0, 0
    test_loss = 0.0
    criterion = nn.BCEWithLogitsLoss()
    history = {'test_loss': [], 'test_acc': [], 'test_auc': []}

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs, None)
            
            outputs = outputs.view(-1).float()
            targets = targets.view(-1).float() 
            
            loss = criterion(outputs, targets)
            
            proba = torch.sigmoid(outputs)
            pred = (proba > 0.5).long()

            all_preds.append(pred.cpu())
            all_probas.append(proba.cpu())
            all_targets.append(targets.cpu())
            
            correct += (pred == targets.long()).sum().item()
            total += targets.size(0)
            test_loss += loss.item() * inputs.size(0)
    
    test_acc = correct / total
    test_loss /= len(test_loader.dataset)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    y_true = torch.cat(all_targets).numpy().astype(int)
    y_pred = torch.cat(all_preds).numpy().astype(int)
    y_proba = torch.cat(all_probas).numpy().astype(float)

    try:
        auc = roc_auc_score(y_true, y_proba)
    except ValueError:
        auc = float('nan')

    report = classification_report(y_true, y_pred, output_dict=True)
    history['test_auc'].append(auc)

    print(
        f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}, Test AUC: {auc:.4f}"
        )
    
    return test_acc, auc, report, history

In [23]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
import gc

In [None]:
all_results = {}

for disorder, dicts in dfs_dicts_binaries.items():
    print(f"\n=== Disorder: {disorder} ===")
    
    y_binary = sets_binaries[disorder]['y']
    X_quantitative = sets_binaries[disorder]['X_quantitative']
    labels = sets_binaries[disorder]['labels']
    
    all_results[disorder] = {}
    
    for dict_name, df_dict in dicts.items():
        if not dict_name.startswith('df_dict_'):
            continue
        
        for df_name, df in df_dict.items():
            print(f"\n--- Processing: {df_name} ---")
            print("Shape:", df.shape)
            print("Missing values:", df.isna().sum().sum())
            
            df_full = pd.concat([X_quantitative, df], axis=1)
            assert all(df_full.index == y_binary.index), 'Misaligned indexes'
            
            X_train, X_test, y_train, y_test = train_test_split(
                df_full,
                y_binary,
                train_size=0.80,
                stratify=y_binary,
                random_state=42
            )
            
            X_cols = X_train.drop(columns=['sex']).columns
            
            scaler = RobustScaler()
            X_train_scaled = scaler.fit_transform(X_train[X_cols])
            X_test_scaled = scaler.transform(X_test[X_cols])
            
            X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X_cols, index=X_train.index)
            X_test_scaled_df = pd.DataFrame(X_test_scaled, columns=X_cols, index=X_test.index)

            X_train_final = pd.concat([X_train_scaled_df, X_train['sex']], axis=1)
            X_test_final = pd.concat([X_test_scaled_df, X_test['sex']], axis=1)

            X_train_array = X_train_final.to_numpy()
            X_test_array = X_test_final.to_numpy()
            y_train_array = y_train.to_numpy()
            y_test_array = y_test.to_numpy()
            
            # tensors
            X_train_tensor = torch.tensor(X_train_array, dtype=torch.float32).to(device)
            y_train_tensor = torch.tensor(y_train_array, dtype=torch.long).to(device)
            X_test_tensor = torch.tensor(X_test_array, dtype=torch.float32).to(device)
            y_test_tensor = torch.tensor(y_test_array, dtype=torch.long).to(device)
            
            train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
            test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

            train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
            test_loader = DataLoader(test_dataset, batch_size=32)
            
            model = FTTransformer.make_default(
                n_num_features=X_train_array.shape[1],
                cat_cardinalities=None,
                last_layer_query_idx=[-1],
                d_out=1,  # binary
            )
            
            fit_history = fit(
                model=model,
                train_loader=train_loader,
                epochs=200,
                lr=1e-3,
                weight_decay=1e-4,
                device=device,
                patience=20
            )
            
            acc, auc, report, test_history = test(
                model=model,
                test_loader=test_loader,
                device=device
            )
            
            all_results.setdefault(disorder, {}).setdefault(dict_name, {})[df_name] = {
                'accuracy': acc,
                'auc': auc,
                'classification_report': report,
                'features': list(X_train_final.columns),
                'X_train_shape': X_train_tensor.shape,
                'X_test_shape': X_test_tensor.shape,
                'history': {
                    'fit': fit_history,
                    'test': test_history
                }
            }
            
            # memory cleaning
            del model
            del X_train_tensor, X_test_tensor, y_train_tensor, y_test_tensor
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()


=== Disorder: Addictive disorder ===

--- Processing: psd_all_bands ---
Shape: (372, 114)
Missing values: 0
Epoch [1/200] | Train Loss: 0.7246, Train Acc: 0.5185 | 
Epoch [2/200] | Train Loss: 0.6970, Train Acc: 0.5286 | 
Epoch [3/200] | Train Loss: 0.6957, Train Acc: 0.4983 | 
Epoch [4/200] | Train Loss: 0.6883, Train Acc: 0.5387 | 
Epoch [5/200] | Train Loss: 0.6726, Train Acc: 0.5926 | 
Epoch [6/200] | Train Loss: 0.5658, Train Acc: 0.7340 | 
Epoch [7/200] | Train Loss: 0.5724, Train Acc: 0.6768 | 
Epoch [8/200] | Train Loss: 0.5268, Train Acc: 0.7374 | 
Epoch [9/200] | Train Loss: 0.4473, Train Acc: 0.8047 | 
Epoch [10/200] | Train Loss: 0.5451, Train Acc: 0.7710 | 
Epoch [11/200] | Train Loss: 0.4532, Train Acc: 0.8047 | 
Epoch [12/200] | Train Loss: 0.4359, Train Acc: 0.7980 | 
Epoch [13/200] | Train Loss: 0.3977, Train Acc: 0.8350 | 
Epoch [14/200] | Train Loss: 0.5103, Train Acc: 0.7744 | 
Epoch [15/200] | Train Loss: 0.4825, Train Acc: 0.7912 | 
Epoch [16/200] | Train Loss: 0

## | 6. Results

In [None]:
rows = []

for disorder, dicts in all_results.items():
    for dict_name, df_results in dicts.items():
        for df_name, metrics in df_results.items():
            row = {
                'Disorder': disorder,
                'Dict': dict_name,
                'DataFrame': df_name,
                'Accuracy': metrics['accuracy'],
                'AUC': metrics['auc'],
                'Classification Report': metrics['classification_report']
            }
            
            rows.append(row)

results_df = pd.DataFrame(rows)
results_df.to_csv(f'{experiment_name}.csv', index=False)

# best_results = df.loc[df.groupby('Disorder')['Accuracy'].idxmax()]

## | 7. Conclusions, Problems, and Limitations

[need to update]