In [None]:
%load_ext lab_black
import mne
import os
import pywt
import numpy as np
import pickle
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.model_selection import train_test_split
from scipy import signal

from utils import base_layout, get_wavelet, load_all_epochs, load_epochs_from_file

signal_frequency = 256
ERROR = 0
CORRECT = 1

In [None]:
# load epochs (only a train set)
# epochs is a 5D structure:
# PARTICIPANTS x [ERROR, CORRECT] x EPOCH X CHANNEL x TIMEPOINT

pickled_data = "../data/train_epochs.p"
if os.path.isfile(pickled_data):
    epochs = pickle.load(open(pickled_data, "rb"))
else:
    epochs = load_all_epochs()
    # pickle data loaded by MNE to save on loading times later
    pickle.dump(epochs, open(pickled_data, "wb"))

print("participants\t", len(epochs))
print("p1 error\t", epochs[0][ERROR].shape)
print("p2 correct\t", epochs[0][CORRECT].shape)

In [None]:
# get metadata
_mne_epochs = load_epochs_from_file("../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = [ch["ch_name"] for ch in _channel_info]

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]

In [None]:
# display electrode locations
x, y, z = channel_locations.T
scalp3d = go.FigureWidget(layout=base_layout)
scalp3d.update_layout(width=700, height=700)
scalp3d.add_scatter3d(
    x=x,
    y=y,
    z=z,
    mode="markers+text",
    text=channel_names,
    marker_size=3,
    marker_color=channel_colors,
)

In [None]:
def cwt(epoch, mwt="mexh"):
    center_wavelet_frequency = pywt.scale2frequency(mwt, [1])[0]
    const = center_wavelet_frequency * signal_frequency

    # construct scales
    density = 3
    freqs = 2 ** (np.arange(7, step=1 / density))
    scales = const / freqs

    # compute coeffs
    coef, freqs = pywt.cwt(
        data=epoch, scales=scales, wavelet=mwt, sampling_period=1 / signal_frequency
    )
    return coef

In [None]:
# plot all channels for a given epoch, and CWT for a chosen channel of this epoch
max_amp = 0.00005

fig = go.FigureWidget(make_subplots(rows=2))
fig.update_layout(**base_layout)
fig.update_layout(
    xaxis_range=[times[0], times[-1]],
    yaxis_range=[-max_amp, max_amp],
)
for i in range(len(channel_names)):
    fig.add_scatter(x=times, row=1, col=1, line_width=0.3, line_color=channel_colors[i])
fig.add_heatmap(x=times, row=2, col=1, zmin=-100e-6, zmax=100e-6)


@interact(
    participant=IntSlider(min=0, max=len(epochs)),
    epoch_num=IntSlider(min=0, max=7),
    channel=IntSlider(min=0, max=len(channel_names) - 1),
    condition=Dropdown(options=["error", "correct"]),
)
def update_plots(participant, epoch_num, channel, condition):
    with fig.batch_update():
        cond_index = int(condition == "correct")
        epoch = epochs[participant][cond_index][epoch_num]
        for ch in range(len(channel_names)):
            fig.data[ch].y = epoch[ch]
            fig.data[ch].line = {"width": 0.3}
        fig.data[channel].line = {"width": 3}
        fig.data[-1].z = cwt(epoch[channel])


fig

In [None]:
# plot average ERP difference between conditions, and its CWT
fig = go.FigureWidget(make_subplots(rows=2))
fig.update_layout(**base_layout)
fig.update_layout(
    xaxis_range=[times[0], times[-1]],
    yaxis_range=[-max_amp / 2, max_amp / 2],
)
# fig.add_scatter(x=times, row=1, col=1)
for i in range(len(channel_names)):
    fig.add_scatter(x=times, row=1, col=1, line_width=0.3, line_color=channel_colors[i])
fig.add_heatmap(x=times, row=2, col=1, zmin=-50e-6, zmax=50e-6)


@interact(
    participant=IntSlider(min=0, max=len(epochs)),
    channel=IntSlider(min=0, max=len(channel_names) - 1),
)
def update_plots(participant, channel):
    with fig.batch_update():
        err, cor = epochs[participant]
        ERP_diff = cor.mean(axis=0) - err.mean(axis=0)
        for ch in range(len(channel_names)):
            fig.data[ch].y = ERP_diff[ch]
            fig.data[ch].line = {"width": 0.3}
        fig.data[channel].line = {"width": 3}
        fig.data[-1].z = cwt(ERP_diff[channel])


fig

