In [None]:
%load_ext lab_black
import os
import pickle
import inspect
from time import time

import pywt
import mne
import scipy
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import xxhash
from cachier import cachier
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
from sklearn.decomposition import FastICA

from utils import *

In [None]:
# ignore FastICA did not converge warnings
# TODO investigate why doesn't it converge
import warnings

warnings.filterwarnings("ignore")

# Load data

#### Data read into dataframe structure. Each epoch is a single record.

In [None]:
df_name = "go_nogo_df"
pickled_data_filename = "../data/" + df_name + ".pkl"
info_filename = "../data/Demographic_Questionnaires_Behavioral_Results_N=163.csv"

# Check if data is already loaded
if os.path.isfile(pickled_data_filename):
    print("Pickled file found. Loading pickled data...")
    epochs = pd.read_pickle(pickled_data_filename)
else:
    print("Pickled file not found. Loading data...")
    epochs = create_df_data(info_filename=info_filename)
    epochs.name = df_name
    # save loaded data into a pickle file
    epochs.to_pickle("../data/" + epochs.name + ".pkl")

display(epochs)

#### Sort participants by the number of errors, descending. This way the best participants are first.

In [None]:
# add new columns with info about error/correct responses amount
grouped = epochs.groupby("id")
epochs["error_sum"] = grouped[["marker"]].transform(lambda x: (x.values == ERROR).sum())
epochs["correct_sum"] = grouped[["marker"]].transform(
    lambda x: (x.values == CORRECT).sum()
)

# mergesort for stable sorting
epochs = epochs.sort_values("error_sum", ascending=False, kind="mergesort")

display(epochs)

#### Get metadata

In [None]:
_mne_epochs = load_epochs_from_file("../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = [ch["ch_name"] for ch in _channel_info]

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]

log_freq = np.log2(get_frequencies())  # for plotting CWT

# Train and test

In [None]:
def _numpy_hasher(args, kwargs):
    def make_hashable(value):
        if type(value) == np.ndarray:
            # largest hash to minimize collisions
            return xxhash.xxh128_digest(value.tobytes())
        else:
            return value

    bound = inspect.signature(vectorize).bind(*args, **kwargs)
    bound.apply_defaults()
    key = tuple(
        (k, make_hashable(value)) for k, value in sorted(bound.arguments.items())
    )
    return key


@cachier(pickle_reload=False, hash_params=_numpy_hasher)
def vectorize(
    X,
    y,
    mwt="mexh",
    cwt_density=2,
    ica_n_components=3,
    wv_weighting="single",
    wv_weighting_n_components=3,
):
    # it returns features of the shape EPOCH x ICA_COMP x WAVELET_COMP
    #          and params of the shape ICA_COMP x (SPATIAL_FILTER, WV_WEIGHTS)

    # X has a shape EPOCHS x CHANNELS x TIMEPOINTS

    # compute ICA
    concat = np.concatenate(X, axis=1)
    # concat.shape == (num_of_channels, timepoints)
    ica = FastICA(n_components=ica_n_components)
    ica.fit(concat.T)
    # ica.components_.shape == (n_components, num_of_channels)

    params = []
    features = []
    for spatial_filter in ica.components_:
        # apply ICA
        X_filtered = np.tensordot(X, spatial_filter, axes=([1], [0]))
        # they have shape EPOCHS x TIMEPOINTS

        # apply cwt
        X_cwts = np.array([cwt(epoch, mwt, cwt_density) for epoch in X_filtered])
        # it has a shape EPOCH x FREQUENCY x TIMEPOINT

        X_flattened = X_cwts.reshape(X_cwts.shape[0], -1)
        if wv_weighting == "single":
            # find bets separating wavelet
            separations = get_separations(X_cwts[y == ERROR], X_cwts[y == CORRECT])
            # separations are shaped FREQUENCY x TIMEPOINT
            index = np.unravel_index(separations.argmax(), separations.shape)
            wv_weights = np.zeros((1, *separations.shape))
            wv_weights[0][index[0]][index[1]] = 1
            # 'single' means only one wv_component is found
            wv_weighting_n_components = 1
        elif wv_weighting == "PCA":
            pca = PCA(n_components=wv_weighting_n_components)
            pca.fit(X_flattened)
            wv_weights = pca.components_
        elif wv_weighting == "ICA":
            ica = FastICA(n_components=wv_weighting_n_components, tol=0.001)
            ica.fit(X_flattened)
            wv_weights = ica.components_
        elif wv_weighting == "LDA":
            lda = LinearDiscriminantAnalysis(n_components=wv_weighting_n_components)
            lda.fit(X_flattened, y)
            wv_weights = lda.scalings_
        else:
            raise ValueError("wrong wv_choice argument")

        # unflatten wv_weights
        cwt_shape = X_cwts.shape[1:]  # FREQUENCY x TIMEPOINT shape
        wv_weights = wv_weights.reshape(wv_weighting_n_components, *cwt_shape)
        # X_cwts has a shape EPOCH x FREQUENCY x TIMEPOINT
        # wv_weights has a shape  WAVELET_COMPONENT x FREQUENCY x TIMEPOINT
        one_channel_features = np.tensordot(X_cwts, wv_weights, axes=([1, 2], [1, 2]))
        # one_channel_features has a shape EPOCH x WAVELET_COMPONENT

        params.append((spatial_filter, wv_weights))
        features.append(one_channel_features)

    features = np.array(features)
    # transform it from shape ICA_COMP x EPOCH x WAVELET_COMP
    #                      to EPOCH x ICA_COMP x WAVELET_COMP
    features = features.transpose((1, 0, 2))
    return features, params


