In [None]:
import time
import os
import mne
mne.set_log_level('ERROR')

from warnings import filterwarnings
filterwarnings('ignore')


from IPython.utils import io

import numpy as np
import pandas as pd


from braindecode.datautil.windowers import create_fixed_length_windows
from braindecode.datautil.serialization import  load_concat_dataset

from braindecode.datasets import BaseConcatDataset



## Data Loading

In [None]:
%%time 
from braindecode.datasets.tuh import TUHAbnormal
data_path = '/data/datasets/TUH/EEG/tuh_eeg_abnormal/v2.0.0/edf/'
dataset = TUHAbnormal(
    path=data_path,
    recording_ids=None,  # loads the n chronologically first recordings
    target_name=target_name,  # age, gender, pathology
    preload=False,
    add_physician_reports=False,
)

## Data Preprocessing and saving


In [None]:
%%time
from braindecode.preprocessing import preprocess, Preprocessor, scale as multiply
import numpy as np
from copy import deepcopy


whole_train_set = dataset.split('train')['True']
whole_eval_set = dataset.split('train')['False']

short_ch_names = sorted([
                'A1', 'A2', 'C3', 'C4', 'Cz', 'F3', 'F4', 'F7', 'F8',
                'Fp1', 'Fp2', 'Fz', 'O1', 'O2', 'P3', 'P4', 'Pz', 'T3',
                 'T4', 'T5', 'T6'
            ])
ar_ch_names = sorted([
    'EEG A1-REF', 'EEG A2-REF',
    'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF',
    'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF',
    'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF',
    'EEG T6-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF'])
le_ch_names = sorted([
    'EEG A1-LE', 'EEG A2-LE',
    'EEG FP1-LE', 'EEG FP2-LE', 'EEG F3-LE', 'EEG F4-LE', 'EEG C3-LE',
    'EEG C4-LE', 'EEG P3-LE', 'EEG P4-LE', 'EEG O1-LE', 'EEG O2-LE',
    'EEG F7-LE', 'EEG F8-LE', 'EEG T3-LE', 'EEG T4-LE', 'EEG T5-LE',
    'EEG T6-LE', 'EEG FZ-LE', 'EEG CZ-LE', 'EEG PZ-LE'])
assert len(short_ch_names) == len(ar_ch_names) == len(le_ch_names)
ar_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(
    ar_ch_names, short_ch_names)}
le_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(
    le_ch_names, short_ch_names)}
ch_mapping = {'ar': ar_ch_mapping, 'le': le_ch_mapping}



def custom_rename_channels(raw, mapping):
    # rename channels which are dependent on referencing:
    # le: EEG 01-LE, ar: EEG 01-REF
    # mne fails if the mapping contains channels as keys that are not present
    # in the raw
    reference = raw.ch_names[0].split('-')[-1].lower()
    assert reference in ['le', 'ref'], 'unexpected referencing'
    reference = 'le' if reference == 'le' else 'ar'
    raw.rename_channels(mapping[reference])


def custom_crop(raw, tmin=0.0, tmax=None, include_tmax=True):
    # crop recordings to tmin â€“ tmax. can be incomplete if recording
    # has lower duration than tmax
    # by default mne fails if tmax is bigger than duration
    tmax = min((raw.n_times - 1) / raw.info['sfreq'], tmax)
    raw.crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax)


n_max_minutes=21
tmin = 1 * 60
tmax = n_max_minutes * 60
sfreq = 100

preprocessors = [
    Preprocessor(custom_crop, tmin=tmin, tmax=tmax, include_tmax=False,
                 apply_on_array=False),

    Preprocessor(custom_rename_channels, mapping=ch_mapping,
                 apply_on_array=False),
    Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True),
 
    Preprocessor(multiply, factor=1e6, apply_on_array=True),
    Preprocessor(np.clip, a_min=-800, a_max=800, apply_on_array=True),
    
    Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),

    Preprocessor('resample', sfreq=sfreq),
    Preprocessor('set_meas_date', meas_date=None)
    
]
# Preprocess the data
preprocess(whole_train_set, preprocessors)


# OR Preprocess and save dataset
preprocess(
            concat_ds=whole_train_set,
            preprocessors=preprocessors,
            n_jobs=4, 
            save_dir='/home/data/preprocessed_TUAB/final_train/', 
        )


preprocess(
            concat_ds=whole_eval_set,
            preprocessors=preprocessors,
            n_jobs=4, 
            save_dir='/home/data/preprocessed_TUAB/final_eval/', 
        )
