# INFO
- experiment: p300 speller

This code demonstrates how to predict  a spelled word from a trained discriminator and unlabelled data.

# DEPENDENCIES

In [148]:
# BUILT-IN
import os,sys
import math
from collections import OrderedDict
import itertools

# DATAFRAMES
import pandas as pd
import numpy as np
# PLOTTING
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
# MNE
from mne import Epochs, find_events
from mne.channels import read_montage
from mne import create_info, concatenate_raws
from mne.io import RawArray
from mne.decoding import Vectorizer
# SCIKIT-LEARN
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import cross_val_score, StratifiedShuffleSplit
from sklearn.externals import joblib
from sklearn.preprocessing import Imputer
# PYRIEMANN
from pyriemann.estimation import ERPCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.classification import MDM
from pyriemann.spatialfilters import Xdawn



from itertools import starmap, product
from operator import and_

# SETTINGS

In [149]:
the_folder_path = "../data/p300_speller" # relative datasets path
the_user = "compmonks" # "all" # check available users in data folder or add new ones
the_device = "muse2+" # "muse2+" # "muse2" # "openbci_v207" # available devices
the_freq = 256 # 256 (Muse) # 125 (OpenBCI) # Sampling Frequency in Hertz
the_montage = "standard_1005" # "standard_1005" (Muse) # "standard_1020" (OpenBCI) # channels montage
the_units = "uVolts" # "uVolts" # "Volts" # unit of received data from device
the_markers = {'Non-Target': 2, 'Target': 1} # markers from stim data
sns.set_context('talk')
sns.set_style('white')
diverging_color_palette = "coolwarm"
categorical_color_palette = "Paired"
char_num = 10 # number of total characters in the word
# if adding more users and data, please the data structure is adequate
break_epoch = 2500
row_col = ['AGMSY5', #0
           'BHNTZ6',
           'CIOU17',
           'DJPV28',
           'EKQW39',
           'FLRX40',
           'ABCDEF',
           'GHIJKL',
           'MNOPQR',
           'STUVWX',
           'YZ1234',
           '567890'] #11
            # 12 is blank


In [163]:
def segmentByLengthSeq(the_list,the_val,the_min_len):
    """Return an index of beginning of specific consecutive value sequences."""
    
    out = []
    i=0
    ind_=0
    for key,group in itertools.groupby(the_list):
        ind_ = len(list(group))
        if key == the_val and ind_>=the_min_len:
            out.append((key,ind_,i))
        i+=ind_
        
    return out

def all_intersections(sets):
    # Convert to set of frozensets for uniquification/type correctness
    last = new = sets = set(map(frozenset, sets))
    # Keep going until further intersections add nothing to results
    while new:
        # Compute intersection of old values with newly found values
        new = set(starmap(and_, product(last, new)))
        last = sets.copy()  # Save off prior state
        new -= last         # Determine truly newly added values
        sets |= new         # Accumulate newly added values in complete set
    # No more intersections being generated, convert results to canonical
    # form, list of lists, where each sublist is displayed in order, and
    # the top level list is ordered first by size of sublist, then by contents
    return sorted(map(sorted, sets), key=lambda x: (len(x), x))
    
def checkSpelledAnswer(answers,flashed,tokens):
    """..."""
    
    intersections = dict()
    selected_tokens = []
    # FILTER PREDICTED ROWS AND COLS
    for i,val in enumerate(answers):
        mask = val==1
        filtered = flashed[i][mask]
        filtered = np.unique(filtered)
        if filtered.size > 0:
            selected_tokens.append(set(tokens[filtered[0]]))
    for rc in  selected_tokens:
        for _char in rc:
            intersections[_char]=intersections.setdefault(_char,0)+1
            
    #print("selected_tokens:{}".format(selected_tokens))
    #the_char = set.intersection(*selected_tokens)
    #print("intersection:{}".format(set.intersection(*selected_tokens)))
    _sorted = sorted(intersections.items(), key=lambda kv: kv[1])
    #print("intersection:{}".format(_sorted[-1]))
    print("FOUND CHARACTER: {}".format(_sorted[-1][0]))
    

In [164]:
the_data_path = os.path.join(the_folder_path,the_device,the_user) if the_user != "all" else os.path.join(the_folder_path,the_device)

