In [None]:
%load_ext lab_black
import mne
import os
import pywt
import numpy as np
import pickle
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.model_selection import train_test_split
from sklearn.decomposition import FastICA
from scipy import signal

from utils import base_layout, get_wavelet, load_all_epochs, load_epochs_from_file

signal_frequency = 256
ERROR = 0
CORRECT = 1

In [None]:
def get_frequencies(density=3):
    return 2 ** (np.arange(7, step=1 / density))


def cwt(epoch, mwt="mexh"):
    center_wavelet_frequency = pywt.scale2frequency(mwt, [1])[0]
    const = center_wavelet_frequency * signal_frequency

    # construct scales
    scales = const / get_frequencies()

    # compute coeffs
    coef, freqs = pywt.cwt(
        data=epoch, scales=scales, wavelet=mwt, sampling_period=1 / signal_frequency
    )
    if "cmor" in mwt:
        # if complex Morlet, change to real
        coef = np.abs(coef)
    return coef


def get_separations(cond1, cond2):
    # compute separation across given parameters
    # TODO think if within_class equation is OK or should conditions be rescaled
    # fmt: off
    within_class_scatter = cond1.var(axis=0) * len(cond1) + \
                           cond2.var(axis=0) * len(cond2)
    # fmt: on
    joined = np.append(cond1, cond2, axis=0)
    between_class_scatter = joined.var(axis=0) * len(joined)
    return between_class_scatter / within_class_scatter


def filter_(data, spatial_filter):
    return np.tensordot(data, spatial_filter, axes=([1], [0]))


def get_best_separation(cond1, cond2, spatial_filter):
    cond1_filtered = filter_(cond1, spatial_filter)
    cond2_filtered = filter_(cond2, spatial_filter)
    separations = get_separations(cond1_filtered, cond2_filtered)

    best_index = np.unravel_index(separations.argmax(), separations.shape)
    return best_index, separations

# Load stuff

In [None]:
# load epochs (only a train set)
# epochs is a 5D structure:
# PARTICIPANTS x [ERROR, CORRECT] x EPOCH X CHANNEL x TIMEPOINT

pickled_data = "../data/train_epochs.p"
if os.path.isfile(pickled_data):
    epochs = pickle.load(open(pickled_data, "rb"))
else:
    epochs = load_all_epochs()
    # pickle data loaded by MNE to save on loading times later
    pickle.dump(epochs, open(pickled_data, "wb"))

# sort participants by the number of errors, descending
# this way the best participants are first
epochs.sort(reverse=True, key=lambda e: len(e[ERROR]))

print("participants\t", len(epochs))
print("p0 error\t", epochs[0][ERROR].shape)
print("p0 correct\t", epochs[0][CORRECT].shape)

In [None]:
# get metadata
_mne_epochs = load_epochs_from_file("../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = [ch["ch_name"] for ch in _channel_info]

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]

log_freq = np.log2(get_frequencies())  # for plotting CWT

# Explore data

In [None]:
# display electrode locations
x, y, z = channel_locations.T
scalp3d = go.FigureWidget(layout=base_layout)
scalp3d.update_layout(width=700, height=700)
scalp3d.add_scatter3d(
    x=x,
    y=y,
    z=z,
    mode="markers+text",
    text=channel_names,
    marker_size=3,
    marker_color=channel_colors,
)

In [None]:
# those sliders are shared across plots
participant_slider = IntSlider(min=0, max=len(epochs))
channel_slider = Dropdown(value="Cz", options=channel_names)
# channel_slider = IntSlider(value=47, min=0, max=len(channel_names) - 1)

In [None]:
print("plot all channels for a given epoch, and CWT for a chosen channel of this epoch")
max_amp = 0.00005

fig = go.FigureWidget(
    make_subplots(
        rows=3,
        vertical_spacing=0.1,
        subplot_titles=("all channels, single epoch", "complex CWT", "real CWT"),
    )
)
fig.update_layout(**base_layout)
fig.update_layout(
    xaxis_range=[times[0], times[-1]],
    yaxis_range=[-max_amp, max_amp],
    #     height=840,
)
for i in range(len(channel_names)):
    fig.add_scatter(x=times, row=1, col=1)
fig.add_heatmap(x=times, row=2, col=1, zmin=0, zmax=40e-6, y=log_freq, colorscale="ice")
fig.add_heatmap(
    x=times, row=3, col=1, zmin=-100e-6, zmax=100e-6, y=log_freq, colorbar_x=1.1
)


