In [1]:
%matplotlib qt

# Short ERP Analysis from OpenVibe data converted to gdf
Data was converted to gdf using openvibe-convert.cmd
- First step is to rename and filter the annotations from openvibe
- translate the annotations into markers
- create epochs
- process epochs
- plot ERP
- (?) a bit of statistics

In [2]:
import matplotlib
import numpy as np
#%matplotlib ipympl
import os
import mne
import itertools

In [3]:
# remove me
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf, read_raw_brainvision
from mne.time_frequency import tfr_multitaper
from mne.stats import permutation_cluster_1samp_test as pcluster_test
from mne.viz.utils import center_cmap

In [4]:

apply_infinite_reference = True
apply_CSD = False  # use Current Source Density (spatial filter)
apply_ASR = False  # use Artifact Subspace Reconstruction (artifact removal)

In [5]:
data_dir=r"C:\BCI\dev\p300_analysis_from_openvibe"
os.path.exists(data_dir)
fnames = []
for file in os.listdir(data_dir):
    if file.endswith(".gdf"):
        fnames.append(os.path.join(data_dir, file))
        print(os.path.join(data_dir, file))

C:\BCI\dev\p300_analysis_from_openvibe\data_calib-024-2021-05-26_17_04_59.gdf


### Load gdf files

In [6]:
# load and preprocess data ####################################################
raws = [mne.io.read_raw_gdf(f, preload=True) for f in [fnames[0]]]
#raw = concatenate_raws(raws)
raw = raws[0]

Extracting EDF parameters from C:\BCI\dev\p300_analysis_from_openvibe\data_calib-024-2021-05-26_17_04_59.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 50687  =      0.000 ...   197.996 secs...


  raws = [mne.io.read_raw_gdf(f, preload=True) for f in [fnames[0]]]


In [7]:
raw.info['ch_names']

['Channel 1',
 'Channel 2',
 'Channel 3',
 'Channel 4',
 'Channel 5',
 'Channel 6',
 'Channel 7',
 'Channel 8']

### Todo: rereference

In [8]:
# define channel names
cname = ['Fz', 'Cz', 'P3', 'Pz', 'P4', 'PO7', 'PO8', 'Oz']
if len(raw.info['ch_names'])>8:
    cname.extend(["ch{}".format(i) for i in range(9,16+1)])
print(cname)
cname_map = dict(zip(raw.info['ch_names'], cname))
# define channel types
types = list(itertools.repeat('eeg', 8))
if len(raw.info['ch_names'])>8:
    types.extend(list(itertools.repeat('misc', 8)))
type_map = dict(zip(cname, types))

# rename and pick eeg
raw.rename_channels(cname_map, allow_duplicates=False)
raw.set_channel_types(type_map)
raw.pick_types(eeg=True, misc=False)


['Fz', 'Cz', 'P3', 'Pz', 'P4', 'PO7', 'PO8', 'Oz']


0,1
Measurement date,Unknown
Experimenter,Unknown
Digitized points,Not available
Good channels,"0 magnetometer, 0 gradiometer,  and 8 EEG channels"
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,256.00 Hz
Highpass,0.00 Hz
Lowpass,128.00 Hz


Set the 10-05 montage defining electrode positions. This will allow topoplots and CSD processing to be made

In [9]:
montage = mne.channels.make_standard_montage('standard_1005')
raw = raw.set_montage(montage, match_case=False) 

In [10]:
### Calculate an infinite reference

In [11]:
if apply_infinite_reference:
    raw.del_proj()  # remove our average reference projector first
    sphere = mne.make_sphere_model('auto', 'auto', raw.info)
    src = mne.setup_volume_source_space(sphere=sphere, exclude=30., pos=15.)
    forward = mne.make_forward_solution(raw.info, trans=None, src=src, bem=sphere)
    raw_rest = raw.copy().set_eeg_reference('REST', forward=forward)

    for title, _raw in zip(['Original', 'REST (∞)'], [raw, raw_rest]):
        fig = _raw.plot(n_channels=len(raw), scalings=dict(eeg=5e-5))
        # make room for title
        fig.subplots_adjust(top=0.9)
        fig.suptitle('{} reference'.format(title), size='xx-large', weight='bold')

  sphere = mne.make_sphere_model('auto', 'auto', raw.info)


