In [None]:
%matplotlib qt

In [None]:
import os, pyxdf, json, yaml
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from mne.time_frequency import psd_welch, tfr_morlet, tfr_multitaper
# from multitaper_spectrogram_python import multitaper_spectrogram
from mne.decoding import Scaler, Vectorizer

from sklearn.pipeline import make_pipeline
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import RepeatedStratifiedKFold, HalvingGridSearchCV

from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import SVC
from sklearn.linear_model import SGDClassifier

### Settings

In [None]:
config_file = 'config_MI-hands.yaml'

with open(config_file) as f:
    config = yaml.load(f.read(), Loader=yaml.Loader)
    print(config)
    locals().update(config)

lslDir = os.path.join(os.path.expanduser('~'), 'Documents\CurrentStudy')

### Find LSL Files

In [None]:
# Find files
xdf_files = []
hasSubject = subject!=''
hasSession = session!=''
hasTask = task!=''
for root, dir, files in os.walk(lslDir):
    for file in files:
        validFile = True
        if hasSubject:
            validFile = validFile and (('sub-'+subject) in file)
        if hasSession:
            validFile = validFile and (('ses-S' + str(session).zfill(3)) in file)
        if hasTask:
            validFile = validFile and (('task-' + task) in file)
        validFile = validFile and file.endswith('.xdf')
        if validFile:
            print(file)
            matchingFile = os.path.join(root, file)
            xdf_files.append(matchingFile)

if len(xdf_files) == 0:
    print('No files found')

In [None]:
# Parse streams
eeg_stream, marker_stream = [], []

print('Parsing streams')
for xdf_file in xdf_files:
    streams, header = pyxdf.load_xdf(xdf_file)
    for i in range(len(streams)):
        if streams[i]['info']['type'][0] == eeg_stream_type:
            print("Found %s stream in %s" % (eeg_stream_type, os.path.basename(xdf_file)))
            eeg_stream.append(streams[i])
        elif streams[i]['info']['type'][0] == markers_stream_type:
            print("Found %s stream in %s" % (markers_stream_type, os.path.basename(xdf_file)))
            marker_stream.append(streams[i])
del streams, header

### Extract EEG and Marker data

In [None]:
# Extract EEG Info
print("Extracting EEG info...")

ch_names = []
if eeg_stream[0]['info']['desc'][0]:
    print("EEG channel names found")
    for i in range(len(eeg_stream[0]['info']['desc'][0]['channels'][0]['channel'])):
        ch_names.append(eeg_stream[0]['info']['desc'][0]['channels'][0]['channel'][i]['label'][0])
else:
    ch_names = default_ch_names
print('Channels: ', ch_names)

sfreq = float(eeg_stream[0]['info']['nominal_srate'][0])
print('Sampling frequency: ', sfreq)

# Create MNE info object
eeg_info = mne.create_info(ch_names, sfreq, ch_types='eeg')

In [None]:
# Setup Montage
montage = mne.channels.read_custom_montage(montage_file)
# montage.plot()

In [None]:
# Get all EEG data
eeg_raw_list = []

