## Linear SVM Classifier on the Motor Imagery vs Rest - Low-Cost EEG System Dataset

In [None]:
import mne
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.decomposition import FastICA
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from mne.decoding import CSP
from mne.preprocessing import ICA
from scipy.stats import pearsonr

In [32]:
def loadData(): 
    #get all subject list
    participants_info = pd.read_csv('participants.tsv', sep='\t')
    participants = participants_info.participant_id.tolist()

    # get edf data for the each particpant
    files = []
    for sub in participants: 
        sub_path = sub + '/eeg/'
        # get all edf files
        part_files = [fn for fn in os.listdir(sub_path) if fn.endswith('edf')]
        files.append(part_files)
    
    return files, participants


In [33]:
def loadRaws(files, participants): 
    # read raw eeglab data and append to a list
    raws = []
    for i in range(len(files)):
        temp = []
        for file in files[i]: 
            path = participants[i] + '/eeg/'
            raw = mne.io.read_raw_edf(path+file, preload=True)
            temp.append(raw)
        raws.append(temp)

    # assert(len(raws) == 109 and len(raws[0]) == 14)
    raws = [raw for sublist in raws for raw in sublist]
    # assert(len(raws) == 1526)
    return raws

In [34]:
%%capture
files, participants = loadData()
raw_list = loadRaws(files, participants)
# resample frequency may affect data resolution
raw_list_resampled = [r.resample(240, npad='auto') for r in raw_list]

In [35]:
raw = mne.concatenate_raws(raw_list_resampled)

In [36]:
# Apply a band-pass filter between 8 and 30 Hz (mu 8-13, beta 13-30) - motor imagery related band
mubeta_raw = raw.filter(l_freq=8, h_freq=30)
mubeta_raw = mubeta_raw.notch_filter(50) # remove power line noise
# Apply ICA
ica = ICA(n_components=10, random_state=42)
mubeta_raw = ica.fit(mubeta_raw)

Filtering raw data in 50 contiguous segments
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 397 samples (1.654 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  15 out of  15 | elapsed:    0.0s finished


Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1585 samples (6.604 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.2s remaining:    0.0s


Fitting ICA to data using 15 channels (please be patient, this may take a while)


[Parallel(n_jobs=1)]: Done  15 out of  15 | elapsed:    1.1s finished


Selecting by number: 10 components
Fitting ICA took 34.0s.


In [37]:
all_events, event_dict =  mne.events_from_annotations(raw);
# get only event from MI (right) and rest (tongue)
env = {"OVTK_GDF_Right": 7, "OVTK_GDF_Tongue": 9}
events, MI_events_dict = mne.events_from_annotations(raw, event_id=env)

Used Annotations descriptions: ['OVTK_GDF_Correct', 'OVTK_GDF_Cross_On_Screen', 'OVTK_GDF_End_Of_Session', 'OVTK_GDF_End_Of_Trial', 'OVTK_GDF_Feedback_Continuous', 'OVTK_GDF_Incorrect', 'OVTK_GDF_Right', 'OVTK_GDF_Start_Of_Trial', 'OVTK_GDF_Tongue', 'OVTK_StimulationId_BaselineStart', 'OVTK_StimulationId_BaselineStop', 'OVTK_StimulationId_Beep', 'OVTK_StimulationId_ExperimentStart', 'OVTK_StimulationId_ExperimentStop', 'OVTK_StimulationId_Train']
Used Annotations descriptions: ['OVTK_GDF_Right', 'OVTK_GDF_Tongue']


In [38]:
#epoch the data with the event info
# we should get 20 trials for each condition
epochs = mne.Epochs(raw, events, tmin=.75, tmax=2.5, baseline=None)
epochs.load_data()

mubeta_data = epochs.get_data()
mubeta_labels = epochs.events[:, -1]

# Define the CSP and LDA pipeline
csp = CSP(n_components=50, reg=None, log=True)

svc = SVC(kernel='rbf', C=50)
csp_pipeline = make_pipeline(csp, StandardScaler(), svc)

cv = ShuffleSplit(5, test_size=0.2, random_state=42)

Not setting metadata
1690 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1690 events and 421 original time points ...
0 bad epochs dropped


In [39]:
%%capture
muband_scores = cross_val_score(csp_pipeline, mubeta_data, mubeta_labels, cv=cv, n_jobs=2)

In [40]:
print("mu band accuracy: ", muband_scores.mean())

mu band accuracy:  0.7017751479289941
