In [None]:
%load_ext lab_black
import mne
import pywt
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from sklearn.model_selection import train_test_split
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from scipy import signal

from utils import base_layout, load_gonogo_responses, tmax, tmin, get_wavelet

In [None]:
signal_frequency = 256

epochs = load_gonogo_responses()
cor_all = epochs["correct_response"]._data
err_all = epochs["error_response"]._data

cor_train, cor_test = train_test_split(cor_all, test_size=0.33, random_state=0)
err_train, err_test = train_test_split(err_all, test_size=0.33, random_state=0)

In [None]:
print(epochs._data.shape)
print(epochs["correct_response"]._data.shape)
print(epochs["error_response"]._data.shape)
print(cor_train.shape)
print(cor_test.shape)
print(err_train.shape)
print(err_test.shape)

In [None]:
def cwt(epoch, mwt="mexh"):
    center_wavelet_frequency = pywt.scale2frequency(mwt, [1])[0]
    const = center_wavelet_frequency * signal_frequency

    # construct scales
    density = 10
    freqs = 2 ** (np.arange(7 * density) / density)
    scales = const / freqs

    # compute coeffs
    coef, freqs = pywt.cwt(
        data=epoch, scales=scales, wavelet=mwt, sampling_period=1 / signal_frequency
    )
    return coef

In [None]:
channels = np.arange(epochs._data.shape[1])

# computes arrays: [channel, epoch, frequency, time]
cor_cwts_allch = np.array([[cwt(epoch[ch]) for epoch in cor_train] for ch in channels])
err_cwts_allch = np.array([[cwt(epoch[ch]) for epoch in err_train] for ch in channels])

print(cor_cwts_allch.shape)
print(err_cwts_allch.shape)

In [None]:
cor_cwts_allch[:, 0, :, :].shape

In [None]:
cor_train.shape

In [None]:
# plot all channels for given epoch, and CWT for channel 0 of this epoch (blue one)
max_amp = 0.00003
conditions = list(epochs.event_id.keys())
min_cond_count = min(cor_train.shape[0], err_train.shape[0])

fig = go.FigureWidget(make_subplots(rows=2))
fig.update_layout(**base_layout)
fig.update_layout(
    xaxis_range=[tmin, tmax],
    yaxis_range=[-max_amp, max_amp],
)
for i in range(3):
    fig.add_scatter(x=epochs.times, row=1, col=1)
# TODO set axis range to -200u:200u
fig.add_heatmap(x=epochs.times, row=2, col=1)


@interact(
    epoch_num=IntSlider(value=1, min=0, max=min_cond_count - 1),
    condition=Dropdown(options=conditions),
)
def update_plots(epoch_num, condition):
    with fig.batch_update():
        epoch = epochs[condition]._data[epoch_num]
        for i in range(3):
            fig.data[i].y = epoch[i]
        # TODO use precomputed CWT
        fig.data[-1].z = cwt(epoch[0])


fig

In [None]:
# plot epoch and wavelet you can adjust
# fig2 = go.FigureWidget(layout=base_layout)
# fig2.update_layout(
#     xaxis_range=[tmin, tmax],
#     yaxis_range=[-max_amp, max_amp],
# )
# for i in range(2):
#     fig2.add_scatter(x=epochs.times)

# channel = 0


# @interact(
#     epoch_num=IntSlider(value=1, min=0, max=min_cond_count - 1),
#     condition=Dropdown(options=conditions),
#     latency=FloatSlider(value=0, min=tmin, max=tmax, step=0.005),
#     frequency=FloatSlider(value=3, min=0.1, max=40),
# )
# def update_plots(epoch_num, condition, latency, frequency):
#     with fig2.batch_update():
#         epoch = epochs[condition]._data[epoch_num]
#         fig2.data[0].y = epoch[channel]
#         fig2.data[1].y = get_wavelet(latency, frequency, epochs.times) * max_amp


# fig2

In [None]:
spatial_filter = [1, 0, 0]
cor_cwts = np.tensordot(cor_cwts_allch, spatial_filter, axes=([0], [0]))
err_cwts = np.tensordot(err_cwts_allch, spatial_filter, axes=([0], [0]))

# fmt: off
within_class_scatter = cor_cwts.var(axis=0) * len(cor_cwts) + \
                       err_cwts.var(axis=0) * len(err_cwts)
# fmt: on
all_cwts = np.append(cor_cwts, err_cwts, axis=0)
between_class_scatter = all_cwts.var(axis=0) * len(all_cwts)
separation = between_class_scatter / within_class_scatter

In [None]:
fig = go.FigureWidget(layout=base_layout)
fig.add_heatmap(z=separation, x=epochs.times)

In [None]:
fig = go.FigureWidget(layout=base_layout)
fig.update_layout(width=500, height=500)
fig.add_scatter(x=cor_cwts[:, 8, 15], y=cor_cwts[:, 7, 98], mode="markers")
fig.add_scatter(x=err_cwts[:, 8, 15], y=err_cwts[:, 7, 98], mode="markers")

In [None]:
fig = go.FigureWidget(layout=base_layout)
fig.update_layout(width=500, height=500)
fig.add_scatter(x=cor_cwts[:, 8, 15], y=cor_cwts[:, 8, 188], mode="markers")
fig.add_scatter(x=err_cwts[:, 8, 15], y=err_cwts[:, 8, 188], mode="markers")