In [None]:
%load_ext lab_black
import os
import pickle

import pywt
import mne
import scipy
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.model_selection import train_test_split
from sklearn.decomposition import FastICA

from utils import *

In [None]:
# ignore FastICA did not converge warnings
# TODO investigate why doesn't it converge
import warnings

warnings.filterwarnings("ignore")

# Load data

#### Data read into dataframe structure. Each epoch is a single record.

In [None]:
df_name = "go_nogo_df"
pickled_data_filename = "../data/" + df_name + ".pkl"
info_filename = "../data/Demographic_Questionnaires_Behavioral_Results_N=163.csv"

# Check if data is already loaded
if os.path.isfile(pickled_data_filename):
    print("Pickled file found. Loading pickled data...")
    epochs = pd.read_pickle(pickled_data_filename)
else:
    print("Pickled file not found. Loading data...")
    epochs = create_df_data(info_filename=info_filename)
    epochs.name = df_name
    # save loaded data into a pickle file
    epochs.to_pickle("../data/" + epochs.name + ".pkl")

#### Sort participants by the number of errors, descending. This way the best participants are first.

In [None]:
# add new columns with info about error/correct responses amount
grouped = epochs.groupby("id")
epochs["error_sum"] = grouped[["marker"]].transform(lambda x: (x.values == ERROR).sum())
epochs["correct_sum"] = grouped[["marker"]].transform(
    lambda x: (x.values == CORRECT).sum()
)

# mergesort for stable sorting
epochs = epochs.sort_values("error_sum", ascending=False, kind="mergesort")

#### Get metadata

In [None]:
_mne_epochs = load_epochs_from_file("../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = [ch["ch_name"] for ch in _channel_info]

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]

log_freq = np.log2(get_frequencies())  # for plotting CWT

# Explore data

In [None]:
# # display electrode locations
# x, y, z = channel_locations.T
# scalp3d = go.FigureWidget(layout=base_layout)
# scalp3d.update_layout(width=700, height=700)
# scalp3d.add_scatter3d(
#     x=x,
#     y=y,
#     z=z,
#     mode="markers+text",
#     text=channel_names,
#     marker_size=3,
#     marker_color=channel_colors,
# )

In [None]:
# # those sliders are shared across plots
# participant_slider = Dropdown(options=epochs["id"].unique())
# channel_slider = Dropdown(value="Cz", options=channel_names)

In [None]:
print("plot all channels for a given epoch, and CWT for a chosen channel of this epoch")
max_amp = 0.00005

fig = go.FigureWidget(
    make_subplots(
        rows=3,
        vertical_spacing=0.1,
        subplot_titles=("all channels, single epoch", "complex CWT", "real CWT"),
    )
)
fig.update_layout(**base_layout)
fig.update_layout(
    xaxis_range=[times[0], times[-1]],
    yaxis_range=[-max_amp, max_amp],
    #     height=840,
)
for i in range(len(channel_names)):
    fig.add_scatter(x=times, row=1, col=1)
fig.add_heatmap(x=times, row=2, col=1, zmin=0, zmax=40e-6, y=log_freq, colorscale="ice")
fig.add_heatmap(
    x=times, row=3, col=1, zmin=-100e-6, zmax=100e-6, y=log_freq, colorbar_x=1.1
)

grouped = epochs.groupby(["id"])


@interact(
    participant=Dropdown(options=epochs["id"].unique()),
    channel=Dropdown(value="Cz", options=channel_names),
    condition=Dropdown(options=[("error", ERROR), ("correct", CORRECT)]),
    epoch_num=IntSlider(min=0, max=7),
)
def update_plots(participant, channel, condition, epoch_num):
    channel = channel_names.index(channel)
    with fig.batch_update():
        df = grouped.get_group(participant)
        epoch = df.loc[df["marker"] == condition].iloc[epoch_num]["epoch"]
        for ch in range(len(channel_names)):
            fig.data[ch].y = epoch[ch]
            if ch == channel:
                fig.data[ch].line = {"width": 3, "color": channel_colors[ch]}
            else:
                fig.data[ch].line = {"width": 0.3, "color": channel_colors[ch]}
        fig.data[-2].z = cwt(epoch[channel], "cmor0.5-1")
        fig.data[-1].z = cwt(epoch[channel], "mexh")
        print(channel_names[channel])


fig

In [None]:
print("plot average ERP difference between conditions, and its CWT")
fig2 = go.FigureWidget(make_subplots(rows=2))
fig2.update_layout(**base_layout)
fig2.update_layout(
    xaxis_range=[times[0], times[-1]],
    yaxis_range=[-max_amp / 2, max_amp / 2],
)
for i in range(len(channel_names)):
    fig2.add_scatter(x=times, row=1, col=1)
