In [None]:
# x_loc = channel_locations.T[0]
# y_loc = channel_locations.T[1]
# mne.viz.plot_topomap(ica.components_[0], np.stack((x_loc, y_loc), axis=-1))

In [None]:
%load_ext lab_black
import os
import pickle
import inspect
import itertools
from time import time
from copy import deepcopy

import pywt
import mne
import scipy
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import xxhash
from cachier import cachier
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
from sklearn.decomposition import FastICA
from ipywidgets import HBox, VBox
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact

from utils import *
from architecture import *

In [None]:
np.set_printoptions(precision=3)

# ignore FastICA did not converge warnings
# TODO investigate why doesn't it converge
import warnings

warnings.filterwarnings("ignore")

# Load data

#### Data read into dataframe structure. Each epoch is a single record.

In [None]:
df_name = "go_nogo_df"
pickled_data_filename = "../data/" + df_name + ".pkl"
info_filename = "../data/Demographic_Questionnaires_Behavioral_Results_N=163.csv"

# Check if data is already loaded
if os.path.isfile(pickled_data_filename):
    print("Pickled file found. Loading pickled data...")
    epochs = pd.read_pickle(pickled_data_filename)
else:
    print("Pickled file not found. Loading data...")
    epochs = create_df_data(info_filename=info_filename)
    epochs.name = df_name
    # save loaded data into a pickle file
    epochs.to_pickle("../data/" + epochs.name + ".pkl")

# epochs

#### Sort participants by the number of errors, descending. This way the best participants are first.

In [None]:
# add new columns with info about error/correct responses amount
grouped = epochs.groupby("id")
epochs["error_sum"] = grouped[["marker"]].transform(lambda x: (x.values == ERROR).sum())
epochs["correct_sum"] = grouped[["marker"]].transform(
    lambda x: (x.values == CORRECT).sum()
)

# mergesort for stable sorting
epochs = epochs.sort_values("error_sum", ascending=False, kind="mergesort")
# epochs

#### Get metadata

In [None]:
_mne_epochs = load_epochs_from_file("../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = [ch["ch_name"] for ch in _channel_info]

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]

log_freq = np.log2(get_frequencies())  # for plotting CWT

# Train and test

In [None]:
cachedir = "/home/filip/.erpinator_cache"

# steps = steps_simple
steps = steps_parallel_pca
# StandardScaler doesn't seem to change anything for LDA
steps = steps[:-2] + [("lda", LinearDiscriminantAnalysis())]
# steps = steps[:-1] + [("knr", KNeighborsRegressor())]
# steps = steps[:-1] + [("lasso", Lasso())]

steps.pop(3)  # remove CWT
# steps.pop(3)  # remove PCA

steps[1] = ("ica", PCA(random_state=0))  # replace ICA with PCA

regressor_params = dict(
    ica__n_components=[2],
    #     cwt__mwt=["morl"],
    #     cwt__octaves=[4],
    pca__n_components=[3],
    # featurize__power__cwt__mwt=["cmor0.5-1"],
    # featurize__power__pca__n_components=[3],
    # featurize__shape__cwt__mwt=["mexh"],
    # featurize__shape__pca__n_components=[3],
    #     svr__C=[0.1],
    #     knr__n_neighbors=[11],
    #     lasso__alpha=[0.2, 0.5, 1],
    lda__solver=["lsqr"],  # to turn off scaling, to simplify visualizing
)
steps

### Separate model for each person

In [None]:
%%time


print("participant            AUROC   err/corr")
aurocs = []
auroc_sems = []
pipelines = []

