# EEG classification

Link to working-documentation. Please, **update** when have an idea, plan or when sth important was done.


https://docs.google.com/document/d/1i7tLHpXD-uXY5BopIGGcu-KWAR1DZvSlxQbYRi4fvBI/edit?usp=sharing

### Imports

In [None]:
# %load_ext lab_black
import matplotlib
import matplotlib.pyplot as plt
import mne
import numpy as np
import pywt
from utils import tmax, tmin

# plt.style.use("dark_background")

In [None]:
def load_epochs_from_file(file, reject_bad_segments="auto"):
    """Load epochs from a header file.

    Args:
        file: path to a header file (.vhdr)
        reject_bad_segments: 'auto' | 'annot' | 'amplitude'

        Whether the epochs with overlapping bad segments are rejected by default.

        'auto' means that bad segments are rejected automatically.
        'annot' rejection based on annotations and reject only channels annotated in .vmrk file as
        'bad'.
        'amplitude' rejection based on peak-to-peak amplitude of channels.

        Rejected with 'annot' and 'amplitude' channels are zerosed.

    Returns:
        mne Epochs

    """
    # Import the BrainVision data into an MNE Raw object
    raw = mne.io.read_raw_brainvision("../data/" + file)

    # Construct annotation filename
    annot_file = file[:-4] + "vmrk"

    # Read in the event information as MNE annotations
    annotations = mne.read_annotations("../data/" + annot_file)

    # Add the annotations to our raw object so we can use them with the data
    raw.set_annotations(annotations)

    # Map with response markers only
    event_dict = {
        "Stimulus/RE*ex*1_n*1_c_1*R*FB": 10004,
        "Stimulus/RE*ex*1_n*1_c_1*R*FG": 10005,
        "Stimulus/RE*ex*1_n*1_c_2*R": 10006,
        "Stimulus/RE*ex*1_n*2_c_1*R": 10007,
        "Stimulus/RE*ex*2_n*1_c_1*R": 10008,
        "Stimulus/RE*ex*2_n*2_c_1*R*FB": 10009,
        "Stimulus/RE*ex*2_n*2_c_1*R*FG": 10010,
        "Stimulus/RE*ex*2_n*2_c_2*R": 10011,
    }

    # Map for merged correct/error response markers
    merged_event_dict = {"correct_response": 0, "error_response": 1}

    # Reconstruct the original events from Raw object
    events, event_ids = mne.events_from_annotations(raw, event_id=event_dict)

    # Merge correct/error response events
    merged_events = mne.merge_events(
        events,
        [10004, 10005, 10009, 10010],
        merged_event_dict["correct_response"],
        replace_events=True,
    )
    merged_events = mne.merge_events(
        merged_events,
        [10006, 10007, 10008, 10011],
        merged_event_dict["error_response"],
        replace_events=True,
    )

    epochs = []
    this_reject_by_annotation = True

    if reject_bad_segments != "auto":
        this_reject_by_annotation = False

    # Read epochs
    temp_epochs = mne.Epochs(
        raw=raw,
        events=merged_events,
        event_id=merged_event_dict,
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        reject_by_annotation=this_reject_by_annotation,
        preload=True,
    )

    if reject_bad_segments == "annot":
        custom_annotations = get_annotations(annot_file)
        epochs = reject_by_annotations(temp_epochs, custom_annotations)
    elif reject_bad_segments == "amplitude":
        epochs = reject_by_amplitude(temp_epochs)
    else:
        epochs = temp_epochs

    return epochs

In [None]:
def reject_by_amplitude(epochs, amplitude=3e-5):
    epoch_data = epochs.get_data()
    epoch_channel = []

    for epoch_index in range(len(epoch_data)):
        epoch = epoch_data[epoch_index]
        for channel_index in range(len(epoch)):
            channel_data = epoch[channel_index]
            this_amplitude = peak_to_peak_amplitude(channel_data)

            if this_amplitude > amplitude:
                #                 print('Epoch: {}, Channel: {}, Amplitude: {}'.format(epoch_index, channel_index, this_amplitude))
                epoch_channel.append((epoch_index, channel_index))

    cleared_epochs = clear_bads(epochs, epoch_channel)

    return cleared_epochs