fig2.add_heatmap(x=times, row=2, col=1, zmin=-50e-6, zmax=50e-6)


grouped = epochs.groupby(["id"])


@interact(
    participant=Dropdown(options=epochs["id"].unique()),
    channel=Dropdown(value="Cz", options=channel_names),
)
def update_plots(participant, channel):
    channel = channel_names.index(channel)
    with fig2.batch_update():

        df = grouped.get_group(participant)
        err = np.stack(df.loc[df["marker"] == ERROR]["epoch"].values)
        cor = np.stack(df.loc[df["marker"] == CORRECT]["epoch"].values)

        ERP_diff = cor.mean(axis=0) - err.mean(axis=0)
        for ch in range(len(channel_names)):
            fig2.data[ch].y = ERP_diff[ch]
            if ch == channel:
                fig2.data[ch].line = {"width": 3, "color": channel_colors[ch]}
            else:
                fig2.data[ch].line = {"width": 0.3, "color": channel_colors[ch]}
        fig2.data[-1].z = cwt(ERP_diff[channel])
        print(channel_names[channel])


fig2

In [None]:
print(
    "all epochs, for a chosen participant and channel, green are correct, red are errors"
)
fig3 = go.FigureWidget()
fig3.update_layout(**base_layout)
fig3.update_layout(
    xaxis_range=[times[0], times[-1]],
    yaxis_range=[-max_amp, max_amp],
    height=300,
)
for i in range(400):  # must be more than epochs for any participant
    fig3.add_scatter(x=times)

grouped = epochs.groupby(["id"])


@interact(
    participant=Dropdown(options=epochs["id"].unique()),
    channel=Dropdown(value="Cz", options=channel_names),
)
def update_plots(participant, channel):
    channel = channel_names.index(channel)
    df = grouped.get_group(participant)
    err = np.stack(df.loc[df["marker"] == ERROR]["epoch"].values)
    cor = np.stack(df.loc[df["marker"] == CORRECT]["epoch"].values)
    ERP_diff = cor.mean(axis=0) - err.mean(axis=0)
    with fig3.batch_update():
        fig3.update_traces(visible=False)
    with fig3.batch_update():
        for i, epoch in enumerate(cor):
            fig3.data[-i].y = epoch[channel]
            fig3.data[-i].line = {"color": "green", "width": 0.2}
            fig3.data[-i].visible = True
        for i, epoch in enumerate(err):
            fig3.data[i].y = epoch[channel]
            fig3.data[i].line = {"color": "red", "width": 0.2}
            fig3.data[i].visible = True
        print(channel_names[channel])


fig3

# Extract features - this section is bad

In [None]:
# compute CWT for a chosen participant
participant = participant_slider.value
print(f"participant: {participant}")

mwt = "mexh"
# bandwidth = 0.5
# mwt = f"cmor{bandwidth}-1"

err, cor = epochs[participant]
# split out test sets
err_train, err_test = train_test_split(err, test_size=0.4, random_state=0)
cor_train, cor_test = train_test_split(cor, test_size=0.4, random_state=0)

density = 2
err_cwts = np.array(
    [[cwt(ch_signal, mwt, density) for ch_signal in epoch] for epoch in err_train]
)
cor_cwts = np.array(
    [[cwt(ch_signal, mwt, density) for ch_signal in epoch] for epoch in cor_train]
)

err_cwts_test = np.array(
    [[cwt(ch_signal, mwt, density) for ch_signal in epoch] for epoch in err_test]
)
cor_cwts_test = np.array(
    [[cwt(ch_signal, mwt, density) for ch_signal in epoch] for epoch in cor_test]
)

# they are 4D numpy arrays:
# EPOCH x CHANNEL x FREQUENCY x TIMEPOINT
print(err_cwts.shape)
print(cor_cwts.shape)

In [None]:
# def reduce_over_timeslices(data, slice_size=30, ufunc=np.maximum):
#     indexes = np.arange(len(times) - slice_size)
#     slice_indexes = np.stack((indexes, indexes + slice_size), axis=-1).flatten()
#     return ufunc.reduceat(data, slice_indexes, axis=-1)[:, :, :, ::2]


# cor_cwts = reduce_over_timeslices(cor_cwts)
# err_cwts = reduce_over_timeslices(err_cwts)
# cor_cwts_test = reduce_over_timeslices(cor_cwts_test)
# err_cwts_test = reduce_over_timeslices(err_cwts_test)

