Cell 0 — Install packages

In [4]:
%pip install -q mne PyWavelets scikit-learn seaborn
%pip install -q imbalanced-learn
%pip install -q azureml-core azure-ai-ml azure-identity

print("✅ Packages installed (if no errors above).")


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
✅ Packages installed (if no errors above).


Cell 1 — Load libraries + set seeds

In [5]:
import os
import random
import numpy as np

import mne
import pywt
from sklearn.decomposition import FastICA
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
from sklearn.model_selection import StratifiedKFold

from collections import defaultdict
from typing import Optional, Union, Sequence, Dict, Tuple, List

import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Conv2D, AveragePooling2D, Flatten, Dense, Dropout, BatchNormalization,
    Input, DepthwiseConv2D, SeparableConv2D, Activation
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, CSVLogger

from imblearn.over_sampling import SMOTE, RandomOverSampler

try:
    tf.keras.utils.enable_interactive_logging()
except Exception:
    pass

random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

print("✅ Imports done + seeds set.")


✅ Imports done + seeds set.


Cell 2 — Helper: load EEG (.set) with labels

In [6]:
def load_eeg_data_with_target(
    folder_path: str,
    session_name: str,
    max_samples: int = 118000,
    discard_samples: int = 10000
):
    eeg_files = [f for f in os.listdir(folder_path) if f.endswith('.set')]
    data_list = []
    targets = []
    sfreq_list = []

    for eeg_file in eeg_files:
        file_path = os.path.join(folder_path, eeg_file)

        raw = mne.io.read_raw_eeglab(file_path, preload=True, verbose=False)
        data = raw.get_data().astype(np.float32)
        sfreq = float(raw.info["sfreq"])

        if data.shape[1] > max_samples:
            data = data[:, :max_samples]

        if data.shape[1] > discard_samples:
            data = data[:, discard_samples:]
        else:
            print(f"⚠️ Not enough samples to discard in {eeg_file}, skipping.")
            continue

        data_list.append(data)
        sfreq_list.append(sfreq)

        if session_name == 'ses-1':
            targets.append(0)
        elif session_name == 'ses-2':
            targets.append(1)
        else:
            print(f"⚠️ Unknown session name: {session_name}")

    return data_list, targets, sfreq_list


Cell 3 — Leakage-safe preprocessing classes

In [7]:
def _names_from_index_mapping(
    n_channels: int,
    index_to_name: Optional[Dict[int, str]]
) -> List[str]:
    if index_to_name is None:
        return [f"EEG{i+1}" for i in range(n_channels)]

    keys = list(index_to_name.keys())
    is_zero_based = (0 in keys) and (1 not in keys)

    names = []
    for i in range(n_channels):
        key = i if is_zero_based else (i + 1)
        names.append(index_to_name.get(key, f"EEG{i+1}"))
    return names


def _make_raw(
    eeg: np.ndarray,
    sfreq: float,
    ch_names: List[str],
    use_standard_1020: bool = True
) -> Tuple[mne.io.Raw, bool]:
    ch_types = ['eog' if str(n).upper().startswith("EOG") else 'eeg' for n in ch_names]
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
    raw = mne.io.RawArray(eeg.astype(np.float32, copy=False), info, verbose=False)

    montage_applied = False
    if use_standard_1020:
        try:
            mont = mne.channels.make_standard_montage("standard_1020")
            raw.set_montage(mont, match_case=False, on_missing="ignore")
            montage_applied = True
        except Exception:
            montage_applied = False

    return raw, montage_applied


