In [1]:
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")

In [2]:
import os
import sys

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 8)
import mne
mne.set_log_level('ERROR')

from braindecode.datasets import MOABBDataset
from braindecode.datasets.tuh import TUHAbnormal
from braindecode.datautil.preprocess import filterbank, preprocess, Preprocessor
from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

sys.path.insert(0, '/home/gemeinl/code/braindecode-features/')
from braindecode_features import extract_ds_features, save_features, filter_df

In [3]:
ds_name = 'Schirrmeister2017'
out_path = None
n_jobs = 2
agg_func = None
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    subject_id = 1  
else:
    assert ds_name in ['TUHAbnormal']
    recording_ids = list(range(20))

#------------------------------------------------------------------------------
# stuffs to optimize
sfreq = 250
# original trials have 4s duration
frequency_bands = [(4, 8), (8, 13), (13, 30), (30, 50)]
trial_start_offset_samples = int(0.5*sfreq)
window_size_samples = 500
window_stride_samples = 500

# clf + hyperparams
#------------------------------------------------------------------------------
factor = 1e6
max_abs_val = 800
trial_stop_offset_samples = 0
sensors = {
    'Schirrmeister2017': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CCP1h', 'CCP2h', 'CCP3h', 'CCP4h',
        'CCP5h', 'CCP6h', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPP1h',
        'CPP2h', 'CPP3h', 'CPP4h', 'CPP5h', 'CPP6h', 'CPz', 'Cz', 'FC1', 'FC2',
        'FC3', 'FC4', 'FC5', 'FC6', 'FCC1h', 'FCC2h', 'FCC3h', 'FCC4h',
        'FCC5h', 'FCC6h', 'FCz', 'FFC1h', 'FFC2h', 'FFC3h', 'FFC4h', 'FFC5h',
        'FFC6h'
    ),
    'BNCI2014001': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP3', 'CP4', 'CPz',
        'Cz', 'FC1', 'FC2', 'FC3', 'FC4', 'FCz', 'Fz', 'P1', 'P2', 'POz', 'Pz'
    ),
    'TUHAbnormal': (
        'EEG A1-REF', 'EEG A2-REF', 'EEG C3-REF', 'EEG C4-REF',
        'EEG CZ-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG F7-REF',
        'EEG F8-REF', 'EEG FP1-REF', 'EEG FP2-REF', 'EEG FZ-REF',
        'EEG O1-REF', 'EEG O2-REF', 'EEG P3-REF', 'EEG P4-REF',
        'EEG PZ-REF', 'EEG T1-REF', 'EEG T2-REF', 'EEG T3-REF',
        'EEG T4-REF'
    ),
}
sensors = sensors[ds_name]
train_eval_split = {
    'Schirrmeister2017': 'run', 
    'BNCI2014001': 'session',
    'TUHAbnormal': 'train',
}
eval_name = {
    'Schirrmeister2017': 'test',
    'BNCI2014001': 'session_E',
    'TUHAbnormal': 'False',
}

In [4]:
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    ds = MOABBDataset(
        dataset_name=ds_name,
        subject_ids=[subject_id],
    )
else:
    assert ds_name in ['TUHAbnormal']
    ds = TUHAbnormal(
        path='/data/schirrmr/gemeinl/tuh-abnormal-eeg/raw/v2.0.0/edf/',
        recording_ids=recording_ids,
        target_name='pathological',
        preload=False,
        add_physician_reports=True,
    )

In [5]:
split_indicator = train_eval_split[ds_name]
eval_ds_name = eval_name[ds_name]

In [6]:
# split into train and test session
splits = ds.split(split_indicator)
for split_name, split_ds in splits.items():
    if split_name == eval_ds_name:
        continue
    break
print(split_name)

train


In [7]:
if ds_name in ['TUHAbnormal']:
    split_name = 'eval' if split_name == 'False' else 'train'

In [8]:
# some preprocessing steps
preprocessors = [
    Preprocessor(
        apply_on_array=False,
        fn='pick_channels', 
        ch_names=sensors, 
        ordered=True,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=lambda x: x * factor,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=lambda x: np.clip(x, -max_abs_val, max_abs_val),
    ),
    Preprocessor(
        apply_on_array=False,
        fn='resample', 
        sfreq=sfreq,
    ),
]
# Additional preprocessing for TUHAbnormal
if ds_name in ['TUHAbnormal']:
    preprocessors.extend([
        # EEG O1-REF -> O1
        Preprocessor(
            apply_on_array=False,
            fn='rename_channels',
            mapping=lambda ch: ch[ch.find(' ')+1:ch.find('-')],
        ),
        # discard first 60s
        Preprocessor(
            apply_on_array=False,
            fn='crop',
            tmin=60,
        ),
    ])
# apply some preprocessing
preprocess(
    concat_ds=split_ds,
    preprocessors=preprocessors,
)

