In [None]:
%load_ext lab_black
import os
import pickle

import pywt
import mne
import scipy
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from ipywidgets import Dropdown, FloatRangeSlider, IntSlider, FloatSlider, interact
from sklearn.model_selection import train_test_split
from sklearn.decomposition import FastICA
from sklearn.metrics import roc_auc_score

from utils import *

# Load stuff

In [None]:
# load epochs (only a train set)
# epochs is a 5D structure:
# PARTICIPANTS x [ERROR, CORRECT] x EPOCH X CHANNEL x TIMEPOINT

pickled_data = "../data/train_epochs.p"
if os.path.isfile(pickled_data):
    epochs = pickle.load(open(pickled_data, "rb"))
else:
    epochs = load_all_epochs()
    # pickle data loaded by MNE to save on loading times later
    pickle.dump(epochs, open(pickled_data, "wb"))

# sort participants by the number of errors, descending
# this way the best participants are first
epochs.sort(reverse=True, key=lambda e: len(e[ERROR]))

print("participants\t", len(epochs))
print("p0 error\t", epochs[0][ERROR].shape)
print("p0 correct\t", epochs[0][CORRECT].shape)

In [None]:
# get metadata
_mne_epochs = load_epochs_from_file("../data/responses/GNG_AA0303-64 el.vhdr")
times = _mne_epochs.times

_channel_info = _mne_epochs.info["chs"]
channel_locations = np.array([ch["loc"][:3] for ch in _channel_info])
channel_names = [ch["ch_name"] for ch in _channel_info]

channel_colors = channel_locations - channel_locations.min(axis=0)
channel_colors /= channel_colors.max(axis=0)
channel_colors = channel_colors * 255 // 1
channel_colors = [f"rgb({c[0]:.0f},{c[1]:.0f},{c[2]:.0f})" for c in channel_colors]

log_freq = np.log2(get_frequencies())  # for plotting CWT

# Train and test

In [None]:
def train(err, cor, mwt, cwt_density=2, ica_n_components=4):
    # apply cwt
    err_cwts = np.array(
        [[cwt(ch_signal, mwt, cwt_density) for ch_signal in epoch] for epoch in err]
    )
    cor_cwts = np.array(
        [[cwt(ch_signal, mwt, cwt_density) for ch_signal in epoch] for epoch in cor]
    )
    # they are 4D numpy arrays:
    # EPOCH x CHANNEL x FREQUENCY x TIMEPOINT

    # compute ICA
    joined_epochs = np.concatenate((err, cor))
    concat = np.concatenate(joined_epochs, axis=1)
    # concat.shape == (num_of_channels, timepoints)
    ica = FastICA(n_components=ica_n_components)
    ica.fit_transform(concat.T)
    # ica.components_.shape == (n_components, num_of_channels)

    # find which ICA components separate best, and sort them by separation, descending
    _components = []
    for i, comp in enumerate(ica.components_):
        _, separations = get_best_separation(err_cwts, cor_cwts, comp)
        _components.append((separations.max(), comp))
    _components.sort(reverse=True)
    ica_components = [comp for separation, comp in _components]

    # find bets separating wavelet
    spatial_filter = ica_components[0]
    index, _ = get_best_separation(err_cwts, cor_cwts, spatial_filter)
    err_end = filter_(err_cwts, spatial_filter)[:, index[0], index[1]]
    cor_end = filter_(cor_cwts, spatial_filter)[:, index[0], index[1]]

    # fit normal distributions to error and correct final features
    _params = scipy.stats.norm.fit(err_end)
    err_distr = scipy.stats.norm(*_params).pdf
    _params = scipy.stats.norm.fit(cor_end)
    cor_distr = scipy.stats.norm(*_params).pdf

    return ica_components, index, err_distr, cor_distr


def predict(epochs, mwt, ica_component, index, err_distr, cor_distr):
    cwts = np.array([[cwt(ch_signal, mwt) for ch_signal in epoch] for epoch in epochs])
    # they are 4D numpy arrays:
    # EPOCH x CHANNEL x FREQUENCY x TIMEPOINT

    spatial_filter = ica_components[0]
    end = filter_(cwts, spatial_filter)[:, index[0], index[1]]

    # assume equal priors
    predictions = cor_distr(end) / (cor_distr(end) + err_distr(end))
    return predictions

In [None]:
aurocs = []
for participant, err_cor in enumerate(epochs[:10]):
    err, cor = err_cor

    mwt = "mexh"
    # bandwidth = 0.5
    # mwt = f"cmor{bandwidth}-1"

    # split out test sets
    err_train, err_test = train_test_split(err, test_size=0.4, random_state=0)
    cor_train, cor_test = train_test_split(cor, test_size=0.4, random_state=0)

    # train
    ica_components, index, err_distr, cor_distr = train(err_train, cor_train, mwt)

    # test
    y_true = np.concatenate((np.zeros(len(err_test)), np.ones(len(cor_test))))
    test_data = np.concatenate((err_test, cor_test))
    y_pred = predict(test_data, mwt, ica_components, index, err_distr, cor_distr)

    auroc = roc_auc_score(y_true, y_pred)
    aurocs.append(auroc)

    print(
        f"participant: {participant:3}    errors/corrects: {len(err):3}/{len(cor):3}    AUROC: {auroc:.3f}"
    )