In [None]:
import sys
sys.path.insert(0, '/home/lukas/Code/braindecode/')
sys.path.insert(0, '/home/lukas/Code/braindecode-features/')

In [None]:
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")

In [None]:
import os

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 8)
import mne
mne.set_log_level('ERROR')

from braindecode.datasets import MOABBDataset
from braindecode.datasets.tuh import TUHAbnormal
#from braindecode.datasets.tuh import _TUHAbnormalMock as TUHAbnormal  # delete to use actual data
from braindecode.preprocessing import (
    filterbank, preprocess, Preprocessor, create_windows_from_events, 
    create_fixed_length_windows, scale as multiply)

from braindecode_features import extract_ds_features, save_features, filter_df

In [None]:
ds_name = 'BNCI2014001'
out_path = './tmp/'
n_jobs = 2
agg_func = None
if ds_name in ['BNCI2014001']:
    subject_id = 1  
    frequency_bands = [(4, 13), (13, 38)]
elif ds_name in ['Schirrmeister2017']:
    subject_id = 1
    # TODO: add frequency bands
else:
    assert ds_name in ['TUHAbnormal']
    frequency_bands = [(4, 8), (8, 13), (13, 30), (30, 50)]

#------------------------------------------------------------------------------
# stuffs to optimize
sfreq = 250
# original trials have 4s duration
trial_start_offset_samples = int(0.5*sfreq)
window_size_samples = 100
window_stride_samples = 100

# clf + hyperparams
#------------------------------------------------------------------------------
factor = 1e6
max_abs_val = 800
trial_stop_offset_samples = 0
sensors = {
    'Schirrmeister2017': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CCP1h', 'CCP2h', 'CCP3h', 'CCP4h',
        'CCP5h', 'CCP6h', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPP1h',
        'CPP2h', 'CPP3h', 'CPP4h', 'CPP5h', 'CPP6h', 'CPz', 'Cz', 'FC1', 'FC2',
        'FC3', 'FC4', 'FC5', 'FC6', 'FCC1h', 'FCC2h', 'FCC3h', 'FCC4h',
        'FCC5h', 'FCC6h', 'FCz', 'FFC1h', 'FFC2h', 'FFC3h', 'FFC4h', 'FFC5h',
        'FFC6h',
    ),
    'BNCI2014001': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP3', 'CP4', 'CPz',
        'Cz', 'FC1', 'FC2', 'FC3', 'FC4', 'FCz', 'Fz', 'P1', 'P2', 'POz', 'Pz',
    ),
    'TUHAbnormal': (
        'EEG A1-REF', 'EEG A2-REF', 'EEG C3-REF', 'EEG C4-REF',
        'EEG CZ-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG F7-REF',
        'EEG F8-REF', 'EEG FP1-REF', 'EEG FP2-REF', 'EEG FZ-REF',
        'EEG O1-REF', 'EEG O2-REF', 'EEG P3-REF', 'EEG P4-REF',
        'EEG PZ-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF',
        'EEG T6-REF',
    ),
}
sensors = sensors[ds_name]
train_eval_split = {
    'Schirrmeister2017': 'run', 
    'BNCI2014001': 'session',
    'TUHAbnormal': 'train',
}
eval_name = {
    'Schirrmeister2017': 'test',
    'BNCI2014001': 'session_E',
    'TUHAbnormal': 'False',
}

In [None]:
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    ds = MOABBDataset(
        dataset_name=ds_name,
        subject_ids=[subject_id],
    )
else:
    assert ds_name in ['TUHAbnormal']
    ds = TUHAbnormal(
        path='/data/schirrmr/gemeinl/tuh-abnormal-eeg/raw/v2.0.0/edf/',
        recording_ids=None,
        target_name='pathological',
        preload=False,
        add_physician_reports=True,
    )

In [None]:
ds.description

In [None]:
split_indicator = train_eval_split[ds_name]
eval_ds_name = eval_name[ds_name]

In [None]:
# split into train and test session
splits = ds.split(split_indicator)
for split_name, split_ds in splits.items():
    if split_name == eval_ds_name:
        continue
    break
print(split_name)

In [None]:
if ds_name in ['TUHAbnormal']:
    split_name = 'eval' if split_name == 'False' else 'train'
elif ds_name in ['BNCI2014001']:
    split_name = 'eval' if split_name == 'session_E' else 'train'
elif ds_name in ['Schirrmeister2017']:
    # TODO: add
    pass

In [None]:
# some preprocessing steps
preprocessors = [
    Preprocessor(
        apply_on_array=False,
        fn='pick_channels', 
        ch_names=sensors, 
        ordered=True,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=multiply,
        factor=factor,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=np.clip, 
        a_min=-max_abs_val, 
        a_max=max_abs_val,
    ),
    Preprocessor(
        apply_on_array=False,
        fn='resample', 
        sfreq=sfreq,
    ),
]
# Additional preprocessing for TUHAbnormal
if ds_name in ['TUHAbnormal']:
    preprocessors.extend([
        # EEG O1-REF -> O1
        Preprocessor(
            apply_on_array=False,
            fn='rename_channels',
            mapping=lambda ch: ch[ch.find(' ')+1:ch.find('-')],
        ),
        # discard first 60s
        Preprocessor(
            apply_on_array=False,
            fn='crop',
            tmin=60,
        ),
    ])
# apply some preprocessing
preprocess(
    concat_ds=split_ds,
    preprocessors=preprocessors,
    #n_jobs=n_jobs,  # wait for braindecode PR277
)

In [None]:
# define windowing parameters
windowing_params = {
    'drop_last_window': False,
    'window_size_samples': window_size_samples,
    'window_stride_samples': window_stride_samples,
}
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    windowing_params['trial_start_offset_samples'] = trial_start_offset_samples
    windowing_params['trial_stop_offset_samples'] = trial_stop_offset_samples
else:
    assert ds_name in ['TUHAbnormal']
    # map boolean pathological targets to integer 
    windowing_params['mapping'] = {False: 0, True: 1}

In [None]:
df = extract_ds_features(
    ds=split_ds,
    frequency_bands=frequency_bands,
    n_jobs=n_jobs,
    params={'Time__higuchi_fractal_dimension__kmax': 3},
    windowing_params=windowing_params,
)

In [None]:
df

In [None]:
df.dtypes

In [None]:
# save each trial feature matrix (windows x features) to an individual file
if out_path is not None:
    if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
        out_path = os.path.join(out_path, str(subject_id), split_name)
    else:
        assert ds_name in ['TUHAbnormal']
        out_path = os.path.join(out_path, split_name)
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    save_features(
        df=df, 
        out_path=out_path, 
    )