Fitted sphere radius:         95.0 mm
Origin head coordinates:      0.9 5.9 51.9 mm
Origin device coordinates:    0.9 5.9 51.9 mm

Equiv. model fitting -> RV = 0.00348816 %
mu1 = 0.944717    lambda1 = 0.13715
mu2 = 0.667488    lambda2 = 0.683746
mu3 = -0.270376    lambda3 = -0.0105261
Set up EEG sphere model with scalp radius    95.0 mm

Sphere                : origin at (0.9 5.9 51.9) mm
              radius  : 85.5 mm
grid                  : 15.0 mm
mindist               : 5.0 mm
Exclude               : 30.0 mm

Setting up the sphere...
Surface CM = (   0.9    5.9   51.9) mm
Surface fits inside a sphere with radius   85.5 mm
Surface extent:
    x =  -84.6 ...   86.5 mm
    y =  -79.7 ...   91.4 mm
    z =  -33.6 ...  137.5 mm
Grid extent:
    x =  -90.0 ...   90.0 mm
    y =  -90.0 ...  105.0 mm
    z =  -45.0 ...  150.0 mm
2548 sources before omitting any.
742 sources after omitting infeasible sources not within 30.0 - 85.5 mm.
615 sources remaining after excluding the sources outsi

## Bandpass the signal
Removes noise and drift from the EEG signal by applying a infinite impulse response (two-pass) filter between .5 and 40Hz

In [12]:
raw.filter(.5, 40, fir_window='hann', method='iir')
raw.notch_filter(50)  # removes 50Hz noise

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 0.50, 40.00 Hz: -6.02, -6.02 dB

Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1691 samples (6.605 sec)



0,1
Measurement date,Unknown
Experimenter,Unknown
Digitized points,11 points
Good channels,"0 magnetometer, 0 gradiometer,  and 8 EEG channels"
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,256.00 Hz
Highpass,0.50 Hz
Lowpass,40.00 Hz


Plot the filtered signal

In [13]:
raw.plot()
apply_CSD

False

## Apply current source density

In [14]:
if apply_CSD:
    raw_csd = mne.preprocessing.compute_current_source_density(raw)
    raw = raw_csd
    raw_csd.plot()


## Artifact Subspace Reconstruction fitting and reconstruction

In [15]:
if apply_ASR:
    #!pip install meegkit pymanopt
    from meegkit.asr import ASR
    fs = int(raw.info["sfreq"])  # sampling frequency
    method='riemann'  # if error, use 'euclid' -> actually the case
    window_s=.5  # .5 sec window of analysis
    data_interval_s  = None # (begin, end) in sec of the training sample
    estimator='lwf'  #leave blank if using euclidian mode 

    # define the ASR model using riemannian method
    #asr_model = ASR(sfreq=fs, method=method, win_len=window_s, estimator=estimator)

    # if failing (after trying twice. SVD error occurs for no reason sometimes)
    asr_model = ASR(sfreq=fs, method="euclid", win_len=window_s)

    # The best would be to choose another recording during the same session to train the model without overfitting
    data = raw._data  # the numpy array with data is stored in the _data variable

    # Select a time interval for training data
    train_idx = None
    if data_interval_s is not None:
        train_idx = np.arange(data_interval_s[0] * fs, data_interval_s[1] * fs, dtype=int)
    # otherwise select the whole training set
    else:
        train_idx = np.arange(0, data.shape[1])

    train_data = data[:, train_idx]
    print('Training on samples of size {}'.format(train_data.shape))

    # fir the ASR model with data intervals
    _, sample_mask = asr_model.fit(train_data)
    print('Model trained')


