# EEG classification

Link to working-documentation. Please, **update** when have an idea, plan or when sth important was done.


https://docs.google.com/document/d/1i7tLHpXD-uXY5BopIGGcu-KWAR1DZvSlxQbYRi4fvBI/edit?usp=sharing

### Imports

In [None]:
# %load_ext lab_black
import matplotlib
import matplotlib.pyplot as plt
import mne
import numpy as np
import pywt
from utils import load_gonogo_responses, tmax, tmin

# plt.style.use("dark_background")

In [None]:
epochs = load_gonogo_responses()

### Data Visualisation

In [None]:
epochs.plot(n_epochs=1, event_colors={0: "g", 1: "m"})
None  # prevents doubled output

In [None]:
correct_response_epochs = epochs["correct_response"]
error_response_epochs = epochs["error_response"]


# Calculate averages of events sets
correct_response_evoked = correct_response_epochs.average()
error_response_evoked = error_response_epochs.average()

In [None]:
# Averages of two event sets

mne.viz.plot_compare_evokeds(
    dict(
        correct_response=correct_response_evoked, error_response=error_response_evoked
    ),
    legend="upper left",
    show_sensors="upper right",
    ylim=dict(eeg=[-10, 10]),
    invert_y=True,
    combine="mean"
)

In [None]:
# Averages of error response events per channel

error_response_evoked.plot_joint(picks="eeg")
error_response_evoked.plot_topomap(times=[0.0, 0.08, 0.1, 0.12, 0.2], ch_type="eeg")
None

In [None]:
# Averages of merged event sets (diff between error and correct) per channel

evoked_diff = mne.combine_evoked(
    [correct_response_evoked, error_response_evoked], weights=[1, -1]
)
evoked_diff.plot_joint()
None

In [None]:
events_mean_dict = {}

for key in epochs.event_id.keys():
    mean_key = key + "_mean"
    events_mean_dict[mean_key] = epochs[key]._data.mean(axis=(0))

In [None]:
# Chart with averages of correct and error responses per channel

colors = ["b", "r", "g"]
color_iterator = 0

plt.figure(figsize=(10, 10))


for key in events_mean_dict:
    epoch = events_mean_dict[key]
    plt.plot(
        epoch.T + np.arange(start=1e-6, step=10e-6, stop=301e-7),
        label=key,
        color=colors[color_iterator],
    )
    color_iterator = color_iterator + 1

plt.yticks([])
plt.xticks(np.arange(0, 181, 181 / 8), np.arange(0, 800, 100))
plt.xlabel("milliseconds", fontsize=15)
plt.ylabel("channels", fontsize=15)
plt.legend(loc="upper left")
plt.show()

## Pre-processing

**Pre-processing done with Brain Vision Software:**

- Notch filter  0.05-25
- Baseline Correction //what baseline?
- Ocular Correction
- Artifact Rejection

## Feature extraction

Feature extraction recommended for eeg data is **Wavelet Transform** (especially **Discrite Wavelet Transform**). Better that FFT for biomedical signals because of its localization characteristics of non-stationary signals in time and frequency domains. DWT decompositing signal into five frequency bands.

https://en.wikipedia.org/wiki/Discrete_wavelet_transform

https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0173138&type=printable

In [None]:
# Prepare data for feature extraction: create X-data array and Y-labels array

X = epochs._data
Y = []

for i in range(len(epochs)):
    event_id = list(epochs[i].event_id.values())[0]
    Y.append(event_id)

### Continuous Wavelet Transform

In [None]:
# Mother wavelets list for continuous WT

MWT_list = pywt.wavelist(kind='continuous')

# Remove cmor, fbsp, shan wavelets because of need of special specification 
# of this wavelet with bandpass and center - do not want to handle with it NOW.
MWT_list.remove('cmor')
MWT_list.remove('fbsp')
MWT_list.remove('shan')

In [136]:
# Construct list of scales corresponding to pseudo-frequencies [128Hz-1Hz] for each wavelet
# Could help understand: https://www.researchgate.net/publication/267491844_Continuous_Wavelet_Transform_EEG_Features_of_Alzheimer%27s_Disease

# TODO: refactor

signal_frequency = 256


# compute coeffs of wavelets listed in MWT_list for an given epoch (one channel)
def compute_coeffs(epoch):

    wavelets = {}

    for MWT in MWT_list:
        center_wavelet_frequency = pywt.scale2frequency(MWT, [1])[0]
        const = center_wavelet_frequency * signal_frequency

        # construct scales
        scales = np.arange(
        const/128, 
        const/1, 
        0.1).tolist()

        # compute coeffs
        coef, freqs = pywt.cwt(
        data=epoch, 
        scales=scales,
        wavelet=MWT, 
        sampling_period=1/signal_frequency)

        # Save coeffs from the MWT
        wavelets[MWT] = coef
        
    return wavelets

epoch = X[0][0]

coeffs_dict = compute_coeffs(epoch)