# **Load and prepare data from MNE-Python**

## Import libraries

In [1]:
import numpy as np
import pandas as pd
import mne
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
import matplotlib.pyplot as plt

## Fetch data

In [3]:
%%capture

# Discard 6 subjects: 88, 89, 92, 100, 104, 106 (see https://www.sciencedirect.com/science/article/pii/S2352340924001525#sec0004)
subjects = set(range(1, 104))
subjects = list(sorted(subjects - {88, 89, 92, 100, 104, 106}))

# Define MI left/right hand runs
runs = [4, 8, 12]

path = ## PATH TO STORE DATASET ##

raw_fnames_all = []
for subject in subjects:
    raw_fnames = eegbci.load_data(subject, runs, path=path)
    raw_fnames_all.extend(raw_fnames)

raws = [read_raw_edf(f, preload=False) for f in raw_fnames_all]
for raw in raws:    
    raw.crop(tmax=122.9937) # Crop to standard recording length for run (19680 measurements)

raw = concatenate_raws(raws)

df_raw = raw.to_data_frame()

print(f"Shape of df_raw:", df_raw.shape) 
print(f"Expected length:", 19680 * 3 * 103) #Run length * number of runs pr subject * number of subjects

## Add annotation/task labels

In [6]:
events = mne.events_from_annotations(raw)
events_df = pd.DataFrame(events[0], columns = ["time_index", "X", "annotation"])
event_df = events_df.drop('X', axis = 1)

# 1 = T0 (rest); 2 = T1 (left); 3 = T2 (right)

df_raw = df_raw.reset_index().rename(columns={'index': 'time_index'})
df = pd.merge(df_raw, event_df[['time_index', 'annotation']], on='time_index', how='left').fillna(method='ffill')
df['annotation'] = df['annotation'].replace({1: 'rest', 2: 'left', 3: 'right'})

print(df['annotation'].unique())
print(df.shape)

Used Annotations descriptions: ['T0', 'T1', 'T2']
['rest' 'right' 'left']
(6081120, 67)


## Add subject_ID and run as columns

In [27]:
df['subject_ID'] = ['S' + str(id) for id in ((np.arange(len(df)) // (19680 * 3)) + 1)]

run_labels = list(np.repeat(['run_1', 'run_2', 'run_3'], 19680))
df['run'] = np.tile(run_labels, 103)[:len(df)]

print(df['subject_ID'].nunique())
print(df['run'].value_counts())

103
run_1    2027040
run_2    2027040
run_3    2027040
Name: run, dtype: int64


## Explore consistency in trial length

In [None]:
df['prev_annotation'] = df['annotation'].shift()
changes = df[df['annotation'].ne(df['prev_annotation'])]

lengths_of_changes = []

prev_idx = None
for idx, row in changes.iterrows():
    if prev_idx is not None:
        change_length = idx - prev_idx
        lengths_of_changes.append(change_length)
    prev_idx = idx

df = df.drop(columns="prev_annotation")
print(f"Number of changes:", len(lengths_of_changes))
print(f"Target number of changes:", 103*3*30-1) #103 subjects, 3 runs, 30 trials, - 1 for end
print(f"Length of trials and count:", np.unique(lengths_of_changes, return_counts=True))

## Truncate/remove samples

Because trial length is not consistent, we remove shorter sequences and truncate longer sequences to fit a standard length (most frequently observed: 656 obs)

### Define functions

In [None]:
def split_dataframe(df, column):
    splits = []
    start_idx = 0
    
    for i in range(1, len(df)):
        if df[column].iloc[i] != df[column].iloc[start_idx]:
            splits.append(df.iloc[start_idx:i])
            start_idx = i

    splits.append(df.iloc[start_idx:])
    
    return splits


def filter_and_truncate_splits(splits, min_length=656, max_length=656):
    filtered_splits = []
    
    for split in splits:
        if len(split) >= min_length:
            if len(split) > max_length:
                split = split.iloc[:max_length]
            filtered_splits.append(split)
    
    return filtered_splits

### Apply

In [None]:
subjects = df['subject_ID'].unique()
runs = df['run'].unique()
final_filtered_splits = []

for subject in subjects:
    for run in runs:
        sub_df = df[(df["subject_ID"] == subject) & (df["run"] == run)]
        splits = split_dataframe(sub_df, 'annotation')
        filtered_splits = filter_and_truncate_splits(splits, min_length=656, max_length=656)
        lengths = [len(s) for s in filtered_splits]
        final_filtered_splits.extend(filtered_splits)

df = pd.concat(final_filtered_splits).reset_index(drop=True)

print(f"Length of dataset:", df.shape)
print(f"Target length of dataset:", len(df) - (63*416+11*624+25*640) - (970*672-970*656 + 11*688-11*656))


## Save

In [None]:
df.to_csv("/data_csv/df_MI.csv")