In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

DEBUG:root:test


In [3]:
import os
import sys

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 8)
import mne
mne.set_log_level('ERROR')

from braindecode.datasets import MOABBDataset
from braindecode.datasets.tuh import TUHAbnormal
from braindecode.datautil.preprocess import filterbank, preprocess, Preprocessor
from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

sys.path.insert(0, '/home/gemeinl/code/braindecode-features/')
from braindecode_features import extract_windows_ds_features, save_features, filter_df

In [4]:
ds_name = 'TUHAbnormal'
out_path = None  # './tmp/'#'/home/lukas/Code/HGD/'
n_jobs = 2
agg_func = None  #'mean'
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    subject_id = 1  
else:
    assert ds_name in ['TUHAbnormal']
    recording_ids = list(range(20))

#------------------------------------------------------------------------------
# stuffs to optimize
sfreq = 100  # 250
# original trials have 4s duration
frequency_bands = [(4, 8), (8, 13)]
trial_start_offset_samples = 0  # int(0.5*sfreq)
window_size_samples = None  #500
window_stride_samples = None  #500

# clf + hyperparams
#------------------------------------------------------------------------------
factor = 1e6
max_abs_val = 800
trial_stop_offset_samples = 0
sensors = {
    'Schirrmeister2017': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CCP1h', 'CCP2h', 'CCP3h', 'CCP4h',
        'CCP5h', 'CCP6h', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPP1h',
        'CPP2h', 'CPP3h', 'CPP4h', 'CPP5h', 'CPP6h', 'CPz', 'Cz', 'FC1', 'FC2',
        'FC3', 'FC4', 'FC5', 'FC6', 'FCC1h', 'FCC2h', 'FCC3h', 'FCC4h',
        'FCC5h', 'FCC6h', 'FCz', 'FFC1h', 'FFC2h', 'FFC3h', 'FFC4h', 'FFC5h',
        'FFC6h'
    ),
    'BNCI2014001': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP3', 'CP4', 'CPz',
        'Cz', 'FC1', 'FC2', 'FC3', 'FC4', 'FCz', 'Fz', 'P1', 'P2', 'POz', 'Pz'
    ),
    'TUHAbnormal': (
        'EEG A1-REF', 'EEG A2-REF', 'EEG C3-REF', 'EEG C4-REF',
        'EEG CZ-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG F7-REF',
        'EEG F8-REF', 'EEG FP1-REF', 'EEG FP2-REF', 'EEG FZ-REF',
        'EEG O1-REF', 'EEG O2-REF', 'EEG P3-REF', 'EEG P4-REF',
        'EEG PZ-REF', 'EEG T1-REF', 'EEG T2-REF', 'EEG T3-REF',
        'EEG T4-REF'
    ),
}
sensors = sensors[ds_name]
train_eval_split = {
    'Schirrmeister2017': 'run', 
    'BNCI2014001': 'session',
    'TUHAbnormal': 'train',
}
eval_name = {
    'Schirrmeister2017': 'test',
    'BNCI2014001': 'session_E',
    'TUHAbnormal': 'False',
}

In [5]:
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    ds = MOABBDataset(
        dataset_name=ds_name,
        subject_ids=[subject_id],
    )
else:
    assert ds_name in ['TUHAbnormal']
    ds = TUHAbnormal(
        path='/data/schirrmr/gemeinl/tuh-abnormal-eeg/raw/v2.0.0/edf/',
        recording_ids=recording_ids,
        target_name='pathological',
        preload=False,
        add_physician_reports=True,
    )

In [6]:
split_indicator = train_eval_split[ds_name]
eval_ds_name = eval_name[ds_name]

In [7]:
# split into train and test session
splits = ds.split(split_indicator)
for split_name, split_ds in splits.items():
    if split_name == eval_ds_name:
        continue
    break
print(split_name)

True


In [8]:
if ds_name in ['TUHAbnormal']:
    split_name = 'eval' if split_name == 'False' else 'train'

