In [21]:
# Imports

import mne
from tqdm import tqdm
import numpy as np
from mne.io import concatenate_raws, read_raw_edf
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score
from sklearn.pipeline import Pipeline
import sys
from contextlib import redirect_stdout
import warnings

In [2]:
# Selecting Data

all_subjects = list(range(1, 110))
real_runs = [3, 7, 11]
imag_runs = [4, 8, 12]
tmin, tmax = 0.0, 4.0

motor_channels = ["C3..", "C4..", "Cp3.", "Cp4.", "Fc3.", "Fc4.", "Cz..", "Fcz."]

In [23]:
# Preprocessing

X_list = []
y_list = []
groups_list = []
valid_subjects = []

mne.set_log_level("ERROR")
warnings.filterwarnings("ignore", category=RuntimeWarning)

for subj in tqdm(all_subjects, desc="Processing subjects"):
    try:
        real_files = mne.datasets.eegbci.load_data(subjects=[subj], runs=real_runs)
        raw_real = concatenate_raws([read_raw_edf(f, preload=True, verbose=False) for f in real_files])
        events_real, event_id_real = mne.events_from_annotations(raw_real)
        epochs_real = mne.Epochs(
            raw_real, events_real,
            event_id={"left_real": event_id_real["T1"]},
            tmin=tmin, tmax=tmax, baseline=None,
            preload=True
        )

        imag_files = mne.datasets.eegbci.load_data(subjects=[subj], runs=imag_runs)
        raw_imag = concatenate_raws([read_raw_edf(f, preload=True, verbose=False) for f in imag_files])
        events_imag, event_id_imag = mne.events_from_annotations(raw_imag)
        epochs_imag = mne.Epochs(
            raw_imag, events_imag,
            event_id={"left_imag": event_id_imag["T1"]},
            tmin=tmin, tmax=tmax, baseline=None,
            preload=True
        )

        if len(epochs_real) < 5 or len(epochs_imag) < 5:
            continue

        epochs_real = epochs_real.copy().pick(motor_channels)
        epochs_imag = epochs_imag.copy().pick(motor_channels)

        epochs_real = epochs_real.set_eeg_reference("average").filter(8, 30, fir_design="firwin", verbose=False)
        epochs_imag = epochs_imag.set_eeg_reference("average").filter(8, 30, fir_design="firwin", verbose=False)

        epochs_real = epochs_real.resample(160, npad="auto")
        epochs_imag = epochs_imag.resample(160, npad="auto")

        X_list.append(epochs_real.get_data())
        X_list.append(epochs_imag.get_data())
        y_list.append(np.zeros(len(epochs_real)))
        y_list.append(np.ones(len(epochs_imag)))
        groups_list.append(np.full(len(epochs_real), subj))
        groups_list.append(np.full(len(epochs_imag), subj))

        valid_subjects.append(subj)

    except Exception:
        continue

X_epochs_all = np.concatenate(X_list, axis=0)
y_epochs_all = np.concatenate(y_list, axis=0)
groups = np.concatenate(groups_list)

Processing subjects: 100%|██████████| 109/109 [00:58<00:00,  1.88it/s]


In [24]:

# Classification

csp = CSP(n_components=6, reg=0.1, log=True, norm_trace=False)
lda = LDA()
clf_csp = Pipeline([('CSP', csp), ('LDA', lda)])



cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

with redirect_stdout(sys.stdout):
        scores = cross_val_score(
            clf_csp,
            X_epochs_all,
            y_epochs_all,
            cv=cv,
            groups=groups
        )


print("\nCSP + LDA CV accuracy per fold:", scores)
print("Mean CV accuracy:", scores.mean())



CSP + LDA CV accuracy per fold: [0.58426966 0.55667001 0.53553038 0.55991944 0.5712831 ]
Mean CV accuracy: 0.5615345171554565


In [25]:
# Channel importance
with redirect_stdout(sys.stdout):
    clf_csp.fit(X_epochs_all, y_epochs_all)

fitted_csp = clf_csp.named_steps['CSP']

patterns = fitted_csp.patterns_
n_csp_components = fitted_csp.n_components

X_csp = fitted_csp.transform(X_epochs_all)

var_class0 = np.var(X_csp[y_epochs_all == 0], axis=0)
var_class1 = np.var(X_csp[y_epochs_all == 1], axis=0)

weights = np.abs(var_class0 - var_class1)
weights /= weights.sum()

patterns_matched = patterns[:, :n_csp_components]

channel_importance = np.average(
    np.abs(patterns_matched),
    axis=1,
    weights=weights
)

channel_importance /= channel_importance.sum()

importance_pairs = list(zip(motor_channels, channel_importance))
importance_pairs.sort(key=lambda x: x[1], reverse=True)

print("\nChannel importance ranking:")
for ch, score in importance_pairs:
    print(f"{ch:6s} → {score:.3f}")


Channel importance ranking:
Fcz.   → 0.229
C3..   → 0.225
Cz..   → 0.158
C4..   → 0.119
Fc4.   → 0.096
Fc3.   → 0.082
Cp3.   → 0.046
Cp4.   → 0.045
