In [None]:
%load_ext lab_black
import os
import pickle
from time import time

import pywt
import mne
import scipy
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.decomposition import FastICA
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from utils import *

# Load stuff

In [None]:
# load epochs (only a train set)
# epochs is a 5D structure:
# PARTICIPANTS x [ERROR, CORRECT] x EPOCH X CHANNEL x TIMEPOINT

pickled_data = "../data/train_epochs.p"
if os.path.isfile(pickled_data):
    epochs = pickle.load(open(pickled_data, "rb"))
else:
    epochs = load_all_epochs()
    # pickle data loaded by MNE to save on loading times later
    pickle.dump(epochs, open(pickled_data, "wb"))

# sort participants by the number of errors, descending
# this way the best participants are first
epochs.sort(reverse=True, key=lambda e: len(e[ERROR]))

print("participants\t", len(epochs))
print("p0 error\t", epochs[0][ERROR].shape)
print("p0 correct\t", epochs[0][CORRECT].shape)


# get metadata
_mne_epochs = load_epochs_from_file("../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = [ch["ch_name"] for ch in _channel_info]

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]

log_freq = np.log2(get_frequencies())  # for plotting CWT

# Train and test

In [None]:
def train(X, y, mwt="mexh", cwt_density=2, ica_n_components=4, wavelet_choice="single"):
    # X has a shape EPOCHS x CHANNELS x TIMEPOINTS
    # y has a shape EPOCHS

    # compute ICA
    concat = np.concatenate(X, axis=1)
    # concat.shape == (num_of_channels, timepoints)
    ica = FastICA(n_components=ica_n_components)
    ica.fit_transform(concat.T)
    # ica.components_.shape == (n_components, num_of_channels)

    features = []
    for spatial_filter in ica.components_:
        # apply ICA
        X_filtered = filter_(X, spatial_filter)
        # they have shape EPOCHS x TIMEPOINTS

        # apply cwt
        X_cwts = np.array([cwt(epoch, mwt, cwt_density) for epoch in X_filtered])
        # it has a shape EPOCH x FREQUENCY x TIMEPOINT

        if wavelet_choice == "single":
            # find bets separating wavelet
            separations = get_separations(X_cwts[y == ERROR], X_cwts[y == CORRECT])
            # separations are shaped FREQUENCY x TIMEPOINT
            index = np.unravel_index(separations.argmax(), separations.shape)
            wavelet_weights = np.zeros_like(separations)
            wavelet_weights[index] = 1
        elif wavelet_choice == "LDA":
            pass
        else:
            raise ValueError("wrong wavelet_choice argument")

        X_end = np.tensordot(X_cwts, wavelet_weights, axes=([1, 2], [0, 1]))

        # fit normal distributions to error and correct final feature
        err_end = X_end[y == ERROR]
        cor_end = X_end[y == CORRECT]
        _params = scipy.stats.norm.fit(err_end)
        err_distr = scipy.stats.norm(*_params).pdf
        _params = scipy.stats.norm.fit(cor_end)
        cor_distr = scipy.stats.norm(*_params).pdf

        separation = separations[index]
        features.append(
            (separation, spatial_filter, wavelet_weights, err_distr, cor_distr)
        )

    # find which spatial_filters separate best, and sort them by separation, descending
    features.sort(reverse=True)
    return features


def predict(epochs, features, mwt="mexh", cwt_density=2):
    feature = features[0]
    _, spatial_filter, wavelet_weights, err_distr, cor_distr = feature

    filtered = filter_(epochs, spatial_filter)

    cwts = np.array([cwt(epoch, mwt, cwt_density) for epoch in filtered])
    # EPOCH x FREQUENCY x TIMEPOINT

    end = np.tensordot(cwts, wavelet_weights, axes=([1, 2], [0, 1]))

    # assume equal priors
    predictions = cor_distr(end) / (cor_distr(end) + err_distr(end))
    return predictions

In [None]:
start = time()
print("participant            AUROC   err/corr")
aurocs = []
auroc_sems = []
for participant, err_cor in enumerate(epochs[:3]):
    err, cor = err_cor
    X = np.concatenate((err, cor))
    y = np.concatenate((np.zeros(len(err)), np.ones(len(cor))))

    mwt = "mexh"
    # bandwidth = 0.5
    # mwt = f"cmor{bandwidth}-1"

    aurocs_personal = []
    # KFold cross-validation
    skf = StratifiedKFold(n_splits=4)
    for train_index, test_index in skf.split(X, y):
        #         print("TRAIN:", train_index, "TEST:", test_index)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        # train
        features = train(X_train, y_train, mwt, wavelet_choice="single")

        # test
        y_pred = predict(X_test, features, mwt)

        auroc = roc_auc_score(y_test, y_pred)
        aurocs_personal.append(auroc)

    aurocs.append(np.mean(aurocs_personal))
    auroc_sems.append(scipy.stats.sem(aurocs_personal))

    print(
        f"{participant:11}    "
        f"{aurocs[-1]:.3f} ± {auroc_sems[-1]:.3f}    "
        f"{len(err):3}/{len(cor):3}"
    )

print(f"\ntraining time: {(time() - start) / 60:.0f} min")
total_sem = sum(np.array(auroc_sems) ** 2) ** (1 / 2) / len(auroc_sems)
print(f"mean AUROC: {np.mean(aurocs):.3f} ± {total_sem:.3f}")

In [None]:
# px.scatter(y=aurocs)