In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")

In [3]:
import os
import sys

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 8)
import mne
mne.set_log_level('ERROR')

from braindecode.datasets import MOABBDataset
from braindecode.datasets.tuh import TUHAbnormal
from braindecode.datautil.preprocess import filterbank, preprocess, Preprocessor
from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

sys.path.insert(0, '/home/gemeinl/code/braindecode-features/')
from braindecode_features import extract_ds_features, save_features, filter_df

In [4]:
ds_name = 'TUHAbnormal'
out_path = None  # './tmp/'#'/home/lukas/Code/HGD/'
n_jobs = 2
agg_func = None  #'mean'
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    subject_id = 1  
else:
    assert ds_name in ['TUHAbnormal']
    recording_ids = list(range(20))

#------------------------------------------------------------------------------
# stuffs to optimize
sfreq = 100  # 250
# original trials have 4s duration
frequency_bands = [(4, 8), (8, 13), (13, 30)]
trial_start_offset_samples = 0  # int(0.5*sfreq)
window_size_samples = None  #500
window_stride_samples = None  #500

# clf + hyperparams
#------------------------------------------------------------------------------
factor = 1e6
max_abs_val = 800
trial_stop_offset_samples = 0
sensors = {
    'Schirrmeister2017': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CCP1h', 'CCP2h', 'CCP3h', 'CCP4h',
        'CCP5h', 'CCP6h', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPP1h',
        'CPP2h', 'CPP3h', 'CPP4h', 'CPP5h', 'CPP6h', 'CPz', 'Cz', 'FC1', 'FC2',
        'FC3', 'FC4', 'FC5', 'FC6', 'FCC1h', 'FCC2h', 'FCC3h', 'FCC4h',
        'FCC5h', 'FCC6h', 'FCz', 'FFC1h', 'FFC2h', 'FFC3h', 'FFC4h', 'FFC5h',
        'FFC6h'
    ),
    'BNCI2014001': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP3', 'CP4', 'CPz',
        'Cz', 'FC1', 'FC2', 'FC3', 'FC4', 'FCz', 'Fz', 'P1', 'P2', 'POz', 'Pz'
    ),
    'TUHAbnormal': (
        'EEG A1-REF', 'EEG A2-REF', 'EEG C3-REF', 'EEG C4-REF',
        'EEG CZ-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG F7-REF',
        'EEG F8-REF', 'EEG FP1-REF', 'EEG FP2-REF', 'EEG FZ-REF',
        'EEG O1-REF', 'EEG O2-REF', 'EEG P3-REF', 'EEG P4-REF',
        'EEG PZ-REF', 'EEG T1-REF', 'EEG T2-REF', 'EEG T3-REF',
        'EEG T4-REF'
    ),
}
sensors = sensors[ds_name]
train_eval_split = {
    'Schirrmeister2017': 'run', 
    'BNCI2014001': 'session',
    'TUHAbnormal': 'train',
}
eval_name = {
    'Schirrmeister2017': 'test',
    'BNCI2014001': 'session_E',
    'TUHAbnormal': 'False',
}

In [5]:
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    ds = MOABBDataset(
        dataset_name=ds_name,
        subject_ids=[subject_id],
    )
else:
    assert ds_name in ['TUHAbnormal']
    ds = TUHAbnormal(
        path='/data/schirrmr/gemeinl/tuh-abnormal-eeg/raw/v2.0.0/edf/',
        recording_ids=recording_ids,
        target_name='pathological',
        preload=False,
        add_physician_reports=True,
    )

In [6]:
split_indicator = train_eval_split[ds_name]
eval_ds_name = eval_name[ds_name]

In [7]:
# split into train and test session
splits = ds.split(split_indicator)
for split_name, split_ds in splits.items():
    if split_name == eval_ds_name:
        continue
    break
print(split_name)

True


In [8]:
if ds_name in ['TUHAbnormal']:
    split_name = 'eval' if split_name == 'False' else 'train'

In [9]:
# some preprocessing steps
preprocessors = [
    Preprocessor(
        apply_on_array=False,
        fn='pick_channels', 
        ch_names=sensors, 
        ordered=True,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=lambda x: x * factor,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=lambda x: np.clip(x, -max_abs_val, max_abs_val),
    ),
    Preprocessor(
        apply_on_array=False,
        fn='resample', 
        sfreq=sfreq,
    ),
]
# Additional preprocessing for TUHAbnormal
if ds_name in ['TUHAbnormal']:
    preprocessors.extend([
        # EEG O1-REF -> O1
        Preprocessor(
            apply_on_array=False,
            fn='rename_channels',
            mapping=lambda ch: ch[ch.find(' ')+1:ch.find('-')],
        ),
        # discard first 60s
        Preprocessor(
            apply_on_array=False,
            fn='crop',
            tmin=60,
        ),
    ])
# apply some preprocessing
preprocess(
    concat_ds=split_ds,
    preprocessors=preprocessors,
)

