# EEG classification

Link to working-documentation. Please, **update** when have an idea, plan or when sth important was done.


https://docs.google.com/document/d/1i7tLHpXD-uXY5BopIGGcu-KWAR1DZvSlxQbYRi4fvBI/edit?usp=sharing

### Imports

In [None]:
# %load_ext lab_black
import matplotlib
import matplotlib.pyplot as plt
import mne
import numpy as np
import pywt
from utils import tmax, tmin

# plt.style.use("dark_background")

In [None]:
def load_epochs_from_file(file, reject_bad_segments=True):
    """Load epochs from a header file.

    Args:
        file: path to a header file (.vhdr)
        reject_bad_segments: bool

        Whether the epochs with overlapping bad segments are rejected automatically
        by default. If False, only segments' channels annotated in .vmrk file as
        'bad' will be rejected.

    Returns:
        mne Epochs

    """
    # Import the BrainVision data into an MNE Raw object
    raw = mne.io.read_raw_brainvision(file)

    # Construct annotation filename
    annot_file = file[:-4] + "vmrk"

    # Read in the event information as MNE annotations
    annotations = mne.read_annotations(annot_file)

    # Add the annotations to our raw object so we can use them with the data
    raw.set_annotations(annotations)

    # Map with response markers only
    event_dict = {
        "Stimulus/RE*ex*1_n*1_c_1*R*FB": 10004,
        "Stimulus/RE*ex*1_n*1_c_1*R*FG": 10005,
        "Stimulus/RE*ex*1_n*1_c_2*R": 10006,
        "Stimulus/RE*ex*1_n*2_c_1*R": 10007,
        "Stimulus/RE*ex*2_n*1_c_1*R": 10008,
        "Stimulus/RE*ex*2_n*2_c_1*R*FB": 10009,
        "Stimulus/RE*ex*2_n*2_c_1*R*FG": 10010,
        "Stimulus/RE*ex*2_n*2_c_2*R": 10011,
    }

    # Map for merged correct/error response markers
    merged_event_dict = {"correct_response": 0, "error_response": 1}

    # Reconstruct the original events from Raw object
    events, event_ids = mne.events_from_annotations(raw, event_id=event_dict)

    # Merge correct/error response events
    merged_events = mne.merge_events(
        events,
        [10004, 10005, 10009, 10010],
        merged_event_dict["correct_response"],
        replace_events=True,
    )
    merged_events = mne.merge_events(
        merged_events,
        [10006, 10007, 10008, 10011],
        merged_event_dict["error_response"],
        replace_events=True,
    )

    # Read epochs
    epochs = mne.Epochs(
        raw=raw,
        events=merged_events,
        event_id=merged_event_dict,
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        reject_by_annotation=reject_bad_segments,
        preload=True,
    )

    return epochs

In [None]:
file = "GNG_BK0504-64 el.vhdr"

# Import the BrainVision data into an MNE Raw object
raw = mne.io.read_raw_brainvision("../data/" + file)

# Construct anootation filename
annot_file = file[:-4] + "vmrk"

# Read in the event information as MNE annotations
annot = mne.read_annotations("../data/" + annot_file)

# Add the annotations to our raw object so we can use them with the data
raw.set_annotations(annot)

events, event_ids = mne.events_from_annotations(raw)

event_ids.pop("New Segment/")
event_ids.pop("Time 0/")

# Read epochs
epochs = mne.Epochs(
    raw=raw,
    events=events,
    event_id=event_ids,
    tmin=tmin,
    tmax=tmax,
    baseline=None,
    reject_by_annotation=False,
    preload=True,
)

epochs

In [None]:
def find_bad_intervals(annotations, description="Bad Interval/Bad Amplitude"):
    bad_intervals = []

    for annot in annotations:
        #         print(annot)
        if annot["description"] == description:
            channel_epochs = get_epochs_channel_index(annot)
            bad_intervals.append(channel_epochs)

    return bad_intervals

In [None]:
# get annotation item and returns dict {epochs: channels} where it occures


def get_epochs_channel_index(item):
    onset = item["onset"]
    duration = item["duration"]
    # mock
    channel_num = 40  # item['channel_num']

    bad_interval_start_index = get_epoch_index(onset)
    bad_interval_end_index = get_epoch_index(onset + duration)

    epochs_indexes = {bad_interval_start_index, bad_interval_end_index}
    #     print(epochs_indexes)

    return (channel_num, epochs_indexes)

In [None]:
# get time in seconds from start of run and return which epoch index it is.


