In [None]:
%load_ext lab_black
import os
import math
import pickle
import inspect
import itertools
from time import time
from copy import deepcopy

import pywt
import mne
import scipy
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import xxhash
import matplotlib
import matplotlib.cm as cm
from cachier import cachier
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
from sklearn.decomposition import FastICA
from sklearn.linear_model import ElasticNet
from sklearn.linear_model import Lasso
from sklearn.linear_model import LinearRegression
from ipywidgets import HBox, VBox
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact

import sys

sys.path.append("..")

from utils import *
from architecture import *

In [None]:
np.set_printoptions(precision=3, suppress=True)

# ignore FastICA did not converge warnings
# TODO investigate why doesn't it converge
import warnings

warnings.filterwarnings("ignore")

# Load data

In [None]:
df_name = "go_nogo_df"
pickled_data_filename = "../../data/" + df_name + ".pkl"
info_filename = "../../data/Demographic_Questionnaires_Behavioral_Results_N=163.csv"

# Check if data is already loaded
if os.path.isfile(pickled_data_filename):
    print("Pickled file found. Loading pickled data...")
    epochs = pd.read_pickle(pickled_data_filename)
else:
    print("Pickled file not found. Loading data...")
    epochs = create_df_data(info_filename=info_filename)
    epochs.name = df_name
    # save loaded data into a pickle file
    epochs.to_pickle(pickled_data_filename)

# epochs

# add new columns with info about error/correct responses amount
grouped = epochs.groupby("id")
epochs["error_sum"] = grouped[["marker"]].transform(lambda x: (x.values == ERROR).sum())
epochs["correct_sum"] = grouped[["marker"]].transform(
    lambda x: (x.values == CORRECT).sum()
)

# mergesort for stable sorting
epochs = epochs.sort_values("error_sum", ascending=False, kind="mergesort")
# epochs

_mne_epochs = load_epochs_from_file("../../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]

log_freq = np.log2(get_frequencies())  # for plotting CWT

In [None]:
with open("../../public_data/regression_PCA_error.pkl", "rb") as file:
    models_data = pickle.load(file)

# Load saved model

In [None]:
model_data = models_data.loc[
    (models_data["pipeline_name"] == "PCA_15_bins") & (models_data["model"] == "en")
]

model = model_data["best_estimator"].iloc[0]

In [None]:
significant_channels = model["channels_filtering"].channel_list
# persisted model numerates channels from 1, not 0
channel_index_list = np.array(significant_channels) - 1

In [None]:
# load channel data
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = np.array([ch["ch_name"] for ch in _channel_info])

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]
channel_colors = np.array(channel_colors)

# trim it to only 15 top electrodes
channel_names = channel_names[channel_index_list]
channel_locations = channel_locations[channel_index_list]
channel_colors = channel_colors[channel_index_list]

# Visualize spatial components

In [None]:
def plot_ica_comp(ica_comp, scale=1):
    x_loc, y_loc, z_loc = channel_locations.T

    fig = go.FigureWidget()
    fig.update_layout(**base_layout)
    fig.update_layout(width=350 * scale, height=350 * scale)

    # sort by z_loc for prettier printing
    info = list(zip(z_loc, x_loc, y_loc, channel_names, ica_comp))
    info.sort()
    _, _x_loc, _y_loc, _channel_names, _component = zip(*info)

    amp = max(np.abs(_component))

    fig.add_scatter(
        x=_x_loc,
        y=_y_loc,
        #         text=_channel_names,
        text=_component,
        marker_color=_component,
        mode="markers",
        marker_size=42 * scale,
        marker_colorscale=blue_black_red,
        marker_cmax=amp,
        marker_cmin=-amp,
    )
    return fig