class WaveletICA:
    def __init__(self, wavelet="db4", level=3, n_components=10, random_state=42):
        self.wavelet = wavelet
        self.level = level
        self.n_components = n_components
        self.random_state = random_state
        self.ica_: Optional[FastICA] = None
        self._n_ch: Optional[int] = None

    def fit(self, X: np.ndarray):
        C = X.shape[0]
        self._n_ch = C

        coeffs = pywt.wavedec(X, wavelet=self.wavelet, level=self.level, axis=1)
        A = coeffs[0]

        k = int(min(self.n_components, C))
        self.ica_ = FastICA(n_components=k, random_state=self.random_state)

        S = self.ica_.fit_transform(A.T)
        A_denoised = self.ica_.inverse_transform(S).T

        coeffs[0] = A_denoised
        _ = pywt.waverec(coeffs, wavelet=self.wavelet, axis=1)

        return self

    def transform(self, X: np.ndarray) -> np.ndarray:
        assert self.ica_ is not None, "WaveletICA not fitted yet."

        coeffs = pywt.wavedec(X, wavelet=self.wavelet, level=self.level, axis=1)
        A = coeffs[0]

        S = self.ica_.transform(A.T)
        A_denoised = self.ica_.inverse_transform(S).T

        coeffs[0] = A_denoised
        Y = pywt.waverec(coeffs, wavelet=self.wavelet, axis=1)

        if Y.shape[1] < X.shape[1]:
            pad_width = X.shape[1] - Y.shape[1]
            Y = np.pad(Y, ((0, 0), (0, pad_width)), mode="constant")
        elif Y.shape[1] > X.shape[1]:
            Y = Y[:, :X.shape[1]]

        return Y.astype(np.float32, copy=False)


