In [1]:
import numpy as np
import json
import matplotlib.pyplot as plt
import mne               # package for EEG and MEG data analysis
import seaborn as sns    # for visualization 
import os
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
import antropy as an
import pandas as pd
import imblearn


import warnings
warnings.filterwarnings("ignore")
mne.set_log_level("ERROR")     # hides INFO + WARNING

# Loading Data

In [2]:
X_5 = np.load('data_50_overlap.npy')
X5_window_list = np.load('all_windows_5p_list.npy', allow_pickle=True)

# shap
print(
    f'Shape of concatenated data:{X_5.shape}\nShape of one patient\'s data after sliding window (N_windows, channels, n_smaple): {X5_window_list[0].shape}'
)

Shape of concatenated data:(140336, 19, 500)
Shape of one patient's data after sliding window (N_windows, channels, n_smaple): (1198, 19, 500)


In [3]:
participants_file = "E:/data_AD/SearchingDataset/SearchingDataset/participants.tsv"

participants_data = pd.read_csv(participants_file, sep='\t')
print("Participants data loaded successfully!")

# first 20 
# participants_subset = participants_data.head(20)
# print("Info of the first 20 participants:\n")
# print(participants_subset)
participants_data.head()

Participants data loaded successfully!


Unnamed: 0,participant_id,Gender,Age,Group,MMSE
0,sub-001,F,57,A,16
1,sub-002,F,78,A,22
2,sub-003,M,70,A,14
3,sub-004,F,67,A,20
4,sub-005,M,70,A,22


In [4]:
window_counts = [w.shape[0] for w in X5_window_list]
len(window_counts)

88

In [5]:
labels = []
for i, counts in enumerate(window_counts):
    if i < 36:
        labels += ['A'] * counts
    elif i < 36 + 29:
        labels += ['C'] * counts
    else:
        labels += ['F'] * counts

labels = np.array(labels)

In [6]:
labels.shape

(140336,)

In [7]:
# X_5[0,0,:].shape

# With PSD analysis

In [8]:
from mne.time_frequency import psd_array_welch
import antropy as an
import numpy as np

def extract_psd_svd_entropy(window, sfreq, bands, order=10):
    """
    window: shape (n_channels, n_times)
    returns: feature vector of shape (n_channels * n_bands,)
    """

    psd, freqs = psd_array_welch(
        window,
        sfreq=sfreq,
        fmin=min(low for low, _ in bands.values()),
        fmax=max(high for _, high in bands.values()),
        n_fft=window.shape[1],
        n_overlap=window.shape[1] // 2,
        window="hann",
        average="mean"
    )

    feature_vector = []

    for low, high in bands.values():
        mask = (freqs >= low) & (freqs <= high)

        for ch in range(psd.shape[0]):
            band_psd = psd[ch, mask]          # PSD spectrum in band
            ent = an.svd_entropy(
                band_psd,
                order=min(order, band_psd.size - 1),
                normalize=True
            )
            feature_vector.append(ent)

    return np.array(feature_vector)


In [9]:
bands = {
    "delta": (1, 4),
    "theta": (4, 8),
    "alpha": (8, 12),
    "beta":  (12, 30),
    "gamma": (30, 45)
}

sfreq = 500

X_features = []

for w in range(X_5.shape[0]):
    feats = extract_psd_svd_entropy(
        X_5[w],
        sfreq=sfreq,
        bands=bands,
        order=10
    )
    X_features.append(feats)

X_features = np.array(X_features)


In [10]:
X_features.shape

(140336, 95)

# without Power Spectral Density (PSD) analysis
- SVD (Singular Value Decomposition) Entropy calculation for each window

In [10]:
# band_names = ['delta', 'theta', 'alpha', 'beta', 'gamma']
# n_channels = 19

# bands_dict = {
#     band: entropy_features[:, i*n_channels:(i+1)*n_channels]
#     for i, band in enumerate(band_names)
# }

In [12]:
subject_ids = np.concatenate([
    np.full(X5_window_list[i].shape[0], i)
    for i in range(len(X5_window_list))
])
assert subject_ids.shape[0] == sum(x.shape[0] for x in X5_window_list)

In [13]:
subject_ids