def plot_erps_after_spatial_filter(
    spatial_filter, epochs, individual, plot_limit, erp_type, max_amp, scale=1
):
    if len(spatial_filter) != len(_channel_info):
        # the channels are cropped
        # converse spatial_filter back to original 64 channels
        # with zeroes for cropped channels
        sparse_spatial_filt = np.zeros(len(_channel_info))
        for index, coef in zip(channel_index_list, spatial_filter):
            sparse_spatial_filt[index] = coef
        spatial_filter = sparse_spatial_filt

    fig = go.FigureWidget()
    fig.update_layout(**base_layout)
    fig.update_layout(
        height=350 * scale,
        width=600 * scale,
        xaxis_range=[times[0], times[-1]],
        yaxis_range=[-max_amp, max_amp],
    )

    if individual:
        grouped = epochs.groupby(["id"])
        for participant_id in epochs["id"].unique()[:plot_limit]:
            df = grouped.get_group(participant_id)

            err = np.stack(df.loc[df["marker"] == ERROR]["epoch"].values)
            cor = np.stack(df.loc[df["marker"] == CORRECT]["epoch"].values)
            all_ = np.stack(df["epoch"].values)
            cor_mean = cor.mean(axis=0)
            err_mean = err.mean(axis=0)
            all_mean = all_.mean(axis=0)
            err_erp = np.tensordot(err_mean, spatial_filter, axes=([0], [0]))
            cor_erp = np.tensordot(cor_mean, spatial_filter, axes=([0], [0]))
            all_erp = np.tensordot(all_mean, spatial_filter, axes=([0], [0]))
            dif_erp = cor_erp - err_erp

            if erp_type == "correct":
                fig.add_scatter(x=times, y=cor_erp)
            elif erp_type == "error":
                fig.add_scatter(x=times, y=err_erp)
            elif erp_type == "all":
                fig.add_scatter(x=times, y=all_erp)
            elif erp_type == "difference":
                fig.add_scatter(x=times, y=dif_erp)
            else:
                raise ValueError("bad argument for erp_type")
            fig.update_traces(line_width=1)

    else:
        err = np.stack(epochs.loc[epochs["marker"] == ERROR]["epoch"].values)
        cor = np.stack(epochs.loc[epochs["marker"] == CORRECT]["epoch"].values)
        cor_mean = cor.mean(axis=0)
        err_mean = err.mean(axis=0)
        err_erp = np.tensordot(err_mean, spatial_filter, axes=([0], [0]))
        cor_erp = np.tensordot(cor_mean, spatial_filter, axes=([0], [0]))

        fig.add_scatter(x=times, y=err_erp, line_color="red")
        fig.add_scatter(x=times, y=cor_erp, line_color="green")

    return fig


def visualize_spatial_components(
    pipeline,
    epochs,
    individual=False,
    plot_limit=200,
    erp_type=None,
    max_amp=0.0001,
    scale=1,
    flip_mask=None,
):
    """
    pipeline
        sklearn Pipeline to be visualized
    epochs
        dataframe with all the epochs
    individual
        whether to plot ERP for each person instead of averaged across everyone
    plot_limit
        how many people to plot - too high makes the plot hard to read
        only has effect if individual==True
    erp_type
        only has effect if individual==True
        possible values:
            'correct'     plot the average of all correct epochs for each person
            'error'       plot the average of all error epochs for each person
            'all'         plot the average of all the epochs for each person
            'difference'  plot the average of correct epochs minus average of error epochs for each person
    max_amp
        maximum amplitude for component plotting
    scale
        scale of the plots
    flip_mask
        optional array of 1s and -1s
        its length must be the same as the number of spatial components
        setting -1 for a corresponding spatial component flips its sign for better readability
    """
    fitted_steps = dict(pipeline.steps)
    spatial = fitted_steps["spatial_filter"]

    for spatial_comp_num, spatial_comp in enumerate(spatial.components_):
        if flip_mask is not None:
            spatial_comp = spatial_comp * flip_mask[spatial_comp_num]
        display(
            HBox(
                [
                    plot_ica_comp(spatial_comp, scale=scale),
                    plot_erps_after_spatial_filter(
                        spatial_comp,
                        epochs,
                        individual=individual,
                        plot_limit=plot_limit,
                        erp_type=erp_type,
                        max_amp=max_amp,
                        scale=scale,
                    ),
                ]
            )
        )

In [None]:
model["en"].coef_

In [None]:
model

In [None]:
# for the first 10 participants, plot the difference of their
# average of correct epochs minus average of error epochs
# passed through the spatial filters

visualize_spatial_components(
    model,
    epochs,
    individual=True,
    plot_limit=10,
    erp_type="difference",
    max_amp=0.00014,
    flip_mask=[-1, 1, 1, 1, 1],
)

# Visualize components

