In [1]:
import mne
import numpy as np
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from mne.io import concatenate_raws, read_raw_edf

In [2]:
#Preprocessing
def load_left_hand_epochs(subject=1, tmin=0.0, tmax=4.0):
    """
    Returns:
        epochs_left_real: left hand REAL movement epochs
        epochs_left_imag: left hand IMAGINED movement epochs
    """

    # Real left vs right hand runs
    real_runs = [3, 7, 11]

    # Imagined left vs right hand runs
    imag_runs = [4, 8, 12]

    real_files = mne.datasets.eegbci.load_data(subjects=[subject], runs=real_runs)

    raw_real = concatenate_raws([
        read_raw_edf(f, preload=True, verbose=False) for f in real_files
    ])

    events_real, event_id_real = mne.events_from_annotations(raw_real)

    # Keep ONLY left hand trials: T1
    epochs_left_real = mne.Epochs(
        raw_real,
        events_real,
        event_id={"left_real": event_id_real["T1"]},
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        preload=True
    )

    # -------------------------
    # Load IMAGINED runs
    # -------------------------
    imag_files = mne.datasets.eegbci.load_data(subjects=[subject], runs=imag_runs)

    raw_imag = concatenate_raws([
        read_raw_edf(f, preload=True, verbose=False) for f in imag_files
    ])

    events_imag, event_id_imag = mne.events_from_annotations(raw_imag)

    # Keep ONLY left hand trials: T1
    epochs_left_imag = mne.Epochs(
        raw_imag,
        events_imag,
        event_id={"left_imag": event_id_imag["T1"]},
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        preload=True
    )

    return epochs_left_real, epochs_left_imag

epochs_left_real, epochs_left_imag = load_left_hand_epochs(subject=1)

Downloading EEGBCI data


Downloading file 'S001/S001R11.edf' from 'https://physionet.org/files/eegmmidb/1.0.0/S001/S001R11.edf' to 'C:\Users\rahto\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0'.


Download complete in 10s (2.5 MB)
Used Annotations descriptions: [np.str_('T0'), np.str_('T1'), np.str_('T2')]
Not setting metadata
23 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 23 events and 641 original time points ...
0 bad epochs dropped
Downloading EEGBCI data


Downloading file 'S001/S001R12.edf' from 'https://physionet.org/files/eegmmidb/1.0.0/S001/S001R12.edf' to 'C:\Users\rahto\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0'.


Download complete in 09s (2.5 MB)
Used Annotations descriptions: [np.str_('T0'), np.str_('T1'), np.str_('T2')]
Not setting metadata
23 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 23 events and 641 original time points ...
0 bad epochs dropped


In [3]:
#PSD - spectral analysis
epochs_real = epochs_left_real
epochs_imag = epochs_left_imag

X_real = epochs_real.get_data()
X_imag = epochs_imag.get_data()

# labels: 0 = real, 1 = imagined
y_real = [0] * len(X_real)
y_imag = [1] * len(X_imag)

X = np.concatenate([X_real, X_imag], axis=0)
y = np.array(y_real + y_imag)

print(X.shape, y.shape)


epochs_left_real = epochs_left_real.copy()
epochs_left_imag = epochs_left_imag.copy()

epochs_all = mne.concatenate_epochs([epochs_left_real, epochs_left_imag])

# labels: 0 = real, 1 = imagined
y = np.array([0] * len(epochs_left_real) + [1] * len(epochs_left_imag))
epochs_all = epochs_all.filter(8, 30)

psd = epochs_all.compute_psd(
    method="welch",
    fmin=8,
    fmax=30,
    n_fft=256,
    verbose=False
)

psds = psd.get_data()     # shape: (n_epochs, n_channels, n_freqs)
freqs = psd.freqs


def bandpower(psds, freqs, fmin, fmax):
    idx = np.logical_and(freqs >= fmin, freqs <= fmax)
    return psds[:, :, idx].mean(axis=2)

mu_power = bandpower(psds, freqs, 8, 13)
beta_power = bandpower(psds, freqs, 13, 30)

# Combine into one feature vector per epoch
X = np.concatenate([mu_power, beta_power], axis=1)

print(X.shape)  # (n_epochs, n_channels*2)

(46, 64, 641) (46,)
Not setting metadata
46 matching events found
No baseline correction applied
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 265 samples (1.656 s)



  epochs_all = mne.concatenate_epochs([epochs_left_real, epochs_left_imag])


(46, 128)


In [4]:
#Logistic Regression
clf = make_pipeline(
    StandardScaler(),
    LogisticRegression(max_iter=1000)
)

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

scores = cross_val_score(clf, X, y, cv=cv)

print("CV accuracy:", scores.mean(), "+/-", scores.std())

CV accuracy: 0.5822222222222222 +/- 0.1829254722422816


In [7]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score
X = np.log(X + 1e-10)

rf = RandomForestClassifier(
    n_estimators=500,
    random_state=42,
    class_weight="balanced",
    max_depth=None
)
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

scores = cross_val_score(rf, X, y, cv=cv)




print("RF CV accuracy:", scores.mean(), "+/-", scores.std())


  X = np.log(X + 1e-10)


RF CV accuracy: 0.5 +/- 0.04969039949999535


In [8]:
rf.fit(X, y)
importances = rf.feature_importances_

# Since you concatenated mu + beta, split them
n_channels = X.shape[1] // 2
mu_importances = importances[:n_channels]
beta_importances = importances[n_channels:]

print("Top 5 mu-band channels:", np.argsort(mu_importances)[-5:])
print("Top 5 beta-band channels:", np.argsort(beta_importances)[-5:])


Top 5 mu-band channels: [51 55 57 48 50]
Top 5 beta-band channels: [43 57 26  2 21]