In [None]:
def get_separations(cond1, cond2):
    # compute separation across given parameters
    # TODO think if within_class equation is OK or should conditions be rescaled
    # fmt: off
    within_class_scatter = cond1.var(axis=0) * len(cond1) + \
                           cond2.var(axis=0) * len(cond2)
    # fmt: on
    joined = np.append(cond1, cond2, axis=0)
    between_class_scatter = joined.var(axis=0) * len(joined)
    return between_class_scatter / within_class_scatter


def filter(data, spatial_filter):
    return np.tensordot(data, spatial_filter, axes=([1], [0]))


def get_best_separation(cond1, cond2, spatial_filter):
    cond1_filtered = filter(cond1, spatial_filter)
    cond2_filtered = filter(cond2, spatial_filter)
    separations = get_separations(cond1_filtered, cond2_filtered)

    best_index = np.unravel_index(separations.argmax(), separations.shape)
    return best_index, separations

In [None]:
# np.array([len(p[0]) for p in epochs]).argmax()
participant = 38
err, cor = epochs[participant]
print(err.shape)
print(cor.shape)

In [None]:
# compute CWT for a chosen participant
participant = 11
err, cor = epochs[participant]

err_cwts = np.array([[cwt(ch_signal) for ch_signal in epoch] for epoch in err])
cor_cwts = np.array([[cwt(ch_signal) for ch_signal in epoch] for epoch in cor])
# they are 4D numpy arrays:
# EPOCH x CHANNEL x FREQUENCY x TIMEPOINT
print(err_cwts.shape)
print(cor_cwts.shape)

# split out test sets
err_cwts, err_cwts_test = train_test_split(err_cwts, test_size=0.3, random_state=0)
cor_cwts, cor_cwts_test = train_test_split(cor_cwts, test_size=0.3, random_state=0)

In [None]:
# show separation for Cz electrode
electrode_index = channel_names.index("Cz")
spatial_filter = np.zeros(len(channel_names))
spatial_filter[electrode_index] = 1

index, separations_train = get_best_separation(err_cwts, cor_cwts, spatial_filter)
print("use only Cz")
print("best index found", index)
print("separation on train set\t", separations_train.max())

# test using Cz electrode
_, separations_test = get_best_separation(err_cwts_test, cor_cwts_test, spatial_filter)
print("separation on test set\t", separations_test[index[0]][index[1]])

fig = go.FigureWidget(make_subplots(rows=2))
fig.update_layout(**base_layout)
fig.add_heatmap(z=separations_train, x=times, row=1, col=1, zmin=1, zmax=2)
fig.add_heatmap(z=separations_test, x=times, row=2, col=1, zmin=1, zmax=2)

In [None]:
# random spatial filter search XD
best_separation = 1
best_filter = None
for i in range(300):
    spatial_filter = np.random.randn(len(channel_names))

    _, separations = get_best_separation(err_cwts, cor_cwts, spatial_filter)
    if separations.max() > best_separation:
        best_separation = separations.max()
        print(i, best_separation)
        best_filter = spatial_filter
print(best_filter)

In [None]:
# show separation for Cz electrode
index, separations_train = get_best_separation(err_cwts, cor_cwts, best_filter)
print("use randomly found spatial filter")
print("best index found", index)
print("separation on train set\t", separations_train.max())

# test using Cz electrode
_, separations_test = get_best_separation(err_cwts_test, cor_cwts_test, best_filter)
print("separation on test set\t", separations_test[index[0]][index[1]])

fig = go.FigureWidget(make_subplots(rows=2))
fig.update_layout(**base_layout)
fig.add_heatmap(z=separations_train, x=times, row=1, col=1, zmin=1, zmax=2)
fig.add_heatmap(z=separations_test, x=times, row=2, col=1, zmin=1, zmax=2)

In [None]:
err_end = filter(err_cwts, best_filter)[:, index[0], index[1]]
cor_end = filter(cor_cwts, best_filter)[:, index[0], index[1]]

threshold = (err_end.mean() + cor_end.mean()) / 2
err_end -= threshold
cor_end -= threshold
fig = go.FigureWidget(layout=base_layout)
fig.add_scatter(x=err_end, mode="markers")
fig.add_scatter(x=cor_end, mode="markers")
fig

In [None]:
# final test
err_end = filter(err_cwts_test, best_filter)[:, index[0], index[1]]
cor_end = filter(cor_cwts_test, best_filter)[:, index[0], index[1]]
err_end -= threshold
cor_end -= threshold

fig = go.FigureWidget(layout=base_layout)
fig.add_scatter(x=err_end, mode="markers")
fig.add_scatter(x=cor_end, mode="markers")
fig