In [3]:
pip install vmdpy

Collecting vmdpy
  Downloading vmdpy-0.2-py2.py3-none-any.whl.metadata (3.0 kB)
Downloading vmdpy-0.2-py2.py3-none-any.whl (6.5 kB)
Installing collected packages: vmdpy
Successfully installed vmdpy-0.2
Note: you may need to restart the kernel to use updated packages.


In [None]:
from pathlib import Path
import numpy as np
import pandas as pd

import mne

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from xgboost import XGBClassifier

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam

import warnings
warnings.filterwarnings("ignore")

DATA_FOLDER = "/kaggle/input/ahmadi-dataset"
INFO_CSV = f"{DATA_FOLDER}/subject-info.csv"

RESAMPLE_TO = 128

RANDOM_STATE = 42

EPOCHS = 20
BATCH_SIZE = 8


def load_task_edf_to_tensor(folder_path, info_csv_path, resample_to=None):
    folder = Path(folder_path)
    if not folder.is_dir():
        raise NotADirectoryError(f"{folder_path} is not a valid directory")

    info_df = pd.read_csv(info_csv_path)
    label_map = dict(zip(info_df["Subject"], info_df["Count quality"]))

    X_list = []
    y_list = []
    subjects = []
    sfreq = None

    for file_path in sorted(folder.glob("Subject*_2.edf")):
        subj = file_path.stem.split("_")[0]
        if subj not in label_map:
            print(f"Warning: {subj} not in subject-info, skipping.")
            continue

        print(f"Loading {file_path.name} for {subj}...")
        raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)

        if resample_to is not None:
            raw.resample(resample_to)

        if sfreq is None:
            sfreq = raw.info["sfreq"]

        data = raw.get_data()
        X_list.append(data)
        y_list.append(int(label_map[subj]))
        subjects.append(subj)

    if not X_list:
        raise ValueError("No *_2.edf files loaded. Check folder path and file pattern.")

    lengths = [d.shape[1] for d in X_list]
    min_len = min(lengths)

    X = np.stack([d[:, :min_len] for d in X_list], axis=0)
    y = np.array(y_list, dtype=int)
    subjects = np.array(subjects)

    print("Final tensor X shape (N, C, T):", X.shape)
    print("Labels y shape:", y.shape)

    return X, y, subjects, sfreq


def build_lstm_model(timesteps, n_features):
    model = Sequential()
    model.add(LSTM(64, input_shape=(timesteps, n_features), return_sequences=False))
    model.add(Dropout(0.5))
    model.add(Dense(32, activation="relu"))
    model.add(Dense(1, activation="sigmoid"))
    model.compile(
        optimizer=Adam(learning_rate=1e-3),
        loss="binary_crossentropy",
        metrics=["accuracy"]
    )
    return model


def compute_channel_importance_with_rf(X_train_raw, y_train, random_state):
    N, C, T = X_train_raw.shape
    X_rf = X_train_raw.reshape(N, C * T)
    rf = RandomForestClassifier(
        n_estimators=300,
        max_depth=None,
        random_state=random_state,
        n_jobs=-1
    )
    rf.fit(X_rf, y_train)
    importances = rf.feature_importances_
    importances_2d = importances.reshape(C, T)
    channel_importance = importances_2d.mean(axis=1)
    chan_perm = np.argsort(-channel_importance)
    print("\nChannel importance (sorted):")
    for rank, ch in enumerate(chan_perm):
        print(f"Rank {rank+1}: Channel {ch} -> importance {channel_importance[ch]:.6f}")
    return chan_perm, channel_importance