class EEGPreprocessor:
    def __init__(
        self,
        *,
        index_to_name: Optional[Dict[int, str]] = None,
        use_standard_1020: bool = True,
        resample_to: Optional[float] = None,
        notch_freqs: Union[None, float, Sequence[float]] = 50.0,
        highpass: Optional[float] = 0.05,
        bad_point_z: float = 6.0,
        bad_channel_z: float = 5.0,
        interpolate_bad_channels: bool = False,
        car: bool = True,
        use_wica: bool = True,
        wica_components: int = 10,
        wica_wavelet: str = "db4",
        wica_level: int = 3,
        wica_random_state: int = 42
    ):
        self.index_to_name = index_to_name
        self.use_standard_1020 = use_standard_1020
        self.resample_to = resample_to
        self.notch_freqs = notch_freqs
        self.highpass = highpass
        self.bad_point_z = bad_point_z
        self.bad_channel_z = bad_channel_z
        self.interpolate_bad_channels = interpolate_bad_channels
        self.car = car
        self.use_wica = use_wica

        self._sfreq_out: Optional[float] = None
        self._train_mu: Optional[np.ndarray] = None
        self._train_sd: Optional[np.ndarray] = None
        self._robust_med: Optional[float] = None
        self._robust_mad: Optional[float] = None
        self._train_eeg_names: Optional[List[str]] = None

        self._wica = WaveletICA(
            wavelet=wica_wavelet,
            level=wica_level,
            n_components=wica_components,
            random_state=wica_random_state
        )

    @property
    def sfreq_out(self) -> float:
        assert self._sfreq_out is not None, "Preprocessor not run yet."
        return float(self._sfreq_out)

    def _filter_and_reference(self, raw: mne.io.Raw):
        if self.resample_to is not None and float(self.resample_to) != float(raw.info["sfreq"]):
            raw.resample(self.resample_to, npad="auto")

        self._sfreq_out = float(raw.info["sfreq"])

        if self.notch_freqs is not None:
            raw.notch_filter(freqs=self.notch_freqs, verbose=False)

        if self.highpass is not None:
            raw.filter(l_freq=self.highpass, h_freq=None, verbose=False)

        if self.car:
            raw.set_eeg_reference("average", projection=True)
            raw.apply_proj()

    def _repair_transients_with_train_stats(self, raw: mne.io.Raw):
        X = raw.get_data()
        mu = self._train_mu
        sd = self._train_sd
        assert mu is not None and sd is not None, "Training stats not set."

        hi = mu + self.bad_point_z * sd
        lo = mu - self.bad_point_z * sd
        mask = (X > hi) | (X < lo)

        if np.any(mask):
            X_fixed = X.copy()
            t = np.arange(X.shape[1], dtype=float)
            for ch in range(X.shape[0]):
                m = mask[ch]
                if m.any():
                    good = ~m
                    if good.sum() >= 2:
                        X_fixed[ch, m] = np.interp(t[m], t[good], X_fixed[ch, good])
            raw._data = X_fixed

    def _interpolate_bad_channels_with_train_calibration(self, raw: mne.io.Raw):
        if not self.interpolate_bad_channels:
            return

        picks = mne.pick_types(raw.info, eeg=True)
        if len(picks) == 0:
            return

        X = raw.get_data(picks=picks)
        ch_std = X.std(axis=1)

        med = self._robust_med
        mad = self._robust_mad
        if med is None or mad is None or mad == 0:
            return

        z = 0.6745 * (ch_std - med) / mad
        eeg_names = mne.pick_info(raw.info, picks).ch_names
        bads = [eeg_names[i] for i in np.where(np.abs(z) > self.bad_channel_z)[0]]

        raw.info["bads"] = bads
        if bads:
            raw.interpolate_bads(reset_bads=True, verbose=False)

    def fit(self, X_train: np.ndarray, sfreq: float):
        C = X_train.shape[0]
        ch_names = _names_from_index_mapping(C, self.index_to_name)

        raw_train, montage_applied = _make_raw(X_train, sfreq, ch_names, self.use_standard_1020)
        self._filter_and_reference(raw_train)

        Xt = raw_train.get_data()
        self._train_mu = Xt.mean(axis=1, keepdims=True)
        self._train_sd = Xt.std(axis=1, keepdims=True) + 1e-12

        if montage_applied:
            picks_eeg = mne.pick_types(raw_train.info, eeg=True)
            if len(picks_eeg):
                X_eeg = Xt[picks_eeg]
                ch_std = X_eeg.std(axis=1)
                med = np.median(ch_std)
                mad = np.median(np.abs(ch_std - med)) + 1e-12
                self._robust_med = float(med)
                self._robust_mad = float(mad)
                self._train_eeg_names = mne.pick_info(raw_train.info, picks_eeg).ch_names
            else:
                self._robust_med = None
                self._robust_mad = None
        else:
            self._robust_med = None
            self._robust_mad = None

        self._repair_transients_with_train_stats(raw_train)
        Xt = raw_train.get_data()

        if montage_applied and self.interpolate_bad_channels:
            self._interpolate_bad_channels_with_train_calibration(raw_train)
            Xt = raw_train.get_data()

        if self.use_wica:
            self._wica.fit(Xt)

        return self

    def transform(self, X: np.ndarray, sfreq: float) -> Tuple[np.ndarray, float]:
        C = X.shape[0]
        ch_names = _names_from_index_mapping(C, self.index_to_name)

        raw, montage_applied = _make_raw(X, sfreq, ch_names, self.use_standard_1020)
        self._filter_and_reference(raw)
        self._repair_transients_with_train_stats(raw)

        if montage_applied and self.interpolate_bad_channels:
            self._interpolate_bad_channels_with_train_calibration(raw)

        Xf = raw.get_data()
        if self.use_wica:
            Xf = self._wica.transform(Xf)

        return Xf.astype(np.float32, copy=False), self.sfreq_out

    def fit_transform(self, X_train: np.ndarray, sfreq: float) -> Tuple[np.ndarray, float]:
        self.fit(X_train, sfreq)
        X_clean, fs_out = self.transform(X_train, sfreq)
        return X_clean, fs_out

print("✅ Preprocessing classes loaded.")


✅ Preprocessing classes loaded.


Cell 4 — Azure download + load raw EEG

In [8]:
from azureml.core import Workspace, Datastore

print("[STEP 4] Connecting to Azure ML workspace...")

subscription_id = "eccc04ba-d8b0-4f70-864a-b4a6753bfc72"
resource_group  = "somnasnest"
workspace_name  = "SomnasNest"

ws = Workspace(
    subscription_id=subscription_id,
    resource_group=resource_group,
    workspace_name=workspace_name
)

