In [None]:
# x_loc = channel_locations.T[0]
# y_loc = channel_locations.T[1]
# mne.viz.plot_topomap(ica.components_[0], np.stack((x_loc, y_loc), axis=-1))

# def plot_pca_comps_on_cwt(pca_comps):
#     amplitude = 0.1
#     fig = go.FigureWidget(make_subplots(rows=len(pca_comps)))
#     fig.update_layout(**base_layout)
#     fig.update_layout(height=350, width=600)
#     fig.update_xaxes(visible=False)
#     for i, comp in enumerate(pca_comps):
#         comp = comp.reshape(-1, timepoints_count)
#         fig.add_heatmap(
#             z=comp,
#             x=times,
#             row=i + 1,
#             col=1,
#             zmin=-amplitude,
#             zmax=amplitude,
#             y=log_freq,
#             colorscale=blue_black_red,
#         )
#     return fig

In [None]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2
import os
import math
import pickle
import inspect
import itertools
from time import time
from copy import deepcopy

import pywt
import mne
import scipy
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import xxhash
import matplotlib
import matplotlib.cm as cm
from cachier import cachier
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
from sklearn.decomposition import FastICA
from sklearn.linear_model import ElasticNet
from sklearn.linear_model import Lasso
from sklearn.linear_model import LinearRegression
from ipywidgets import HBox, VBox
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact

import sys

sys.path.append("..")

from utils import *
from architecture import *
from visualization_helpers import *

In [None]:
np.set_printoptions(precision=3, suppress=True)

# ignore FastICA did not converge warnings
# TODO investigate why doesn't it converge
import warnings

warnings.filterwarnings("ignore")

# Load data

In [None]:
df_name = "go_nogo_df"
pickled_data_filename = "../../data/" + df_name + ".pkl"
info_filename = "../../data/Demographic_Questionnaires_Behavioral_Results_N=163.csv"

# Check if data is already loaded
if os.path.isfile(pickled_data_filename):
    print("Pickled file found. Loading pickled data...")
    epochs = pd.read_pickle(pickled_data_filename)
else:
    print("Pickled file not found. Loading data...")
    epochs = create_df_data(info_filename=info_filename)
    epochs.name = df_name
    # save loaded data into a pickle file
    epochs.to_pickle(pickled_data_filename)

# epochs

# add new columns with info about error/correct responses amount
grouped = epochs.groupby("id")
epochs["error_sum"] = grouped[["marker"]].transform(lambda x: (x.values == ERROR).sum())
epochs["correct_sum"] = grouped[["marker"]].transform(
    lambda x: (x.values == CORRECT).sum()
)

# mergesort for stable sorting
epochs = epochs.sort_values("error_sum", ascending=False, kind="mergesort")
# epochs

_mne_epochs = load_epochs_from_file("../../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = [ch["ch_name"] for ch in _channel_info]

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]

log_freq = np.log2(get_frequencies())  # for plotting CWT

# Train and test

In [None]:
cachedir = "/home/filip/.erpinator_cache"

# steps = steps_simple  # one PCA for all

steps = deepcopy(steps_parallel_pca)
steps.pop(3)  # remove CWT
steps.pop(-1)  # remove scaler

# StandardScaler doesn't seem to change anything for LDA
# steps = steps[:-2] + [("lasso", Lasso())]
# steps = steps[:-2] + [("lda", LinearDiscriminantAnalysis())]
# steps = steps[:-1] + [("knr", KNeighborsRegressor())]
steps = steps[:-1] + [("lasso", Lasso())]

steps[1] = ("spatial_filter", PCA(random_state=0))  # replace ICA with PCA

regressor_params = dict(
    spatial_filter__n_components=[4],
    #     cwt__mwt=["morl"],
    #     cwt__octaves=[4],
    pca__n_components=[8],
    # featurize__power__cwt__mwt=["cmor0.5-1"],
    # featurize__power__pca__n_components=[3],
    # featurize__shape__cwt__mwt=["mexh"],
    # featurize__shape__pca__n_components=[3],
    #     svr__C=[0.1],
    #     knr__n_neighbors=[11],
    lasso__alpha=[0.0000003],
    # lda__solver=["lsqr"],  # to turn off scaling, to simplify visualizing
)
steps

In [None]:
regressor_params

### One model for all people

In [None]:
X = np.array(epochs["epoch"].to_list())  # [::20]
y = np.array(epochs["marker"].to_list())  # [::20]
# y = np.array(epochs["Rumination Full Scale"].to_list())  # [::20]

In [None]:
%%time

pipelines = custom_gridsearch(X, y, steps, cv=2, regressor_params=regressor_params, memory=cachedir)

# Visualize spatial components

In [None]:
# for the first 10 participants, plot the difference of their
# average of correct epochs minus average of error epochs
# passed through the spatial filters