array([ 0,  0,  0, ..., 87, 87, 87])

In [28]:
bands_name = ['delta', 'theta', 'alpha', 'beta', 'gamma']
cols = []
for name in bands_name:
    cols.append([f'{name}_ch{i+1}' for i in range(19)])

columns = [
    x
    for xs in cols
    for x in xs
]

In [41]:
# ch_names = [
#     'Fp1','Fp2','F3','F4','C3','C4','P3','P4','O1','O2',
#     'F7','F8','T3','T4','T5','T6','Fz','Cz','Pz'
# ]
df = pd.DataFrame(
    # entropy_features,
    data=X_features,
    columns=columns
)

df['subject_id'] = subject_ids
df['label'] = labels




In [42]:
df

Unnamed: 0,delta_ch1,delta_ch2,delta_ch3,delta_ch4,delta_ch5,delta_ch6,delta_ch7,delta_ch8,delta_ch9,delta_ch10,...,gamma_ch12,gamma_ch13,gamma_ch14,gamma_ch15,gamma_ch16,gamma_ch17,gamma_ch18,gamma_ch19,subject_id,label
0,0.496560,0.490786,0.485825,0.480512,0.475196,0.459005,0.458097,0.447380,0.415884,0.452967,...,0.771155,0.770102,0.755954,0.732206,0.734761,0.769686,0.754894,0.748604,0,A
1,0.163230,0.164937,0.167320,0.173573,0.175321,0.189434,0.192057,0.209154,0.213297,0.219594,...,0.682464,0.713220,0.653349,0.615737,0.561788,0.680767,0.668823,0.593564,0,A
2,0.556496,0.545845,0.569107,0.568987,0.588856,0.585000,0.589776,0.487758,0.583757,0.577450,...,0.797770,0.778012,0.797364,0.770276,0.770603,0.799757,0.779243,0.754030,0,A
3,0.488949,0.486645,0.495018,0.493495,0.506939,0.504582,0.515286,0.527812,0.514774,0.514166,...,0.743134,0.742521,0.734884,0.697604,0.619662,0.743754,0.736124,0.709793,0,A
4,0.234658,0.233691,0.243551,0.244590,0.238606,0.257500,0.244734,0.228977,0.275483,0.278925,...,0.801625,0.799459,0.805403,0.804389,0.783891,0.803353,0.802329,0.798102,0,A
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
140331,0.449629,0.419543,0.397881,0.335375,0.425585,0.383526,0.372221,0.345858,0.207975,0.248871,...,0.758723,0.656702,0.699093,0.578136,0.722870,0.645093,0.649560,0.705697,87,F
140332,0.530804,0.509992,0.550735,0.536541,0.573633,0.584482,0.578450,0.577111,0.548338,0.517166,...,0.749133,0.747085,0.745321,0.725991,0.743420,0.730727,0.719970,0.746017,87,F
140333,0.036553,0.033087,0.053825,0.086469,0.089126,0.125686,0.122006,0.080776,0.158060,0.113643,...,0.703946,0.621573,0.704520,0.666167,0.699399,0.752779,0.741463,0.686966,87,F
140334,0.143700,0.157403,0.109887,0.134349,0.102490,0.147490,0.116032,0.157772,0.133099,0.124137,...,0.709595,0.726655,0.737996,0.697661,0.674979,0.731326,0.753269,0.709177,87,F


In [43]:
df[['subject_id', 'label']].drop_duplicates().head(100)

Unnamed: 0,subject_id,label
0,0,A
1198,1,A
2783,2,A
3394,3,A
4805,4,A
...,...,...
133794,83,F
135097,84,F
136228,85,F
137538,86,F


In [44]:
# check if dataset splitted subject-wise (per patient)
2783 - 1198 == X5_window_list[1].shape[0]

True

In [45]:
assert df.groupby('subject_id')['label'].nunique().max() == 1

In [46]:
windows_per_subject = df.groupby('subject_id').size()
windows_per_subject.sort_values(ascending=False).head(14000)
windows_per_subject.describe()

count      88.000000
mean     1594.727273
std       277.651511
min       611.000000
25%      1496.000000
50%      1616.000000
75%      1753.000000
max      2562.000000
dtype: float64

In [47]:
# labels distribution