In [10]:
# define windowing parameters
windowing_params = {
    'drop_last_window': False,
    'window_size_samples': window_size_samples,
    'window_stride_samples': window_stride_samples,
}
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    windowing_params['trial_start_offset_samples'] = trial_start_offset_samples
    windowing_params['trial_stop_offset_samples'] = trial_stop_offset_samples
else:
    assert ds_name in ['TUHAbnormal']
    # map boolean pathological targets to integer 
    windowing_params['mapping'] = {False: 0, True: 1}
    windowing_params['start_offset_samples'] = trial_start_offset_samples
    windowing_params['stop_offset_samples'] = None
    windowing_params['window_size_samples'] = 1000 
    windowing_params['window_stride_samples'] = 1000        

In [11]:
df = extract_ds_features(
    ds=split_ds,
    frequency_bands=frequency_bands,
    n_jobs=n_jobs,
    params={'Time__higuchi_fractal_dimension__kmax': 3},
    windowing_params=windowing_params,
)

INFO:braindecode_features.feature_extraction:Computing features of domain: Time.
INFO:numexpr.utils:Note: NumExpr detected 20 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
INFO:braindecode_features.feature_extraction:Computing features of domain: Fourier.
INFO:braindecode_features.feature_extraction:Computing features of domain: Wavelet.
INFO:braindecode_features.feature_extraction:Computing features of domain: Hilbert.
INFO:braindecode_features.feature_extraction:Computing features of domain: Cross-frequency.


In [12]:
df

Domain,Description,Description,Description,Cross-frequency,...,Time,Time,Time,Time
Feature,Trial,Window,Target,cross_frequency_coupling,...,zero_crossings_derivative,zero_crossings_derivative,zero_crossings_derivative,zero_crossings_derivative
Channel,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,A1,...,T1,T2,T3,T4
Frequency,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,"4-8, 8-13",...,13-30,13-30,13-30,13-30
0,0,0,0,4.219507,...,463.0,459.0,476.0,469.0
1,0,1,0,12.065466,...,453.0,455.0,469.0,460.0
2,0,2,0,6.827751,...,463.0,472.0,479.0,475.0
3,0,3,0,2.356303,...,521.0,523.0,523.0,533.0
4,0,4,0,5.835598,...,517.0,507.0,509.0,515.0
...,...,...,...,...,...,...,...,...,...
773,4,131,1,5.634732,...,442.0,455.0,454.0,442.0
774,4,132,1,4.889182,...,454.0,452.0,478.0,462.0
775,4,133,1,4.482789,...,452.0,432.0,484.0,445.0
776,4,134,1,4.330865,...,439.0,464.0,470.0,458.0


In [13]:
df.dtypes

Domain           Feature                    Channel  Frequency
Description      Trial                                              int64
                 Window                                             int64
                 Target                                             int64
Cross-frequency  cross_frequency_coupling   A1       4-8, 8-13    float32
                                            A2       4-8, 8-13    float32
                                                                   ...   
Time             zero_crossings_derivative  PZ       13-30        float32
                                            T1       13-30        float32
                                            T2       13-30        float32
                                            T3       13-30        float32
                                            T4       13-30        float32
Length: 3468, dtype: object

In [15]:
# save each trial feature matrix (windows x features) to an individual file
if out_path is not None:
    if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
        out_path = os.path.join(out_path, str(subject_id), split_name)
    else:
        assert ds_name in ['TUHAbnormal']
        out_path = os.path.join(out_path, split_name)
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    save_features(
        df=df, 
        out_path=out_path, 
    )

In [16]:
# inspect subsets of the feature data frame
filter_df(
    df=df, 
    query='Wavelet', 
    exact_match=False, 
    level_to_consider=None,
)

Domain,Description,Description,Description,Wavelet,Wavelet,Wavelet,Wavelet,Wavelet,Wavelet
Feature,Trial,Window,Target,bounded_variation,...,variance,variance,variance,variance
Channel,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,A1,...,T1,T2,T3,T4
Frequency,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,4-8,...,13-30,13-30,13-30,13-30
0,0,0,0,41.007496,...,81.537766,83.283997,93.796097,72.612091
1,0,1,0,43.805557,...,106.145592,99.305328,133.705170,94.965050
2,0,2,0,43.136700,...,74.915779,75.702995,91.128815,67.470360
3,0,3,0,51.004570,...,53.703457,51.714802,89.093330,51.403843
4,0,4,0,50.064571,...,69.412254,62.997101,141.036667,83.123924
...,...,...,...,...,...,...,...,...,...
773,4,131,1,43.303131,...,31.228409,30.446785,30.790102,35.002201
774,4,132,1,49.453873,...,22.973703,22.279510,23.362446,26.495647
775,4,133,1,49.167305,...,18.523458,23.434082,21.858273,29.097565
776,4,134,1,46.854084,...,24.882755,20.160500,25.414705,27.663832