def get_epoch_index(onset):
    freq = raw.info["sfreq"]
    segment_duration = int((tmax - tmin) * freq)

    position_in_data_points = onset * freq
    epoch_index = int(position_in_data_points // segment_duration)

    return epoch_index

In [None]:
def clear_bad_channels(epochs, bad_intervals):
    for channel, epochs_index in bad_intervals:
        for index in epochs_index:
            epochs[index].get_data()[0][channel] = "dupa"

In [None]:
bad_channels = find_bad_intervals(raw.annotations)
# print(bad_channels)
# clear_bad_channels(epochs, bad_channels)

# a = [np.zeros(181)]

epochs[3]._data[0][40]

# epochs[3].get_data()[0][40] = []

In [None]:
epochs[3]._data[0][40]

In [None]:
epoch_index = 3
channel_num = 40

for i in range(len(epochs)):
    epochs[i]

In [None]:
def reject_bad_channel(epochs, epoch_index, channel_num):
    pass

In [None]:
epochs.plot(n_epochs=5, event_colors={0: "g", 1: "m"})

In [None]:
epochs = load_gonogo_responses()

### Data Visualisation

In [None]:
epochs.plot(n_epochs=1, event_colors={0: "g", 1: "m"})
None  # prevents doubled output

In [None]:
correct_response_epochs = epochs["correct_response"]
error_response_epochs = epochs["error_response"]


# Calculate averages of events sets
correct_response_evoked = correct_response_epochs.average()
error_response_evoked = error_response_epochs.average()

In [None]:
# Averages of two event sets

mne.viz.plot_compare_evokeds(
    dict(
        correct_response=correct_response_evoked, error_response=error_response_evoked
    ),
    legend="upper left",
    show_sensors="upper right",
    ylim=dict(eeg=[-10, 10]),
    invert_y=True,
    combine="mean",
)

In [None]:
# Averages of error response events per channel

error_response_evoked.plot_joint(picks="eeg")
error_response_evoked.plot_topomap(times=[0.0, 0.08, 0.1, 0.12, 0.2], ch_type="eeg")
None

In [None]:
# Averages of merged event sets (diff between error and correct) per channel

evoked_diff = mne.combine_evoked(
    [correct_response_evoked, error_response_evoked], weights=[1, -1]
)
evoked_diff.plot_joint()
None

In [None]:
events_mean_dict = {}

for key in epochs.event_id.keys():
    mean_key = key + "_mean"
    events_mean_dict[mean_key] = epochs[key]._data.mean(axis=(0))

In [None]:
# Chart with averages of correct and error responses per channel

colors = ["b", "r", "g"]
color_iterator = 0

plt.figure(figsize=(10, 10))


for key in events_mean_dict:
    epoch = events_mean_dict[key]
    plt.plot(
        epoch.T + np.arange(start=1e-6, step=10e-6, stop=301e-7),
        label=key,
        color=colors[color_iterator],
    )
    color_iterator = color_iterator + 1

plt.yticks([])
plt.xticks(np.arange(0, 181, 181 / 8), np.arange(0, 800, 100))
plt.xlabel("milliseconds", fontsize=15)
plt.ylabel("channels", fontsize=15)
plt.legend(loc="upper left")
plt.show()

## Pre-processing

**Pre-processing done with Brain Vision Software:**

- Notch filter  0.05-25
- Baseline Correction //what baseline?
- Ocular Correction
- Artifact Rejection

## Feature extraction

Feature extraction recommended for eeg data is **Wavelet Transform** (especially **Discrite Wavelet Transform**). Better that FFT for biomedical signals because of its localization characteristics of non-stationary signals in time and frequency domains. DWT decompositing signal into five frequency bands.

https://en.wikipedia.org/wiki/Discrete_wavelet_transform

https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0173138&type=printable

In [None]:
# Prepare data for feature extraction: create X-data array and Y-labels array

X = epochs._data
Y = []

for i in range(len(epochs)):
    event_id = list(epochs[i].event_id.values())[0]
    Y.append(event_id)

### Continuous Wavelet Transform

In [None]:
# Mother wavelets list for continuous WT

MWT_list = pywt.wavelist(kind="continuous")

# Remove cmor, fbsp, shan wavelets because of need of special specification
# of this wavelet with bandpass and center - do not want to handle with it NOW.
MWT_list.remove("cmor")
MWT_list.remove("fbsp")
MWT_list.remove("shan")

In [None]:
# Construct list of scales corresponding to pseudo-frequencies [128Hz-1Hz] for each wavelet
# Could help understand: https://www.researchgate.net/publication/267491844_Continuous_Wavelet_Transform_EEG_Features_of_Alzheimer%27s_Disease

# TODO: refactor

signal_frequency = 256


# compute coeffs of wavelets listed in MWT_list for an given epoch (one channel)
def compute_coeffs(epoch):

    wavelets = {}

    for MWT in MWT_list:
        center_wavelet_frequency = pywt.scale2frequency(MWT, [1])[0]
        const = center_wavelet_frequency * signal_frequency

        # construct scales
        scales = np.arange(const / 128, const / 1, 0.1).tolist()

        # compute coeffs
        coef, freqs = pywt.cwt(
            data=epoch, scales=scales, wavelet=MWT, sampling_period=1 / signal_frequency
        )

        # Save coeffs from the MWT
        wavelets[MWT] = coef

    return wavelets


epoch = X[0][0]

coeffs_dict = compute_coeffs(epoch)