In [9]:
# define windowing parameters
windowing_params = {
    'drop_last_window': False,
    'window_size_samples': window_size_samples,
    'window_stride_samples': window_stride_samples,
}
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    windowing_params['trial_start_offset_samples'] = trial_start_offset_samples
    windowing_params['trial_stop_offset_samples'] = trial_stop_offset_samples
else:
    assert ds_name in ['TUHAbnormal']
    # map boolean pathological targets to integer 
    windowing_params['mapping'] = {False: 0, True: 1}
    windowing_params['start_offset_samples'] = trial_start_offset_samples
    windowing_params['stop_offset_samples'] = None if not trial_stop_offset_samples else trial_stop_offset_samples 

In [10]:
df = extract_ds_features(
    ds=split_ds,
    frequency_bands=frequency_bands,
    n_jobs=n_jobs,
    params={'Time__higuchi_fractal_dimension__kmax': 3},
    windowing_params=windowing_params,
)

INFO:braindecode_features.feature_extraction:Computing features of domain: Time.
INFO:numexpr.utils:Note: NumExpr detected 20 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
INFO:braindecode_features.feature_extraction:Computing features of domain: Fourier.
INFO:braindecode_features.feature_extraction:Computing features of domain: Wavelet.
INFO:braindecode_features.feature_extraction:Computing features of domain: Hilbert.
INFO:braindecode_features.feature_extraction:Computing features of domain: Cross-frequency.


In [11]:
df

Domain,Description,Description,Description,Cross-frequency,...,Time,Time,Time,Time
Feature,Trial,Window,Target,cross_frequency_coupling,...,zero_crossings_derivative,zero_crossings_derivative,zero_crossings_derivative,zero_crossings_derivative
Channel,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,C1,...,FFC3h,FFC4h,FFC5h,FFC6h
Frequency,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,"4-8, 8-13",...,30-50,30-50,30-50,30-50
0,0,0,2,1.085079,...,161.0,154.0,176.0,158.0
1,0,1,2,1.145345,...,170.0,154.0,172.0,149.0
2,1,0,0,1.681981,...,170.0,158.0,169.0,156.0
3,1,1,0,1.098230,...,164.0,154.0,176.0,158.0
4,2,0,1,0.576200,...,161.0,171.0,187.0,153.0
...,...,...,...,...,...,...,...,...,...
635,317,1,0,0.684891,...,166.0,165.0,173.0,163.0
636,318,0,3,1.785293,...,153.0,161.0,175.0,149.0
637,318,1,3,1.085817,...,159.0,161.0,177.0,150.0
638,319,0,1,1.200460,...,161.0,153.0,171.0,154.0


In [12]:
df.dtypes

Domain           Feature                    Channel  Frequency
Description      Trial                                              int64
                 Window                                             int64
                 Target                                             int64
Cross-frequency  cross_frequency_coupling   C1       4-8, 8-13    float32
                                            C2       4-8, 8-13    float32
                                                                   ...   
Time             zero_crossings_derivative  FFC2h    30-50        float32
                                            FFC3h    30-50        float32
                                            FFC4h    30-50        float32
                                            FFC5h    30-50        float32
                                            FFC6h    30-50        float32
Length: 14313, dtype: object

In [13]:
# save each trial feature matrix (windows x features) to an individual file
if out_path is not None:
    if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
        out_path = os.path.join(out_path, str(subject_id), split_name)
    else:
        assert ds_name in ['TUHAbnormal']
        out_path = os.path.join(out_path, split_name)
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    save_features(
        df=df, 
        out_path=out_path, 
    )

In [14]:
# inspect subsets of the feature data frame
filter_df(
    df=df, 
    query='Wavelet', 
    exact_match=False, 
    level_to_consider=None,
)

Domain,Description,Description,Description,Wavelet,Wavelet,Wavelet,Wavelet,Wavelet,Wavelet
Feature,Trial,Window,Target,bounded_variation,...,variance,variance,variance,variance
Channel,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,C1,...,FFC3h,FFC4h,FFC5h,FFC6h
Frequency,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,4-8,...,30-50,30-50,30-50,30-50
0,0,0,2,13.595449,...,11.432487,10.353985,55.365276,14.341510
1,0,1,2,13.643919,...,9.656889,7.770929,73.783691,11.824708
2,1,0,0,14.216390,...,8.197099,8.491804,78.199631,15.087831
3,1,1,0,12.713604,...,6.412235,7.976724,52.003956,12.782948
4,2,0,1,15.126334,...,9.740189,7.249357,51.466599,13.490342
...,...,...,...,...,...,...,...,...,...
635,317,1,0,16.639980,...,8.508876,6.273733,49.232372,9.705858
636,318,0,3,14.300038,...,13.547613,10.303136,53.098301,15.806651
637,318,1,3,12.928943,...,10.618900,6.853260,46.035866,10.188142
638,319,0,1,14.945504,...,9.731249,10.312473,49.923561,14.931082
