In [1]:
import warnings
warnings.filterwarnings("ignore", category= UserWarning)
warnings.filterwarnings("ignore", category= FutureWarning)
warnings.filterwarnings("ignore", category= RuntimeWarning)

In [2]:
import mne
mne.set_log_level("CRITICAL")
import numpy as np
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from imblearn.over_sampling import SMOTE

In [3]:
import double_dipper
from double_dipper import dataset, constants, io, ml
from double_dipper.constants import problem, strategy_prompt
from double_dipper.features import chain, time_window, bandpass_filter, psd, flatten_end

In [4]:
def labeller(meta):
    strat = meta["strategy"]
    if strat is None: return None
    if strat.lower().startswith("fact"):        return 0
    elif strat.lower().startswith("procedure"): return 1
    else:                                       return None
divider = lambda meta: (meta["id"], meta["epoch"])

In [5]:
subjNos = [i for i in range(1, 12) if i != 5]
pairs = io.filePairs(*[f"cleaned/main/{subjNo}" for subjNo in subjNos])
dset = io.partition(divider, labeller, pairs)

In [6]:
def subj_data(subjNo):
    subj_set = dataset.subset(lambda meta: meta[0] == subjNo, dset)
    keys = sorted(subj_set.keys(), key=lambda k: k[1])
    X = np.concatenate([subj_set[k]["x"] for k in keys], axis=0)
    Y = np.concatenate([subj_set[k]["y"] for k in keys], axis=0)
    return (X, Y)

In [7]:
feature_selector = chain(
    time_window(0, strategy_prompt.delay),
    bandpass_filter(1, 45),
    psd(1, 45, add=True),
    flatten_end
)

In [8]:
mySMOTE = lambda: SMOTE(random_state=0, n_jobs=4)

In [9]:
splits = [.2, .4,.6, .8]
zeros = lambda: np.zeros([len(subjNos), len(splits)])
prec = zeros()
rec = zeros()
f1  = zeros()
for (i, subjNo) in enumerate(subjNos):
    print(f"SUBJECT {subjNo}")
    (X, Y) = subj_data(subjNo)
    X = feature_selector(X)
    (prec[i, :], rec[i,:]) = ml.temporal_cross_validation(X, Y,
                                                resampler=mySMOTE,
                                                model=LinearDiscriminantAnalysis,
                                                splits=splits)
    print(f"precision={prec[i, :]}")
    print(f"recall={rec[i, :]}")

SUBJECT 1
precision=[1.  0.5 0.  0. ]
recall=[0.05555556 0.14285714 0.         0.        ]
SUBJECT 2
precision=[0.         0.4        0.5        0.33333333]
recall=[0.         0.125      0.08333333 0.14285714]
SUBJECT 3
precision=[0. 0. 0. 0.]
recall=[0. 0. 0. 0.]
SUBJECT 4
precision=[0.31147541 0.5        0.4        0.5       ]
recall=[0.61290323 0.29166667 0.22222222 0.11111111]
SUBJECT 6
precision=[0.30337079 0.8        0.         0.5       ]
recall=[0.96428571 0.19047619 0.         0.1       ]
SUBJECT 7
precision=[0. 0. 0. 0.]
recall=[0. 0. 0. 0.]
SUBJECT 8
precision=[0. 0. 0. 0.]
recall=[0. 0. 0. 0.]
SUBJECT 9
precision=[0.37931034 0.5        0.6        0.5       ]
recall=[0.26190476 0.29032258 0.27272727 0.36363636]
SUBJECT 10
precision=[0.38461538 0.7        0.6        0.5       ]
recall=[0.32258065 0.29166667 0.35294118 0.5       ]
SUBJECT 11
precision=[0.  0.5 0.  0. ]
recall=[0.         0.07692308 0.         0.        ]
