## EEG classification (decoding)

#### load libraries

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit, cross_val_score
from mne import Epochs, pick_types, events_from_annotations
from mne.channels import make_standard_montage
from mne.io import read_raw_edf, concatenate_raws
from mne.decoding import CSP

In [None]:
matplotlib.use('Qt5Agg')

### Load the data

In [None]:
# define path
sample_data_folder = '/Users/christinadelta/datasets/eeg_motor_imagery'

In [None]:
runs = 14
raws = []
subject = 1

In [None]:
for sub in range(subject):
    this_sub = 'S0{0:02d}'.format(sub+1)
    for run in range(runs):
        this_run = 'R{0:02d}'.format(run+1)
        this_eeg = os.path.join(sample_data_folder, this_sub,
                                f'{this_sub}-S001{this_run}.edf')
        
        this_raw = read_raw_edf(this_eeg, preload=True)
        raws.append(this_raw)

In [None]:
# for now keep only runs 6, 10, 14
raws_temp = []
raws_temp.append(raws[5])
raws_temp.append(raws[9])
raws_temp.append(raws[13])

In [None]:
# concatinate the 3 raw files
raw = concatenate_raws(raws_temp)
raw

### Apply default sensor locations (montage) to the data
I'll use the function that MNE provides for standardization of raw:
https://github.com/mne-tools/mne-python/blob/maint/0.23/mne/datasets/eegbci/eegbci.py#L157-L174 

In [None]:
# function to standardise raw (taken from )
def standardize(raw):
    """Standardize channel positions and names.
    Parameters
    ----------
    raw : instance of Raw
        The raw data to standardize. Operates in-place.
    """
    rename = dict()
    for name in raw.ch_names:
        std_name = name.strip('.')
        std_name = std_name.upper()
        if std_name.endswith('Z'):
            std_name = std_name[:-1] + 'z'
        if std_name.startswith('FP'):
            std_name = 'Fp' + std_name[2:]
        rename[name] = std_name
    raw.rename_channels(rename)

In [None]:
standardize(raw)
montage = make_standard_montage('standard_1005')
raw.set_montage(montage)

In [None]:
# filter data
raw.filter(7., 30.)
event_id = dict(T1=0, T2=1)

In [None]:
events, _ = events_from_annotations(raw, event_id = event_id)
picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')

In [None]:
events

### Make epochs around each of the two conditions 

In [None]:
# set beginning and end of each epoch
tmin, tmax = -1., 4.

In [None]:
epochs = Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, proj=True, picks=picks, 
               baseline=None, preload=True)

#### Choose a segment and not the entire trial for training

The whole epoch lasts for 5 sec (-1 to 4 sec), here we will train the classifier on a time window from 1 to 2 sec (after onset), but we will test the classifier in the entire trial 

In [None]:
epochs_train = epochs.copy().crop(tmin=1., tmax=2.)
labels = epochs.events[:,-1]

### Constract the classifier 

make K folds for cross-validation of classifier. 

In [None]:
scores = []
epochs_data = epochs.get_data()
epochs_data_train = epochs_train.get_data()

Split the data into 10 different groups and get 9 groups as training data and test on the remaining group. These are called folds; here we 10 folds cross-validation 

In [None]:
cv = ShuffleSplit(10, test_size=0.2, random_state=42) 
cv_split = cv.split(epochs_data_train)

### Assemble a classifier based on the common spatial patterns (CSP) for feature extraction

CSP is a spatial filter which takes the EEG channels and weights them appropriately to point at a source which shows a maximal power difference between the two conditions  

In [None]:
lda = LinearDiscriminantAnalysis()
csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)

### Train the CSP classifier in order to visualise patterns (inverse of spatial filters)

In [None]:
csp.fit_transform(epochs_data, labels)
csp.plot_patterns(epochs.info, ch_type='eeg', units='Patterns (AU)', size=1.5)

### Prepare to classify the data in a sliding window (starting from imagery onset)

In [None]:
sfreq = raw.info['sfreq']
w_length = int(sfreq * 0.5) # running classifier: window length
w_step = int(sfreq * 0.1) # running classifier: window step size 
w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step)
w_scores = []

In [None]:
for train_idx, test_idx in cv_split:
    y_train, y_test = labels[train_idx], labels[test_idx]

    X_train = csp.fit_transform(epochs_data_train[train_idx], y_train)
    X_test = csp.transform(epochs_data_train[test_idx])

    # fit classifier
    lda.fit(X_train, y_train)

    # running classifier: test classifier on sliding window
    score_this_window = []
    for n in w_start:
        X_test = csp.transform(epochs_data[test_idx][:, :, n:(n + w_length)])
        score_this_window.append(lda.score(X_test, y_test))
    w_scores.append(score_this_window)

In [None]:
w_scores

In [None]:
# Plot scores over time
w_times = (w_start + w_length / 2.) / sfreq + epochs.tmin

plt.figure()
plt.plot(w_times, np.mean(w_scores , 0), label='Score')
plt.axvline(0, linestyle='--', color='k', label='Onset')
plt.axhline(0.5, linestyle='-', color='k', label='Chance')
plt.xlabel('time (s)')
plt.ylabel('classification accuracy')
plt.title('Classification score over time')
plt.legend(loc='lower right')
plt.show()