print("[STEP 4] Getting datastore 'workspaceblobstore'...")
datastore = Datastore.get(ws, "workspaceblobstore")

remote_prefix = "UI/2025-12-11_033542_UTC/New Dataset"
local_root = "./azureml_eeg_data"
os.makedirs(local_root, exist_ok=True)

print(f"[STEP 4] Downloading datastore prefix: {remote_prefix}")
n_files = datastore.download(
    target_path=local_root,
    prefix=remote_prefix,
    overwrite=False,
    show_progress=True,
)
print(f"[STEP 4] Downloaded {n_files} file(s).")

base_path = os.path.join(local_root, *remote_prefix.split("/"))
print(f"[STEP 4] base_path = {base_path}")

if not os.path.isdir(base_path):
    raise RuntimeError(f"[STEP 4] base_path does not exist: {base_path}")

print("[STEP 4] Listing base_path contents:")
print(os.listdir(base_path)[:10], "...")

sub_ses1 = [f"sub-{i:02d}" for i in range(1, 60)]
sub_ses2 = [f"sub-{i:02d}" for i in range(60, 72)]
subjects_used = sub_ses1 + sub_ses2

print("\n[STEP 4] Dataset rule summary:")
print(f"  ses-1 subjects count: {len(sub_ses1)} (sub-01..sub-59)")
print(f"  ses-2 subjects count: {len(sub_ses2)} (sub-60..sub-71)")
print(f"  total subjects used : {len(subjects_used)}")

raw_list = []
targets = []
sfreqs = []
subject_ids = []

print("\n[STEP 4] Loading raw EEG data according to rule...")

for sub in sub_ses1:
    session = "ses-1"
    path = os.path.join(base_path, sub, session)
    if not os.path.isdir(path):
        print(f"⚠️ Missing folder: {path}, skipping.")
        continue
    data_list, t_list, sf_list = load_eeg_data_with_target(path, session)
    print(f"  Loaded {len(data_list)} trial(s) from {sub}/{session}")
    for data, t, sf in zip(data_list, t_list, sf_list):
        raw_list.append(data.astype(np.float32, copy=False))
        targets.append(int(t))
        sfreqs.append(float(sf))
        subject_ids.append(sub)

for sub in sub_ses2:
    session = "ses-2"
    path = os.path.join(base_path, sub, session)
    if not os.path.isdir(path):
        print(f"⚠️ Missing folder: {path}, skipping.")
        continue
    data_list, t_list, sf_list = load_eeg_data_with_target(path, session)
    print(f"  Loaded {len(data_list)} trial(s) from {sub}/{session}")
    for data, t, sf in zip(data_list, t_list, sf_list):
        raw_list.append(data.astype(np.float32, copy=False))
        targets.append(int(t))
        sfreqs.append(float(sf))
        subject_ids.append(sub)

targets = np.array(targets, dtype=np.int32)
subject_ids = np.array(subject_ids)

print("\n[STEP 4] Finished loading raw trials.")
print("  Total trials loaded:", len(raw_list))
print("  Targets shape:", targets.shape)
print("  Unique labels + counts:", np.unique(targets, return_counts=True))

sfreqs = np.array(sfreqs, dtype=np.float32)
if len(sfreqs) == 0:
    raise RuntimeError("No EEG data found. Check paths / dataset.")

fs = float(sfreqs[0])
if not np.allclose(sfreqs, fs):
    print("⚠️ Warning: Not all sampling frequencies are identical!")
print(f"[STEP 4] Using fs={fs} Hz")


[STEP 4] Connecting to Azure ML workspace...
[STEP 4] Getting datastore 'workspaceblobstore'...
[STEP 4] Downloaded 0 file(s).
[STEP 4] base_path = ./azureml_eeg_data/UI/2025-12-11_033542_UTC/New Dataset
[STEP 4] Listing base_path contents:
['CHANGES', 'code', 'dataset_description.json', 'participants.json', 'participants.tsv', 'README', 'session_1_eeg_data.csv', 'session_2_eeg_data.csv', 'sub-01', 'sub-01_eeg_data.pkl'] ...

