In [1]:
import os
import sys

import matplotlib.pyplot as plt
plt.style.use('seaborn')
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 8)
import mne
mne.set_log_level('ERROR')

from braindecode.datasets import MOABBDataset
from braindecode.datasets.tuh import TUHAbnormal
from braindecode.datautil.preprocess import filterbank, preprocess, Preprocessor
from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

sys.path.insert(0, '/home/gemeinl/code/braindecode-features/')
from braindecode_features import *

In [2]:
ds_name = 'TUHAbnormal'  # 'Schirrmeister2017'
subject_id = 1
out_path = None#'/home/lukas/Code/HGD/'
n_jobs = 2
agg_func = None#'mean'

#------------------------------------------------------------------------------
# stuffs to optimize
sfreq = 100  # 250
# original trials have 4s duration
frequency_bands = [(8, 13)]
trial_start_offset_samples = 0  # int(0.5*sfreq)
window_size_samples = 500
window_stride_samples = 500

# clf + hyperparams
#------------------------------------------------------------------------------
factor = 1e6
max_abs_val = 800
trial_stop_offset_samples = None
sensors = {
    'Schirrmeister2017': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CCP1h', 'CCP2h', 'CCP3h', 'CCP4h',
        'CCP5h', 'CCP6h', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPP1h',
        'CPP2h', 'CPP3h', 'CPP4h', 'CPP5h', 'CPP6h', 'CPz', 'Cz', 'FC1', 'FC2',
        'FC3', 'FC4', 'FC5', 'FC6', 'FCC1h', 'FCC2h', 'FCC3h', 'FCC4h',
        'FCC5h', 'FCC6h', 'FCz', 'FFC1h', 'FFC2h', 'FFC3h', 'FFC4h', 'FFC5h',
        'FFC6h'
    ),
    'BNCI2014001': (
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP3', 'CP4', 'CPz',
        'Cz', 'FC1', 'FC2', 'FC3', 'FC4', 'FCz', 'Fz', 'P1', 'P2', 'POz', 'Pz'
    ),
    'TUHAbnormal': (
        'EEG A1-REF', 'EEG A2-REF', 'EEG C3-REF', 'EEG C4-REF',
        'EEG CZ-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG F7-REF',
        'EEG F8-REF', 'EEG FP1-REF', 'EEG FP2-REF', 'EEG FZ-REF',
        'EEG O1-REF', 'EEG O2-REF', 'EEG P3-REF', 'EEG P4-REF',
        'EEG PZ-REF', 'EEG T1-REF', 'EEG T2-REF', 'EEG T3-REF',
        'EEG T4-REF'
    ),
}
sensors = sensors[ds_name]
train_eval_split = {
    'Schirrmeister2017': 'run', 
    'BNCI2014001': 'session',
    'TUHAbnormal': 'train',
}
eval_name = {
    'Schirrmeister2017': 'test',
    'BNCI2014001': 'session_E',
    'TUHAbnormal': 'False',
}

In [3]:
if ds_name in ['Schirrmeister2017', 'BNCI14001']:
    ds = MOABBDataset(
        dataset_name=ds_name,
        subject_ids=[subject_id],
    )
else:
    assert ds_name in ['TUHAbnormal']
    ds = TUHAbnormal(
        path='/data/schirrmr/gemeinl/tuh-abnormal-eeg/raw/v2.0.0/edf/',
        recording_ids=range(20),
        target_name='pathological',
        preload=False,
        add_physician_reports=True,
    )

In [4]:
split_indicator = train_eval_split[ds_name]
eval_ds_name = eval_name[ds_name]

In [5]:
# split into train and test session
splits = ds.split(split_indicator)
for split_name, split_ds in splits.items():
    if split_name == eval_ds_name:
        continue
    break
print(split_name)

True


In [6]:
if ds_name in ['TUHAbnormal']:
    split_name = 'eval' if split_name == 'False' else 'train'

In [7]:
# some preprocessing steps
preprocessors = [
    Preprocessor(
        apply_on_array=False,
        fn='pick_channels', 
        ch_names=sensors, 
        ordered=True,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=lambda x: x * factor,
    ),
    Preprocessor(
        apply_on_array=True,
        fn=lambda x: np.clip(x, -max_abs_val, max_abs_val),
    ),
    Preprocessor(
        apply_on_array=False,
        fn='resample', 
        sfreq=sfreq,
    ),
]
if ds_name in ['TUHAbnormal']:
    preprocessors.extend([
        # EEG O1-REF -> O1
        Preprocessor(
            apply_on_array=False,
            fn='rename_channels',
            mapping=lambda ch: ch[ch.find(' '):ch.find('-')],
        ),
        # discard first 60s
        Preprocessor(
            apply_on_array=False,
            fn='crop',
            tmin=60,
        ),
    ])
