# INFO
- experiment: p300 speller
- stimulation: rsvp 100ms(stim) / 75ms(blank) / 2500ms break between char / 15 flashes per char / random words of 10 chars
- users tested: 1
- devices tested : 
    - muse 2: freq 256Hz / channels TP9,AF7,AF8,TP10
    - muse 2+: freq 256Hz / channels TP9,AF7,AF8,TP10,POz
    - OpenBCI: freq 125Hz / channels FC3,FCz,FC4,T7,C3,Cz,C4,T8,P7,P3,Pz,P4,P8,O1,O2,Oz
- metric used : Area Under the Curve (AUC)


This code demonstrates the evolution of selected predictors over increasing amount of data

# DEPENDENCIES

In [1]:
# BUILT-IN
import os,sys
import math
from collections import OrderedDict
import itertools
import datetime

# DATAFRAMES
import pandas as pd
import numpy as np

# SCIKIT-LEARN
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import cross_val_score, StratifiedShuffleSplit,GridSearchCV
from sklearn.externals import joblib

# PYRIEMANN
from pyriemann.estimation import ERPCovariances, XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.classification import MDM
from pyriemann.spatialfilters import Xdawn

# MNE
from mne import Epochs, find_events
from mne.channels import read_montage
from mne import create_info, concatenate_raws
from mne.io import RawArray
from mne.decoding import Vectorizer

# PLOTTING
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

# SETTINGS

In [2]:
the_folder_path = "../data/p300_speller" # relative datasets path
the_user = "compmonks" # check available users in data folder or add new ones
the_device = "muse2+" # "muse2+" # "muse2" # "openbci_v207" # available devices
the_freq = 256 # 256 (Muse) # 125 (OpenBCI) # Sampling Frequency in Hertz
break_epoch = 2500 # ms epoch used to break between character sessions
the_montage = "standard_1005" # "standard_1005" (Muse) # "standard_1020" (OpenBCI) # channels montage
the_units = "uVolts" # "uVolts" # "Volts" # unit of received data from device
the_markers = {'Non-Target': 2, 'Target': 1} # markers from stim data
sns.set_context('talk')
sns.set_style('white')
diverging_color_palette = "coolwarm"
categorical_color_palette = "Paired"

# list of best known discriminators for P300
clfs = OrderedDict()
clfs['Vect + LR'] = make_pipeline(Vectorizer(), StandardScaler(), LogisticRegression())
clfs['Vect + RegLDA'] = make_pipeline(Vectorizer(), LDA(shrinkage='auto', solver='eigen'))
clfs['Xdawn + RegLDA'] = make_pipeline(Xdawn(2, classes=[1]), Vectorizer(), LDA(shrinkage='auto', solver='eigen'))
clfs['ERPCov + TS'] = make_pipeline(ERPCovariances(), TangentSpace(), LogisticRegression())
clfs['ERPCov + MDM'] = make_pipeline(ERPCovariances(), MDM())
clfs['XdawnCov + TS + LReg'] = make_pipeline(XdawnCovariances(2),TangentSpace(metric='riemann'),LogisticRegression())

# UTILS

In [3]:
def estimatorParamsDict():
    """# get all params of a list of estimators and return them as a distionary."""
    
    pass
#
def train_save(self, the_predictor, the_X, the_y):
    """ Train a given predictor formatted as a pipeline."""

    trained = the_predictor.fit(the_X,the_y)
    joblib.dump(trained,'predictor_{}.pkl'.format(datetime.datetime.now()))

# COMPARISON