split_index = 0
visualize_spatial_components(
    pipelines[split_index],
    epochs,
    channel_locations,
    channel_names,
    times,
    plot_limit=10,
    erp_type="error",
    max_amp=0.00014,
    flip_mask=[-1, 1, 1, 1],
    scale=0.8,
)

# Visualize temporal components

In [None]:
# for each of spatial filters, plot what shape the PCA components are trying to match
# yellow line shows the sum of all the shapes - its the shape that the whole model
# tries to match for that spatial filter

split_index = 1
visualize_pipeline(
    pipelines[split_index],
    channel_locations,
    channel_names,
    times,
    heatmap=False,
    one_pca=False,
    flip_mask=[-1, 1, -1, 1],
    max_amp=230,
    scale=0.8,
)

In [None]:
# # check that PCA+LDA can be replaced by just a dot product with a shape computed from weighted PCA comps

# split_index = 0
# fitted_steps = dict(pipelines[split_index].steps)
# pcas = fitted_steps["pca"].PCAs
# ica_comp_num = 0
# pca_comps = pcas[ica_comp_num].components_
# lda = fitted_steps["lda"]

# # do ICA steps
# X_ = pipelines[0].steps[0][1].transform(X)
# X_ = pipelines[0].steps[1][1].transform(X_)
# X_ = pipelines[0].steps[2][1].transform(X_)

# # compute a shape from the weighted PCA comps
# acc = np.zeros_like(times)
# for comp, coef in zip(pca_comps, lda.coef_[0]):
#     acc += comp * coef

# features = []
# for epoch in X_[0]:
#     feature = np.sum(acc * epoch)
#     features.append(feature)
# features = np.array(features).reshape(-1)

# real_features = pipelines[split_index].decision_function(X)

# corr = np.corrcoef(features, real_features)[0][1]
# np.isclose(corr, 1)
# # if True, the dot product gives the same as the normal pipeline execution

# Test ICA stability

In [None]:
# cv should be 2 !
# otherwise the components will be stable, but trivially
#  - they are trained on overlapping data, so no wonder they are similar

In [None]:
spatial_filters = [pipeline.steps[1][1].components_ for pipeline in pipelines]

In [None]:
print("correlations between factors found in the first, and the second split")
correlations(spatial_filters[0], spatial_filters[1])

In [None]:
factor_similarity(spatial_filters[0], spatial_filters[1])

In [None]:
# print(
#     "similarity measures between factors found in each pair of splits, for a single participant"
# )
# similarities = np.array(
#     [
#         [factor_similarity(sf_i, sf_j) for sf_i in spatial_filters]
#         for sf_j in spatial_filters
#     ]
# )
# print(similarities)
# print("mean", similarities.mean())

In [None]:
# # mne plotting for comparison
# x, y, z = channel_locations.T
# mne.viz.plot_topomap(
#     spatial_filters[participant][2], np.stack((x, y), axis=-1)
# )

In [None]:
# # try to find corresponding components
# best_similarity = 0
# for perm in itertools.permutations(range(3)):
#     perm = list(perm)
#     diag = corr[perm].diagonal()
#     similarity = abs(diag).mean()
#     if similarity > best_similarity:
#         best_similarity = similarity
#         best_perm = perm

# print(best_similarity)
# print(best_perm)
# corr[best_perm]

# Visualize personal differences

In [None]:
# split participants in half - one for training of common model, one for validation

ids = epochs["id"].unique()

kf = KFold(n_splits=2, shuffle=True)
first_split = list(kf.split(ids))[0]
train_index, test_index = first_split
train_ids = ids[train_index]
test_ids = ids[test_index]

train_epochs = epochs[epochs["id"].isin(train_ids)]
test_epochs = epochs[epochs["id"].isin(test_ids)]

X_train = np.array(train_epochs["epoch"].to_list())
y_train = np.array(train_epochs["marker"].to_list())
X_test = np.array(test_epochs["epoch"].to_list())
y_test = np.array(test_epochs["marker"].to_list())


pipeline = Pipeline(deepcopy(steps), memory=cachedir)
pipeline.set_params(**ParameterGrid(regressor_params)[0])
pipeline.fit(X_train, y_train)

In [None]:
# scores
if type(steps[-1][1]) == LinearDiscriminantAnalysis:
    y_pred = pipeline.predict_proba(X_test)[:, 1]
else:
    y_pred = pipeline.predict(X_test)

auroc = roc_auc_score(y_test, y_pred)
corr = np.corrcoef(y_test, y_pred)[0][1]
r2 = r2_score(y_test, y_pred)
auroc, corr, r2

In [None]:
sc = pipeline["scaler"]
sc.scale_ * 10000