# group data by participants' ids
grouped = epochs.groupby(["id"])
for participant_id in epochs["id"].unique():
    participant_df = grouped.get_group(participant_id)

    X = np.array(participant_df["epoch"].to_list())
    y = np.array(participant_df["marker"].to_list())

    pipeline = Pipeline(deepcopy(steps), memory=cachedir)
    pipeline.set_params(**ParameterGrid(regressor_params)[0])
    
    aurocs_personal = []
    pipelines_personal = []
    skf = StratifiedKFold(n_splits=2)
    for train_index, test_index in skf.split(X, y):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        pipeline.fit(X_train, y_train)
        
        if type(steps[-1][1]) == LinearDiscriminantAnalysis:
            y_pred = pipeline.predict_proba(X_test)[:, 1]
        else:
            y_pred = pipeline.predict(X_test)
        # corr = np.corrcoef(y_test, y_pred)[0][1]
        # r2 = r2_score(y_test, y_pred)
        auroc = roc_auc_score(y_test, y_pred)
        aurocs_personal.append(auroc)
        pipelines_personal.append(pipeline)

    aurocs.append(np.mean(aurocs_personal))
    auroc_sems.append(scipy.stats.sem(aurocs_personal))
    pipelines.append(pipelines_personal)

    error_size = participant_df["error_sum"].iloc[0]
    correct_size = participant_df["correct_sum"].iloc[0]
    print(
        f"{participant_id:11}    "
        f"{aurocs[-1]:.3f} ± {auroc_sems[-1]:.3f}    "
        f"{error_size:3}/{correct_size:3}"
    )

total_sem = sum(np.array(auroc_sems) ** 2) ** (1 / 2) / len(auroc_sems)
mean_auroc = f"{np.mean(aurocs):.3f} ± {total_sem:.3f}"
print("mean AUROC: " + mean_auroc)

### One model for all people

In [None]:
def custom_gridsearch(steps, cv, regressor_params, memory):
    print("AUROC   corr     r2")

    # get params randomly
    all_params = list(ParameterGrid(regressor_params))
    # shuffle(all_params)

    for params in all_params:
        pipelines = []
        scores = []
        kf = KFold(n_splits=cv)
        for train_index, test_index in kf.split(X, y):
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]

            pipeline = Pipeline(deepcopy(steps), memory=memory)
            pipeline.set_params(**params)
            pipeline.fit(X_train, y_train)

            if type(steps[-1][1]) == LinearDiscriminantAnalysis:
                y_pred = pipeline.predict_proba(X_test)[:, 1]
            else:
                y_pred = pipeline.predict(X_test)
            corr = np.corrcoef(y_test, y_pred)[0][1]
            r2 = r2_score(y_test, y_pred)
            auroc = roc_auc_score(y_test, y_pred)  # it's different in classification!

            scores.append([auroc, corr, r2])
            print(f"{auroc:.3f}  {corr:.3f}  {r2:.3f}")

            pipelines.append(pipeline)

        # print scores
        print(f"{str(params):126}")
        means = np.mean(scores, axis=0)
        sems = scipy.stats.sem(scores, axis=0)
        for mean, sem in zip(means, sems):
            print(f"{mean:5.2f}±{sem:4.2f}", end="   ")
        print()

    # note that it returns pipelines only for last parameters in the grid
    return pipelines

In [None]:
X = np.array(epochs["epoch"].to_list())  # [::20]
y = np.array(epochs["marker"].to_list())  # [::20]

In [None]:
%%time

pipelines = custom_gridsearch(steps, cv=2, regressor_params=regressor_params, memory=cachedir)

# Visualize components

In [None]:
scale = 0.6


def plot_ica_comp(ica_comp):
    x_loc, y_loc, z_loc = channel_locations.T

    fig = go.FigureWidget()
    fig.update_layout(**base_layout)
    fig.update_layout(width=350 * scale, height=350 * scale)

    # sort by z_loc for prettier printing
    info = list(zip(z_loc, x_loc, y_loc, channel_names, ica_comp))
    info.sort()
    _, _x_loc, _y_loc, _channel_names, _component = zip(*info)

    fig.add_scatter(
        x=_x_loc,
        y=_y_loc,
        text=_channel_names,
        marker_color=_component,
        mode="markers",
        marker_size=42 * scale,
        marker_colorscale=blue_black_red,
    )
    return fig