In [None]:
def plot_pca_shape(pca_comps, mwt, clf_coefs, xs, max_amp, scale=1, heatmap=False):
    fig = go.FigureWidget()
    fig.update_layout(**base_layout)
    fig.update_layout(height=350 * scale, width=600 * scale)
    if not heatmap:
        fig.update_layout(yaxis_range=[-max_amp, max_amp])
    accs = []
    for i, comp in enumerate(pca_comps):
        if mwt is not None:
            # CWT+PCA in practice, multiplies by this shape
            # TODO test this block, if this is really the shape
            comp = comp.reshape(-1, timepoints_count)
            acc = np.zeros_like(xs)
            for amps_for_freq, freq in zip(comp, get_frequencies()):
                for amp, latency in zip(amps_for_freq, xs):
                    wv = get_wavelet(latency, freq, xs, mwt)
                    acc += wv * amp
        else:
            acc = np.copy(comp)

        # weight by the component importance from classifier
        acc *= clf_coefs[i]
        if not heatmap:
            fig.add_scatter(x=xs, y=acc)
        accs.append(acc)

    # show also the sum of all pca comps weighted by importance
    acc = np.zeros_like(accs[-1])
    for comp, coef in zip(pca_comps, clf_coefs):
        acc += comp * coef
    accs.append(acc)

    if not heatmap:
        fig.add_scatter(x=xs, y=acc, line_width=5, line_color="yellow")
    else:
        # reverse, so that later components are on the bottom
        accs = np.array(accs[::-1])
        fig.add_heatmap(
            x=xs, z=accs, zmin=-max_amp, zmax=max_amp, colorscale=blue_black_red
        )

    return fig