In [None]:
# TODO delete
def get_best_separation(cond1, cond2, spatial_filter):
    cond1_filtered = np.tensordot(cond1, spatial_filter, axes=([1], [0]))
    cond2_filtered = np.tensordot(cond2, spatial_filter, axes=([1], [0]))
    separations = get_separations(cond1_filtered, cond2_filtered)

    best_index = np.unravel_index(separations.argmax(), separations.shape)
    return best_index, separations

In [None]:
# for each channel, check how well its CWT separates conditions
best_separation = 1
sep_for_channels = []
for i in range(len(channel_names)):
    spatial_filter = np.zeros(len(channel_names))
    spatial_filter[i] = 1

    _, separations = get_best_separation(err_cwts, cor_cwts, spatial_filter)
    best_separation = separations.max()
    sep_for_channels.append(best_separation)

x = channel_locations.T[0]
y = channel_locations.T[1]
mne.viz.plot_topomap(sep_for_channels, np.stack((x, y), axis=-1))

In [None]:
# show separation for a chosen spatial filter
cz_spatial_filter = np.zeros(len(channel_names))
for ch_name in ["Cz"]:
    # for ch_name in ["Cz", "CPz", "FCz", "C1", "CP1", "FC1", "CP3", "C3", "FC3"]:
    # for ch_name in ["Cz", "FCz", "C1", "FC1"]:
    ch_index = channel_names.index(ch_name)
    cz_spatial_filter[ch_index] = 1

#########################################
# spatial_filter = cz_spatial_filter
spatial_filter = ica_components[0]

index, separations_train = get_best_separation(err_cwts, cor_cwts, spatial_filter)
print("best index found", index)
print("separation on train set\t", separations_train.max())

# test using Cz electrode
_, separations_test = get_best_separation(err_cwts_test, cor_cwts_test, spatial_filter)
print("separation on test set\t", separations_test[index])

fig4 = go.FigureWidget(make_subplots(rows=2))
fig4.update_layout(**base_layout)
fig4.add_heatmap(z=separations_train, x=times, row=1, col=1, zmin=1, zmax=2, y=log_freq)
fig4.add_heatmap(z=separations_test, x=times, row=2, col=1, zmin=1, zmax=2, y=log_freq)

In [None]:
# apply spatial filter and get wavelet value at the given index
err_end = np.tensordot(err_cwts, spatial_filter, axes=([1], [0]))[:, index[0], index[1]]
cor_end = np.tensordot(cor_cwts, spatial_filter, axes=([1], [0]))[:, index[0], index[1]]

threshold = (err_end.mean() + cor_end.mean()) / 2
err_end -= threshold
cor_end -= threshold
fig5 = go.FigureWidget(layout=base_layout)
fig5.update_layout(height=150)
fig5.add_scatter(
    x=err_end, y=np.linspace(0, 1, len(err_end)), mode="markers", marker_color="red"
)
fig5.add_scatter(
    x=cor_end, y=np.linspace(0, 1, len(cor_end)), mode="markers", marker_color="green"
)
fig5

# TODO use raincloud plots

In [None]:
# final test
err_end = np.tensordot(err_cwts_test, spatial_filter, axes=([1], [0]))[
    :, index[0], index[1]
]
cor_end = np.tensordot(cor_cwts_test, spatial_filter, axes=([1], [0]))[
    :, index[0], index[1]
]
err_end -= threshold
cor_end -= threshold

fig6 = go.FigureWidget(layout=base_layout)
fig6.update_layout(height=150)
fig6.add_scatter(
    x=err_end, y=np.linspace(0, 1, len(err_end)), mode="markers", marker_color="red"
)
fig6.add_scatter(
    x=cor_end, y=np.linspace(0, 1, len(cor_end)), mode="markers", marker_color="green"
)
fig6

In [None]:
# t_params = scipy.stats.t.fit(err_end, fdf=len(err_end) - 1)
# distr = scipy.stats.t(*t_params).pdf

# t_params = scipy.stats.t.fit(cor_end, fdf=len(cor_end) - 1)
# distr2 = scipy.stats.t(*t_params).pdf

params = scipy.stats.norm.fit(err_end)
distr = scipy.stats.norm(*params).pdf
print(params)

params = scipy.stats.norm.fit(cor_end)
distr2 = scipy.stats.norm(*params).pdf
print(params)

# distr = scipy.stats.norm(0, 1).pdf
# distr2 = scipy.stats.t(10, 0, 1).pdf

# xs = np.linspace(-2, 2, 100)
xs = np.linspace(-0.08, 0.1, 100)
fig7 = go.FigureWidget(layout=base_layout)
fig7.update_layout(height=300)
fig7.add_scatter(x=xs, y=distr(xs))
fig7.add_scatter(x=xs, y=distr2(xs))