# EEG classification

### Imports

In [None]:
%load_ext lab_black
import matplotlib
import matplotlib.pyplot as plt
import mne
import numpy as np
from utils import load_gonogo_responses, tmax, tmin

# plt.style.use("dark_background")

## Constructing epoched data 

In [None]:
epochs = load_gonogo_responses()

### Data Visualisation

In [None]:
epochs.plot(
    n_epochs=1,
    event_colors={0: "g", 1: "m"},
)
None  # prevents doubled output

In [None]:
correct_response_epochs = epochs["correct_response"]
error_response_epochs = epochs["error_response"]


# Calculate averages of events sets
correct_response_evoked = correct_response_epochs.average()
error_response_evoked = error_response_epochs.average()

In [None]:
# Averages of two event sets

mne.viz.plot_compare_evokeds(
    dict(
        correct_response=correct_response_evoked, error_response=error_response_evoked
    ),
    legend="upper left",
    show_sensors="upper right",
    ylim=dict(eeg=[-10, 10]),
    invert_y=True,
    combine="mean",
)

In [None]:
# Averages of error response events per channel

error_response_evoked.plot_joint(picks="eeg")
error_response_evoked.plot_topomap(times=[0.0, 0.08, 0.1, 0.12, 0.2], ch_type="eeg")
None

In [None]:
# Averages of merged event sets (diff between error and correct) per channel

evoked_diff = mne.combine_evoked(
    [correct_response_evoked, error_response_evoked], weights=[1, -1]
)
evoked_diff.plot_joint()
None

In [None]:
events_mean_dict = {}

for key in epochs.event_id.keys():
    mean_key = key + "_mean"
    events_mean_dict[mean_key] = epochs[key]._data.mean(axis=(0))

In [None]:
# Chart with averages of correct and error responses per channel

colors = ["b", "r", "g"]
color_iterator = 0

plt.figure(figsize=(10, 10))


for key in events_mean_dict:
    epoch = events_mean_dict[key]
    plt.plot(
        epoch.T + np.arange(start=1e-6, step=10e-6, stop=301e-7),
        label=key,
        color=colors[color_iterator],
    )
    color_iterator = color_iterator + 1

plt.yticks([])
plt.xticks(np.arange(0, 181, 181 / 8), np.arange(0, 800, 100))
plt.xlabel("milliseconds", fontsize=15)
plt.ylabel("channels", fontsize=15)
plt.legend(loc="upper left")
plt.show()

## Pre-processing

**Pre-processing done with Brain Vision Software:**

- Notch filter  0.05-25
- Baseline Correction //what baseline?
- OcularCorrection
- Artifact Rejection

**TOTHINK**

- downsampling
- additional bandpass filter

## Feature extraction

Feature extraction recommended for eeg data is **Wavelet Transform** (especially **Discrite Wavelet Transform**). Better that FFT for biomedical signals because of its localization characteristics of non-stationary signals in time and frequency domains. DWT decompositing signal into five frequency bands.

https://en.wikipedia.org/wiki/Discrete_wavelet_transform

https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0173138&type=printable

**Plan:**
- **5th level of DWT** because of 256 sampling rate ---> decomposition into five specific frequency: delta band (δ), theta band (θ), alpha band (α), beta band (β), and gamma band (γ)

        (The  six  sub-bands,  particularly cD1, cD2, cD3, cD4, cD5 and cA5, represented the frequency range from the band-limited EEG signal, where cA is the decomposition approximation coefficient and cDs are the decomposition detail coefficients)

- choosing mother of wavelet (**MWT**):
    - choose few families of filters (need to read about best)


- **parametrize classifier** with filter name (move vectorisation to classifier)
    - consider: after decomposition each signal into five frequency bands **featurize each band separately** with choosen function (mean, std etc.) 


- run classifiers and print **comparison** of all "parameters" (classifier x MWT x small_function)
    - with ANN eg. stop after ~1000 epoches, do comparison and choose best parameters

**TOTHINK** Feature selection

## halo tu Filip
- https://en.wikipedia.org/wiki/Mexican_hat_wavelet wydaje się dla nas najsensowniejsza
- zamiast ANN można by spróbować zaimplementować to: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8930304 bo wydaje się potężne i w miare proste (w sensie że może da rade zrobić to interpretowalne) i wygrywa z EEGNetem