def visualize_pipeline(
    pipeline,
    clf_coefs_all=None,
    max_amp=0.018,
    scale=1,
    heatmap=False,
    one_pca=False,
    flip_mask=None,
):
    """
    pipeline
        sklearn Pipeline to be visualized
    clf_coefs_ll
        optional classifier coefficients to be used to weigh component plots
        if left as None, assume that lasso classifier was used, and use its coefs
    max_amp
        maximum amplitude for component plotting
    scale
        scale of the plots
    heatmap
        whether to use a heatmap instead of overlayed plots for each component
    one_pca
        whether only one PCA if fitted instead of a separate PCA for each spatial component
    flip_mask
        optional array of 1s and -1s
        its length must be the same as the number of spatial components
        setting -1 for a corresponding spatial component flips its sign for better readability
    """
    if heatmap:
        print("the component on the bottom is the sum of all the above components")

    fitted_steps = dict(pipeline.steps)
    spatial = fitted_steps["spatial_filter"]

    if "pca" in fitted_steps:
        dims_reduction = fitted_steps["pca"]
    elif "feature_selection" in fitted_steps:
        dims_reduction = fitted_steps["feature_selection"]

    if clf_coefs_all is None:
        if "lasso" in fitted_steps:
            clf_coefs_all = fitted_steps["lasso"].coef_
        elif "en" in fitted_steps:
            clf_coefs_all = fitted_steps["en"].coef_
        elif "lda" in fitted_steps:
            clf_coefs_all = fitted_steps["lda"].coef_[0]

    if "binning" in fitted_steps:
        bin_step = fitted_steps["binning"].step
        xs = times[bin_step // 2 :: bin_step]
    else:
        xs = times

    if one_pca:
        pca_comps_separated = dims_reduction.components_.reshape(
            len(dims_reduction.components_), len(spatial.components_), -1
        )
    else:
        pcas = dims_reduction.PCAs
        clf_coefs_for_each_ica_comp = clf_coefs_all.reshape(
            len(spatial.components_), -1
        )
        # the line below was tested visually to be the wrong unflattening
        # clf_coefs_for_each_ica_comp = clf_coefs_all.reshape(-1, len(ica.components_)).T

    for ica_comp_num, ica_comp in enumerate(spatial.components_):
        if one_pca:
            pca_comps = pca_comps_separated[:, ica_comp_num, :]
            clf_coefs = clf_coefs_all
        else:
            pca_comps = pcas[ica_comp_num].components_
            clf_coefs = clf_coefs_for_each_ica_comp[ica_comp_num]

        if "cwt" in fitted_steps:
            mwt = fitted_steps["cwt"].mwt
        else:
            mwt = None

        if flip_mask is not None:
            ica_comp = ica_comp * flip_mask[ica_comp_num]
            pca_comps = pca_comps * flip_mask[ica_comp_num]

        display(
            HBox(
                [
                    plot_ica_comp(ica_comp, scale=scale),
                    plot_pca_shape(
                        pca_comps,
                        mwt,
                        clf_coefs,
                        xs=xs,
                        max_amp=max_amp,
                        scale=scale,
                        heatmap=heatmap,
                    ),
                ]
            )
        )
        # display(HBox([plot_ica_comp(ica_comp), plot_pca_comps_on_cwt(pca_comps)]))

In [None]:
# for each of spatial filters, plot what shape the PCA components are trying to match
# yellow line shows the sum of all the shapes - its the shape that the whole model
# tries to match for that spatial filter

visualize_pipeline(
    model, heatmap=False, one_pca=True, flip_mask=[-1, 1, 1, 1, 1], max_amp=0.04
)

In [None]:
def visualize_pipeline_but_focus_on_pca_comps(
    pipeline,
    clf_coefs_all=None,
    max_amp=0.015,
    scale=1,
):
    """
    Note: works only with common PCA for all channels and no CWT

    pipeline
        sklearn Pipeline to be visualized
    clf_coefs_ll
        optional classifier coefficients to be used to weigh component plots
        if left as None, assume that lasso classifier was used, and use its coefs
    max_amp
        maximum amplitude for component plotting
    scale
        scale of the plots
    """
    fitted_steps = dict(pipeline.steps)
    spatial = fitted_steps["spatial_filter"]
    pca = fitted_steps["pca"]

    if clf_coefs_all is None:
        # clf_coefs_all = fitted_steps["lda"].coef_[0]
        clf_coefs_all = fitted_steps["lasso"].coef_

    for pca_comp, clf_coef in zip(pca.components_, clf_coefs_all):
        pca_comp_2d = pca_comp.reshape(len(spatial.components_), -1)

        fig = go.FigureWidget()
        fig.update_layout(**base_layout)
        fig.update_layout(height=350 * scale, width=600 * scale)
        fig.add_heatmap(
            x=times,
            z=pca_comp_2d * clf_coef,
            zmin=-max_amp,
            zmax=max_amp,
            colorscale=blue_black_red,
        )
        display(fig)

In [None]:
# for each PCA component, plot a heatmap with its components
# with each row corresponding to one spatial filter

split_index = 0
visualize_pipeline_but_focus_on_pca_comps(pipelines[split_index])

In [None]:
# # check that PCA+LDA can be replaced by just a dot product with a shape computed from weighted PCA comps

# split_index = 0
# fitted_steps = dict(pipelines[split_index].steps)
# pcas = fitted_steps["pca"].PCAs
# ica_comp_num = 0
# pca_comps = pcas[ica_comp_num].components_
# lda = fitted_steps["lda"]

# # do ICA steps
# X_ = pipelines[0].steps[0][1].transform(X)
# X_ = pipelines[0].steps[1][1].transform(X_)
# X_ = pipelines[0].steps[2][1].transform(X_)

# # compute a shape from the weighted PCA comps
# acc = np.zeros_like(times)
# for comp, coef in zip(pca_comps, lda.coef_[0]):
#     acc += comp * coef

# features = []
# for epoch in X_[0]:
#     feature = np.sum(acc * epoch)
#     features.append(feature)
# features = np.array(features).reshape(-1)

# real_features = pipelines[split_index].decision_function(X)

# corr = np.corrcoef(features, real_features)[0][1]
# np.isclose(corr, 1)
# # if True, the dot product gives the same as the normal pipeline execution

# Test ICA stability

In [None]:
# cv should be 2 !
# otherwise the components will be stable, but trivially
#  - they are trained on overlapping data, so no wonder they are similar

In [None]:
def correlations(a0, a1):
    """Find correlation matrix between 2 matrices.
    It's similar to np.corrcoef, but it doesn't subtract the mean,
    when calculating the sum of squares.

    Parameters
    ----------
    a0, a1 : array_like
        2-D arrays containing multiple variables and observations.
        Each row represents a variable, and each column a single
        observation of all those variables.
        Their number of columns must be equal.
    """
    cov = a0 @ a1.T
    sum_of_squares0 = np.sum(a0 * a0, axis=1).reshape(-1, 1)
    sum_of_squares1 = np.sum(a1 * a1, axis=1).reshape(1, -1)
    return cov / (sum_of_squares0 @ sum_of_squares1) ** (1 / 2)


def factor_similarity(a0, a1):
    """Measure how similar are the factors.
    Reordering and rescaling them doesn't change the similarity.
    """
    corr = correlations(a0, a1)
    sim = abs(corr)  # don't care if factors' sign is flipped
    sim_hor = sim.max(axis=0)  # don't care if factors are reordered
    sim_ver = sim.max(axis=1)
    # in case some row or comuln have two candidates, choose the more pessimistic axis
    mean_sim = min(sim_hor.mean(), sim_ver.mean())
    # TODO? a more robust way would be to generate permutations and chack them
    return mean_sim

In [None]:
spatial_filters = [pipeline.steps[1][1].components_ for pipeline in pipelines]

In [None]:
print("correlations between factors found in the first, and the second split")
correlations(spatial_filters[0], spatial_filters[1])

In [None]:
factor_similarity(spatial_filters[0], spatial_filters[1])

In [None]:
# print(
#     "similarity measures between factors found in each pair of splits, for a single participant"
# )
# similarities = np.array(
#     [
#         [factor_similarity(sf_i, sf_j) for sf_i in spatial_filters]
#         for sf_j in spatial_filters
#     ]
# )
# print(similarities)
# print("mean", similarities.mean())

In [None]:
# # mne plotting for comparison
# x, y, z = channel_locations.T
# mne.viz.plot_topomap(
#     spatial_filters[participant][2], np.stack((x, y), axis=-1)
# )

In [None]:
# # try to find corresponding components
# best_similarity = 0
# for perm in itertools.permutations(range(3)):
#     perm = list(perm)
#     diag = corr[perm].diagonal()
#     similarity = abs(diag).mean()
#     if similarity > best_similarity:
#         best_similarity = similarity
#         best_perm = perm

# print(best_similarity)
# print(best_perm)
# corr[best_perm]

# Visualize personal differences

In [None]:
# split participants in half - one for training of common model, one for validation

grouped = epochs.groupby(["id"])
ids = epochs["id"].unique()

kf = KFold(n_splits=2)
first_split = list(kf.split(ids))[0]
train_index, test_index = first_split
train_ids = ids[train_index]
test_ids = ids[test_index]

train_epochs = epochs[epochs["id"].isin(train_ids)]
test_epochs = epochs[epochs["id"].isin(test_ids)]

X_train = np.array(train_epochs["epoch"].to_list())
y_train = np.array(train_epochs["marker"].to_list())
X_test = np.array(test_epochs["epoch"].to_list())
y_test = np.array(test_epochs["marker"].to_list())


pipeline = Pipeline(deepcopy(steps), memory=cachedir)
pipeline.set_params(**ParameterGrid(regressor_params)[0])
pipeline.fit(X_train, y_train)

In [None]:
# scores
if type(steps[-1][1]) == LinearDiscriminantAnalysis:
    y_pred = pipeline.predict_proba(X_test)[:, 1]
else:
    y_pred = pipeline.predict(X_test)

auroc = roc_auc_score(y_test, y_pred)
corr = np.corrcoef(y_test, y_pred)[0][1]
r2 = r2_score(y_test, y_pred)
auroc, corr, r2

In [None]:
# scores on train set
if type(steps[-1][1]) == LinearDiscriminantAnalysis:
    y_pred = pipeline.predict_proba(X_train)[:, 1]
else:
    y_pred = pipeline.predict(X_train)

auroc = roc_auc_score(y_train, y_pred)
corr = np.corrcoef(y_train, y_pred)[0][1]
r2 = r2_score(y_train, y_pred)
auroc, corr, r2

In [None]:
visualize_pipeline(pipeline, one_pca=True, flip_mask=[-1, 1, 1, 1])

In [None]:
truncated_pipeline = Pipeline(pipeline.steps[:-1])

In [None]:
# adapted from https://github.com/eriklindernoren/ML-From-Scratch/blob/master/mlfromscratch/supervised_learning/regression.py
class l_half_regularization:
    """ Regularization for Ridge Regression """

    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, w):
        return self.alpha * np.sum((np.abs(w) + 0.00001) ** (1 / 2))
        # return 0  # to see olny fit error

    def grad(self, w):
        return self.alpha * 1 / 2 / ((np.abs(w) + 0.00001) ** (1 / 2)) * np.sign(w)