preprocessors.append(
    # used by connectivity, time, and cross-frequency domain features
    # not used by dft, cwt domain features
    Preprocessor(
        apply_on_array=False,
        fn=filterbank, 
        frequency_bands=sorted(frequency_bands, key=lambda b: b[0]), 
        drop_original_signals=False, 
    )
)

In [8]:
# apply some preprocessing
preprocess(
    concat_ds=split_ds,
    preprocessors=preprocessors,
)
# extract compute windows 
if ds_name in ['Schirrmeister2017', 'BNCI2014001']:
    create_windows = create_windows_from_events
    kwargs = {}
else:
    assert ds_name in ['TUHAbnormal']
    create_windows = create_fixed_length_windows
    kwargs = {'mapping': {False: 0, True: 1}}
windows_ds = create_windows(
    concat_ds=split_ds,
    start_offset_samples=trial_start_offset_samples,
    stop_offset_samples=trial_stop_offset_samples,
    drop_last_window=False,
    window_size_samples=window_size_samples,
    window_stride_samples=window_stride_samples,
    **kwargs,
)

In [9]:
for x, y, ind in windows_ds:
    break
x.shape, y, ind

((42, 500), 0, [0, 6000, 6500])

In [10]:
df = extract_windows_ds_features(
    windows_ds=windows_ds,
    frequency_bands=frequency_bands,
    n_jobs=n_jobs,
)

Connectivity
CWT
DFT
Time


In [11]:
# save each trial feature matrix (windows x features) to an individual file
# TODO: subject_id is unsuitable for TUHAbrormal
if out_path is not None:
    save_features_by_trial(
        df=df, 
        out_path=out_path, 
        subject_id=subject_id, 
        split_name=split_name,
    )

In [12]:
df

Domain,Dataset,Window,Target,Time,...,Connectivity,Connectivity,Connectivity,Connectivity
Feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,covariance,...,phase_locking_value,phase_locking_value,phase_locking_value,phase_locking_value
Channel,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,A1- A2,...,T1- T4,T2- T3,T2- T4,T3- T4
Frequency,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,8-13,...,8-13,8-13,8-13,8-13
0,0,0,0,28.291420,...,0.592420,0.532948,0.846600,0.441852
1,0,1,0,13.167811,...,0.656590,0.465723,0.865962,0.508563
2,0,2,0,91.953445,...,0.725725,0.666351,0.887772,0.593228
3,0,3,0,161.218643,...,0.857858,0.756994,0.921302,0.805921
4,0,4,0,92.587456,...,0.744106,0.740304,0.888831,0.726527
...,...,...,...,...,...,...,...,...,...
1550,4,267,1,14.145298,...,0.496300,0.361156,0.881787,0.326279
1551,4,268,1,24.762482,...,0.608351,0.448643,0.841559,0.514064
1552,4,269,1,3.577626,...,0.218010,0.212520,0.716193,0.182217
1553,4,270,1,15.999822,...,0.575744,0.443076,0.823722,0.451964


In [13]:
# inspect subsets of the feature data frame
filter_df(
    df=df, 
    query='DFT', 
    exact_match=False, 
    level_to_consider=None,
)

Domain,Dataset,Window,Target,DFT,DFT,DFT,DFT,DFT,DFT
Feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,maximum,...,variance,variance,variance,variance
Channel,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,A1,...,T1,T2,T3,T4
Frequency,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,8-13,...,8-13,8-13,8-13,8-13
0,0,0,0,759.156738,...,29115.685547,18230.708984,22081.833984,11348.553711
1,0,1,0,465.723907,...,7807.271973,6927.698730,7436.895020,10083.588867
2,0,2,0,1703.783936,...,92804.265625,97898.093750,107554.726562,98841.726562
3,0,3,0,2034.640869,...,149876.921875,153210.343750,142015.765625,133009.984375
4,0,4,0,1330.977661,...,46093.613281,50036.425781,72451.765625,56893.003906
...,...,...,...,...,...,...,...,...,...
1550,4,267,1,776.797729,...,18496.423828,14862.626953,18806.304688,17486.054688
1551,4,268,1,851.201660,...,33119.800781,20150.591797,20888.955078,11225.412109
1552,4,269,1,341.137695,...,5283.733398,3395.334961,4863.229004,3671.004639
1553,4,270,1,878.735168,...,15437.017578,8641.436523,19099.152344,9620.468750