subject_labels = (
    df[['subject_id', 'label']]
    .drop_duplicates()
    .value_counts('label')
)

print(subject_labels)


label
A    36
C    29
F    23
Name: count, dtype: int64


In [212]:
df[['subject_id', 'label']].drop_duplicates()['label'].value_counts()

label
A    36
C    29
F    23
Name: count, dtype: int64

In [48]:
# Now we get the mean of mean of windows of each participant for examp. num_subject ID 0 = 1198 
# that we mean to have 88 averaged windows for 88 particpants 

df_subject = (
    df
    .groupby('subject_id')
    .mean(numeric_only=True)
)

df_subject['label'] = (
    df[['subject_id', 'label']]
    .drop_duplicates()
    .set_index('subject_id')['label']
)

df_subject

Unnamed: 0_level_0,delta_ch1,delta_ch2,delta_ch3,delta_ch4,delta_ch5,delta_ch6,delta_ch7,delta_ch8,delta_ch9,delta_ch10,...,gamma_ch11,gamma_ch12,gamma_ch13,gamma_ch14,gamma_ch15,gamma_ch16,gamma_ch17,gamma_ch18,gamma_ch19,label
subject_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.369480,0.370520,0.367196,0.366549,0.363453,0.363257,0.363106,0.363655,0.362589,0.365914,...,0.716049,0.714878,0.719276,0.715889,0.718409,0.717614,0.717378,0.719031,0.718048,A
1,0.414458,0.409961,0.411412,0.409816,0.422480,0.427246,0.434102,0.427139,0.433581,0.425332,...,0.704920,0.708077,0.704668,0.704271,0.707713,0.707126,0.699614,0.702143,0.708734,A
2,0.385187,0.384465,0.382611,0.383899,0.386318,0.383197,0.393005,0.386392,0.393052,0.392397,...,0.713962,0.711570,0.712799,0.713557,0.713629,0.718020,0.709767,0.708528,0.709359,A
3,0.433619,0.430516,0.431423,0.426255,0.425962,0.423228,0.433675,0.410082,0.430863,0.424324,...,0.721113,0.722704,0.723202,0.722091,0.723734,0.720550,0.719683,0.718424,0.721366,A
4,0.387507,0.367278,0.382198,0.390422,0.405090,0.383984,0.366921,0.393229,0.350497,0.409059,...,0.716220,0.714842,0.719216,0.720024,0.718865,0.721736,0.716553,0.710776,0.715244,A
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
83,0.365247,0.365294,0.363397,0.364392,0.364686,0.362765,0.367260,0.363381,0.366842,0.360329,...,0.717446,0.717085,0.718549,0.716165,0.718757,0.717140,0.715019,0.716186,0.715369,F
84,0.361318,0.361212,0.361042,0.360842,0.361288,0.359793,0.359917,0.358019,0.358336,0.359235,...,0.718410,0.718926,0.717997,0.719036,0.718691,0.717940,0.717924,0.718664,0.718638,F
85,0.368730,0.368730,0.368730,0.368730,0.368730,0.368730,0.368730,0.368730,0.368730,0.368730,...,0.717841,0.717841,0.717841,0.717841,0.717841,0.717841,0.717841,0.717841,0.717841,F
86,0.341776,0.339213,0.354719,0.359863,0.356814,0.360941,0.360961,0.362937,0.365031,0.365860,...,0.719087,0.716462,0.716333,0.719060,0.720961,0.716839,0.713468,0.712052,0.712788,F


In [140]:
# ===================================
# RUN THIS CELL FOR comparing two groups like: AD and C, or FDT and C, Or FDT and AD
# ===============================

# df.columns = df.columns.astype(str)

df_filtered = df_subject[df_subject['label'].isin(['A', 'C'])]



# Now df_filtered only contains rows with labels 'C' and 'A'
print(df_filtered['label'].value_counts())


label
A    36
C    29
Name: count, dtype: int64


In [141]:
# ===================================
# RUN THIS CELL FOR comparing 2 groups
# ==================================

X = df_filtered.drop('label', axis=1)  # features
y = df_filtered['label']               # target
y.shape