class Regression(object):
    """Base regression model. Models the relationship between a scalar dependent variable y and the independent
    variables X.
    Parameters:
    -----------
    n_iterations: float
        The number of training iterations the algorithm will tune the weights for.
    learning_rate: float
        The step length that will be used when updating the weights.
    """

    def __init__(self, n_iterations, learning_rate):
        self.n_iterations = n_iterations
        self.learning_rate = learning_rate

    def initialize_weights(self, n_features):
        """ Initialize weights randomly [-1/N, 1/N] """
        limit = 1 / math.sqrt(n_features)
        self.w = np.random.uniform(-limit, limit, (n_features,))

    def fit(self, X, y, reinit=True):
        # Insert constant ones for bias weights
        X = np.insert(X, 0, 1, axis=1)
        if reinit:
            self.training_errors = []
            self.initialize_weights(n_features=X.shape[1])

        # Do gradient descent for n_iterations
        for i in range(self.n_iterations):
            y_pred = X.dot(self.w)
            # Calculate l2 loss
            mse = np.mean(0.5 * (y - y_pred) ** 2 + self.regularization(self.w))
            self.training_errors.append(mse)
            # Gradient of l2 loss w.r.t w
            grad_w = -(y - y_pred).dot(X) + self.regularization.grad(self.w)
            # Update the weights
            self.w -= self.learning_rate * grad_w

    def predict(self, X):
        # Insert constant ones for bias weights
        X = np.insert(X, 0, 1, axis=1)
        y_pred = X.dot(self.w)
        return y_pred


