In [1]:
import mne
import numpy as np
import matplotlib.pyplot as plt
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from mne.channels import make_standard_montage

In [2]:
def load_and_preprocess_subject(subject, runs_dict, l_freq=1., h_freq=40.):   # bandpass: (7.0, 30.0)
    """
    Load & preprocess EEGBCI data for a subject split into conditions.

    Parameters:
    -----------
    subject : int
        Subject ID (e.g., 1)
    runs_dict : dict
        Dictionary like {'rest': [1], 'motor_execution': [3,7,11], 'motor_imagery': [4,8,12]}
    l_freq : float
        Bandpass low cutoff
    h_freq : float
        Bandpass high cutoff
    
    Returns:
    --------
    subject_data : dict
        Dict with keys 'rest', 'motor_execution', 'motor_imagery', each containing a raw object
    """
    
    subject_data = {}

    for condition, run_list in runs_dict.items():
        print(f"\n➡️ Loading {condition.upper()} | Runs: {run_list}")

        raw_fnames = eegbci.load_data(subject, run_list)
        raws = [read_raw_edf(f, preload=True) for f in raw_fnames]
        raw_concat = concatenate_raws(raws)
        
        # Preprocessing pipeline
        eegbci.standardize(raw_concat)
        montage = make_standard_montage('standard_1005')
        raw_concat.set_montage(montage)
        raw_concat.set_eeg_reference(projection=True)
        raw_concat.filter(l_freq, h_freq, fir_design='firwin', skip_by_annotation="edge")
        
        print(f"✅ {condition} | Shape: {raw_concat._data.shape} | Duration: {raw_concat.times[-1] / 60:.2f} min")

        # Save preprocessed raw for this condition
        subject_data[condition] = raw_concat

    return subject_data

In [3]:
def quick_plot(raw, title="Raw EEG Debug"):
    """
    Quick plot for sanity check.
    """
    raw.plot(n_channels=8, scalings="auto", title=title, show=True)

In [4]:
subject = 1
runs = {
    "rest": [1],
    "motor_execution": [3, 7, 11],
    "motor_imagery": [4, 8, 12]
}

subject_data = load_and_preprocess_subject(subject, runs)


➡️ Loading REST | Runs: [1]
Extracting EDF parameters from C:\Users\flavi\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0\S001\S001R01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 9759  =      0.000 ...    60.994 secs...
EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cu

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 529 samples (3.306 s)

✅ motor_execution | Shape: (64, 60000) | Duration: 6.25 min

➡️ Loading MOTOR_IMAGERY | Runs: [4, 8, 12]
Extracting EDF parameters from C:\Users\flavi\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0\S001\S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Extracting EDF parameters from C:\Users\flavi\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0\S001\S001R08.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from C:\Users\flavi\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0\S001\S001R12.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Filtering raw data in 3 contiguous segments
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 pas

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


In [None]:
%matplotlib qt

# Debug plot (e.g., rest condition)
quick_plot(subject_data["rest"], title=f"Subject {subject} | REST")

# Optional: quick check of motor
quick_plot(subject_data["motor_execution"], title=f"Subject {subject} | MOTOR EXECUTION")

In [None]:
def extract_condition(raw, condition):
    """
    Extracts either REST or TASK segments from the raw data, drops edge-boundary epochs,
    and shuffles TASK segments to avoid block structure bias.
    """
    print(f"\n📌 {condition.upper()} Extraction")

    # Get events
    events, event_id = events_from_annotations(raw)
    print(f"Found events: {event_id}")

    # Create Epochs (we'll drop baseline later if needed)
    epochs = Epochs(raw, events, event_id=event_id, tmin=0, tmax=2.0, baseline=None, preload=True)

    # Drop epochs that have boundary annotations (EDGE)
    epochs = epochs.drop_bad_annotations()
    print(f"🔍 Dropped epochs with boundary flags. Remaining epochs: {len(epochs)}")

    # Separate based on condition
    if condition == "rest":
        rest_epochs = epochs["T0"]
        print(f"🟢 REST epochs: {len(rest_epochs)}")
        return rest_epochs

    elif condition == "task":
        task_epochs = mne.concatenate_epochs([epochs["T1"], epochs["T2"]])
        print(f"🟠 TASK epochs before shuffle: {len(task_epochs)}")

        # Shuffle the epochs
        shuffled_indices = np.random.permutation(len(task_epochs))
        task_epochs = task_epochs[shuffled_indices]
        print(f"🔀 TASK epochs shuffled.")

        return task_epochs

    else:
        raise ValueError("Condition must be either 'rest' or 'task'")

In [None]:
# For REST run - no need to segment (optional)
raw_rest = subject_data["rest"]

# For MOTOR EXECUTION (T0/T1/T2 separation)
segments_exec = segment_condition(subject_data["motor_execution"], condition_label="motor_execution")

# For MOTOR IMAGERY (T0/T1/T2 separation)
segments_imag = segment_condition(subject_data["motor_imagery"], condition_label="motor_imagery")

# Optional sanity plot
if segments_exec["task"]:
    segments_exec["task"].plot(n_channels=8, title="Motor Execution - TASK ONLY")

if segments_exec["rest"]:
    segments_exec["rest"].plot(n_channels=8, title="Motor Execution - REST ONLY")

In [6]:
def extract_clean_epochs(raw, tmin=0.0, tmax=4.0, reject_boundary_epochs=True):
    """
    Extract clean epochs from raw data while avoiding edge artifacts.
    
    Parameters:
    -----------
    raw : mne.io.Raw
        Preprocessed raw object.
    tmin : float
        Start time (in seconds) relative to the event marker.
    tmax : float
        End time (in seconds) relative to the event marker.
    reject_boundary_epochs : bool
        Whether to automatically reject epochs too close to raw boundaries.
    
    Returns:
    --------
    dict
        {'rest': epochs_T0, 'task': epochs_T1_T2_combined}
    """
    # Extract events from annotations (T0, T1, T2)
    events, event_id = mne.events_from_annotations(raw)
    print(f"\n⏺️ Used Annotations descriptions: {list(event_id.keys())}")

    # Define event IDs
    rest_id = event_id.get("T0")
    task_ids = [event_id.get("T1"), event_id.get("T2")]
    
    # Create epochs for REST (T0)
    epochs_rest = mne.Epochs(
        raw, events, event_id=rest_id,
        tmin=tmin, tmax=tmax, baseline=None,
        reject_by_annotation=reject_boundary_epochs,
        preload=True
    )
    
    # Create epochs for TASK (T1 + T2 combined)
    epochs_task = mne.Epochs(
        raw, events, event_id=task_ids,
        tmin=tmin, tmax=tmax, baseline=None,
        reject_by_annotation=reject_boundary_epochs,
        preload=True
    )

    # Debug prints
    print(f"📊 Extracted {len(epochs_rest)} REST epochs & {len(epochs_task)} TASK epochs (combined T1+T2)")

    return {"rest": epochs_rest, "task": epochs_task}

In [None]:
# Extract epochs and organize them into "rest", "motor_execution", "motor_imagery"
epochs_dict = extract_clean_epochs(raw)