In [9]:
# some preprocessing steps
preprocessors = [
    Preprocessor(
        apply_on_array=False,
        fn='pick_channels', 
        ch_names=sensors, 
        ordered=True,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=lambda x: x * factor,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=lambda x: np.clip(x, -max_abs_val, max_abs_val),
    ),
    Preprocessor(
        apply_on_array=False,
        fn='resample', 
        sfreq=sfreq,
    ),
]
# Additional preprocessing for TUHAbnormal
if ds_name in ['TUHAbnormal']:
    preprocessors.extend([
        # EEG O1-REF -> O1
        Preprocessor(
            apply_on_array=False,
            fn='rename_channels',
            mapping=lambda ch: ch[ch.find(' ')+1:ch.find('-')],
        ),
        # discard first 60s
        Preprocessor(
            apply_on_array=False,
            fn='crop',
            tmin=60,
        ),
    ])
##preprocessors.append(
 #   # used by connectivity, time, and cross-frequency domain features
 #   # not used by dft, cwt domain features
 #   Preprocessor(
 #       apply_on_array=False,
 #       fn=filterbank, 
 #       frequency_bands=sorted(frequency_bands, key=lambda b: b[0]), 
 #       drop_original_signals=False, 
 #   )
#)

In [10]:
# apply some preprocessing
preprocess(
    concat_ds=split_ds,
    preprocessors=preprocessors,
)

In [11]:
# define windowing parameters
windowing_params = {
    'drop_last_window': False,
    'window_size_samples': window_size_samples,
    'window_stride_samples': window_stride_samples,
}
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    windowing_params['trial_start_offset_samples'] = trial_start_offset_samples
    windowing_params['trial_stop_offset_samples'] = trial_stop_offset_samples
else:
    assert ds_name in ['TUHAbnormal']
    # map boolean pathological targets to integer 
    windowing_params['mapping'] = {False: 0, True: 1}
    windowing_params['start_offset_samples'] = trial_start_offset_samples
    windowing_params['stop_offset_samples'] = None
    windowing_params['window_size_samples'] = 1000 
    windowing_params['window_stride_samples'] = 1000        

In [12]:
df = extract_windows_ds_features(
    windows_ds=split_ds,
    frequency_bands=frequency_bands,
    n_jobs=1,
    params={'Time__higuchi_fractal_dimension__kmax': 3},
    domains=None,  # ['Fourier', 'Wavelet', 'Time', 'Hilbert', 'Cross-frequency'],
    windowing_params=windowing_params,
)

DEBUG:braindecode_features.feature_extraction:got 5 datasets
INFO:braindecode_features.feature_extraction:Computing features of domain: Hilbert.
DEBUG:braindecode_features.utils:Filtering ...
INFO:braindecode_features.domains.hilbert:Extracting ...
DEBUG:braindecode_features.domains.hilbert:Transforming ...
DEBUG:braindecode_features.utils:Windowing ...
DEBUG:braindecode_features.utils:got 113 windows
DEBUG:braindecode_features.domains.hilbert:hilbert in (4, 8) before union (113, 21, 1000)
DEBUG:braindecode_features.domains.hilbert:Transforming ...
DEBUG:braindecode_features.utils:Windowing ...
DEBUG:braindecode_features.utils:got 113 windows
DEBUG:braindecode_features.domains.hilbert:hilbert in (8, 13) before union (113, 21, 1000)
DEBUG:braindecode_features.domains.hilbert:feature shape (113, 420)
DEBUG:braindecode_features.domains.hilbert:Transforming ...
DEBUG:braindecode_features.utils:Windowing ...
DEBUG:braindecode_features.utils:got 295 windows
DEBUG:braindecode_features.domains

In [15]:
df

Domain,Description,Description,Description,Time,...,Hilbert,Hilbert,Hilbert,Hilbert
Feature,Trial,Window,Target,covariance,...,phase_locking_value,phase_locking_value,phase_locking_value,phase_locking_value
Channel,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,A1-A2,...,T1-T4,T2-T3,T2-T4,T3-T4
Frequency,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,4-8,...,8-13,8-13,8-13,8-13
0,0,0,0,30.955677,...,0.622721,0.498675,0.857576,0.474217
1,0,1,0,62.680031,...,0.789897,0.703161,0.905746,0.690007
2,0,2,0,47.013363,...,0.699455,0.545796,0.830926,0.601407
3,0,3,0,11.812249,...,0.524165,0.472973,0.677856,0.485061
4,0,4,0,24.312698,...,0.685401,0.611790,0.786605,0.578159
...,...,...,...,...,...,...,...,...,...
773,4,131,1,7.390909,...,0.601557,0.555658,0.889447,0.529076
774,4,132,1,4.752991,...,0.671412,0.559649,0.869743,0.560742
775,4,133,1,8.111958,...,0.522633,0.456299,0.897723,0.389256
776,4,134,1,4.507317,...,0.412542,0.329900,0.778118,0.336767