In [None]:
all_results =  pd.DataFrame({'Temp' : []})
raw = []
session_num = 0
the_training_path = os.path.join(the_folder_path,the_device,the_user,"training")
the_additional_path = os.path.join(the_folder_path,the_device,the_user,"additional")
all_sessions_data_path = list(os.walk(the_training_path)) + list(os.walk(the_additional_path))
#
for root, subdirs, files in all_sessions_data_path:
    for dirs in subdirs:
        print("device:{} user:{} dir:{} session dir:{}".format(the_device,the_user,subdirs,dirs))
        file_list = [fn for fn in os.listdir(os.path.join(root,dirs)) if fn.endswith("hdf5")]
        if len(file_list) == 2:
            try:
                print("session_num: {}".format(session_num))
                print("files: {}".format(file_list))
                if "FEEDBACK" in file_list[0]:
                    dataF = pd.read_hdf(os.path.join(root,dirs,file_list[0]),'data')
                    dataA = pd.read_hdf(os.path.join(root,dirs,file_list[1]),'EEG')
                else:
                    dataF = pd.read_hdf(os.path.join(root,dirs,file_list[1]),'data')
                    dataA = pd.read_hdf(os.path.join(root,dirs,file_list[0]),'EEG')
                #
                dataA['Stim'] = np.nan
                prev_marker = 0
                for index,row in dataA.iterrows():
                    new_marker = dataF.iloc[dataF.index.get_loc(pd.to_datetime(index),method='nearest')]['marker'].astype(int)
                    if prev_marker == 0:
                        if new_marker != 0:
                            row['Stim'] = new_marker
                            prev_marker = new_marker
                        else:
                            row['Stim'] = 0
                    else:
                        row['Stim'] = 0
                        prev_marker = new_marker

                channel_names = list(dataA.keys())
                # temporary fix on possibly erroneous channel labelling on some of the former data
                for _it,_ch in enumerate(channel_names):
                    if _ch == "P2":
                        channel_names[_it] = "O2"
                #
                channel_types = ['eeg'] * (len(list(dataA.keys()))-1) + ['stim']
                the_data = dataA.values[:].T
                the_data[:-1] *= 1e-6 if the_units == "uVolts" else 1
                info = create_info(ch_names = channel_names, 
                                   ch_types = channel_types,
                                   sfreq = the_freq, 
                                   montage = the_montage)
                raw.append(RawArray(data = the_data, info = info))
                the_raw = concatenate_raws(raw)
                the_events = find_events(the_raw)
                the_epochs = Epochs(the_raw,
                                    events = the_events, 
                                    event_id = the_markers,
                                    tmin = -0.1, tmax = 0.7,
                                    baseline = None,
                                    reject = {'eeg': 75e-6},
                                    preload = True,
                                    verbose = False,
                                    picks = list(range(len(channel_names)-1))
                                   )
                #
                the_epochs.pick_types(eeg=True)
                X = the_epochs.get_data() * 1e6
                y = the_epochs.events[:, -1]
                cv = StratifiedShuffleSplit(n_splits=10, test_size=0.25, random_state=42)
                auc = []
                methods = []
                #params = []
                # cross validation AUC score by classifier
                for m in clfs:
                    #clf = GridSearchCV(clfs[m], grid_params)
                    #clf.fit(train, y_train)
                    res = cross_val_score(clfs[m], X, y==2, scoring='roc_auc', cv=cv, n_jobs=-1)
                    auc.extend(res)
                    methods.extend([m]*len(res))
                the_best_score = max(auc)
                the_best_method = methods[auc.index(the_best_score)]
                print("The best predictor is: {} score:{}".format(the_best_method, the_best_score))
                # plot
                results = pd.DataFrame(data=auc, columns=['AUC'])
                results['Method'] = methods
                results['session_num'] = session_num
                if all_results.empty:
                    palette = sns.color_palette(categorical_color_palette, len(results))
                    all_results = results.copy()
                else:
                    all_results = pd.concat([all_results, results], ignore_index=True, sort=False)
                session_num +=1
            except Exception as e:
                print("ERROR: {}".format(e))
                pass
# train and save best predictor with all data
train_save(clfs[the_best_method], X, y)

# plotting the accuracy comparison over sessions
lp = sns.lineplot(x='session_num', y='AUC', hue='Method', data=all_results)
lp.legend(loc='center right', bbox_to_anchor=(1.25, 0.5), ncol=1)
plt.savefig(os.path.join(the_additional_path,"AUC.png"),dpi = 800,format = "png")

device:muse2+ user:compmonks dir:['session_003', 'session_004', 'session_002', 'session_000', 'session_001'] session dir:session_003
session_num: 0
files: ['compmonks_T2_FEEDBACK_2019-5-2_8-16-27-246000.hdf5', 'compmonks_T2_INTERAXON-Muse2_2019-5-2_8-24-15-829809.hdf5']
Creating RawArray with float64 data, n_channels=6, n_times=96829
    Range : 0 ... 96828 =      0.000 ...   378.234 secs
Ready.
1798 events found
Event IDs: [1 2]
The best predictor is: ERPCov + MDM score:0.707277628032345
device:muse2+ user:compmonks dir:['session_003', 'session_004', 'session_002', 'session_000', 'session_001'] session dir:session_004
session_num: 1
files: ['compmonks_T2_INTERAXON-Muse2_2019-5-2_11-6-58-216310.hdf5', 'compmonks_T2_FEEDBACK_2019-5-2_10-59-17-340000.hdf5']
Creating RawArray with float64 data, n_channels=6, n_times=96478
    Range : 0 ... 96477 =      0.000 ...   376.863 secs
Ready.
3585 events found
Event IDs: [1 2]
The best predictor is: XdawnCov + TS + LReg score:0.74629743475102
devi