for n in range(len(eeg_stream)):
    # Create MNE Raw object
    eeg_data = np.transpose(eeg_stream[n]['time_series'])
    eeg_data = eeg_data / 1e6
    print(eeg_data.shape)
    eeg_raw = mne.io.RawArray(eeg_data, eeg_info)
    
    # Set montage
    eeg_raw = eeg_raw.set_montage(montage)

    # Add annotations
    onset, duration, description = [], [], []
    current_target = -1
    current_flash = -1
    for i in range(len(marker_stream[n]['time_series'])):
        if 'MI' in task:
            if ('rest' in marker_stream[n]['time_series'][i][0]) and ('cue' not in marker_stream[n]['time_series'][i][0]):
                window_onset = tmin
                window_start = marker_stream[n]['time_stamps'][i] - eeg_stream[n]['time_stamps'][0]
                while (window_onset + window_size <= tmax):
                    onset.append(window_start + window_onset)
                    duration.append(window_size)
                    description.append(marker_stream[n]['time_series'][i][0])
                    window_onset = window_onset + window_size - window_overlap
            elif ('task' in marker_stream[n]['time_series'][i][0]) and ('cue' not in marker_stream[n]['time_series'][i][0]):
                window_onset = tmin
                window_start = marker_stream[n]['time_stamps'][i] - eeg_stream[n]['time_stamps'][0]
                while (window_onset + window_size <= tmax):
                    onset.append(window_start + window_onset)
                    duration.append(window_size)
                    description.append(marker_stream[n]['time_series'][i][0].replace('task_', '').replace('-','/'))
                    window_onset = window_onset + window_size - window_overlap
        elif 'P300' in task:
            if('target' in marker_stream[n]['time_series'][i][0]):
                current_target = json.loads(marker_stream[n]['time_series'][i][0])['target']
            elif('flash' in marker_stream[n]['time_series'][i][0]):
                current_flash = json.loads(marker_stream[n]['time_series'][i][0])['flash']
                onset.append(marker_stream[n]['time_stamps'][i] - eeg_stream[n]['time_stamps'][0])
                duration.append(task_duration)
                description.append("target" if current_flash == current_target else "nontarget")
    annotations = mne.Annotations(onset, duration, description)
    eeg_raw = eeg_raw.set_annotations(annotations)
    
    # Create list of raw objects
    eeg_raw_list.append(eeg_raw)

In [None]:
print(eeg_raw_list[0].annotations.orig_time)

In [None]:
# Concatenate raw objects
raw = mne.concatenate_raws(eeg_raw_list)
raw

### Pre-processing

In [None]:
# Common average reference
raw_orig = raw.copy()
raw = raw.set_eeg_reference('average', projection=False)

if plotGraphs:
    fig = raw_orig.plot(title='Before Re-referencing', n_channels=16, scalings=scalings)
    fig = raw.plot(title='After Re-referencing', n_channels=16, scalings=scalings)

In [None]:
# Bandpass filter data
raw_orig = raw.copy()
raw = raw.filter(l_freq=bp_l_freq, h_freq=bp_h_freq)

if plotGraphs:
    fig = raw_orig.plot(title='Before Filtering', scalings=scalings, duration=plot_duration)
    fig = raw.plot(title='After Filtering', scalings=scalings, duration=plot_duration)

### ICA Artifact Removal

In [None]:
if performICA:
    print('Performing ICA artifact removal...')
    raw_orig = raw.copy()

    # filter data to remove slow drifts
    raw_filt = raw.copy()
    raw_filt.filter(l_freq=1., h_freq=None)

    # ICA decomposition
    ica = mne.preprocessing.ICA(n_components=16, method='fastica', max_iter=200, random_state=42, verbose=True)
    ica = ica.fit(raw_filt)

In [None]:
# Plot ICA sources
if performICA and plotGraphs:
    fig = ica.plot_sources(raw_orig)

In [None]:
# Select source that corresponds to artifact and remove it
if performICA:
    ica.exclude = [0]
    # ica.exclude = [2]
    print('ICA sources to exclude: ', ica.exclude)

In [None]:
if performICA:
    ica.apply(raw)
    if plotGraphs:
        fig = raw_orig.plot(title='Before ICA', scalings=scalings, duration=plot_duration)
        fig = raw.plot(title='After ICA', scalings=scalings, duration=plot_duration)

### Epoch data

In [None]:
# Epoch data
events, event_id = mne.events_from_annotations(raw, event_id=event_dict)
epochs = mne.Epochs(raw, events, event_id=event_id, tmin=0., tmax=window_size, baseline=None, picks='eeg', preload=True)
print(epochs)

### Features

In [None]:
# Labels
y = epochs.events[:,-1] - min(epochs.events[:,-1])