### Clean the current dataset
Please check whether using this artifact filtering method increases signal to noise ratio rather than reducing it

In [16]:
if apply_ASR:
    clean =  asr_model.transform(raw._data)

    display_asr_results = True
    display_window_s = 15  # 

    if display_asr_results:  #
        data_p = raw._data[0:fs*display_window_s]  # reshape to (n_chans, n_times)
        clean_p = clean[0:fs*display_window_s]

        ###############################################################################
        # Plot the results
        # -----------------------------------------------------------------------------
        #
        # Data was trained on a 40s window from 5s to 45s onwards (gray filled area).
        # The algorithm then removes portions of this data with high amplitude
        # artifacts before running the calibration (hatched area = good).
        nb_ch_disp = 5
        times = np.arange(data_p.shape[-1]) / fs
        f, ax = plt.subplots(nb_ch_disp, sharex=True, figsize=(32, 16))
        for i in range(nb_ch_disp):
            # ax[i].fill_between(train_idx / fs, 0, 1, color='grey', alpha=.3,
            #                   transform=ax[i].get_xaxis_transform(),
            #                   label='calibration window')
            # ax[i].fill_between(train_idx / fs, 0, 1, where=sample_mask.flat,
            #                   transform=ax[i].get_xaxis_transform(),
            #                   facecolor='none', hatch='...', edgecolor='k',
            #                   label='selected window')
            ax[i].plot(times, data_p[i], lw=.5, label='before ASR')
            ax[i].plot(times, clean_p[i], label='after ASR', lw=.5)
            # ax[i].plot(times, raw[i]-clean[i], label='Diff', lw=.5)
            # ax[i].set_ylim([-50, 50])
            ax[i].set_ylabel(f'ch{i}')
            ax[i].set_yticks([])
        ax[i].set_xlabel('Time (s)')
        ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1), borderaxespad=0)
        plt.subplots_adjust(hspace=0, right=0.75)
        plt.suptitle('Before/after ASR')
        plt.show()
    raw.data_ = clean

Dont forget to use the cleaned data

### Convert text annotations (i.e. unprocessed events) into events

Lets have a look at the annotations

In [17]:
import pprint
print(raw.annotations.to_data_frame())
df = raw.annotations.to_data_frame()
print('Displaying all annotations')
annot_codes = [np.int64(n) for n in np.unique(df['description'])]
annot_codes

                          onset  duration description
0    1970-01-01 00:00:43.785156  0.003906           0
1    1970-01-01 00:00:43.789062  0.003906       32773
2    1970-01-01 00:00:43.789062  0.003906       33038
3    1970-01-01 00:00:43.789062  0.003906       33286
4    1970-01-01 00:00:43.847656  0.003906       32779
...                         ...       ...         ...
2807 1970-01-01 00:03:15.339844  0.003906       32780
2808 1970-01-01 00:03:15.339844  0.003906       33028
2809 1970-01-01 00:03:15.339844  0.003906       33286
2810 1970-01-01 00:03:15.402344  0.003906       32779
2811 1970-01-01 00:03:15.535156  0.003906       32780

[2812 rows x 3 columns]
Displaying all annotations


[0,
 32773,
 32774,
 32779,
 32780,
 33025,
 33026,
 33027,
 33028,
 33029,
 33030,
 33031,
 33032,
 33033,
 33034,
 33035,
 33036,
 33037,
 33038,
 33285,
 33286]