In [142]:
df.shape, df_subject.shape
# df is dataframe of all windows (140336 * 21), but df_subject is mean of windows for each subjectss (88, 20)

((140336, 97), (88, 96))

In [145]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.3,
    stratify=y,
    random_state=42
)

In [146]:
print('Train dataset')
print(y_train.value_counts())
print('Test dataset')
print(y_test.value_counts())

Train dataset
label
A    25
C    20
Name: count, dtype: int64
Test dataset
label
A    11
C     9
Name: count, dtype: int64


In [134]:
# smote = imblearn.over_sampling.SMOTE(random_state=42)

# X_train_sm, y_train_sm = smote.fit_resample(X_train, y_train)

In [135]:
#print(y_train_sm.value_counts())


In [136]:
# from sklearn.preprocessing import StandardScaler

# scaler = StandardScaler()

# X_train_sm = scaler.fit_transform(X_train_sm)
# X_test = scaler.transform(X_test)


In [147]:
from imblearn.pipeline import Pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix


models = {
    'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
    'Support Vector Machine (SVM)': SVC(random_state=42),
    'K-Nearest Neighbors (KNN)': KNeighborsClassifier(),
    'Decision Tree': DecisionTreeClassifier(random_state=42),
    'Random Forest': RandomForestClassifier(random_state=42),
    #'AdaBoost': AdaBoostClassifier(random_state=42),
    #'XGBoost': XGBClassifier(random_state=42)
}



for model_name, model in models.items():
    
    pipe = Pipeline([
        ('scaler', StandardScaler()),
        ('smote', SMOTE(random_state=42)),
        ( model_name, model)
    ])
    
    cv = StratifiedKFold(
        n_splits=5,
        shuffle=True,
        random_state=42
    )
    
    
    # cv = RepeatedStratifiedKFold(
    #     n_splits=5,
    #     n_repeats=10,
    #     random_state=42
    # )
    
    
    # gkf = GroupKFold(n_splits=5)
    
    # cv_scores = cross_val_score(
    #     pipe,
    #     X,
    #     y,
    #     cv=gkf,
    #     groups=df_filtered['subject_id'],
    #     scoring='accuracy'
    # )
    
    cv_scores = cross_val_score(
        pipe,
        X_train,
        y_train,
        cv=cv,
        scoring='accuracy'
    )
    print(model_name)
    #print(cv_scores)
    print("CV Accuracy:%.3f" %(cv_scores.mean()), "+/-%.4f" %(cv_scores.std()))
    print('='*80)

Logistic Regression
CV Accuracy:0.822 +/-0.1133
Support Vector Machine (SVM)
CV Accuracy:0.778 +/-0.0994
K-Nearest Neighbors (KNN)
CV Accuracy:0.778 +/-0.0703
Decision Tree
CV Accuracy:0.800 +/-0.0831
Random Forest
CV Accuracy:0.778 +/-0.0000


In [148]:

from sklearn.metrics import accuracy_score, classification_report


for model_name, model in models.items():
    
    pipe = Pipeline([
        ('scaler', StandardScaler()),
        ('smote', SMOTE(random_state=42)),
        ( model_name, model)
    ])

    pipe.fit(X_train, y_train)
    y_pred = pipe.predict(X_test)

    print(model_name)
    print("Final Test Accuracy:", accuracy_score(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print('='*80)
        
    #cm = confusion_matrix(y_test, y_pred)
    #print(cm)

Logistic Regression
Final Test Accuracy: 0.95
              precision    recall  f1-score   support

           A       1.00      0.91      0.95        11
           C       0.90      1.00      0.95         9

    accuracy                           0.95        20
   macro avg       0.95      0.95      0.95        20
weighted avg       0.96      0.95      0.95        20

Support Vector Machine (SVM)
Final Test Accuracy: 0.8
              precision    recall  f1-score   support

           A       0.77      0.91      0.83        11
           C       0.86      0.67      0.75         9

    accuracy                           0.80        20
   macro avg       0.81      0.79      0.79        20
weighted avg       0.81      0.80      0.80        20

K-Nearest Neighbors (KNN)
Final Test Accuracy: 0.9
              precision    recall  f1-score   support

           A       0.91      0.91      0.91        11
           C       0.89      0.89      0.89         9

    accuracy                   