In [None]:
plt.close('all')
# Time-Domain Features
if features == 'time':
    scaler = Scaler(epochs.info)
    X = scaler.fit_transform(epochs.get_data())
    
    if ('P300' in task) and plotGraphs:
        fig = epochs['target'].average().plot_joint(times=[-0.2, 0., 0.3, 0.7], ts_args=dict(ylim=dict(eeg=[-10, 10])))
        fig = epochs['nontarget'].average().plot_joint(times=[-0.2, 0., 0.3, 0.7], ts_args=dict(ylim=dict(eeg=[-10, 10])))
#         fig = epoch['target'].average().plot_joint(times=[-0.2, 0., 0.3, 0.7], picks=['Fz','Cz','P3','Pz','P4','PO3','PO4','Oz'], ts_args=dict(ylim=dict(eeg=[-10, 10])))
#         fig = epoch['nontarget'].average().plot_joint(times=[-0.2, 0., 0.3, 0.7], picks=['Fz','Cz','P3','Pz','P4','PO3','PO4','Oz'], ts_args=dict(ylim=dict(eeg=[-10, 10])))

In [None]:
# Frequency-Domain Features
if features == 'psd':
    psds, freqs = psd_welch(epochs, average='mean', fmin=bp_l_freq, fmax=bp_h_freq, n_fft=126, n_jobs=-1)
    X = 10 * np.log10(psds)
#     X = psds / np.sum(psds, axis=-1, keepdims=True)
    
    if ('MI' in task) and plotGraphs:
#         sel_chs = [2, 3, 4, 5, 6, 7, 14, 15]
        sel_chs = range(16)
        psd_means_class_0 = np.transpose(np.mean(X[y==0], axis=0))
        psd_means_class_1 = np.transpose(np.mean(X[y==1], axis=0))
        psd_means_class_0 = psd_means_class_0[:,sel_chs]
        psd_means_class_1 = psd_means_class_1[:,sel_chs]
        
        fig = plt.figure()
        ax = fig.add_subplot(111)
        for i in range(len(sel_chs)):
            line = ax.plot(freqs, psd_means_class_0[:,i], ':', label=ch_names[sel_chs[i]] + ' Rest')
            ax.plot(freqs, psd_means_class_1[:,i], '-', label=ch_names[sel_chs[i]] + ' MI-hands', color=line[0].get_color())
        ax.set(title='Welch PSD', xlabel='Frequency (Hz)', ylabel='Power Spectral Density (dB)')
        ax.set_ylim(bottom=-135, top=-85)
        ax.legend(loc='best')

In [None]:
# Time-frequency features
if features =='tfr':
    # Parameters
    freqs = np.logspace(*np.log10([6, 35]), num=8)
#     freqs = np.linspace(2, 40, 20)
    print('TFR freqs: ', freqs)
    n_cycles = freqs / 2.
    time_bandwidth = 4.0 # param for multitaper
    
    # Compute TFR Power
    if tfr_type == 'morlet':
        power = tfr_morlet(epochs, freqs, n_cycles=n_cycles, use_fft=False, return_itc=False, average=False, n_jobs=-1)
    elif tfr_type == 'multitaper':
        power = tfr_multitaper(epochs, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, use_fft=False, return_itc=False, average=False, n_jobs=-1)
    print(power.data.shape)
    X = power.data
    
    if ('MI' in task) and plotGraphs:
        fig = power['rest'].average().plot_topo(baseline=baseline, mode='percent', cmap='jet', tmin=tmin, tmax=tmax, vmin=vmin, vmax=vmax, title='Average power (Rest)')
        fig = power['MI/hands'].average().plot_topo(baseline=baseline, mode='percent', cmap='jet', tmin=tmin, tmax=tmax, vmin=vmin, vmax=vmax, title='Average power (MI-hands)')

