# INFO
- experiment: p300 speller
- data preprocessing is similar to p300_speller_training

This code demonstrates how to predict  a spelled word from a trained discriminator and unlabelled data.

# DEPENDENCIES

In [1]:
# BUILT-IN
import os
from itertools import starmap, product, groupby
from operator import and_

# DATAFRAMES
import pandas as pd
import numpy as np

# MNE
from mne import Epochs, find_events
from mne import create_info, concatenate_raws
from mne.io import RawArray

# SCIKIT-LEARN
from sklearn.externals import joblib

# SETTINGS

In [12]:
the_folder_path = "../data/p300_speller" # relative datasets path
the_user = "compmonks" # check available users in data folder or add new ones
the_device = "openbci_v207" # "muse2+" # "muse2" # "openbci_v207" # available devices
the_freq = 125 # 256 (Muse) # 125 (OpenBCI) # Sampling Frequency in Hertz
the_units = "uVolts" # "uVolts" # "Volts" # unit of received data from device
the_predictor_name = "compmonks_T2_XdawnCov-TS-LReg_OPENBCI-Cytonv207_2019-6-14_17-16-54-379468.pkl" # trained discriminator in the data folder
the_montage = "standard_1020" # "standard_1005" (Muse) # "standard_1020" (OpenBCI) # channels montage
break_epoch = 2500 # time in ms of break between each character spelling sessions
blank_index = 12 # index used to mark a non-flashing time
#
# list of tokens ie. flashed rows and cols
row_col = ['AGMSY5', #0
           'BHNTZ6',
           'CIOU17',
           'DJPV28',
           'EKQW39',
           'FLRX40',
           'ABCDEF',
           'GHIJKL',
           'MNOPQR',
           'STUVWX',
           'YZ1234',
           '567890'] #11
            # 12 is blank

# UTILS

In [13]:
def segmentByLengthSeq(the_list,the_val,the_min_len):
    """Return an index of beginning of specific consecutive value sequences."""
    
    out = []
    i=0
    ind_=0
    for key,group in groupby(the_list):
        ind_ = len(list(group))
        if key == the_val and ind_>=the_min_len:
            out.append((key,ind_,i))
        i+=ind_
        
    return out
#
def all_intersections(sets):
    """Convert to set of frozensets for uniquification/type correctness."""
    
    last = new = sets = set(map(frozenset, sets))
    while new:
        new = set(starmap(and_, product(last, new)))
        last = sets.copy()
        new -= last
        sets |= new
        
    return sorted(map(sorted, sets), key=lambda x: (len(x), x))
#    
def checkSpelledAnswer(answers,flashed,tokens):
    """Return a token from a list with the most labelled intersections."""
    
    intersections = dict()
    selected_tokens = []
    for i,val in enumerate(answers):
        mask = val==1
        filtered = flashed[i][mask]
        filtered = np.unique(filtered)
        if filtered.size > 0:
            selected_tokens.append(set(tokens[filtered[0]]))
    for rc in  selected_tokens:
        for _char in rc:
            intersections[_char]=intersections.setdefault(_char,0)+1
    _sorted = sorted(intersections.items(), key=lambda kv: kv[1])

    return "{}".format(_sorted[-1][0])

# RETRIEVE WORD FROM SPELLED CHARACTERS