@interact(
    participant=participant_slider,
    epoch_num=IntSlider(min=0, max=7),
    channel=channel_slider,
    condition=Dropdown(options=["error", "correct"]),
)
def update_plots(participant, epoch_num, channel, condition):
    channel = channel_names.index(channel)
    with fig.batch_update():
        cond_index = int(condition == "correct")
        epoch = epochs[participant][cond_index][epoch_num]
        for ch in range(len(channel_names)):
            fig.data[ch].y = epoch[ch]
            if ch == channel:
                fig.data[ch].line = {"width": 3, "color": channel_colors[ch]}
            else:
                fig.data[ch].line = {"width": 0.3, "color": channel_colors[ch]}
        fig.data[-2].z = cwt(epoch[channel], "cmor0.5-1")
        fig.data[-1].z = cwt(epoch[channel], "mexh")
        print(channel_names[channel])


fig

In [None]:
print("plot average ERP difference between conditions, and its CWT")
fig2 = go.FigureWidget(make_subplots(rows=2))
fig2.update_layout(**base_layout)
fig2.update_layout(
    xaxis_range=[times[0], times[-1]],
    yaxis_range=[-max_amp / 2, max_amp / 2],
)
for i in range(len(channel_names)):
    fig2.add_scatter(x=times, row=1, col=1)
fig2.add_heatmap(x=times, row=2, col=1, zmin=-50e-6, zmax=50e-6)


@interact(
    participant=participant_slider,
    channel=channel_slider,
)
def update_plots(participant, channel):
    channel = channel_names.index(channel)
    with fig2.batch_update():
        err, cor = epochs[participant]
        ERP_diff = cor.mean(axis=0) - err.mean(axis=0)
        for ch in range(len(channel_names)):
            fig2.data[ch].y = ERP_diff[ch]
            if ch == channel:
                fig2.data[ch].line = {"width": 3, "color": channel_colors[ch]}
            else:
                fig2.data[ch].line = {"width": 0.3, "color": channel_colors[ch]}
        fig2.data[-1].z = cwt(ERP_diff[channel])
        print(channel_names[channel])


fig2

In [None]:
print(
    "all epochs, for a chosen participant and channel, green are correct, red are errors"
)
fig3 = go.FigureWidget()
fig3.update_layout(**base_layout)
fig3.update_layout(
    xaxis_range=[times[0], times[-1]],
    yaxis_range=[-max_amp, max_amp],
    height=300,
)
for i in range(400):  # must be more than epochs for any participant
    fig3.add_scatter(x=times)


@interact(
    participant=participant_slider,
    channel=channel_slider,
)
def update_plots(participant, channel):
    channel = channel_names.index(channel)
    with fig3.batch_update():
        fig3.update_traces(visible=False)
    with fig3.batch_update():
        err, cor = epochs[participant]
        ERP_diff = cor.mean(axis=0) - err.mean(axis=0)
        for i, epoch in enumerate(cor):
            fig3.data[-i].y = epoch[channel]
            fig3.data[-i].line = {"color": "green", "width": 0.2}
            fig3.data[-i].visible = True
        for i, epoch in enumerate(err):
            fig3.data[i].y = epoch[channel]
            fig3.data[i].line = {"color": "red", "width": 0.2}
            fig3.data[i].visible = True
        print(channel_names[channel])


fig3

# Extract features

In [None]:
# compute CWT for a chosen participant
participant = participant_slider.value
print(f"participant: {participant}")

mwt = "mexh"
# bandwidth = 0.5
# mwt = f"cmor{bandwidth}-1"

err, cor = epochs[participant]
err_cwts = np.array([[cwt(ch_signal, mwt) for ch_signal in epoch] for epoch in err])
cor_cwts = np.array([[cwt(ch_signal, mwt) for ch_signal in epoch] for epoch in cor])
# they are 4D numpy arrays:
# EPOCH x CHANNEL x FREQUENCY x TIMEPOINT
print(err_cwts.shape)
print(cor_cwts.shape)

# split out test sets
err_cwts, err_cwts_test = train_test_split(err_cwts, test_size=0.4, random_state=0)
cor_cwts, cor_cwts_test = train_test_split(cor_cwts, test_size=0.4, random_state=0)