In [None]:
def peak_to_peak_amplitude(signal):
    n_samples = len(signal)

    signal_fft = np.fft.fft(signal)
    amplitudes = 2 / n_samples * np.abs(signal_fft)
    peak_to_peak_amplitude = max(amplitudes) - min(amplitudes)

    #     print('peak to peak amplitude {}, max amplitude: {}'.format(peak_to_peak_amplitude, max(amplitudes)))

    return peak_to_peak_amplitude

In [None]:
file = "GNG_BK0504-64 el.vhdr"

data_bad_annot = load_epochs_from_file(file=file, reject_bad_segments="annot")
data_bad_amplitude = load_epochs_from_file(file=file, reject_bad_segments="amplitude")
data_clear = load_epochs_from_file(file=file, reject_bad_segments="auto")

In [None]:
data_temp = data_bad[1]._data[0][4]
# print(data_bad[6]._data[0][2])
amplitude = peak_to_peak_amplitude(data_temp)

print("amp: {}".format(amplitude))


data_bad[1].plot(n_epochs=2, event_colors={0: "g", 1: "m"})
None

In [None]:
def reject_by_annotations(
    epochs, annotations, description="Bad Interval/Bad Amplitude"
):
    epoch_channel = []
    current_segment_index = -1

    for annot in annotations:
        if annot["description"] == "New Segment/":
            current_segment_index += 1

        if annot["description"] == description:
            bad_epoch_channel = get_epochs_channel_index(annot, current_segment_index)
            epoch_channel = epoch_channel + bad_epoch_channel

    cleared_epochs = clear_bads(epochs, epoch_channel)

    #     print("Amount of custom bad_annotations: {}".format(len(bad_intervals)))
    return cleared_epochs

In [None]:
# get annotation item and returns dict {epochs: channels} where it occures


def get_epochs_channel_index(item, current_segment_index):
    onset = item["onset"]  # as position in datapoint
    duration = item["duration"]  # as ticks
    channel_num = item["channel_num"]
    channel_index = channel_num - 1

    bad_epoch_channel = []

    #     bad_interval_start_index = get_epoch_index(onset)
    bad_interval_start_index = current_segment_index
    bad_interval_end_index = get_epoch_index(onset + duration)

    #     bad_epoch_indexes = list(
    #         range(bad_interval_start_index, bad_interval_end_index + 1)
    #     )
    #     bad_epoch_indexes_set = set(bad_epoch_indexes)

    for i in range(bad_interval_start_index, bad_interval_end_index + 1):
        bad_epoch_channel.append((i, channel_index))

    return bad_epoch_channel

In [None]:
# get time in seconds from start of run and return which epoch index it is.