These annotations seem to relate to hex codes. OpenViBE definitions can be found on [OpenViBE's website](http://openvibe.inria.fr/stimulation-codes/). Let's parse the copypasted list

In [18]:
import re
tr_sim= ''
pat_extract= re.compile('^([^ ]+)[ ]+0x[0-9A-Fa-f]+[ \/]+([0-9]+)')
#OVTK_GDF_125_Watt                                     0x585       //  1413
k_stim = []
k_stim_int = []
v_stim = []

# read and convert annotations
with open(r'.\ov_stims.txt', 'r') as fd:
    for line in fd.readlines():
        m = pat_extract.match(line)
        v, k = m.groups()
        k_stim.append(k)
        k_stim_int.append(int(k))
        v_stim.append(v)

# format dict and list
stim_map = dict(zip(k_stim_int, v_stim))
stim_map_inv = dict(zip(v_stim, k_stim))

stim_tup = list(zip(k_stim_int, v_stim))

Make a dataframe of the stimuli in common between both

In [19]:
import pandas as pd
df = pd.DataFrame.from_dict(stim_tup)
df.columns = ['coden', 'desc']
df[[c in annot_codes for c in df.coden]]

Unnamed: 0,coden,desc
125,33025,OVTK_StimulationId_Label_01
126,33026,OVTK_StimulationId_Label_02
127,33027,OVTK_StimulationId_Label_03
128,33028,OVTK_StimulationId_Label_04
129,33029,OVTK_StimulationId_Label_05
130,33030,OVTK_StimulationId_Label_06
131,33031,OVTK_StimulationId_Label_07
132,33032,OVTK_StimulationId_Label_08
133,33033,OVTK_StimulationId_Label_09
134,33034,OVTK_StimulationId_Label_0A


From this table, lets locate and save the codes for Target and Non-Target and give them the following values: target=1 and non-target=0 

In [20]:
target_map = {'33286':0, '33285':1}

Then we can convert annotations into events

In [21]:
events, _ = mne.events_from_annotations(raw, event_id=target_map)
print("Found {} events".format(len(events[:])))

Used Annotations descriptions: ['33285', '33286']
Found 701 events


### Choose the channels to analyze

In [22]:
raw_backup = raw.copy()

In [23]:
# pick all channels
picks = mne.pick_channels(raw.info["ch_names"], include=[])
picks
raw.plot_sensors(show_names=True)
fig = raw.plot_sensors('3d')

### Create epochs for each class

Check for duplicates

In [24]:
from collections import Counter
events[:, 0]
a = np.array(events[:, 0])
dups = [item for item, count in Counter(a).items() if count > 1]
if dups:
    print("WARNING: Duplicate found at sample(s) {}".format(dups))



In [25]:
event_ids = dict(NonTarget=0, Target=1) 
epochs = mne.Epochs(raw, events, event_id=event_ids, tmin=-0.5, tmax=0.6, event_repeated='drop', picks = ['eeg'],
                    preload=True)
fig = epochs.plot()

Not setting metadata
Multiple event values for single event times found. Keeping the first occurrence and dropping all others.
Not setting metadata
700 matching events found
Setting baseline interval to [-0.5, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Loading data for 700 events and 283 original time points ...
0 bad epochs dropped


### todo: reject some events


In [26]:
if False:
    reject_criteria = dict(eeg=100e-6,  # 100 µV
                       eog=200e-6)  # 200 µV
    _ = epochs.drop_bad(reject=reject_criteria)
    epochs.plot_drop_log()

### Average the epochs of each class

In [27]:
l_nt = epochs['NonTarget'].average()
l_target = epochs['Target'].average()

In [28]:
fig, ax = plt.subplots(2, 1)
fig1 = l_target.plot(spatial_colors=True, axes=ax[0])
fig2 = l_nt.plot(spatial_colors=True, axes=ax[1])
# Add title
fig.suptitle("Target(top) - Non-Target(bottom)")
# Fix font spacing
plt.subplots_adjust(hspace=0.5)

In [29]:
spec_kw = dict(width_ratios=[1,1,1,.15], wspace=0.5,
               hspace=0.5,height_ratios=[1,1])
                         #hspace=0.5, height_ratios=[1, 2])

fig, ax = plt.subplots(2, 4, gridspec_kw=spec_kw)
l_target.plot_topomap(times=[-0.2, 0.1, 0.4], average=0.05, axes=ax[0,:])
l_nt.plot_topomap(times=[-0.2, 0.1, 0.4], average=0.05, axes=ax[1,:])
fig.suptitle("Target(top) - Non-Target(bottom)")
plt.subplots_adjust(hspace=0.5)

In [30]:
l_target.plot_joint()
plt.gcf().canvas.set_window_title('Target joint plot')
l_nt.plot_joint()
plt.gcf().canvas.set_window_title('Non-Target joint plot')

No projector specified for this dataset. Please consider the method self.add_proj.
No projector specified for this dataset. Please consider the method self.add_proj.


### Compare conditions

In [31]:

evokeds = dict(NonTarget=list(epochs['NonTarget'].iter_evoked()), 
               Target=list(epochs['Target'].iter_evoked()))
#picks = [f'eeg{n}' for n in range(10, 15)]
mne.viz.plot_compare_evokeds(evokeds, picks=picks, combine='mean')

More than 6 channels, truncating title ...
combining channels using "mean"
combining channels using "mean"


[<Figure size 800x600 with 1 Axes>]

In [32]:
nb_chans = epochs['Target']._data.shape[1]
splt_width = int(np.ceil(np.sqrt(1.0*nb_chans+1)))  # adding an extra plot with all channels combined at the end
fig, ax = plt.subplots(splt_width,splt_width)

evokeds = dict(NonTarget=list(epochs['NonTarget'].iter_evoked()), 
               Target=list(epochs['Target'].iter_evoked()))
#picks = [f'eeg{n}' for n in range(10, 15)]

shape_epochs = epochs['Target']._data.shape
for ch_idx in range(nb_chans):
    print('plotting channel {}'.format(ch_idx+1))
    mne.viz.plot_compare_evokeds(evokeds,picks=[epochs.info['ch_names'][ch_idx]],
                                 legend=False,
                                 axes=ax[ch_idx//splt_width, ch_idx%splt_width],)
    plt.show(block=False)
    plt.subplots_adjust(hspace=0.5, wspace=.5)
    plt.pause(.1)
print('plotting averaged channels')
mne.viz.plot_compare_evokeds(evokeds, picks=picks, combine='mean',
                             legend=True,
                             axes=ax[-1,-1])

plt.subplots_adjust(hspace=0.5, wspace=.5)

plotting channel 1
plotting channel 2
plotting channel 3
plotting channel 4
plotting channel 5
plotting channel 6
plotting channel 7
plotting channel 8
plotting averaged channels
More than 6 channels, truncating title ...
combining channels using "mean"
combining channels using "mean"


In [33]:
epochs['Target']

0,1
Number of events,100
Events,Target: 100
Time range,-0.500 – 0.602 sec
Baseline,-0.500 – 0.000 sec


In [34]:
epochs['Target']._data.shape

(100, 8, 283)

### Display single epochs

In [35]:
epochs['Target'].plot_image(combine='mean')
plt.gcf().canvas.set_window_title('Target')
epochs['NonTarget'].plot_image(combine='mean')
plt.gcf().canvas.set_window_title('Non-Target')

Not setting metadata
Not setting metadata
100 matching events found
No baseline correction applied
0 projection items activated
0 bad epochs dropped
combining channels using "mean"
Not setting metadata
Not setting metadata
600 matching events found
No baseline correction applied
0 projection items activated
0 bad epochs dropped
combining channels using "mean"


In [36]:
epochs.info

0,1
Measurement date,Unknown
Experimenter,Unknown
Digitized points,11 points
Good channels,"0 magnetometer, 0 gradiometer,  and 8 EEG channels"
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,256.00 Hz
Highpass,0.50 Hz
Lowpass,40.00 Hz


In [37]:
mne.channels.find_layout(epochs.info, ch_type='eeg')

<Layout | EEG - Channels: Fz, Cz, P3 ...>

### Check for trials that should be rejected

In [38]:
reject_criteria = dict(eeg=150e-6)       # 150 µV

for ch_type, title in dict(eeg='EEG').items():
    layout = mne.channels.find_layout(epochs.info, ch_type=ch_type)
    epochs['Target'].plot_topo_image(layout=layout, fig_facecolor='w',
                                            font_color='k', title=title+'Target Trial x time amplitude')
    epochs['NonTarget'].plot_topo_image(layout=layout, fig_facecolor='w',
                                            font_color='k', title=title+'Non-Target Trial x time amplitude')
epochs.drop_bad(reject=reject_criteria)

0 bad epochs dropped


0,1
Number of events,700
Events,NonTarget: 600 Target: 100
Time range,-0.500 – 0.602 sec
Baseline,-0.500 – 0.000 sec


### Difference Target vs Non-Target

In [39]:
nb_chans = diff_vis._data.shape[1]
splt_width = int(np.ceil(np.sqrt(1.0*nb_chans+1)))  # adding an extra plot with all channels combined at the end
fig, ax = plt.subplots(splt_width,splt_width)

evokeds = dict(NonTarget=list(epochs['NonTarget'].iter_evoked()), 
               Target=list(epochs['Target'].iter_evoked()))
#picks = [f'eeg{n}' for n in range(10, 15)]

shape_epochs = epochs['Target']._data.shape
for ch_idx in range(nb_chans):
    print('plotting channel {}'.format(ch_idx+1))
    mne.viz.plot_compare_evokeds(evokeds,picks=[epochs.info['ch_names'][ch_idx]],
                                 legend=False,
                                 axes=ax[ch_idx//splt_width, ch_idx%splt_width],)
    plt.show(block=False)
    plt.subplots_adjust(hspace=0.5, wspace=.5)
    plt.pause(.1)
print('plotting averaged channels')
mne.viz.plot_compare_evokeds(evokeds, picks=picks, combine='mean',
                             legend=True,
                             axes=ax[-1,-1])

plt.subplots_adjust(hspace=0.5, wspace=.5)

NameError: name 'diff_vis' is not defined

In [None]:
epochs_balanced, _ = epochs.equalize_event_counts(event_ids=event_ids, method='truncate')
epochs_balanced

In [None]:
l_nt_balanced = epochs_balanced['NonTarget'].average()
l_target_balanced = epochs_balanced['Target'].average()
diff_vis = mne.combine_evoked([l_target_balanced, l_nt_balanced], weights=[1, -1])
diff_vis.plot_joint()

### Grand average (disabled)

In [None]:
if False:
    # Grand average of all signal (useless here, but would be good for averaging all participants or runs or clusters of runs)
    grand_average = mne.grand_average([epochs['Target'].average(), epochs['NonTarget'].average()])
    print(grand_average)
    grand_average.evoked()

### LDA

In [79]:
from sklearn import metrics
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import KFold

In [150]:
if False:
    #kwargs = dict(n_permutations=100, step_down_p=0.05, seed=1,
    #              buffer_size=None, out_type='mask')
    mne.stats.permutation_t_test()
    

    
X = epochs._data[:,1,:]  # input data at CZ (TODO:flatten all electrodes)
X = epochs._data[:,:,:]  # input data at CZ (TODO:flatten all electrodes)
y = epochs.events[:,2]  # ground truth

# remove the information 
    
#mne.stats.permutation_t_test()
#X.shape
X = np.moveaxis(X,1,-1)
X = X.reshape([X.shape[0],X.shape[1]*X.shape[2]])

Make K-folds

In [195]:
clf = LinearDiscriminantAnalysis(solver='lsqr',shrinkage='auto')
kf = KFold(n_splits=5)
kf.get_n_splits(X)

5

In [197]:
accuracy = []
transformed = []
for train_index, test_index in kf.split(X):
    #print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    clf.fit(X_train, y_train)
    kscore = clf.score(X_test,y_test)
    print(kscore)
    accuracy.append(kscore)
    #transformed.append(clf.transform(X_test))
print('Average accuracy {}-Fold = {}'.format(kf.get_n_splits(X), np.mean(accuracy)))

0.9071428571428571
0.9142857142857143
0.9142857142857143
0.95
0.9142857142857143
Average accuracy 5-Fold = 0.9200000000000002


In [190]:
kf_array = np.array(transformed)
kf_array.shape
#kf_array = kf_array.mean(axis=0)
kf_array = kf_array.reshape(kf_array.size,1)
kf_array.shape
y.shape

(700,)

In [191]:


target_names = ['NonTarget', 'Target']
colors = ['navy', 'turquoise', 'darkorange']
lw = 2

plt.figure()
for color, i, target_name in zip(colors, [0, 1, 2], target_names):
    plt.scatter(kf_array[y == i, 0], kf_array[y == i, 0], alpha=.8, color=color,
                label=target_name)
plt.legend(loc='best', shadow=False, scatterpoints=1)
plt.title('LDA dataset')

Text(0.5, 1.0, 'LDA dataset')

In [96]:
X_r[y == i, 1]

IndexError: index 1 is out of bounds for axis 1 with size 1

LinearDiscriminantAnalysis()

In [None]:
print(epochs['NonTarget']._data.shape)
m_nt = epochs['NonTarget']._data[:,1,...]
m_t = epochs['Target']._data[:,1,...]

Define the r_squared calculation function adapted for handling MNE data

In [231]:
# From https://github.com/bbci/wyrm/blob/master/wyrm/processing.py
# Bastian Venthur for wyrm
# Code initially from Benjamin Blankertz for bbci (Matlab)

def calculate_signed_r_square_mne(epochs, classes=[0,1], classaxis=0, **kwargs):
    """Calculate the signed r**2 values.
    This method calculates the signed r**2 values over the epochs of the
    ``dat``.
    Parameters
    ----------
    epochs : MNE epoched data
    classes: list, optional 
        (either int index or str for the class name of the epoch))
    classaxis : int, optional
        the dimension containing epochs
    Returns
    -------
    signed_r_square : ndarray
        the signed r**2 values, signed_r_square has one axis less than
        the ``dat`` parameter, the ``classaxis`` has been removed
    Examples
    --------
    >>> dat.data.shape
    (400, 100, 64)
    >>> r = calculate_signed_r_square(dat)
    >>> r.shape
    (100, 64)
    """
    # TODO: explain the algorithm in the docstring and add a reference
    # to a paper.
    # select class 0 and 1
    # TODO: make class 0, 1 variables
    fv1 = epochs[classes[0]]._data
    fv2 = epochs[classes[1]]._data
    # number of epochs per class
    l1 = epochs[classes[0]]._data.shape[classaxis]
    l2 = epochs[classes[1]]._data.shape[classaxis]
    # calculate r-value (Benjamin approved!)
    a = (fv1.mean(axis=classaxis) - fv2.mean(axis=classaxis)) * np.sqrt(l1 * l2)
    b = epochs._data.std(axis=classaxis) * (l1 + l2)
    r = a / b
    # return signed r**2
    return np.sign(r) * np.square(r)


Apply signed r square function

In [232]:
# display using 
rsq = calculate_signed_r_square_mne(epochs, classes=['Target','NonTarget'])
#rsq.shape
#plt.imshow(rsq, cmap='Blues')

In [233]:
#!pip install seaborn

In [246]:
import seaborn as sns
hm = sns.heatmap(rsq,linewidths=0,cmap="coolwarm").set(title='Signed r-square maps Target vs Non-Target', xlabel='Time (samples)')