the_predictor = joblib.load(os.path.join(the_data_path,"compmonks_T2_erp-cov-ts.pkl"),"r")
# loop through subdirs and concatenate data
for root, subdirs, files in os.walk(the_data_path):
    for dirs in subdirs:
        print("device:{} user:{} dir:{}".format(the_device,the_user,dirs))
        file_list = os.listdir(os.path.join(root,dirs))
        for _f in file_list:
            if len(_f.split('.')[0]) <= 1 or ".pkl" in _f:
                file_list.remove(_f)
        #print("files:{}".format(file_list))
        if len(file_list) == 2:
            if "FEEDBACK" in file_list[0]:
                dataF = pd.read_hdf(os.path.join(root,dirs,file_list[0]),'data')
                dataA = pd.read_hdf(os.path.join(root,dirs,file_list[1]),'EEG')
            else:
                dataF = pd.read_hdf(os.path.join(root,dirs,file_list[1]),'data')
                dataA = pd.read_hdf(os.path.join(root,dirs,file_list[0]),'EEG')
        # BREAK DATA BY CHARACTER _____________________________________________________
        # the break length between char epochs
        the_char_epoch = (the_freq * (break_epoch/1000))*0.9
        chunks_indexes = segmentByLengthSeq(dataF['index'],12.0,the_char_epoch)
        #print("chunks_indexes:{}".format(chunks_indexes))
        for i,_char in enumerate(chunks_indexes):
            raw = []
            #
            start_F = dataF.index[_char[2]+_char[1]]
            stop_F =  dataF.index[chunks_indexes[i+1][2] if i+1 < len(chunks_indexes) else -1]
            start_A = dataA.index.get_loc(start_F,method='ffill')
            stop_A = dataA.index.get_loc(stop_F,method='ffill')
            #
            #print("start:{}".format(dataA.index[start_A]))
            #print("stop:{}".format(dataA.index[stop_A]))
            #
            chunk_dataF = dataF.iloc[dataF.index.get_loc(start_F,method='ffill'):dataF.index.get_loc(stop_F,method='ffill')][:]
            chunk_dataA = dataA.iloc[start_A:stop_A][:]
            chunk_dataA['Stim'] = np.nan
            prev_marker = 0
            prev_flash = 12
            index_list = []
            # TO CHECK ALSO IN TRAINING: BETTER MARKING FOR EVENTS AND CHAR INDEXES
            for index,row in chunk_dataA.iterrows():
                new_marker = chunk_dataF.iloc[chunk_dataF.index.get_loc(pd.to_datetime(index),method='nearest')]['marker'].astype(int)
                if prev_marker == 0:
                    if new_marker != 0:
                        row['Stim'] = new_marker
                        prev_marker = new_marker
                        new_flash = chunk_dataF['index'][chunk_dataF.index.get_loc(pd.to_datetime(index),method='nearest')].astype(int)
                        index_list.append(new_flash)
                    else:
                        row['Stim'] = 0
                else:
                    row['Stim'] = 0
                    prev_marker = new_marker
                
            channel_names = list(chunk_dataA.keys())
            channel_types = ['eeg'] * (len(list(chunk_dataA.keys()))-1) + ['stim']
            
            the_data = chunk_dataA.values[:].T
            the_data[:-1] *= 1e-6 if the_units == "uVolts" else 1

            info = create_info(ch_names = channel_names, 
                               ch_types = channel_types,
                               sfreq = the_freq, 
                               montage = the_montage)
            raw.append(RawArray(data = the_data, info = info))
            the_raw = concatenate_raws(raw)

            the_raw=the_raw.filter(.1,20, method='iir',iir_params = dict(order=8,ftype='butter',output='sos'))
            
            # events with artifact rejection
            the_events = find_events(the_raw)
            # remove first index if orphaned data found first during find_events
            if the_events[0][0] != 0:
                index_list = index_list[1:]
            
            #print("unique event codes:{}".format(np.unique(the_events[:,2])))
            the_epochs = Epochs(the_raw,
                                events = the_events, 
            #                    event_id = None,
                                tmin = -0.1, tmax = 0.7, # tmin and tmax are in s. ie. 100ms before and 800ms after
                                baseline = None, # no baseline correction necessary after bandpass
            #                    reject = {'eeg': 75e-6}, # remove amplitudes above 75uV ie. eye blinks
                                preload = True,
            #                    verbose = False,
                                picks = list(range(len(channel_names)-1))
                               )
            # pick indexes according to selected epochs
            index_list = [index_list[i] for i in the_epochs.selection]
            #print("Flash length:{}".format(len(index_list)))
            #print("Flash indexes:{}".format(index_list))
            
            #the_epochs = the_epochs.pick_types(eeg=True)
            X = the_epochs.get_data()# * 1e6
            preds = the_predictor.predict(X)
            #print("preds:{}".format(preds))
            #print("preds shape:{}".format(preds.shape))
            
            checkSpelledAnswer(preds,index_list,row_col)

device:muse2+ user:compmonks dir:debug
Creating RawArray with float64 data, n_channels=6, n_times=8359
    Range : 0 ... 8358 =      0.000 ...    32.648 secs
Ready.
Setting up band-pass filter from 0.1 - 20 Hz
Trigger channel has a non-zero initial value of 2 (consider using initial_event=True to detect this event)
Removing orphaned offset at the beginning of the file.
178 events found
Event IDs: [1 2]
178 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Loading data for 178 events and 206 original time points ...
4 bad epochs dropped
FOUND CHARACTER: O
Creating RawArray with float64 data, n_channels=6, n_times=8405
    Range : 0 ... 8404 =      0.000 ...    32.828 secs
Ready.
Setting up band-pass filter from 0.1 - 20 Hz
Trigger channel has a non-zero initial value of 2 (consider using initial_event=True to detect this event)
Removing orphaned offset at the beginning of the file.
179 events found
Event IDs: [1 2]
179 matching events