In [37]:
import mne
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.pipeline import Pipeline
from mne.decoding import CSP
from scipy.stats import pearsonr

In [2]:
def loadData(): 
    #get all subject list
    participants_info = pd.read_csv('participants.tsv', sep='\t')
    participants = participants_info.participant_id.tolist()

    # get edf data for the each particpant
    files = []
    for sub in participants: 
        sub_path = sub + '/eeg/'
        # get all edf files
        part_files = [fn for fn in os.listdir(sub_path) if fn.endswith('set')]
        files.append(part_files)
    
    return files, participants


In [3]:
def loadRaws(files, participants): 
    # read raw eeglab data and append to a list
    raws = []
    for i in range(len(files)):
        temp = []
        for file in files[i]: 
            path = participants[i] + '/eeg/'
            raw = mne.io.read_raw_eeglab(path+file, preload=False)
            temp.append(raw)
        raws.append(temp)

    assert(len(raws) == 109 and len(raws[0]) == 14)
    raws = [raw for sublist in raws for raw in sublist]
    assert(len(raws) == 1526)
    return raws

In [4]:
# function to match numeric event labels to textual true labels
def matchLabels(events, event_dict) : 
    task_labels = list(event_dict.keys())
    
    if len(task_labels) == 3: 
        label_dict = {1:task_labels[0], 2:task_labels[1], 3:task_labels[2]}
    elif len(task_labels) == 2: 
        label_dict = {1:task_labels[0], 2:task_labels[1]}
    else: 
        label_dict = {1:task_labels[0]}
        
    # numeric event label is the last element of each row
    labels = [label_dict[e[-1]] for e in events]
    return labels

In [5]:
def epoch(raws): 
    X_epochs = []
    y_epochs = []
    for raw in raws: 
        events, event_dict = mne.events_from_annotations(raw)
        epoch_raw = mne.Epochs(raw, events, tmin=-0.1, tmax=0.1, baseline=None, preload=False)
        
        epoch_data = epoch_raw.get_data()
        epoch_data = np.array(epoch_data)
        # shape = epoch_data.shape
        # #reshape to num_events x (num_channels*num_timepoints)
        # epoch_data = epoch_data.reshape(shape[0], shape[1]*shape[2])
        
        curr_labels = matchLabels(epoch_raw.events, event_dict)

        csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)
        X_epochs.append(csp.fit_transform(epoch_data, curr_labels))
        y_epochs.append(curr_labels)

    assert(len(X_epochs) == len(y_epochs))
    X_data = [event for x in X_epochs for event in x]
    y_true = [label for y in y_epochs for label in y]
    assert(len(X_data) == len(y_true))
    
    return X_data, y_true

In [6]:
# %%capture
# files, participants = loadData()
# raws = loadRaws(files, participants)

In [7]:
# %%capture
# X_data, y_true = epoch(raws)

In [8]:
# labels_text = sorted(list(set(y_true)))
# print(labels_text)

In [9]:
%%capture
files, participants = loadData()
raw_list = loadRaws(files, participants)

In [10]:
print(raw_list[0].info['sfreq'], raw_list[1218].info['sfreq'])
raw_list_resampled = [r.resample(128, npad='auto') for r in raw_list]

160.0 128.0


In [11]:
raw = mne.concatenate_raws(raw_list_resampled)

In [62]:
events, event_id = mne.events_from_annotations(raw)
epochs = mne.Epochs(raw, events, event_id, tmin=-0.4, tmax=0.4, baseline=None)

Used Annotations descriptions: ['BASE1T0', 'BASE1T1', 'BASE2T0', 'BASE2T1', 'TASK1T0', 'TASK1T1', 'TASK1T2', 'TASK2T0', 'TASK2T1', 'TASK2T2', 'TASK3T0', 'TASK3T1', 'TASK3T2', 'TASK4T0', 'TASK4T1', 'TASK4T2']
Not setting metadata
39569 matching events found
No baseline correction applied
0 projection items activated


In [69]:
# Separate the data and labels
data = epochs.get_data()
labels = epochs.events[:, -1]

Using data from preloaded Raw for 38041 events and 103 original time points ...


In [71]:
# Define the CSP and LDA pipeline
csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)
lda = LDA()
svc = SVC(kernel='linear', C=1)
logreg = LogisticRegression()
clf = Pipeline([('CSP', csp), ('LDA', lda)])

In [72]:
# Define cross-validation and train the model
cv = ShuffleSplit(20, test_size=0.3, random_state=42)
scores = cross_val_score(clf, data, labels, cv=cv, n_jobs=1)

Computing rank from data with rank=None
    Using tolerance 0.0037 (2.2e-16 eps * 64 dim * 2.6e+11  max singular value)
    Estimated rank (mag): 64
    MAG: rank 64 computed from 64 data channels with 0 projectors
Reducing data rank from 64 -> 64
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 0.0025 (2.2e-16 eps * 64 dim * 1.8e+11  max singular value)
    Estimated rank (mag): 64
    MAG: rank 64 computed from 64 data channels with 0 projectors
Reducing data rank from 64 -> 64
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 0.0025 (2.2e-16 eps * 64 dim * 1.7e+11  max singular value)
    Estimated rank (mag): 64
    MAG: rank 64 computed from 64 data channels with 0 projectors
Reducing data rank from 64 -> 64
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 0.0034 (2.2e-16 eps * 64 dim * 2.4e+11  max singular value)
    

In [73]:
print("Classification accuracy: ", scores.mean())

Classification accuracy:  0.12179094015596251