def train(
    X,
    y,
    mwt="mexh",
    cwt_density=2,
    ica_n_components=3,
    wv_weighting="single",
    wv_weighting_n_components=3,
):
    # X has a shape EPOCHS x CHANNELS x TIMEPOINTS
    # y has a shape EPOCHS

    features, params = vectorize(
        X,
        y,
        mwt,
        cwt_density,
        ica_n_components,
        wv_weighting,
        wv_weighting_n_components,
    )

    # flatten features into shape EPOCH x (ICA_COMP*WAVELET_COMP)
    features = features.reshape(features.shape[0], -1)

    # create a classifier from end feature values
    # TODO maybe balance class sizes or priors somehow?
    clf = LinearDiscriminantAnalysis()
    clf.fit(features, y)

    return params, clf


def predict(epochs, params, clf, mwt="mexh", cwt_density=2):
    features = []
    for spatial_filter, wv_weights in params:
        # apply spatial filter
        filtered = np.tensordot(epochs, spatial_filter, axes=([1], [0]))

        cwts = np.array([cwt(epoch, mwt, cwt_density) for epoch in filtered])
        # EPOCH x FREQUENCY x TIMEPOINT

        one_channel_features = np.tensordot(cwts, wv_weights, axes=([1, 2], [1, 2]))
        #  features has a shape EPOCH x WAVELET_COMP
        features.append(one_channel_features)

    features = np.array(features)
    # transform it from shape ICA_COMP x EPOCH x WAVELET_COMP
    #                      to EPOCH x ICA_COMP x WAVELET_COMP
    features = features.transpose((1, 0, 2))
    # flatten feature_values into shape EPOCH x (ICA_COMP*WAVELET_COMP)
    features = features.reshape(features.shape[0], -1)

    probs = clf.predict_proba(features)
    return probs[:, 1]

In [None]:
def benchmark(epochs, test_on_train_set=False, verbose=False, **hyperparams):
    start = time()
    if verbose:
        print("participant            AUROC   err/corr")
    aurocs = []
    auroc_sems = []

    # group data by participants' ids
    grouped = epochs.groupby(["id"])
    for participant_id in epochs["id"].unique():
        participant_df = grouped.get_group(participant_id)

        X = np.array(participant_df["epoch"].to_list())

        # you can change y set in a easy way ---> y=np.array(participant_df["column_name"].to_list())
        y = np.array(participant_df["marker"].to_list())

        aurocs_personal = []
        # KFold cross-validation
        skf = StratifiedKFold(n_splits=4)
        for train_index, test_index in skf.split(X, y):
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]
            if test_on_train_set:
                X_test = X_train
                y_test = y_train

            # train
            params, clf = train(X_train, y_train, **hyperparams)

            # test
            y_pred = predict(X_test, params, clf, hyperparams["mwt"])

            auroc = roc_auc_score(y_test, y_pred)
            aurocs_personal.append(auroc)

        aurocs.append(np.mean(aurocs_personal))
        auroc_sems.append(scipy.stats.sem(aurocs_personal))

        if verbose:
            error_size = participant_df["error_sum"].iloc[0]
            correct_size = participant_df["correct_sum"].iloc[0]

            print(
                f"{participant_id:11}    "
                f"{aurocs[-1]:.3f} ± {auroc_sems[-1]:.3f}    "
                f"{error_size:3}/{correct_size:3}"
            )

    total_sem = sum(np.array(auroc_sems) ** 2) ** (1 / 2) / len(auroc_sems)
    mean_auroc = f"{np.mean(aurocs):.3f} ± {total_sem:.3f}"
    if verbose:
        print(f"\ntraining time: {(time() - start) / 60:.0f} min")
        print("mean AUROC: " + mean_auroc)
    return mean_auroc

In [None]:
print("single wavelet choice")
auroc = benchmark(epochs, mwt="mexh", wv_weighting="PCA", verbose=True)
print(auroc)

In [None]:
print("finding the best number of PCA wavelet components")
for wv_comps in [1, 2, 3, 4, 5, 6]:
    auroc = benchmark(
        epochs,
        mwt="mexh",
        wv_weighting="PCA",
        wv_weighting_n_components=wv_comps,
        ica_n_components=3,
    )
    print(f"{wv_comps}   {auroc}")

In [None]:
print("finding the best number of ICA components")
for ica_comps in [1, 2, 3, 4, 5, 6, 7, 8]:
    auroc = benchmark(
        epochs,
        mwt="mexh",
        wv_weighting="PCA",
        wv_weighting_n_components=3,
        ica_n_components=ica_comps,
    )
    print(f"{ica_comps}   {auroc}")

In [None]:
# px.scatter(y=aurocs)
# px.scatter(y=sorted(aurocs))