[STEP 4] Dataset rule summary:
  ses-1 subjects count: 59 (sub-01..sub-59)
  ses-2 subjects count: 12 (sub-60..sub-71)
  total subjects used : 71

[STEP 4] Loading raw EEG data according to rule...
  Loaded 1 trial(s) from sub-01/ses-1
  Loaded 1 trial(s) from sub-02/ses-1
  Loaded 1 trial(s) from sub-03/ses-1
  Loaded 1 trial(s) from sub-04/ses-1
  Loaded 1 trial(s) from sub-05/ses-1
  Loaded 1 trial(s) from sub-06/ses-1
  Loaded 1 trial(s) from sub-07/ses-1
  Loaded 1 trial(s) from sub-08/ses-1
  Loaded 1 trial(s) from sub-09/ses-1
  Loaded 1 trial(s) from sub-1

Path already exists. Skipping download for ./azureml_eeg_data/UI/2025-12-11_033542_UTC/New Dataset/participants.tsv
Path already exists. Skipping download for ./azureml_eeg_data/UI/2025-12-11_033542_UTC/New Dataset/session_1_eeg_data.csv
Path already exists. Skipping download for ./azureml_eeg_data/UI/2025-12-11_033542_UTC/New Dataset/session_2_eeg_data.csv
Path already exists. Skipping download for ./azureml_eeg_data/UI/2025-12-11_033542_UTC/New Dataset/sub-01/ses-1/sub-01_ses-1_task-eyesclosed_eeg.fdt
Path already exists. Skipping download for ./azureml_eeg_data/UI/2025-12-11_033542_UTC/New Dataset/sub-01/ses-1/sub-01_ses-1_task-eyesclosed_eeg.set
Path already exists. Skipping download for ./azureml_eeg_data/UI/2025-12-11_033542_UTC/New Dataset/sub-01/ses-2/sub-01_ses-2_task-eyesclosed_eeg.fdt
Path already exists. Skipping download for ./azureml_eeg_data/UI/2025-12-11_033542_UTC/New Dataset/sub-01/ses-2/sub-01_ses-2_task-eyesclosed_eeg.set
Path already exists. Skipping download for .

Cell 5 — Fit preprocessor

In [9]:
CHANNEL_MAP = None

pre = EEGPreprocessor(
    index_to_name=CHANNEL_MAP,
    use_standard_1020=True,
    resample_to=None,
    notch_freqs=[50.0, 100.0, 150.0],
    highpass=0.05,
    bad_point_z=6.0,
    bad_channel_z=5.0,
    interpolate_bad_channels=False,
    car=True,
    use_wica=True,
    wica_components=10,
    wica_wavelet="db4",
    wica_level=3,
    wica_random_state=42,
)

print("[STEP 5] Fitting EEGPreprocessor on a subset of loaded trials...")

max_calib_trials = min(10, len(raw_list))
if max_calib_trials == 0:
    raise RuntimeError("No data to fit EEGPreprocessor.")

calib_trials = raw_list[:max_calib_trials]
X_calib = np.concatenate(calib_trials, axis=1).astype(np.float32, copy=False)

X_calib_clean, fs_out = pre.fit_transform(X_calib, fs)
print(f"[STEP 5] Preprocessor fitted. Output fs: {fs_out} Hz")

data_clean = []
for i, segment in enumerate(raw_list):
    X_clean, _ = pre.transform(segment, fs)
    data_clean.append(X_clean.astype(np.float32, copy=False))
    if (i + 1) % 20 == 0:
        print(f"[STEP 5] Preprocessed {i+1}/{len(raw_list)} trials")

data_clean = np.array(data_clean, dtype=np.float32)

print("[STEP 5] Done preprocessing all trials.")
print("  data_clean shape:", data_clean.shape)
print("  targets shape   :", targets.shape)


