In [None]:
import os
import mne
import numpy as np
from mne.io import read_raw_edf
from mne.decoding import CSP
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score

subjects = ['S001','S002']
runs = ['R03', 'R04', 'R07', 'R08', 'R11', 'R12',
        'R05', 'R06', 'R09', 'R10', 'R13', 'R14']

epochs_list = []
labels_list = []

def clean_channel_names(raw):
    new_ch_names = {}
    for name in raw.info['ch_names']:
        new_name = name.rstrip('.').upper()
        if new_name == 'T9':
            new_name = 'FT9'
        elif new_name == 'T10':
            new_name = 'FT10'
        new_ch_names[name] = new_name
    raw.rename_channels(new_ch_names)

for subject in subjects:
    for run in runs:
        edf_path = f'/content/{subject}{run}.edf'
        if not os.path.isfile(edf_path):
            print(f'File {edf_path} not found. Please upload it.')
            continue
        raw = read_raw_edf(edf_path, preload=True, verbose=False)

        clean_channel_names(raw)

        montage = mne.channels.make_standard_montage('standard_1005')
        raw.set_montage(montage, match_case=False)

        raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge', verbose=False)

        events, event_id = mne.events_from_annotations(raw, verbose=False)

        tmin, tmax = 0., 4.
        picks = mne.pick_types(raw.info, eeg=True, exclude='bads')
        epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                            picks=picks, baseline=None, preload=True, verbose=False)

        if run in ['R03', 'R04', 'R07', 'R08', 'R11', 'R12']:
           label_map = {'T1': 0, 'T2': 1}
        elif run in ['R05', 'R06', 'R09', 'R10', 'R13', 'R14']:
           label_map = {'T2': 2}
        else:
            continue




        labels = []
        selected_epochs = []
        for i, e in enumerate(epochs.events):
            event = list(event_id.keys())[list(event_id.values()).index(e[2])]
            if event in label_map:
                labels.append(label_map[event])
                selected_epochs.append(i)

        epochs = epochs[selected_epochs]

        epochs_list.append(epochs)
        labels_list.append(labels)

if len(epochs_list) == 0:
    print("No epochs found. Please check your data files and labels.")
else:
    epochs = mne.concatenate_epochs(epochs_list)
    labels = np.concatenate(labels_list)

X = epochs.get_data()
y = labels



  epochs = mne.concatenate_epochs(epochs_list)


Not setting metadata
271 matching events found
No baseline correction applied


In [None]:
X = epochs.get_data()
y = labels


X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)

csp.fit(X_train, y_train)

X_train_csp = csp.transform(X_train)
X_test_csp = csp.transform(X_test)

clf = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=1000)

clf.fit(X_train_csp, y_train)

y_pred = clf.predict(X_test_csp)

print("Classification Report:")
print(classification_report(y_test, y_pred))

accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy * 100:.2f}%")


Computing rank from data with rank=None
    Using tolerance 0.00058 (2.2e-16 eps * 64 dim * 4.1e+10  max singular value)
    Estimated rank (data): 64
    data: rank 64 computed from 64 data channels with 0 projectors
Reducing data rank from 64 -> 64
Estimating class=0 covariance using EMPIRICAL
Done.
Estimating class=1 covariance using EMPIRICAL
Done.
Estimating class=2 covariance using EMPIRICAL
Done.
Classification Report:
              precision    recall  f1-score   support

           0       0.79      0.79      0.79        19
           1       0.57      0.72      0.63        18
           2       0.92      0.67      0.77        18

    accuracy                           0.73        55
   macro avg       0.76      0.73      0.73        55
weighted avg       0.76      0.73      0.73        55

Accuracy: 72.73%


