In [None]:
%load_ext lab_black
import matplotlib
import matplotlib.pyplot as plt
import mne
import numpy as np
from common import load_gonogo_responses, tmax, tmin

# plt.style.use("dark_background")

In [None]:
epochs = load_gonogo_responses()

In [None]:
epochs.plot(
    n_epochs=1,
    event_colors={0: "g", 1: "m"},
)
None

In [None]:
correct_response_epochs = epochs["correct_response"]
error_response_epochs = epochs["error_response"]


# Calculate averages of events sets
correct_response_evoked = correct_response_epochs.average()
error_response_evoked = error_response_epochs.average()

In [None]:
# Averages of two event sets

mne.viz.plot_compare_evokeds(
    dict(
        correct_response=correct_response_evoked, error_response=error_response_evoked
    ),
    legend="upper left",
    show_sensors="upper right",
    ylim=dict(eeg=[-10, 10]),
    invert_y=True,
    combine="mean",
)

In [None]:
# Averages of error response events per channel

error_response_evoked.plot_joint(picks="eeg")
error_response_evoked.plot_topomap(times=[0.0, 0.08, 0.1, 0.12, 0.2], ch_type="eeg")
None

In [None]:
# Averages of merged event sets (diff between error and correct) per channel

evoked_diff = mne.combine_evoked(
    [correct_response_evoked, error_response_evoked], weights=[1, -1]
)
evoked_diff.plot_joint()
None

In [None]:
events_mean_dict = {}

for key in epochs.event_id.keys():
    mean_key = key + "_mean"
    events_mean_dict[mean_key] = epochs[key]._data.mean(axis=(0))

In [None]:
# Chart with averages of correct and error responses per channel

colors = ["b", "r", "g"]
color_iterator = 0

plt.figure(figsize=(10, 10))


for key in events_mean_dict:
    epoch = events_mean_dict[key]
    plt.plot(
        epoch.T + np.arange(start=1e-6, step=10e-6, stop=301e-7),
        label=key,
        color=colors[color_iterator],
    )
    color_iterator = color_iterator + 1

plt.yticks([])
plt.xticks(np.arange(0, 181, 181 / 8), np.arange(0, 800, 100))
plt.xlabel("milliseconds", fontsize=15)
plt.ylabel("channels", fontsize=15)
plt.legend(loc="upper left")
plt.show()

## Playground

### Uwagi - małe przemyślenia

- wydaje mi się, że do wygodnego korzystania playground musi pokazywać wszystkie kanały (ew. z opcją wyłączania)
- pogrubiona linia ze średnią jest hardkodowana i trzeba ja zmieniac za każdym razem - i jest to średnia z konkretnego kanału (jednego), podczas gdy widget *Scalp* pozwala na wybranie kilku kanałów i pojedyncze ERP są już uśredniane wzg. kilku kanałów - troche to chyba nie ma sensu, taki misz-masz
- fajniej by chyba było gdyby na wykresie były wyświetlane wszystkie kanały,i na każdy kanał wyświetlana by była średnia z tego kanału jako 'pogrubione' odniesienie i zwykły ERP, plus ew jakies inne uśrednienia

In [None]:
import mne
import numpy as np
import plotly.graph_objects as go
from common import band_pass, base_layout, extract_erp, mask

# import tensorflow as tf
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, interact
from mne.datasets import sample
from scipy import signal

In [None]:
channel_locations = np.array([ch["loc"][:3] for ch in epochs.info["chs"]])
x, y, z = channel_locations.T

scalp3d = go.FigureWidget(layout=base_layout)
scalp3d.update_layout(width=300, height=300)
scalp3d.add_scatter3d(
    x=x,
    y=y,
    z=z,
    mode="markers+text",
    text=list(range(len(x))),
    marker_size=4,
    hoverinfo="skip",
)
# scalp3d.add_scatter3d(x=x, y=y, z=z, mode='markers')

In [None]:
scalp = go.FigureWidget(layout=base_layout)
scalp.update_layout(width=300, height=300)
scalp.add_scatter(x=x, y=y, mode="text", hoverinfo="skip", text=list(range(len(x))))
scalp.data[0].selectedpoints = (0,)
scalp
# select some points by dragging on the plot

In [None]:
channel_num = 1
max_amp = 0.000008

sampling_freq = 1 / (epochs.times[1] - epochs.times[0])

# weights = np.ones((1, 60)) / 60
# dist(clean, weights @ epoch0)

conditions = list(epochs.event_id.keys())
max_cond_count = max(epochs[cond].events.__len__() for cond in conditions)

fig = go.FigureWidget(layout=base_layout)
fig.update_layout(
    xaxis_range=[tmin, tmax],
    yaxis_range=[-max_amp, max_amp],
)
fig.add_scatter(
    x=epochs.times,
    hoverinfo="skip",
    mode="lines+markers",
    marker_opacity=0,
    line_color="red",
    line_width=4,
)
for _ in range(max_cond_count):
    fig.add_scatter(x=epochs.times, hoverinfo="skip", line_width=1, opacity=0)
window = FloatRangeSlider(
    value=[tmin, tmax],
    min=tmin,
    max=tmax,
)

# setting window by dragging:
# def set_range(trace, points, selector):
#     window.value = selector.xrange
# fig.data[0].on_selection(set_range)

cond_selector = Dropdown(options=conditions)


@interact(condition=cond_selector)
def set_condition(condition):
    with fig.batch_update():
        all_epochs = epochs[condition]._data
        grand_average = all_epochs.mean(axis=0)[channel_num]
        for trace in fig.data[1:]:
            trace.opacity = 0
        fig.data[0].y = grand_average


# TODO max_cond_count should be changed accoring to the condition used
# ensure, set_condition is executed first, when changing condition
@interact(
    epoch_num=IntSlider(value=1, min=0, max=max_cond_count - 1),
    band_pass_range=FloatRangeSlider(value=[0.1, 20], min=0.1, max=50),
    window=window,
    condition=cond_selector,
    many_epochs=False,
)
def update_plots(epoch_num, band_pass_range, window, condition, many_epochs):
    with fig.batch_update():
        selected_chs = list(scalp.data[0].selectedpoints)
        all_epochs = epochs[condition]._data
        # dist_total = 0
        for i, epoch in enumerate(all_epochs):
            i += 1  # i=0 is grand_average
            if i <= epoch_num and many_epochs or i == epoch_num and not many_epochs:
                filtered = extract_erp(
                    epoch, selected_chs, band_pass_range, sampling_freq, window
                )
                fig.data[i].y = filtered
                fig.data[i].opacity = 1 / (epoch_num ** (1 / 3))
                # dist_total += dist(grand_average, filtered)
            else:
                fig.data[i].opacity = 0
        # print(f' distance: {dist_total:9.0f}')


fig