[STEP 5] Fitting EEGPreprocessor on a subset of loaded trials...
EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Created an SSP operator (subspace dimension = 1)
1 projection items activated
SSP projectors applied...
EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Created an SSP operator (subspace dimension = 1)
1 projection items activated
SSP projectors applied...
[STEP 5] Preprocessor fitted. Output fs: 500.0 Hz
EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.

Cell 6 — Augmentation

In [10]:
def augment_data(data: np.ndarray, target: int, segment_size: int = 100):
    augmented_data = []
    augmented_targets = []

    if data.ndim == 3:
        data = data[0]

    n_segments = data.shape[1] // segment_size

    for i in range(n_segments):
        seg = data[:, i * segment_size:(i + 1) * segment_size]
        augmented_data.append(seg.astype(np.float32, copy=False))
        augmented_targets.append(int(target))

    return augmented_data, augmented_targets


print("[STEP 6] Augmenting all trials...")

augmented = []
aug_targets = []
aug_subject_ids = []

for trial, y, subj in zip(data_clean, targets, subject_ids):
    segs, ys = augment_data(trial, int(y), segment_size=100)
    augmented.extend(segs)
    aug_targets.extend(ys)
    aug_subject_ids.extend([subj] * len(segs))

augmented = np.array(augmented, dtype=np.float32)
aug_targets = np.array(aug_targets, dtype=np.int32)
aug_subject_ids = np.array(aug_subject_ids)

print("[STEP 6] Augmentation done.")
print("  augmented shape:", augmented.shape)
print("  aug_targets shape:", aug_targets.shape)
print("  Class counts (augmented):", np.unique(aug_targets, return_counts=True))


[STEP 6] Augmenting all trials...
[STEP 6] Augmentation done.
  augmented shape: (76680, 61, 100)
  aug_targets shape: (76680,)
  Class counts (augmented): (array([0, 1], dtype=int32), array([63720, 12960]))


Cell 7 — Fair selection

In [11]:
print("[STEP 7] Grouping segments by subject and class...")

subject_data = defaultdict(lambda: {0: [], 1: []})

for x, y, subj in zip(augmented, aug_targets, aug_subject_ids):
    subject_data[subj][int(y)].append(x)

max_per_class_per_subject = 200

selected_data = []
selected_targets = []

print("[STEP 7] Selecting up to 200 segments per (subject, class)...")

for subj in subjects_used:
    for class_label in [0, 1]:
        samples = subject_data.get(subj, {0: [], 1: []})[class_label]
        picked = samples[:max_per_class_per_subject]
        selected_data.extend(picked)
        selected_targets.extend([class_label] * len(picked))

selected_data = np.array(selected_data, dtype=np.float32)
selected_targets = np.array(selected_targets, dtype=np.int32)

print("[STEP 7] Selection done.")
print("  selected_data shape:", selected_data.shape)
print("  selected_targets shape:", selected_targets.shape)
print("  Class counts (selected):", np.unique(selected_targets, return_counts=True))


[STEP 7] Grouping segments by subject and class...
[STEP 7] Selecting up to 200 segments per (subject, class)...
[STEP 7] Selection done.
  selected_data shape: (14200, 61, 100)
  selected_targets shape: (14200,)
  Class counts (selected): (array([0, 1], dtype=int32), array([11800,  2400]))


Cell 8 — Reshape for CNN input

In [12]:
print("[STEP 8] Reshaping for CNN...")

X_all = selected_data[..., np.newaxis].astype(np.float32, copy=False)
y_all = selected_targets.astype(np.int32, copy=False)

print("[STEP 8] Done.")
print("  X_all shape:", X_all.shape)
print("  y_all shape:", y_all.shape)
print("  Class counts:", np.unique(y_all, return_counts=True))


[STEP 8] Reshaping for CNN...
[STEP 8] Done.
  X_all shape: (14200, 61, 100, 1)
  y_all shape: (14200,)
  Class counts: (array([0, 1], dtype=int32), array([11800,  2400]))


Cell 9 — EEGNet model definition