In [None]:
# # plt.close('all')
# baseline = (0., 0.1)
# vmin, vmax = -1, 1
# fig = power['rest'].average().plot_topo(baseline=baseline, mode='percent', tmin=tmin, tmax=tmax, vmin=vmin, vmax=vmax, title='Average power (Rest)')
# fig = power['MI/hands'].average().plot_topo(baseline=baseline, mode='percent', tmin=tmin, tmax=tmax, vmin=vmin, vmax=vmax, title='Average power (MI-hands)')

In [None]:
# # Select Channels
# print('Selecting channels from data...')
# print('Original X.shape: ', X.shape)
# X = X[:,2:, :]
# print('X.shape: ', X.shape)

In [None]:
# Vectorize features
if len(X.shape) > 2:
    print('Vectorizing features to 2D...')
    print('Original X.shape: ', X.shape)
    vec = Vectorizer()
    X = vec.fit_transform(X)

print('X.shape: ', X.shape)
print('y.shape: ', y.shape)

### Classification

In [None]:
# Set up cross validation
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=5, random_state=42)

# Set up scoring
scoring = 'accuracy'
scores = {'Classifier': [],
          'Score': [],
          'Std': []
         }

In [None]:
# Set up Classifiers
classifiers = []

# KNN
params = {}
params['n_neighbors'] = np.arange(2,11,1)
params['weights'] = ['uniform', 'distance']
clf = HalvingGridSearchCV(KNeighborsClassifier(), param_grid=params, n_jobs=-1, cv=cv, scoring=scoring)
classifiers.append(['KNN', clf, params])

# DT
params = {}
params['criterion'] = ['gini', 'entropy']
params['min_samples_split'] = np.arange(2,11,2)
clf = HalvingGridSearchCV(DecisionTreeClassifier(), param_grid=params, n_jobs=-1, cv=cv, scoring=scoring)
classifiers.append(['DT', clf, params])

# RF
params = {}
params['criterion'] = ['gini', 'entropy']
params['n_estimators'] = (10, 20, 30)
params['min_samples_split'] = np.arange(2,11,2)
clf = HalvingGridSearchCV(RandomForestClassifier(), param_grid=params, n_jobs=-1, cv=cv, scoring=scoring)
classifiers.append(['RF', clf, params])

# LDA
params = {}
params['solver'] = ['svd']
clf = HalvingGridSearchCV(LinearDiscriminantAnalysis(), param_grid=params, n_jobs=-1, cv=cv, scoring=scoring)
classifiers.append(['LDA', clf, params])

# SVM
params = {}
params['C'] = (1e-4, 1e-2, 1)
params['gamma'] = (1e-4, 1e-2, 1, 10)
params['kernel'] = ['linear', 'rbf']
clf = HalvingGridSearchCV(SVC(), param_grid=params, n_jobs=-1, cv=cv, scoring=scoring)
classifiers.append(['SVM', clf, params])

# SGD
params = {}
params['loss'] = ['hinge', 'log', 'modified_huber', 'squared_hinge', 'perceptron']
params['penalty'] = ['l2', 'l1', 'elasticnet']
params['alpha'] = (1e-4, 1e-2, 1, 10)
clf = HalvingGridSearchCV(SGDClassifier(), param_grid=params, n_jobs=-1, cv=cv, scoring=scoring)
classifiers.append(['SGD', clf, params])

In [None]:
# Train Classifiers
for c in range(len(classifiers)):
    clf_name = classifiers[c][0]
    print("Training %s..." % clf_name)
    clf = classifiers[c][1].fit(X, y)
    print('%s score: %2.2f' % (clf_name, clf.best_score_))
    print('%s std  : %2.2f' % (clf_name, np.mean(clf.cv_results_['std_test_score'])))
    print()
    scores['Classifier'].append(clf_name)
    scores['Score'].append(clf.best_score_)
    scores['Std'].append(np.mean(clf.cv_results_['std_test_score']))

In [None]:
# Score summary
df = pd.DataFrame(scores)
df

In [None]:
# Best Classifier
print('Best Classifier:')
df.loc[df['Score'].idxmax()]