In [1]:
"""
Created on Mon May  2 13:13:08 2022

Data functions for handling the TUH Abnormal dataset

Each method has its own description in it's header section.'

The methods defined in this file are:
    
    

    

@author: Kitti
"""

# Import required packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne

from braindecode.datasets.tuh import TUHAbnormal, TUH
from braindecode.preprocessing import preprocess, Preprocessor, create_fixed_length_windows, scale as multiply

plt.style.use('seaborn')
mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted

In [2]:


def import_tuh_abnormal(path, n_jobs=1, target_name='pathological', preload=False, add_physician_reports=True):
    
    tuh_abnormal_raw = TUHAbnormal(
        path=path,
        recording_ids=None,
        target_name=target_name,
        preload=preload,
        add_physician_reports=add_physician_reports,
        n_jobs=1 if TUH.__name__ == '_TUHMock' else n_jobs,  # Mock dataset can't be loaded in parallel
    )
    
    short_ch_names, ch_mapping = create_ch_mapping()
    tuh_abnormal_selected = select_by_channels(tuh_abnormal_raw, short_ch_names, ch_mapping)
    
    return tuh_abnormal_selected

In [None]:
def filter_only_adults(tuh_abnormal):

    tuh_adults_split = tuh_abnormal.split("age")
    tuh_adults_l = []
    
    for key, value in tuh_adults_split.items():
        key = int(key)
        if key >= 18:
            tuh_adults_l.append(value)
            
    tuh_adults = BaseConcatDataset(tuh_adults_l)
    
    return tuh_adults

In [3]:

def create_ch_mapping():
    short_ch_names = sorted([
        'A1', 'A2',
        'FP1', 'FP2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2',
        'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'FZ', 'CZ', 'PZ'])
    ar_ch_names = sorted([
        'EEG A1-REF', 'EEG A2-REF',
        'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF',
        'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF',
        'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF',
        'EEG T6-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF'])
    le_ch_names = sorted([
        'EEG A1-LE', 'EEG A2-LE',
        'EEG FP1-LE', 'EEG FP2-LE', 'EEG F3-LE', 'EEG F4-LE', 'EEG C3-LE',
        'EEG C4-LE', 'EEG P3-LE', 'EEG P4-LE', 'EEG O1-LE', 'EEG O2-LE',
        'EEG F7-LE', 'EEG F8-LE', 'EEG T3-LE', 'EEG T4-LE', 'EEG T5-LE',
        'EEG T6-LE', 'EEG FZ-LE', 'EEG CZ-LE', 'EEG PZ-LE'])
    assert len(short_ch_names) == len(ar_ch_names) == len(le_ch_names)
    ar_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(
        ar_ch_names, short_ch_names)}
    le_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(
        le_ch_names, short_ch_names)}
    ch_mapping = {'ar': ar_ch_mapping, 'le': le_ch_mapping}    
    return short_ch_names, ch_mapping


def select_by_channels(ds, short_ch_names, ch_mapping):
    split_ids = []
    for i, d in enumerate(ds.datasets):
        ref = 'ar' if d.raw.ch_names[0].endswith('-REF') else 'le'
        # these are the channels we are looking for
        seta = set(ch_mapping[ref].keys())
        # these are the channels of the recoding
        setb = set(d.raw.ch_names)
        # if recording contains all channels we are looking for, include it
        if seta.issubset(setb):
            split_ids.append(i)
    return ds.split(split_ids)['0']


def custom_crop(raw, tmin=0.0, tmax=None, include_tmax=True):
    # crop recordings to tmin – tmax. can be incomplete if recording
    # has lower duration than tmax
    # by default mne fails if tmax is bigger than duration
    tmax = min((raw.n_times - 1) / raw.info['sfreq'], tmax)
    raw.crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax)
    
    
def custom_rename_channels(raw, mapping):
    # rename channels which are dependent on referencing:
    # le: EEG 01-LE, ar: EEG 01-REF
    # mne fails if the mapping contains channels as keys that are not present
    # in the raw
    reference = raw.ch_names[0].split('-')[-1].lower()
    assert reference in ['le', 'ref'], 'unexpected referencing'
    reference = 'le' if reference == 'le' else 'ar'
    raw.rename_channels(mapping[reference])
    
def create_preproc_pipeline(tmin, tmax, sfreq, clipping, ch_mapping, short_ch_names, include_tmax=True):
    
    preprocessors = [
        Preprocessor(custom_crop, tmin=tmin, tmax=tmax, include_tmax=include_tmax,
                     apply_on_array=False),
        Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'), # mne Raw class function
        Preprocessor(custom_rename_channels, mapping=ch_mapping, # rename channels to short channel names
                     apply_on_array=False), #
        Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True), # mne Raw class function
        Preprocessor(multiply, factor=1e6, apply_on_array=True), # scaling signals to microvolt
        Preprocessor(np.clip, a_min=-clipping, a_max=clipping, apply_on_array=True), # clip outlier values to +/- 800 micro volts
        Preprocessor('resample', sfreq=sfreq), # mne Raw class function
    ]
    
    return preprocessors

In [4]:
def preprocess_dataset(tuh_abnormal, tmin, tmax, sfreq, clipping, n_jobs=1):
    
    short_ch_names, ch_mapping = create_ch_mapping()
    preprocessors = create_preproc_pipeline(tmin, tmax, sfreq, clipping, ch_mapping, short_ch_names, include_tmax=True)

    tuh_preproc = preprocess(
        concat_ds=tuh_abnormal,
        preprocessors=preprocessors,
        n_jobs=n_jobs,
        save_dir=None
    )
    
    return tuh_preproc

In [5]:

def import_and_preprocess(path, tmin, tmax, sfreq, clipping, include_tmax=True, n_jobs=1):
    
    tuh_abnormal_raw = import_tuh_abnormal(path, n_jobs)
    short_ch_names, ch_mapping = create_ch_mapping()
    tuh_abnormal_selected = select_by_channels(tuh_abnormal_raw, short_ch_names, ch_mapping)
    preprocessors = create_preproc_pipeline(tmin, tmax, sfreq, clipping, ch_mapping, short_ch_names, include_tmax)
    tuh_abnormal_preproc = preprocess_dataset(tuh_abnormal_selected, preprocessors, n_jobs)
    
    return tuh_abnormal_preproc