In [13]:
def get_lr(opt):
    try:
        return float(tf.keras.backend.get_value(opt.learning_rate))
    except Exception:
        try:
            return float(opt.learning_rate.numpy())
        except Exception:
            try:
                return float(opt.lr.numpy())
            except Exception:
                return float(getattr(opt, "lr", 0.0))


def create_eegnet(input_shape, dropout_rate=0.5, num_classes=1):
    n_electrodes = input_shape[0]
    inputs = Input(shape=input_shape)

    x = Conv2D(16, (1, 64), padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)

    x = DepthwiseConv2D((n_electrodes, 1), depth_multiplier=2, padding='valid', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    x = AveragePooling2D((1, 4))(x)
    x = Dropout(dropout_rate)(x)

    x = SeparableConv2D(16, (1, 16), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    x = AveragePooling2D((1, 8))(x)
    x = Dropout(dropout_rate)(x)

    x = Flatten()(x)
    x = Dense(64, activation='relu')(x)
    outputs = Dense(num_classes, activation='sigmoid')(x)

    return Model(inputs=inputs, outputs=outputs, name="EEGNet_simple")


print("✅ Model function defined.")


✅ Model function defined.


Cell 10 — Prepare full dataset + normalization

In [14]:
print("\n" + "=" * 80)
print("[STEP 10A] PREPARE FULL DATASET + NORMALIZATION")
print("=" * 80)

X_train_raw = X_all.astype(np.float32, copy=False)
y_train_raw = y_all.astype(np.int32, copy=False)

n_electrodes = X_train_raw.shape[1]
segment_size = X_train_raw.shape[2]
input_shape = (n_electrodes, segment_size, 1)

print(f"[STEP 10A] Input shape: {input_shape}")
print(f"[STEP 10A] Total samples: {X_train_raw.shape[0]}")

epsilon = 1e-6
train_mean = np.mean(X_train_raw, axis=(0, 2, 3), keepdims=True)
train_std  = np.std(X_train_raw, axis=(0, 2, 3), keepdims=True)
train_std  = np.maximum(train_std, epsilon)

X_train_norm = ((X_train_raw - train_mean) / train_std).astype(np.float32)

print("[STEP 10A] ✅ Normalization complete")



[STEP 10A] PREPARE FULL DATASET + NORMALIZATION
[STEP 10A] Input shape: (61, 100, 1)
[STEP 10A] Total samples: 14200
[STEP 10A] ✅ Normalization complete


Cell 11 — SMOTE balancing

In [15]:
from collections import Counter

print("\n" + "-" * 80)
print("[STEP 10B] SMOTE CONFIRMATION BLOCK")
print("-" * 80)

X_train_2d = X_train_norm.reshape(X_train_norm.shape[0], -1)

counts_before = Counter(y_train_raw.tolist())
total_before = sum(counts_before.values())
minority_n = min(counts_before.values())

print("\n[STEP 10B] Label counts BEFORE balancing:")
for cls, cnt in sorted(counts_before.items()):
    print(f"  Class {cls}: {cnt}")

print(f"[STEP 10B] Total samples BEFORE: {total_before}")
print(f"[STEP 10B] Minority class size : {minority_n}")

if minority_n >= 2:
    k_neighbors = max(1, min(5, minority_n - 1))
    smote = SMOTE(random_state=42, k_neighbors=k_neighbors)
    X_train_bal_2d, y_train_bal_int = smote.fit_resample(X_train_2d, y_train_raw)
    sampler_used = f"SMOTE (k_neighbors={k_neighbors})"
else:
    ros = RandomOverSampler(random_state=42)
    X_train_bal_2d, y_train_bal_int = ros.fit_resample(X_train_2d, y_train_raw)
    sampler_used = "RandomOverSampler"

print(f"\n[STEP 10B] ✅ Sampler USED: {sampler_used}")

counts_after = Counter(y_train_bal_int.tolist())
total_after = sum(counts_after.values())

print("\n[STEP 10B] Label counts AFTER balancing:")
for cls, cnt in sorted(counts_after.items()):
    print(f"  Class {cls}: {cnt}")

print(f"[STEP 10B] Total samples AFTER : {total_after}")
print(f"[STEP 10B] Samples added       : {total_after - total_before}")

X_train_bal = X_train_bal_2d.reshape(-1, n_electrodes, segment_size, 1).astype(np.float32)
y_train_bal = y_train_bal_int.astype(np.float32)

print("\n[STEP 10B] Final balanced tensors:")
print(f"  X_train_bal shape: {X_train_bal.shape}")
print(f"  y_train_bal shape: {y_train_bal.shape}")

print("-" * 80)



--------------------------------------------------------------------------------
[STEP 10B] SMOTE CONFIRMATION BLOCK
--------------------------------------------------------------------------------

[STEP 10B] Label counts BEFORE balancing:
  Class 0: 11800
  Class 1: 2400
[STEP 10B] Total samples BEFORE: 14200
[STEP 10B] Minority class size : 2400

[STEP 10B] ✅ Sampler USED: SMOTE (k_neighbors=5)

[STEP 10B] Label counts AFTER balancing:
  Class 0: 11800
  Class 1: 11800
[STEP 10B] Total samples AFTER : 23600
[STEP 10B] Samples added       : 9400

[STEP 10B] Final balanced tensors:
  X_train_bal shape: (23600, 61, 100, 1)
  y_train_bal shape: (23600,)
--------------------------------------------------------------------------------


Cell 13 — Callbacks

In [16]:
callbacks = [
    ReduceLROnPlateau(
        monitor="loss",
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        verbose=1
    ),
    EarlyStopping(
        monitor="loss",
        patience=15,
        restore_best_weights=True,
        verbose=1
    ),
    ModelCheckpoint(
        filepath="EEGNet-SD-Final.keras",
        monitor="loss",
        save_best_only=True,
        verbose=1
    ),
    CSVLogger("final_model_training_log.csv", append=False),
]

print("✅ Callbacks ready.")


✅ Callbacks ready.


Cell 14 — Train

In [17]:
EPOCHS = 300
BATCH_SIZE = 200

print("\n[STEP 10D] 🚀 Starting training")
print("=" * 80)

history = model.fit(
    X_train_bal,
    y_train_bal,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    shuffle=True,
    callbacks=callbacks,
    verbose=1
)

print("=" * 80)
print("[STEP 10D] ✅ Training complete")



[STEP 10D] 🚀 Starting training
Epoch 1/300
[1m 81/118[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m10s[0m 274ms/step - accuracy: 0.6566 - loss: 0.6190[1m 82/118[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m9s[0m 274ms/step - accuracy: 0.6568 - loss: 0.6187 [1m 95/118[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m6s[0m 274ms/step - accuracy: 0.6604 - loss: 0.6150[1m 96/118[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m6s[0m 274ms/step - accuracy: 0.6607 - loss: 0.6147

Cell 15 — Evaluation

In [None]:
train_loss, train_acc = model.evaluate(
    X_train_bal,
    y_train_bal,
    batch_size=BATCH_SIZE,
    verbose=1
)

print(f"\n[FINAL] Training loss    : {train_loss:.6f}")
print(f"[FINAL] Training accuracy: {train_acc:.6f}")

y_prob = model.predict(
    X_train_bal,
    batch_size=BATCH_SIZE,
    verbose=1
).reshape(-1)

y_pred = (y_prob > 0.5).astype(int)

cm = confusion_matrix(y_train_bal.astype(int), y_pred)

print("\n[FINAL] Training Confusion Matrix")
print(cm)

print("\n[FINAL] Classification Report")
print(classification_report(
    y_train_bal.astype(int),
    y_pred,
    target_names=["Class 0", "Class 1"]
))

print("\n" + "=" * 80)
print("[FINAL] ✅ FULL PIPELINE COMPLETED SUCCESSFULLY")
print("       Model: EEGNet-SD-Final.keras")
print("       Log  : final_model_training_log.csv")
print("=" * 80)