In [None]:
# scores on train set
if type(steps[-1][1]) == LinearDiscriminantAnalysis:
    y_pred = pipeline.predict_proba(X_train)[:, 1]
else:
    y_pred = pipeline.predict(X_train)

auroc = roc_auc_score(y_train, y_pred)
corr = np.corrcoef(y_train, y_pred)[0][1]
r2 = r2_score(y_train, y_pred)
auroc, corr, r2

In [None]:
visualize_pipeline(pipeline, one_pca=True, flip_mask=[-1, 1, 1, 1])

In [None]:
truncated_pipeline = Pipeline(pipeline.steps[:-1])

In [None]:
# adapted from https://github.com/eriklindernoren/ML-From-Scratch/blob/master/mlfromscratch/supervised_learning/regression.py
class l_half_regularization:
    """ Regularization for Ridge Regression """

    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, w):
        return self.alpha * np.sum((np.abs(w) + 0.00001) ** (1 / 2))
        # return 0  # to see olny fit error

    def grad(self, w):
        return self.alpha * 1 / 2 / ((np.abs(w) + 0.00001) ** (1 / 2)) * np.sign(w)


class Regression(object):
    """Base regression model. Models the relationship between a scalar dependent variable y and the independent
    variables X.
    Parameters:
    -----------
    n_iterations: float
        The number of training iterations the algorithm will tune the weights for.
    learning_rate: float
        The step length that will be used when updating the weights.
    """

    def __init__(self, n_iterations, learning_rate):
        self.n_iterations = n_iterations
        self.learning_rate = learning_rate

    def initialize_weights(self, n_features):
        """ Initialize weights randomly [-1/N, 1/N] """
        limit = 1 / math.sqrt(n_features)
        self.w = np.random.uniform(-limit, limit, (n_features,))

    def fit(self, X, y, reinit=True):
        # Insert constant ones for bias weights
        X = np.insert(X, 0, 1, axis=1)
        if reinit:
            self.training_errors = []
            self.initialize_weights(n_features=X.shape[1])

        # Do gradient descent for n_iterations
        for i in range(self.n_iterations):
            y_pred = X.dot(self.w)
            # Calculate l2 loss
            mse = np.mean(0.5 * (y - y_pred) ** 2 + self.regularization(self.w))
            self.training_errors.append(mse)
            # Gradient of l2 loss w.r.t w
            grad_w = -(y - y_pred).dot(X) + self.regularization.grad(self.w)
            # Update the weights
            self.w -= self.learning_rate * grad_w

    def predict(self, X):
        # Insert constant ones for bias weights
        X = np.insert(X, 0, 1, axis=1)
        y_pred = X.dot(self.w)
        return y_pred


class LHalfRegression(Regression):
    def __init__(self, reg_factor, n_iterations=1000, learning_rate=0.001):
        self.regularization = l_half_regularization(alpha=reg_factor)
        super(LHalfRegression, self).__init__(n_iterations, learning_rate)

In [None]:
# prepare data
features_train = truncated_pipeline.transform(X_train)
features_test = truncated_pipeline.transform(X_test)

In [None]:
# initialize regressor
reg = LHalfRegression(0, n_iterations=6000, learning_rate=0.0000001)
reg.fit(features_train, y_train)

In [None]:
# manually retrain with different alphas
# suppested consecutive alphas: 0, 3, 10, 30, 10
reg.regularization.alpha = 10
reg.fit(features_train, y_train, reinit=False)

In [None]:
reg.w

In [None]:
# scores
y_pred = reg.predict(features_test)

auroc = roc_auc_score(y_test, y_pred)
corr = np.corrcoef(y_test, y_pred)[0][1]
r2 = r2_score(y_test, y_pred)

auroc, corr, r2

In [None]:
# scores on train set
y_pred = reg.predict(features_train)

auroc = roc_auc_score(y_train, y_pred)
corr = np.corrcoef(y_train, y_pred)[0][1]
r2 = r2_score(y_train, y_pred)
auroc, corr, r2

In [None]:
px.scatter(reg.training_errors)

In [None]:
visualize_pipeline(
    truncated_pipeline, one_pca=True, flip_mask=[-1, 1, 1, 1], clf_coefs_all=reg.w[1:]
)

In [None]:
indices = np.where(np.abs(reg.w[1:]) > 0.01)[0]
assert len(indices) <= 3
indices

In [None]:
xs = features_test[:, indices[0]]
ys = features_test[:, indices[1]]
if len(indices) >= 3:
    zs = features_test[:, indices[2]]