# def plot_pca_comps_on_cwt(pca_comps):
#     amplitude = 0.1
#     fig = go.FigureWidget(make_subplots(rows=len(pca_comps)))
#     fig.update_layout(**base_layout)
#     fig.update_layout(height=350, width=600)
#     fig.update_xaxes(visible=False)
#     for i, comp in enumerate(pca_comps):
#         comp = comp.reshape(-1, timepoints_count)
#         fig.add_heatmap(
#             z=comp,
#             x=times,
#             row=i + 1,
#             col=1,
#             zmin=-amplitude,
#             zmax=amplitude,
#             y=log_freq,
#             colorscale=blue_black_red,
#         )
#     return fig


def plot_pca_shape(pca_comps, mwt, clf_coefs):
    # CWT+PCA in practice, multiplies by this shape
    fig = go.FigureWidget()
    fig.update_layout(**base_layout)
    fig.update_layout(height=350 * scale, width=600 * scale)
    for i, comp in enumerate(pca_comps):
        if mwt is not None:
            # TODO test this block, if this is really the shape
            comp = comp.reshape(-1, timepoints_count)
            acc = np.zeros_like(times)
            for amps_for_freq, freq in zip(comp, get_frequencies()):
                for amp, latency in zip(amps_for_freq, times):
                    wv = get_wavelet(latency, freq, times, mwt)
                    acc += wv * amp
        else:
            acc = np.copy(comp)

        # weight by the component importance from LDA
        acc *= clf_coefs[i]
        #         print(clf_coefs[i])
        fig.add_scatter(x=times, y=acc)

    # show also the sum of all pca comps weighted by importance
    acc = np.zeros_like(times)
    for comp, coef in zip(pca_comps, clf_coefs):
        acc += comp * coef
    fig.add_scatter(x=times, y=acc, line_width=5, line_color="yellow")

    return fig

In [None]:
# visualizing for steps with separate PCA for each ICA component

split_index = 0
fitted_steps = dict(pipelines[split_index].steps)
ica = fitted_steps["ica"]
pcas = fitted_steps["pca"].PCAs
lda = fitted_steps["lda"]
clf_coefs_for_each_ica_comp = lda.coef_[0].reshape(len(ica.components_), -1)
# clf_coefs_for_each_ica_comp = lda.coef_[0].reshape(-1, len(ica.components_)).T   # this was tested visually to be the wrong unflattening


for ica_comp_num, ica_comp in enumerate(ica.components_):
    pca_comps = pcas[ica_comp_num].components_
    clf_coefs = clf_coefs_for_each_ica_comp[ica_comp_num]
    if "cwt" in fitted_steps:
        mwt = fitted_steps["cwt"].mwt
    else:
        mwt = None

    display(HBox([plot_ica_comp(ica_comp), plot_pca_shape(pca_comps, mwt, clf_coefs)]))
    # display(HBox([plot_ica_comp(ica_comp), plot_pca_comps_on_cwt(pca_comps)]))

In [None]:
# visualizing for steps with only one PCA

split_index = 0
fitted_steps = dict(pipelines[split_index].steps)
ica = fitted_steps["ica"]
pca = fitted_steps["pca"]
lda = fitted_steps["lda"]

pca_comps_separated = pca.components_.reshape(
    len(pca.components_), len(ica.components_), -1
)

for ica_comp_num, ica_comp in enumerate(ica.components_):
    pca_comps = pca_comps_separated[:, ica_comp_num, :]
    clf_coefs = lda.coef_[0]
    if "cwt" in fitted_steps:
        mwt = fitted_steps["cwt"].mwt
    else:
        mwt = None

    display(HBox([plot_ica_comp(ica_comp), plot_pca_shape(pca_comps, mwt, clf_coefs)]))
    # display(HBox([plot_ica_comp(ica_comp), plot_pca_comps_on_cwt(pca_comps)]))

