In [None]:
%load_ext lab_black
import mne
import os
import pywt
import numpy as np
import pickle
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from scipy import signal

from utils import base_layout, tmax, tmin, get_wavelet, load_all_epochs

signal_frequency = 256
ERROR = 0
CORRECT = 1

In [None]:
# load epochs (only a train set)
# epochs is a 5D structure:
# PARTICIPANTS x [ERROR, CORRECT] x EPOCH X CHANNEL x TIMEPOINT

pickled_data = "../data/train_epochs.p"
if os.path.isfile(pickled_data):
    epochs = pickle.load(open(pickled_data, "rb"))
else:
    epochs = load_all_epochs()
    # pickle data loaded by MNE to save on loading times later
    pickle.dump(epochs, open(pickled_data, "wb"))

In [None]:
print("participants:", len(epochs))
print(epochs[0][ERROR].shape)
print(epochs[0][CORRECT].shape)

In [None]:
def cwt(epoch, mwt="mexh"):
    center_wavelet_frequency = pywt.scale2frequency(mwt, [1])[0]
    const = center_wavelet_frequency * signal_frequency

    # construct scales
    density = 5
    freqs = 2 ** (np.arange(7, step=1 / density))
    scales = const / freqs

    # compute coeffs
    coef, freqs = pywt.cwt(
        data=epoch, scales=scales, wavelet=mwt, sampling_period=1 / signal_frequency
    )
    return coef

In [None]:
channels = np.arange(epochs[0][ERROR].shape[1])
participant = 0

err_cwts = np.array(
    [[cwt(epoch[ch]) for ch in channels] for epoch in epochs[participant][ERROR]]
)
cor_cwts = np.array(
    [[cwt(epoch[ch]) for ch in channels] for epoch in epochs[participant][CORRECT]]
)

# they are 4D numpy arrays:
# EPOCH x CHANNEL x FREQUENCY x TIMEPOINT
print(err_cwts.shape)
print(cor_cwts.shape)

In [None]:
# plot all channels for given epoch, and CWT for channel 0 of this epoch (blue one)
max_amp = 0.00003

fig = go.FigureWidget(make_subplots(rows=2))
fig.update_layout(**base_layout)
fig.update_layout(
    xaxis_range=[tmin, tmax],
    yaxis_range=[-max_amp, max_amp],
)
for i in range(3):
    fig.add_scatter(x=epochs.times, row=1, col=1)
fig.add_heatmap(x=epochs.times, row=2, col=1, zmin=-100e-6, zmax=100e-6)


@interact(
    participant=IntSlider(value=1, min=0, max=len(epochs)),
    epoch_num=IntSlider(value=1, min=0, max=30),
    condition=Dropdown(options=["error", "correct"]),
)
def update_plots(participant, epoch_num, condition):
    with fig.batch_update():
        cond_index = int(condition == "correct")
        epoch = epochs[participant][cond_index][epoch_num]
        for i in range(3):
            fig.data[i].y = epoch[i]
        fig.data[-1].z = cwt(epoch[0])


fig

In [None]:
# plot average ERP difference between conditions, and its CWT
channel = 0
ERP_diff = cor_train.mean(axis=0) - err_train.mean(axis=0)

fig = go.FigureWidget(make_subplots(rows=2))
fig.update_layout(**base_layout)
# fig.update_layout(height=800)
fig.update_layout(
    xaxis_range=[tmin, tmax],
    yaxis_range=[-max_amp, max_amp],
)
fig.add_scatter(y=ERP_diff[channel], x=epochs.times, row=1, col=1)
fig.add_heatmap(
    z=cwt(ERP_diff[channel]), x=epochs.times, row=2, col=1, zmin=-50e-6, zmax=50e-6
)

In [None]:
# plot epoch and wavelet you can adjust
# fig2 = go.FigureWidget(layout=base_layout)
# fig2.update_layout(
#     xaxis_range=[tmin, tmax],
#     yaxis_range=[-max_amp, max_amp],
# )
# for i in range(2):
#     fig2.add_scatter(x=epochs.times)

# channel = 0


# @interact(
#     epoch_num=IntSlider(value=1, min=0, max=min_cond_count - 1),
#     condition=Dropdown(options=conditions),
#     latency=FloatSlider(value=0, min=tmin, max=tmax, step=0.005),
#     frequency=FloatSlider(value=3, min=0.1, max=40),
# )
# def update_plots(epoch_num, condition, latency, frequency):
#     with fig2.batch_update():
#         epoch = epochs[condition]._data[epoch_num]
#         fig2.data[0].y = epoch[channel]
#         fig2.data[1].y = get_wavelet(latency, frequency, epochs.times) * max_amp


# fig2

In [None]:
def get_separation(cond1, cond2):
    # compute separation across given parameters
    # fmt: off
    within_class_scatter = cond1.var(axis=0) * len(cond1) + \
                           cond2.var(axis=0) * len(cond2)
    # fmt: on
    joined = np.append(cond1, cond2, axis=0)
    between_class_scatter = joined.var(axis=0) * len(joined)
    return between_class_scatter / within_class_scatter


def get_separation_filtered(cor_cwts, err_cwts, spatial_filter):
    # apply spatial filter
    cor_cwts_filtered = np.tensordot(cor_cwts, spatial_filter, axes=([1], [0]))
    err_cwts_filtered = np.tensordot(err_cwts, spatial_filter, axes=([1], [0]))
    return get_separation(cor_cwts_filtered, err_cwts_filtered)


spatial_filter = [1, 0, 0]
separation = get_separation_filtered(cor_cwts, err_cwts, spatial_filter)

print(np.max(separation))

fig = go.FigureWidget(layout=base_layout)
fig.add_heatmap(z=separation, x=epochs.times)

In [None]:
# for i in range(100):
#     spatial_filter = np.random.randn(3)
#     separation = get_separation_filtered(cor_cwts, err_cwts, spatial_filter)
#     print(np.max(separation), spatial_filter)

In [None]:
cor_cwts_filtered = np.tensordot(cor_cwts, spatial_filter, axes=([1], [0]))
err_cwts_filtered = np.tensordot(err_cwts, spatial_filter, axes=([1], [0]))

fig = go.FigureWidget(layout=base_layout)
fig.update_layout(width=500, height=500)
fig.add_scatter(
    x=cor_cwts_filtered[:, 8, 15], y=cor_cwts_filtered[:, 7, 98], mode="markers"
)
fig.add_scatter(
    x=err_cwts_filtered[:, 8, 15], y=err_cwts_filtered[:, 7, 98], mode="markers"
)