In [None]:
# def reduce_over_timeslices(data, slice_size=30, ufunc=np.maximum):
#     indexes = np.arange(len(times) - slice_size)
#     slice_indexes = np.stack((indexes, indexes + slice_size), axis=-1).flatten()
#     return ufunc.reduceat(data, slice_indexes, axis=-1)[:, :, :, ::2]


# cor_cwts = reduce_over_timeslices(cor_cwts)
# err_cwts = reduce_over_timeslices(err_cwts)
# cor_cwts_test = reduce_over_timeslices(cor_cwts_test)
# err_cwts_test = reduce_over_timeslices(err_cwts_test)

In [None]:
# for each channel, check how well its CWT separates conditions
best_separation = 1
sep_for_channels = []
for i in range(len(channel_names)):
    spatial_filter = np.zeros(len(channel_names))
    spatial_filter[i] = 1

    _, separations = get_best_separation(err_cwts, cor_cwts, spatial_filter)
    best_separation = separations.max()
    sep_for_channels.append(best_separation)

mne.viz.plot_topomap(sep_for_channels, np.stack((x, y), axis=-1))

In [None]:
# compute ICA
joined_epochs = np.concatenate((err, cor))

concat = np.concatenate(joined_epochs, axis=1)
print(concat.shape)

ica = FastICA(n_components=4)
ica.fit_transform(concat.T)
print(ica.components_.shape)

In [None]:
# find which ICA components separate best, and sort them by separation, descending
_components = []
for i, comp in enumerate(ica.components_):
    _, separations = get_best_separation(err_cwts, cor_cwts, comp)
    _components.append((separations.max(), comp))

_components.sort(reverse=True)
print([separation for separation, comp in _components])
components = [comp for separation, comp in _components]

In [None]:
# all interpolation methods in mne.viz.plot_topomap
# give strange artifacts for some reason, so use this instead
x, y, z = channel_locations.T
scalp = go.FigureWidget(layout=base_layout)
scalp.update_layout(width=400, height=400)
scalp.add_scatter(
    x=x,
    y=y,
    mode="markers+text",
    text=channel_names,
    marker_size=30,
    marker_color=-components[0],  # negate, so that red is positive
    marker_colorscale="RdBu",
)

In [None]:
# show separation for a chosen spatial filter
cz_spatial_filter = np.zeros(len(channel_names))
for ch_name in ["Cz"]:
    # for ch_name in ["Cz", "CPz", "FCz", "C1", "CP1", "FC1", "CP3", "C3", "FC3"]:
    # for ch_name in ["Cz", "FCz", "C1", "FC1"]:
    ch_index = channel_names.index(ch_name)
    cz_spatial_filter[ch_index] = 1

#########################################
# spatial_filter = cz_spatial_filter
spatial_filter = components[0]

index, separations_train = get_best_separation(err_cwts, cor_cwts, spatial_filter)
print("best index found", index)
print("separation on train set\t", separations_train.max())

# test using Cz electrode
_, separations_test = get_best_separation(err_cwts_test, cor_cwts_test, spatial_filter)
print("separation on test set\t", separations_test[index])

fig4 = go.FigureWidget(make_subplots(rows=2))
fig4.update_layout(**base_layout)
fig4.add_heatmap(z=separations_train, x=times, row=1, col=1, zmin=1, zmax=2, y=log_freq)
fig4.add_heatmap(z=separations_test, x=times, row=2, col=1, zmin=1, zmax=2, y=log_freq)

In [None]:
err_end = filter_(err_cwts, spatial_filter)[:, index[0], index[1]]
cor_end = filter_(cor_cwts, spatial_filter)[:, index[0], index[1]]

threshold = (err_end.mean() + cor_end.mean()) / 2
err_end -= threshold
cor_end -= threshold
fig5 = go.FigureWidget(layout=base_layout)
fig5.add_scatter(x=err_end, mode="markers", marker_color="red")
fig5.add_scatter(x=cor_end, mode="markers", marker_color="green")
fig5

In [None]:
# final test
err_end = filter_(err_cwts_test, spatial_filter)[:, index[0], index[1]]
cor_end = filter_(cor_cwts_test, spatial_filter)[:, index[0], index[1]]
err_end -= threshold
cor_end -= threshold

fig6 = go.FigureWidget(layout=base_layout)
fig6.add_scatter(x=err_end, mode="markers", marker_color="red")
fig6.add_scatter(x=cor_end, mode="markers", marker_color="green")
fig6