In [None]:
# plot feature points for all participants
feature_plot_2d = go.FigureWidget()
feature_plot_2d.update_layout(**base_layout)
max_amp = 4
feature_plot_2d.update_layout(
    width=600,
    height=600,
    xaxis_range=[-max_amp, max_amp],
    yaxis_range=[-max_amp, max_amp],
)
skip = 16
feature_plot_2d.add_scatter(
    x=xs[::skip],
    y=ys[::skip],
    marker_color=test_epochs["marker"][::skip],
    mode="markers",
    marker_size=4,
    marker_colorscale=blue_black_red,
)

In [None]:
# plot feature points for one participant, greens are CORRECT, reds are ERROR
participant_num = 1
test_id = test_ids[participant_num]

error_mask = test_epochs["marker"] == ERROR
correct_mask = test_epochs["marker"] == CORRECT

id_mask = test_epochs["id"] == test_id
x_cor = xs[id_mask & correct_mask]
x_err = xs[id_mask & error_mask]
y_cor = ys[id_mask & correct_mask]
y_err = ys[id_mask & error_mask]

grouped_plot_2d = go.FigureWidget()
grouped_plot_2d.update_layout(**base_layout)
max_amp = 4
grouped_plot_2d.update_layout(
    width=600,
    height=600,
    xaxis_range=[-max_amp, max_amp],
    yaxis_range=[-max_amp, max_amp],
)
grouped_plot_2d.add_scatter(
    x=x_cor,
    y=y_cor,
    marker_color="green",
    mode="markers",
    # marker_symbol=test_epochs["marker"][:lim] * 4,
    marker_size=4,
    # marker_colorscale=blue_black_red,
)
grouped_plot_2d.add_scatter(
    x=x_err,
    y=y_err,
    marker_color="red",
    mode="markers",
    # marker_symbol=test_epochs["marker"][:lim] * 4,
    marker_size=4,
    # marker_colorscale=blue_black_red,
)

In [None]:
# for each participant, show arrow in feature space
# from their error median to correct median
# arrows are colored by the chosen scale - hotter color means higher on the scale


@interact(column=Dropdown(value="Sex", options=epochs.columns))
def update_plots(column):
    error_mask = test_epochs["marker"] == ERROR
    correct_mask = test_epochs["marker"] == CORRECT

    arrow_plot_2d = go.FigureWidget()
    arrow_plot_2d.update_layout(**base_layout)
    max_amp = 2.7
    arrow_plot_2d.update_layout(
        width=600,
        height=600,
        xaxis_range=[-max_amp, max_amp],
        yaxis_range=[-max_amp, max_amp],
    )

    for test_id in test_ids:
        id_mask = test_epochs["id"] == test_id
        color_val = test_epochs[test_epochs["id"] == test_id][column].iloc[0] / 5
        x_cor = xs[id_mask & correct_mask]
        x_err = xs[id_mask & error_mask]
        y_cor = ys[id_mask & correct_mask]
        y_err = ys[id_mask & error_mask]
        arrow_plot_2d.add_annotation(
            x=np.median(x_cor),
            y=np.median(y_cor),
            ax=np.median(x_err),
            ay=np.median(y_err),
            xref="x",
            yref="y",
            axref="x",
            ayref="y",
            text="",  # if you want only the arrow
            showarrow=True,
            arrowhead=3,
            arrowsize=1,
            arrowwidth=1,
            arrowcolor=matplotlib.colors.rgb2hex(cm.hot(color_val)),
        )

    display(arrow_plot_2d)

In [None]:
# for each participant, show arrow in feature space
# from their error median to correct median
# arrows are colored by rumination - hotter color means higher rumination

error_mask = test_epochs["marker"] == ERROR
correct_mask = test_epochs["marker"] == CORRECT

arrow_plot_3d = go.FigureWidget()
arrow_plot_3d.update_layout(**base_layout)
max_amp = 2.7
arrow_plot_3d.update_layout(
    #     width=600,
    #     height=600,
    #     xaxis_range=[-max_amp, max_amp],
    #     yaxis_range=[-max_amp, max_amp],
    #     zaxis_range=[-max_amp, max_amp],
)

for test_id in test_ids:
    id_mask = test_epochs["id"] == test_id
    rumination = test_epochs[test_epochs["id"] == test_id][
        "Rumination Full Scale"
    ].iloc[0]
    x_cor = xs[id_mask & correct_mask]
    x_err = xs[id_mask & error_mask]
    y_cor = ys[id_mask & correct_mask]
    y_err = ys[id_mask & error_mask]
    z_cor = zs[id_mask & correct_mask]
    z_err = zs[id_mask & error_mask]
    arrow_plot_3d.add_scatter3d(
        x=[np.median(x_err), np.median(x_cor)],
        y=[np.median(y_err), np.median(y_cor)],
        z=[np.median(z_err), np.median(z_cor)],
        line_color=matplotlib.colors.rgb2hex(cm.hot(rumination / 5)),
        marker_size=[0, 3],
    )

arrow_plot_3d