In [2]:
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")
log = logging.getLogger(__name__)

In [3]:
import sys

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, GroupKFold
from sklearn.metrics import accuracy_score

sys.path.insert(0, '/home/gemeinl/code/braindecode-features/')
from braindecode_features import read_features, prepare_features, score
from braindecode_features.utils import drop_window, _check_df_consistency

In [4]:
def cross_validate(
    df, clf, subject_id, only_last_fold, agg_func, windows_as_examples, 
    out_path=None):
    """
    Run (cross-)validation on features and targets in df using estimator clf.
    
    Parameters
    ----------
    df: `pd.DataFrame`
        A feature DataFrame.
    clf: sklearn.estimator
        A scikit-learn estimator.
    subject_id: int
    only_last_fold: bool
        Whether to only run the last fold of CV. Corresponds to 80/20 split.
    agg_func: callable
        Function to aggregate trial features, e.g. mean, median...
    windows_as_examples: bool
        Whether to consider compute windows as independent examples.
    out_path: str
        Directory to save 'cv_results.csv' to.
    """
    invalid_cols = [
        (col, ty) for col, ty in df.dtypes.items() if ty not in ['float32', 'int64']]
    if invalid_cols:
        log.error(f'Only integer and float values are allowed to exist in the DataFrame. '
              f'Found {invalid_cols}. Please convert.')
        return
    _check_df_consistency(df)
    results = pd.DataFrame()
    n_splits = 5
    X, y, groups, feature_names = prepare_features(
        df=df,
        agg_func=agg_func,
        windows_as_examples=windows_as_examples,
    )
    if agg_func is not None or not windows_as_examples:
        # preserves order of examples but might split groups
        # therefore don't use when not aggregating and using windows as examples
        cv = KFold(n_splits=n_splits, shuffle=False)
    else:
        # does not preserve order of examples but guarantees not splitting groups
        cv = GroupKFold(n_splits=n_splits)

    #if isinstance(clf, AutoSklearnClassifier) or isinstance(clf, AutoSklearn2Classifier):
    if hasattr(clf, 'refit'):
        # optimize hyperparameters on entire training data
        dataset_name = '_'.join(['subject', str(subject_id)])
        clf = clf.fit(X=X, y=y, dataset_name=dataset_name)
    # perform validation
    infos = []
    for fold_i, (train_is, valid_is) in enumerate(cv.split(X, y, groups)):
        if only_last_fold and fold_i != cv.n_splits - 1:
            continue
        X_train, y_train, groups_train = X[train_is], y[train_is],groups[train_is]
        X_valid, y_valid, groups_valid = X[valid_is], y[valid_is], groups[valid_is]
        log.debug(f'train shapes {X_train.shape}, {y_train.shape}, {groups_train.shape}')
        log.debug(f'valid shapes {X_valid.shape}, {y_valid.shape}, {groups_valid.shape}')
        if hasattr(clf, 'refit'):
            # for autosklearn, refit the ensemble found on the entire training set
            # on the training data of the cv split
            clf = clf.refit(X_train, y_train)
        else:
            clf = clf.fit(X_train, y_train)
        pred = clf.predict(X_valid)
        d = {
            'subject': subject_id,
            'fold': fold_i,
            'estimator': clf.__class__.__name__,
            'predictions': pred.tolist(),
            'targets': y_valid.tolist(),
            'model': clf.show_models() if hasattr(clf, 'show_models') else str(clf),
            'feature_names': feature_names.to_dict(),
            'windows_as_examples': windows_as_examples,
            'agg_func': agg_func.__name__ if agg_func is not None else agg_func,
        }
        # compute scores
        scores = score(
            score_func=accuracy_score,
            y=y_valid,
            y_pred=pred,
            y_groups=groups_valid,
        )
        d.update(scores)
        info = pd.DataFrame([pd.Series(d)])
        if out_path is not None:
            out_file = os.path.join(os.path.dirname(os.path.dirname(out_path)), 'cv_results.csv')
            info.to_csv(out_file, mode='a', header=not os.path.exists(out_file))
        log.info(info.tail(1))
        infos.append(info)
    return pd.concat(infos)

In [5]:
subject_id = 1
train_or_test = 'train'
path = f'./tmp/{train_or_test}'
n_jobs = 1
seed = 20210408
np.random.seed(seed)
only_last_fold = True
# set to integer to enable autosklearn usage with specified integer as training time in minutes
n_min = None  # None/2/20 
drop_a_window = 1

clfs = []
rf = RandomForestClassifier(
    n_estimators=750, 
    random_state=seed
)
clfs.append(rf)
if n_min is not None:
    # optional import which raises warnings
    from autosklearn.experimental.askl2 import AutoSklearn2Classifier
    from autosklearn.classification import AutoSklearnClassifier
    asc = AutoSklearnClassifier( 
        time_left_for_this_task=60*n_min,
        ml_memory_limit=8192,
        n_jobs=n_jobs,
        seed=seed,
        initial_configurations_via_metalearning=0,
    )
    clfs.append(asc)

# read features from disk
df = read_features(
    path=path,
    n_jobs=n_jobs,
)

# potentially drop first window (for HGD, better decoding when starting at 500ms?)
if drop_a_window is not None:
    df = drop_window(df, drop_a_window)

# try different feature usages
agg_funcs = [None, np.mean, np.median]
windows_as_examples_ = [True, False, False]
all_results = []
for clf in clfs:
    # TODO: add filter_df of feature dataframe
    for agg_func, windows_as_examples in zip(agg_funcs, windows_as_examples_):
        results = cross_validate(
            df=df,
            clf=clf,
            subject_id=subject_id,
            only_last_fold=only_last_fold,
            agg_func=agg_func,
            windows_as_examples=windows_as_examples,
            out_path=None,
        )
        all_results.append(results)
all_results = pd.concat(all_results)

INFO:numexpr.utils:Note: NumExpr detected 20 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
INFO:__main__:   subject  fold               estimator  \
0        1     4  RandomForestClassifier   

                                         predictions  \
0  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...   

                                             targets  \
0  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...   

                                               model  \
0  RandomForestClassifier(bootstrap=True, ccp_alp...   

                                       feature_names  windows_as_examples  \
0  {'Domain': {0: 'Cross-frequency', 1: 'Cross-fr...                 True   

  agg_func  window_accuracy_score  trial_accuracy_score  
0     None               0.659453                 0.625  
INFO:__main__:   subject  fold               estimator            predictions  \
0        1     4  RandomForestClassifi

In [6]:
all_results

Unnamed: 0,subject,fold,estimator,predictions,targets,model,feature_names,windows_as_examples,agg_func,window_accuracy_score,trial_accuracy_score
0,1,4,RandomForestClassifier,"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","RandomForestClassifier(bootstrap=True, ccp_alp...","{'Domain': {0: 'Cross-frequency', 1: 'Cross-fr...",True,,0.659453,0.625
0,1,4,RandomForestClassifier,"[1, 1, 1, 1, 1, 0, 0]","[1, 1, 0, 0, 0, 0, 0]","RandomForestClassifier(bootstrap=True, ccp_alp...","{'Domain': {0: 'Cross-frequency', 1: 'Cross-fr...",False,mean,0.571429,0.571429
0,1,4,RandomForestClassifier,"[1, 1, 1, 1, 1, 0, 0]","[1, 1, 0, 0, 0, 0, 0]","RandomForestClassifier(bootstrap=True, ccp_alp...","{'Domain': {0: 'Cross-frequency', 1: 'Cross-fr...",False,median,0.571429,0.571429