class LHalfRegression(Regression):
    def __init__(self, reg_factor, n_iterations=1000, learning_rate=0.001):
        self.regularization = l_half_regularization(alpha=reg_factor)
        super(LHalfRegression, self).__init__(n_iterations, learning_rate)

In [None]:
# prepare data
features_train = truncated_pipeline.transform(X_train)
features_test = truncated_pipeline.transform(X_test)

In [None]:
# initialize regressor
reg = LHalfRegression(0, n_iterations=6000, learning_rate=0.0000001)
reg.fit(features_train, y_train)

In [None]:
# manually retrain with different alphas
# suppested consecutive alphas: 0, 3, 10, 30, 10
reg.regularization.alpha = 10
reg.fit(features_train, y_train, reinit=False)

In [None]:
reg.w

In [None]:
# scores
y_pred = reg.predict(features_test)

auroc = roc_auc_score(y_test, y_pred)
corr = np.corrcoef(y_test, y_pred)[0][1]
r2 = r2_score(y_test, y_pred)

auroc, corr, r2

In [None]:
# scores on train set
y_pred = reg.predict(features_train)

auroc = roc_auc_score(y_train, y_pred)
corr = np.corrcoef(y_train, y_pred)[0][1]
r2 = r2_score(y_train, y_pred)
auroc, corr, r2

In [None]:
px.scatter(reg.training_errors)

In [None]:
visualize_pipeline(
    truncated_pipeline, one_pca=True, flip_mask=[-1, 1, 1, 1], clf_coefs_all=reg.w[1:]
)

In [None]:
indices = np.where(np.abs(reg.w[1:]) > 0.01)[0]
assert len(indices) <= 3
indices

In [None]:
xs = features_test[:, indices[0]]
ys = features_test[:, indices[1]]
if len(indices) >= 3:
    zs = features_test[:, indices[2]]

In [None]:
# plot feature points for all participants
feature_plot_2d = go.FigureWidget()
feature_plot_2d.update_layout(**base_layout)
max_amp = 4
feature_plot_2d.update_layout(
    width=600,
    height=600,
    xaxis_range=[-max_amp, max_amp],
    yaxis_range=[-max_amp, max_amp],
)
skip = 16
feature_plot_2d.add_scatter(
    x=xs[::skip],
    y=ys[::skip],
    marker_color=test_epochs["marker"][::skip],
    mode="markers",
    marker_size=4,
    marker_colorscale=blue_black_red,
)

In [None]:
# plot feature points for one participant, greens are CORRECT, reds are ERROR
participant_num = 1
test_id = test_ids[participant_num]