In [None]:
# check that PCA+LDA can be replaced by just a dot product with a shape computed from weighted PCA comps

split_index = 0
fitted_steps = dict(pipelines[split_index].steps)
pcas = fitted_steps["pca"].PCAs
ica_comp_num = 0
pca_comps = pcas[ica_comp_num].components_
lda = fitted_steps["lda"]

# do ICA steps
X_ = pipelines[0].steps[0][1].transform(X)
X_ = pipelines[0].steps[1][1].transform(X_)
X_ = pipelines[0].steps[2][1].transform(X_)

# compute a shape from the weighted PCA comps
acc = np.zeros_like(times)
for comp, coef in zip(pca_comps, lda.coef_[0]):
    acc += comp * coef

features = []
for epoch in X_[0]:
    feature = np.sum(acc * epoch)
    features.append(feature)
features = np.array(features).reshape(-1)

real_features = pipelines[split_index].decision_function(X)

corr = np.corrcoef(features, real_features)[0][1]
np.isclose(corr, 1)
# if True, the dot product gives the same as the normal pipeline execution

# Testing ICA stability

In [None]:
# cv should be 2 !
# otherwise the components will be stable, but trivially
#  - they are trained on overlapping data, so no wonder they are similar

In [None]:
def correlations(a0, a1):
    """Find correlation matrix between 2 matrices.
    It's similar to np.corrcoef, but it doesn't subtract the mean,
    when calculating the sum of squares.

    Parameters
    ----------
    a0, a1 : array_like
        2-D arrays containing multiple variables and observations.
        Each row represents a variable, and each column a single
        observation of all those variables.
        Their number of columns must be equal.
    """
    cov = a0 @ a1.T
    sum_of_squares0 = np.sum(a0 * a0, axis=1).reshape(-1, 1)
    sum_of_squares1 = np.sum(a1 * a1, axis=1).reshape(1, -1)
    return cov / (sum_of_squares0 @ sum_of_squares1) ** (1 / 2)


def factor_similarity(a0, a1):
    """Measure how similar are the factors.
    Reordering and rescaling them doesn't change the similarity.
    """
    corr = correlations(a0, a1)
    sim = abs(corr)  # don't care if factors' sign is flipped
    sim_hor = sim.max(axis=0)  # don't care if factors are reordered
    sim_ver = sim.max(axis=1)
    # in case some row or comuln have two candidates, choose the more pessimistic axis
    mean_sim = min(sim_hor.mean(), sim_ver.mean())
    # TODO? a more robust way would be to generate permutations and chack them
    return mean_sim

In [None]:
spatial_filters = [pipeline.steps[1][1].components_ for pipeline in pipelines]

In [None]:
print("correlations between factors found in the first, and the second split")
correlations(spatial_filters[0], spatial_filters[1])

In [None]:
factor_similarity(spatial_filters[0], spatial_filters[1])

In [None]:
# print(
#     "similarity measures between factors found in each pair of splits, for a single participant"
# )
# similarities = np.array(
#     [
#         [factor_similarity(sf_i, sf_j) for sf_i in spatial_filters]
#         for sf_j in spatial_filters
#     ]
# )
# print(similarities)
# print("mean", similarities.mean())

In [None]:
# # mne plotting for comparison
# x, y, z = channel_locations.T
# mne.viz.plot_topomap(
#     spatial_filters[participant][2], np.stack((x, y), axis=-1)
# )

In [None]:
# # try to find corresponding components
# best_similarity = 0
# for perm in itertools.permutations(range(3)):
#     perm = list(perm)
#     diag = corr[perm].diagonal()
#     similarity = abs(diag).mean()
#     if similarity > best_similarity:
#         best_similarity = similarity
#         best_perm = perm

# print(best_similarity)
# print(best_perm)
# corr[best_perm]