In [25]:
the_data_path = os.path.join(the_folder_path,the_device,the_user,"testing")
the_predictor = joblib.load(os.path.join(the_data_path,the_predictor_name),"r")
the_word = ""
# loop through subdirs and concatenate data
for root, subdirs, files in os.walk(the_data_path):
    for dirs in subdirs:
        print("device:{} user:{} dir:{}".format(the_device,the_user,dirs))
        file_list = os.listdir(os.path.join(root,dirs))
        for _f in file_list:
            if len(_f.split('.')[0]) <= 1 or ".pkl" in _f:
                file_list.remove(_f)
        if len(file_list) == 2:
            if "FEEDBACK" in file_list[0]:
                dataF = pd.read_hdf(os.path.join(root,dirs,file_list[0]),'data')
                dataA = pd.read_hdf(os.path.join(root,dirs,file_list[1]),'EEG')
            else:
                dataF = pd.read_hdf(os.path.join(root,dirs,file_list[1]),'data')
                dataA = pd.read_hdf(os.path.join(root,dirs,file_list[0]),'EEG')
        # break down dataset by character session
        the_char_epoch = (the_freq * (break_epoch/1000))*0.9
        chunks_indexes = segmentByLengthSeq(dataF['index'],blank_index,the_char_epoch) # 12 is the index
        for i,_char in enumerate(chunks_indexes):
            raw = []
            start_F = dataF.index[_char[2]+_char[1]]
            stop_F =  dataF.index[chunks_indexes[i+1][2] if i+1 < len(chunks_indexes) else -1]
            start_A = dataA.index.get_loc(start_F,method='ffill')
            stop_A = dataA.index.get_loc(stop_F,method='ffill')
            chunk_dataF = dataF.iloc[dataF.index.get_loc(start_F,method='ffill'):dataF.index.get_loc(stop_F,method='ffill')][:]
            chunk_dataA = dataA.iloc[start_A:stop_A][:]
            chunk_dataA['Stim'] = np.nan
            prev_marker = 0
            prev_flash = blank_index
            index_list = []
            for index,row in chunk_dataA.iterrows():
                new_marker = chunk_dataF.iloc[chunk_dataF.index.get_loc(pd.to_datetime(index),method='nearest')]['marker'].astype(int)
                if prev_marker == 0:
                    if new_marker != 0:
                        row['Stim'] = new_marker
                        prev_marker = new_marker
                        new_flash = chunk_dataF['index'][chunk_dataF.index.get_loc(pd.to_datetime(index),method='nearest')].astype(int)
                        index_list.append(new_flash)
                    else:
                        row['Stim'] = 0
                else:
                    row['Stim'] = 0
                    prev_marker = new_marker
            channel_names = list(chunk_dataA.keys())
            channel_types = ['eeg'] * (len(list(chunk_dataA.keys()))-1) + ['stim']
            the_data = chunk_dataA.values[:].T
            the_data[:-1] *= 1e-6 if the_units == "uVolts" else 1
            info = create_info(ch_names = channel_names, 
                               ch_types = channel_types,
                               sfreq = the_freq, 
                               montage = the_montage)
            raw.append(RawArray(data = the_data, info = info))
            the_raw = concatenate_raws(raw)
            the_raw=the_raw.filter(.1,20, method='iir',iir_params = dict(order=8,ftype='butter',output='sos'))
            the_events = find_events(the_raw)
            
            # remove first index if orphaned data found first during find_events
            if the_events[0][0] != 0:
                index_list = index_list[1:]            
            the_epochs = Epochs(the_raw,
                                events = the_events, 
                                event_id = None,
                                tmin = -0.1, tmax = 0.7,
                                baseline = None,
                                reject = {'eeg': 75e-6}, # remove amplitudes above 75uV ie. eye blinks
                                preload = True,
                                verbose = False,
                                picks = list(range(len(channel_names)-1))
                               )
            # epoch averaging
            #evoked = the_epochs.average() # inplace
            
            # pick indexes according to selected epochs
            index_list = [index_list[i] for i in the_epochs.selection]
            the_epochs = the_epochs.pick_types(eeg=True)
            X = the_epochs.get_data()
            preds = the_predictor.predict(X)
            
            # relabel epoch events with predictions
            #the_epochs.events[:,-1] = preds
            
            # retrieve the spelled character
            the_word += checkSpelledAnswer(preds,index_list,row_col)
#
print("DID YOU SPELL: {} ?".format(the_word))

device:openbci_v207 user:compmonks dir:session_000
Creating RawArray with float64 data, n_channels=17, n_times=4245
    Range : 0 ... 4244 =      0.000 ...    33.952 secs
Ready.
Setting up band-pass filter from 0.1 - 20 Hz
Trigger channel has a non-zero initial value of 1 (consider using initial_event=True to detect this event)
Removing orphaned offset at the beginning of the file.
75 events found
Event IDs: [1]
the events:[[  66    0    1]
 [  91    0    1]
 [ 116    0    1]
 [ 166    0    1]
 [ 191    0    1]
 [ 296    0    1]
 [ 321    0    1]
 [ 346    0    1]
 [ 371    0    1]
 [ 426    0    1]
 [ 481    0    1]
 [ 506    0    1]
 [ 561    0    1]
 [ 616    0    1]
 [ 671    0    1]
 [ 726    0    1]
 [ 751    0    1]
 [ 806    0    1]
 [ 861    0    1]
 [ 916    0    1]
 [ 971    0    1]
 [ 996    0    1]
 [1126    0    1]
 [1146    0    1]
 [1191    0    1]
 [1241    0    1]
 [1306    0    1]
 [1356    0    1]
 [1376    0    1]
 [1446    0    1]
 [1561    0    1]
 [1581    0    

Creating RawArray with float64 data, n_channels=17, n_times=4104
    Range : 0 ... 4103 =      0.000 ...    32.824 secs
Ready.
Setting up band-pass filter from 0.1 - 20 Hz
Trigger channel has a non-zero initial value of 1 (consider using initial_event=True to detect this event)
Removing orphaned offset at the beginning of the file.
69 events found
Event IDs: [1]
the events:[[  96    0    1]
 [ 116    0    1]
 [ 186    0    1]
 [ 231    0    1]
 [ 301    0    1]
 [ 321    0    1]
 [ 391    0    1]
 [ 461    0    1]
 [ 486    0    1]
 [ 506    0    1]
 [ 621    0    1]
 [ 691    0    1]
 [ 711    0    1]
 [ 781    0    1]
 [ 806    0    1]
 [ 896    0    1]
 [ 921    0    1]
 [ 941    0    1]
 [1011    0    1]
 [1056    0    1]
 [1151    0    1]
 [1171    0    1]
 [1196    0    1]
 [1311    0    1]
 [1331    0    1]
 [1401    0    1]
 [1471    0    1]
 [1516    0    1]
 [1536    0    1]
 [1631    0    1]
 [1651    0    1]
 [1746    0    1]
 [1766    0    1]
 [1791    0    1]
 [1811    0 