def get_epoch_index(onset):
    #     freq = raw.info['sfreq']
    freq = 256
    segment_duration = int((tmax - tmin) * freq)
    epoch_index = int(onset // segment_duration)

    return epoch_index

In [None]:
def clear_bads(epochs, bads, replacement=0):
    eeg_data = epochs.get_data()
    overlapped_epochs_set = set()

    #     print("Cleared channels: ")
    for epoch_index, channel_index in bads:
        overlapped_epochs_set.add(epoch_index)
        #         print("channel: {} , epoch_index: {}".format(channel_index, epoch_index))

        eeg_data[epoch_index][channel_index] = [replacement]

    print("Amount of overlapped epochs: {}".format(len(overlapped_epochs_set)))
    print("Overlapped epochs: {}".format(overlapped_epochs_set))

    return epochs

In [None]:
file = "GNG_BK0504-64 el.vhdr"

data_bad = load_epochs_from_file(file=file, reject_bad_segments=False)
# data_clear = load_epochs_from_file(file=file, reject_bad_segments=True)

In [None]:
data_bad[4].plot(n_epochs=10, event_colors={0: "g", 1: "m"})
None

### Getting channel info from vmrk file

In [None]:
import re
import os
import os.path as op


def _read_vmrk(fname):
    """Read annotations from a vmrk file.
    Parameters
    ----------
    fname : str
        vmrk file to be read.
    Returns
    -------
    onset : array, shape (n_annots,)
        The onsets in seconds.
    duration : array, shape (n_annots,)
        The onsets in seconds.
    description : array, shape (n_annots,)
        The description of each annotation.
    date_str : str
        The recording time as a string. Defaults to empty string if no
        recording time is found.
    """
    # read vmrk file
    with open(fname, "rb") as fid:
        txt = fid.read()

    # we don't actually need to know the coding for the header line.
    # the characters in it all belong to ASCII and are thus the
    # same in Latin-1 and UTF-8
    header = txt.decode("ascii", "ignore").split("\n")[0].strip()
    #     _check_bv_version(header, 'marker')

    # although the markers themselves are guaranteed to be ASCII (they
    # consist of numbers and a few reserved words), we should still
    # decode the file properly here because other (currently unused)
    # blocks, such as that the filename are specifying are not
    # guaranteed to be ASCII.

    try:
        # if there is an explicit codepage set, use it
        # we pretend like it's ascii when searching for the codepage
        cp_setting = re.search(
            "Codepage=(.+)", txt.decode("ascii", "ignore"), re.IGNORECASE & re.MULTILINE
        )
        codepage = "utf-8"
        if cp_setting:
            codepage = cp_setting.group(1).strip()
        # BrainAmp Recorder also uses ANSI codepage
        # an ANSI codepage raises a LookupError exception
        # python recognize ANSI decoding as cp1252
        if codepage == "ANSI":
            codepage = "cp1252"
        txt = txt.decode(codepage)
    except UnicodeDecodeError:
        # if UTF-8 (new standard) or explicit codepage setting fails,
        # fallback to Latin-1, which is Windows default and implicit
        # standard in older recordings
        txt = txt.decode("latin-1")

    # extract Marker Infos block
    m = re.search(r"\[Marker Infos\]", txt, re.IGNORECASE)
    if not m:
        return np.array(list()), np.array(list()), np.array(list()), ""

    mk_txt = txt[m.end() :]
    m = re.search(r"^\[.*\]$", mk_txt)
    if m:
        mk_txt = mk_txt[: m.start()]

    # extract event information
    items = re.findall(r"^Mk\d+=(.*)", mk_txt, re.MULTILINE)
    onset, duration, description, channel_num = list(), list(), list(), list()
    date_str = ""
    for info in items:
        info_data = info.split(",")
        mtype, mdesc, this_onset, this_duration, this_channel_num = info_data[:5]
        # commas in mtype and mdesc are handled as "\1". convert back to comma
        mtype = mtype.replace(r"\1", ",")
        mdesc = mdesc.replace(r"\1", ",")
        if date_str == "" and len(info_data) == 5 and mtype == "New Segment":
            # to handle the origin of time and handle the presence of multiple
            # New Segment annotations. We only keep the first one that is
            # different from an empty string for date_str.
            date_str = info_data[-1]

        this_duration = int(this_duration) if this_duration.isdigit() else 0
        duration.append(this_duration)
        onset.append(int(this_onset) - 1)  # BV is 1-indexed, not 0-indexed
        description.append(mtype + "/" + mdesc)
        #         print(this_channel_num)
        channel_num.append(int(this_channel_num))

    return (
        np.array(onset),
        np.array(duration),
        np.array(description),
        np.array(channel_num),
    )

In [None]:
def get_annotations(file):
    annotations_attributes = ["onset", "duration", "description", "channel_num"]

    onset, duration, description, channel_num = _read_vmrk("../data/" + file)
    annotations_list = list(zip(onset, duration, description, channel_num))

    annotations = []

    for item in annotations_list:
        #         annot = CustomAnnotations(*item)
        annot = dict(zip(annotations_attributes, list(item)))
        annotations.append(annot)

    return annotations

### Data Visualisation

In [None]:
epochs.plot(n_epochs=1, event_colors={0: "g", 1: "m"})
None  # prevents doubled output

In [None]:
correct_response_epochs = epochs["correct_response"]
error_response_epochs = epochs["error_response"]


# Calculate averages of events sets
correct_response_evoked = correct_response_epochs.average()
error_response_evoked = error_response_epochs.average()

In [None]:
# Averages of two event sets

mne.viz.plot_compare_evokeds(
    dict(
        correct_response=correct_response_evoked, error_response=error_response_evoked
    ),
    legend="upper left",
    show_sensors="upper right",
    ylim=dict(eeg=[-10, 10]),
    invert_y=True,
    combine="mean",
)

In [None]:
# Averages of error response events per channel

error_response_evoked.plot_joint(picks="eeg")
error_response_evoked.plot_topomap(times=[0.0, 0.08, 0.1, 0.12, 0.2], ch_type="eeg")
None

In [None]:
# Averages of merged event sets (diff between error and correct) per channel

evoked_diff = mne.combine_evoked(
    [correct_response_evoked, error_response_evoked], weights=[1, -1]
)
evoked_diff.plot_joint()
None

In [None]:
events_mean_dict = {}

for key in epochs.event_id.keys():
    mean_key = key + "_mean"
    events_mean_dict[mean_key] = epochs[key]._data.mean(axis=(0))

In [None]:
# Chart with averages of correct and error responses per channel

colors = ["b", "r", "g"]
color_iterator = 0

plt.figure(figsize=(10, 10))


for key in events_mean_dict:
    epoch = events_mean_dict[key]
    plt.plot(
        epoch.T + np.arange(start=1e-6, step=10e-6, stop=301e-7),
        label=key,
        color=colors[color_iterator],
    )
    color_iterator = color_iterator + 1

plt.yticks([])
plt.xticks(np.arange(0, 181, 181 / 8), np.arange(0, 800, 100))
plt.xlabel("milliseconds", fontsize=15)
plt.ylabel("channels", fontsize=15)
plt.legend(loc="upper left")
plt.show()

## Pre-processing

**Pre-processing done with Brain Vision Software:**

- Notch filter  0.05-25
- Baseline Correction //what baseline?
- Ocular Correction
- Artifact Rejection

## Feature extraction

Feature extraction recommended for eeg data is **Wavelet Transform** (especially **Discrite Wavelet Transform**). Better that FFT for biomedical signals because of its localization characteristics of non-stationary signals in time and frequency domains. DWT decompositing signal into five frequency bands.

https://en.wikipedia.org/wiki/Discrete_wavelet_transform

https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0173138&type=printable

In [None]:
# Prepare data for feature extraction: create X-data array and Y-labels array

X = epochs._data
Y = []

for i in range(len(epochs)):
    event_id = list(epochs[i].event_id.values())[0]
    Y.append(event_id)

### Continuous Wavelet Transform

In [None]:
# Mother wavelets list for continuous WT

MWT_list = pywt.wavelist(kind="continuous")

# Remove cmor, fbsp, shan wavelets because of need of special specification
# of this wavelet with bandpass and center - do not want to handle with it NOW.
MWT_list.remove("cmor")
MWT_list.remove("fbsp")
MWT_list.remove("shan")

In [None]:
# Construct list of scales corresponding to pseudo-frequencies [128Hz-1Hz] for each wavelet
# Could help understand: https://www.researchgate.net/publication/267491844_Continuous_Wavelet_Transform_EEG_Features_of_Alzheimer%27s_Disease

# TODO: refactor

signal_frequency = 256


# compute coeffs of wavelets listed in MWT_list for an given epoch (one channel)
def compute_coeffs(epoch):

    wavelets = {}

    for MWT in MWT_list:
        center_wavelet_frequency = pywt.scale2frequency(MWT, [1])[0]
        const = center_wavelet_frequency * signal_frequency

        # construct scales
        scales = np.arange(const / 128, const / 1, 0.1).tolist()

        # compute coeffs
        coef, freqs = pywt.cwt(
            data=epoch, scales=scales, wavelet=MWT, sampling_period=1 / signal_frequency
        )

        # Save coeffs from the MWT
        wavelets[MWT] = coef

    return wavelets


epoch = X[0][0]

coeffs_dict = compute_coeffs(epoch)