error_mask = test_epochs["marker"] == ERROR
correct_mask = test_epochs["marker"] == CORRECT

id_mask = test_epochs["id"] == test_id
x_cor = xs[id_mask & correct_mask]
x_err = xs[id_mask & error_mask]
y_cor = ys[id_mask & correct_mask]
y_err = ys[id_mask & error_mask]

grouped_plot_2d = go.FigureWidget()
grouped_plot_2d.update_layout(**base_layout)
max_amp = 4
grouped_plot_2d.update_layout(
    width=600,
    height=600,
    xaxis_range=[-max_amp, max_amp],
    yaxis_range=[-max_amp, max_amp],
)
grouped_plot_2d.add_scatter(
    x=x_cor,
    y=y_cor,
    marker_color="green",
    mode="markers",
    # marker_symbol=test_epochs["marker"][:lim] * 4,
    marker_size=4,
    # marker_colorscale=blue_black_red,
)
grouped_plot_2d.add_scatter(
    x=x_err,
    y=y_err,
    marker_color="red",
    mode="markers",
    # marker_symbol=test_epochs["marker"][:lim] * 4,
    marker_size=4,
    # marker_colorscale=blue_black_red,
)

In [None]:
# for each participant, show arrow in feature space
# from their error median to correct median
# arrows are colored by the chosen scale - hotter color means higher on the scale


@interact(column=Dropdown(value="Sex", options=epochs.columns))
def update_plots(column):
    error_mask = test_epochs["marker"] == ERROR
    correct_mask = test_epochs["marker"] == CORRECT

    arrow_plot_2d = go.FigureWidget()
    arrow_plot_2d.update_layout(**base_layout)
    max_amp = 2.7
    arrow_plot_2d.update_layout(
        width=600,
        height=600,
        xaxis_range=[-max_amp, max_amp],
        yaxis_range=[-max_amp, max_amp],
    )

    for test_id in test_ids:
        id_mask = test_epochs["id"] == test_id
        color_val = test_epochs[test_epochs["id"] == test_id][column].iloc[0] / 5
        x_cor = xs[id_mask & correct_mask]
        x_err = xs[id_mask & error_mask]
        y_cor = ys[id_mask & correct_mask]
        y_err = ys[id_mask & error_mask]
        arrow_plot_2d.add_annotation(
            x=np.median(x_cor),
            y=np.median(y_cor),
            ax=np.median(x_err),
            ay=np.median(y_err),
            xref="x",
            yref="y",
            axref="x",
            ayref="y",
            text="",  # if you want only the arrow
            showarrow=True,
            arrowhead=3,
            arrowsize=1,
            arrowwidth=1,
            arrowcolor=matplotlib.colors.rgb2hex(cm.hot(color_val)),
        )

    display(arrow_plot_2d)

In [None]:
# for each participant, show arrow in feature space
# from their error median to correct median
# arrows are colored by rumination - hotter color means higher rumination

error_mask = test_epochs["marker"] == ERROR
correct_mask = test_epochs["marker"] == CORRECT

arrow_plot_3d = go.FigureWidget()
arrow_plot_3d.update_layout(**base_layout)
max_amp = 2.7
arrow_plot_3d.update_layout(
    #     width=600,
    #     height=600,
    #     xaxis_range=[-max_amp, max_amp],
    #     yaxis_range=[-max_amp, max_amp],
    #     zaxis_range=[-max_amp, max_amp],
)

for test_id in test_ids:
    id_mask = test_epochs["id"] == test_id
    rumination = test_epochs[test_epochs["id"] == test_id][
        "Rumination Full Scale"
    ].iloc[0]
    x_cor = xs[id_mask & correct_mask]
    x_err = xs[id_mask & error_mask]
    y_cor = ys[id_mask & correct_mask]
    y_err = ys[id_mask & error_mask]
    z_cor = zs[id_mask & correct_mask]
    z_err = zs[id_mask & error_mask]
    arrow_plot_3d.add_scatter3d(
        x=[np.median(x_err), np.median(x_cor)],
        y=[np.median(y_err), np.median(y_cor)],
        z=[np.median(z_err), np.median(z_cor)],
        line_color=matplotlib.colors.rgb2hex(cm.hot(rumination / 5)),
        marker_size=[0, 3],
    )

arrow_plot_3d