In [1]:
import warnings
warnings.filterwarnings("ignore", category= UserWarning)
warnings.filterwarnings("ignore", category= FutureWarning)
warnings.filterwarnings("ignore", category= RuntimeWarning)

In [2]:
import mne
mne.set_log_level("CRITICAL")
import numpy as np
import double_dipper
from double_dipper import dataset, constants, io, ml
from double_dipper.constants import problem, strategy_prompt

In [3]:
from functools import reduce
def chain(*functions):
    def chained(args):
        for fn in functions:
            args = fn(args)
        return args
    return chained

In [4]:
def labeller(meta):
    strat = meta["strategy"]
    if strat is None: return None
    if strat.lower().startswith("fact"):        return 0
    elif strat.lower().startswith("procedure"): return 1
    else:                                       return None
divider = lambda meta: meta["epoch"]

In [5]:
def gen_dset(subjNo, split=.7):
    pairs = io.filePairs(f"cleaned/main/{subjNo}")
    dset = io.partition(divider, labeller, pairs)
    keys = sorted(dset.keys())
    X = np.concatenate([dset[k]["x"] for k in keys], axis = 0)
    Y = np.concatenate([dset[k]["y"] for k in keys], axis = 0)
    split_ind = int(len(X) * split)
    (trainX, testX) = (X[:split_ind], X[split_ind:])
    (trainY, testY) = (Y[:split_ind], Y[split_ind:])
    return (trainX, trainY, testX, testY)

In [6]:
SUBJ_NO = 10
(trX, trY, tsX, tsY) = gen_dset(SUBJ_NO)

In [7]:
def problem_window(X):
    firstSample = problem.delay * constants.sfreq
    lastSample = strategy_prompt.delay * constants.sfreq
    return X[...,firstSample:lastSample]

In [8]:
def bandpass_filt(X):
    return mne.filter.filter_data(X, constants.sfreq, l_freq=1, h_freq=32)

In [9]:
def flatten_end(X):
    return X.reshape([X.shape[0], np.prod(X.shape[1:])])

In [15]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from imblearn.over_sampling import ADASYN, SMOTE
feature_selectors = [flatten_end,
                     chain(bandpass_filt, flatten_end),
                     chain(bandpass_filt, problem_window, flatten_end)
]
models = [LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis]
resamplers = [None, lambda: SMOTE(n_jobs=4), lambda: ADASYN(n_jobs=4)]

In [16]:
(inds, conf) = ml.grid_search(trX, trY, tsX, tsY, feature_selectors, resamplers, models)

feature_selector=0,resampler=0,model=0: 	precision=0.333, recall=0.167, f1=0.167
New best achieved

feature_selector=1,resampler=0,model=0: 	precision=0.467, recall=0.583, f1=0.583
New best achieved

feature_selector=2,resampler=0,model=0: 	precision=0.250, recall=0.250, f1=0.250



  prec = conf[1, 1] / np.sum(conf[:, 1])


feature_selector=0,resampler=1,model=0: 	precision=nan, recall=0.000, f1=nan

feature_selector=1,resampler=1,model=0: 	precision=0.375, recall=0.500, f1=0.500



  prec = conf[1, 1] / np.sum(conf[:, 1])


feature_selector=2,resampler=1,model=0: 	precision=nan, recall=0.000, f1=nan

feature_selector=0,resampler=2,model=0: 	precision=0.500, recall=0.250, f1=0.250

feature_selector=1,resampler=2,model=0: 	precision=0.350, recall=0.583, f1=0.583

feature_selector=2,resampler=2,model=0: 	precision=0.333, recall=0.250, f1=0.250



  prec = conf[1, 1] / np.sum(conf[:, 1])


feature_selector=0,resampler=0,model=1: 	precision=nan, recall=0.000, f1=nan

feature_selector=1,resampler=0,model=1: 	precision=0.538, recall=0.583, f1=0.583



  prec = conf[1, 1] / np.sum(conf[:, 1])


feature_selector=2,resampler=0,model=1: 	precision=nan, recall=0.000, f1=nan

feature_selector=0,resampler=1,model=1: 	precision=0.125, recall=0.083, f1=0.083

feature_selector=1,resampler=1,model=1: 	precision=0.333, recall=0.417, f1=0.417

feature_selector=2,resampler=1,model=1: 	precision=0.231, recall=0.250, f1=0.250



  prec = conf[1, 1] / np.sum(conf[:, 1])


feature_selector=0,resampler=2,model=1: 	precision=nan, recall=0.000, f1=nan

feature_selector=1,resampler=2,model=1: 	precision=0.273, recall=0.250, f1=0.250

feature_selector=2,resampler=2,model=1: 	precision=nan, recall=0.000, f1=nan



  prec = conf[1, 1] / np.sum(conf[:, 1])