def main():
    X, y, subjects, sfreq = load_task_edf_to_tensor(
        DATA_FOLDER,
        INFO_CSV,
        resample_to=RESAMPLE_TO
    )

    N, C, T = X.shape
    unique_subjects = np.array(subjects)
    n_subjects = len(unique_subjects)

    results_lstm = []
    results_models = {
        "KNN": [],
        "RandomForest": [],
        "DecisionTree": [],
        "XGBoost": []
    }

    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        device_name = "/GPU:0"
        print("\nUsing GPU")
    else:
        device_name = "/CPU:0"
        print("\nGPU not found, using CPU")

    for fold_idx in range(n_subjects):
        test_subj = unique_subjects[fold_idx]
        val_subj = unique_subjects[(fold_idx + 1) % n_subjects]

        test_mask = (subjects == test_subj)
        val_mask = (subjects == val_subj)
        train_mask = ~(test_mask | val_mask)

        X_train_raw = X[train_mask]
        X_val_raw = X[val_mask]
        X_test_raw = X[test_mask]

        y_train = y[train_mask]
        y_val = y[val_mask]
        y_test = y[test_mask]

        print("\n" + "=" * 60)
        print(f"Fold {fold_idx+1}/{n_subjects}")
        print(f"Train subjects: {subjects[train_mask]}")
        print(f"Val subject   : {subjects[val_mask]}")
        print(f"Test subject  : {subjects[test_mask]}")

        chan_perm, chan_scores = compute_channel_importance_with_rf(
            X_train_raw, y_train, RANDOM_STATE
        )

        X_train_raw_perm = X_train_raw[:, chan_perm, :]
        X_val_raw_perm = X_val_raw[:, chan_perm, :]
        X_test_raw_perm = X_test_raw[:, chan_perm, :]

        X_train_flat_perm = X_train_raw_perm.reshape(X_train_raw_perm.shape[0], C * T)
        X_test_flat_perm = X_test_raw_perm.reshape(X_test_raw_perm.shape[0], C * T)

        models = {
            "KNN": KNeighborsClassifier(n_neighbors=3),
            "RandomForest": RandomForestClassifier(
                n_estimators=300,
                max_depth=None,
                random_state=RANDOM_STATE,
                n_jobs=-1
            ),
            "DecisionTree": DecisionTreeClassifier(
                max_depth=None,
                random_state=RANDOM_STATE
            ),
            "XGBoost": XGBClassifier(
                n_estimators=300,
                max_depth=4,
                learning_rate=0.05,
                subsample=0.8,
                colsample_bytree=0.8,
                objective="binary:logistic",
                eval_metric="logloss",
                n_jobs=-1,
                tree_method="hist"
            ),
        }

        for name, model in models.items():
            model.fit(X_train_flat_perm, y_train)
            y_pred_m = model.predict(X_test_flat_perm)
            acc_m = accuracy_score(y_test, y_pred_m)
            prec_m = precision_score(y_test, y_pred_m, zero_division=0)
            rec_m = recall_score(y_test, y_pred_m, zero_division=0)
            f1_m = f1_score(y_test, y_pred_m, zero_division=0)
            results_models[name].append([acc_m, prec_m, rec_m, f1_m])
            print(f"{name} (permuted channels) -> acc: {acc_m:.4f}, prec: {prec_m:.4f}, rec: {rec_m:.4f}, f1: {f1_m:.4f}")

        X_train_seq_perm = np.transpose(X_train_raw_perm, (0, 2, 1))
        X_val_seq_perm = np.transpose(X_val_raw_perm, (0, 2, 1))
        X_test_seq_perm = np.transpose(X_test_raw_perm, (0, 2, 1))

        timesteps = X_train_seq_perm.shape[1]
        n_features = X_train_seq_perm.shape[2]

        with tf.device(device_name):
            tf.keras.backend.clear_session()
            tf.random.set_seed(RANDOM_STATE)
            model_lstm = build_lstm_model(timesteps, n_features)
            model_lstm.fit(
                X_train_seq_perm,
                y_train,
                epochs=EPOCHS,
                batch_size=BATCH_SIZE,
                validation_data=(X_val_seq_perm, y_val),
                verbose=1
            )
            y_pred_prob = model_lstm.predict(X_test_seq_perm)

        y_pred_lstm = (y_pred_prob >= 0.5).astype(int).ravel()

        acc = accuracy_score(y_test, y_pred_lstm)
        prec = precision_score(y_test, y_pred_lstm, zero_division=0)
        rec = recall_score(y_test, y_pred_lstm, zero_division=0)
        f1 = f1_score(y_test, y_pred_lstm, zero_division=0)

        results_lstm.append([acc, prec, rec, f1])

        print(f"LSTM (permuted channels) -> acc: {acc:.4f}, prec: {prec:.4f}, rec: {rec:.4f}, f1: {f1:.4f}")

    results_lstm = np.array(results_lstm)
    print("\n" + "=" * 60)
    print("Average metrics over subject-wise folds:")
    print("LSTM (RF-based channel ordering):")
    print(f"Accuracy : {results_lstm[:,0].mean():.4f}")
    print(f"Precision: {results_lstm[:,1].mean():.4f}")
    print(f"Recall   : {results_lstm[:,2].mean():.4f}")
    print(f"F1-score : {results_lstm[:,3].mean():.4f}")

    for name, vals in results_models.items():
        arr = np.array(vals)
        print(f"\n{name} (RF-based channel ordering):")
        print(f"Accuracy : {arr[:,0].mean():.4f}")
        print(f"Precision: {arr[:,1].mean():.4f}")
        print(f"Recall   : {arr[:,2].mean():.4f}")
        print(f"F1-score : {arr[:,3].mean():.4f}")


if __name__ == "__main__":
    main()


Loading Subject00_2.edf for Subject00...
Loading Subject01_2.edf for Subject01...
Loading Subject02_2.edf for Subject02...
Loading Subject03_2.edf for Subject03...
Loading Subject04_2.edf for Subject04...
Loading Subject05_2.edf for Subject05...
Loading Subject06_2.edf for Subject06...
Loading Subject07_2.edf for Subject07...
Loading Subject08_2.edf for Subject08...
Loading Subject09_2.edf for Subject09...
Loading Subject10_2.edf for Subject10...
Loading Subject11_2.edf for Subject11...
Loading Subject12_2.edf for Subject12...
Loading Subject13_2.edf for Subject13...
Loading Subject14_2.edf for Subject14...
Loading Subject15_2.edf for Subject15...
Loading Subject16_2.edf for Subject16...
Loading Subject17_2.edf for Subject17...
Loading Subject18_2.edf for Subject18...
Loading Subject19_2.edf for Subject19...
Loading Subject20_2.edf for Subject20...
Loading Subject21_2.edf for Subject21...
Loading Subject22_2.edf for Subject22...
Loading Subject23_2.edf for Subject23...
Loading Subject2



RandomForest (permuted channels) -> acc: 0.0000, prec: 0.0000, rec: 0.0000, f1: 0.0000
DecisionTree (permuted channels) -> acc: 0.0000, prec: 0.0000, rec: 0.0000, f1: 0.0000