In [16]:
df.dtypes

Domain       Feature              Channel  Frequency
Description  Trial                                        int64
             Window                                       int64
             Target                                       int64
Time         covariance           A1-A2    4-8          float32
                                  A1-C3    4-8          float32
                                                         ...   
Hilbert      phase_locking_value  T1-T3    8-13         float32
                                  T1-T4    8-13         float32
                                  T2-T3    8-13         float32
                                  T2-T4    8-13         float32
                                  T3-T4    8-13         float32
Length: 2292, dtype: object

In [101]:
from braindecode_features.feature_extraction import (
    get_feature_functions_and_extraction_routines, 
    _build_transformer_list,
    _params_to_domain_params,
)
from braindecode_features.utils import _get_unfiltered_chs
from braindecode_features.feature_extraction import _finalize_df
from sklearn.pipeline import FeatureUnion

In [104]:
def create_base_ds_features(
    base_ds, frequency_bands, window_size_samples, 
    window_stride_samples, n_jobs=1, params=None, domains=None,
):
    feature_functions, extraction_routines = get_feature_functions_and_extraction_routines()
    if domains is not None:
        feature_functions = {domain: feature_functions[domain] for domain in domains}
        extraction_routines = {domain: extraction_routines[domain] for domain in domains}
    if params is not None:
        params = _params_to_domain_params(params=params)
    domain_dfs = {}
    # extract features by domain, since each domain has it's very own routine
    for domain in extraction_routines.keys():
        # Do not extract cross-frequency features if there is only one band
        if len(frequency_bands) == 1 and domain == 'Cross-frequency':
            continue
        print(f'Computing features of domain: {domain}.')
        transformer_list = _build_transformer_list(feature_functions[domain])
        fu = FeatureUnion(
            transformer_list=transformer_list,
            n_jobs=n_jobs,
        )
        # set params
        if params is not None and domain in params:
            fu.set_params(**params[domain])
        # extract features of one domain at a time
        domain_dfs[domain] = extraction_routines[domain](
            windows_ds=windows_ds,
            frequency_bands=frequency_bands,
            fu=fu,
        )
    # concatenate domain dfs and make final df pretty
    df = _finalize_df(
        dfs=list(domain_dfs.values()),
    )
    return df

In [105]:
feature_df = create_base_ds_features(    
    base_ds=split_ds,
    frequency_bands=frequency_bands,
    window_size_samples=window_size_samples,
    window_stride_samples=window_stride_samples,
    n_jobs=n_jobs,
    params={'Time__higuchi_fractal_dimension__kmax': 3},
    domains=['Time'],
)

Computing features of domain: Time.
['4-8', '8-13']


In [106]:
d = [['ji'], []]

In [112]:
not all(d)

True

In [19]:
out_path = None
# save each trial feature matrix (windows x features) to an individual file
if out_path is not None:
    if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
        out_path = os.path.join(out_path, str(subject_id), split_name)
    else:
        assert ds_name in ['TUHAbnormal']
        out_path = os.path.join(out_path, split_name)
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    save_features(
        df=df, 
        out_path=out_path, 
    )

In [21]:
# inspect subsets of the feature data frame
filter_df(
    df=df, 
    query='Description', 
    exact_match=False, 
    level_to_consider=None,
)

Domain,Description,Description,Description
Feature,Trial,Window,Target
Channel,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
Frequency,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3
0,0,0,0
1,0,1,0
2,0,2,0
3,0,3,0
4,0,4,0
...,...,...,...
1550,4,267,1
1551,4,268,1
1552,4,269,1
1553,4,270,1
