In [None]:
import numpy as np
from datasets import get_dataset
from model_fitting import MultiCCA
from decodingCurveSupervised import decodingCurveSupervised
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from analyse_datasets import analyse_dataset, analyse_datasets, debug_test_dataset, debug_test_single_dataset
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
dataset_loader, dataset_files, dataroot = get_dataset('plos_one')
print("Got {} datasets\n{}".format(len(dataset_files),dataset_files))

In [None]:
# run all 'plos_one' datasets with cross-validation to estimate system performance, with:
#   dataset:string - dataset to load, using datasets.get_dataset.  See datasets.py for full list datasets
#   loader_args:dict - ofs:float - output-sample-rate of 60Hz
#                  stopband:list - list of stop-bands applied during loading before slicing.  
#                            (0,3) = high-pass at 3hz,   (30,-1) = stop between 30 and inf = low-pass at 30
#                            together these mean band-pass 3-30
#   preprocess_args:dict - set of parameters to pass to the preprocessor (if any)
#   model:string - type of model to fit, including: 'cca','fwd','bwd','ridge','lr','svr','svc','sklearn' see model_fitting for full list.
#                  Note: different models may require different parameters in clsfr_args
#   clsfr_args:dict - tau_ms:float - stimulus response length in milliseconds
#                 evtlabs:list - list of brain-events to use (see stim2event.py for full possibilities)  're'=rising-edge, 'fe'=falling-edge
#                 rank:int - rank of the decomposition to fit 
analyse_datasets('plos_one',loader_args=dict(ofs=80,stopband=((0,3),(25,-1))),
                 model='cca',clsfr_args=dict(tau_ms=450,evtlabs=('re','fe'),rank=1,reg=0.02))
#  when it's done it will make a summary plot ofthe decoding curves over all the datasets

# bp=3-25, tau=500, rank=3, reg=.02, ofs=80 -> .75 (@.09)
# bp=3-25, tau=450, rank=1, reg=.02, ofs=80 -> .71 (@.12 : 28,9,27,24)
# bp=3-25, tau=450, rank=1, reg=.02, ofs=80, evtlabs=('re','fe','anyfe') -> .67 (@.14)

In [None]:
analyse_datasets('plos_one',loader_args=dict(ofs=60,stopband=((0,3),(30,-1))),
                 model='lr',clsfr_args=dict(tau_ms=350,evtlabs=('re','fe')))

In [None]:
# try with different modeling parameters, 
#   e.g. 2-bit brain responses, '00'=low,'11'=high,'01'=rising-edge,'10'=falling-edge
analyse_datasets('plos_one',loader_args=dict(ofs=60,stopband=((0,3),(30,-1))),
                 model='cca',clsfr_args=dict(tau_ms=350,evtlabs=('00','01','10','11'),rank=5))

In [None]:
# WOW! that killed!! performance, Suspect that too many evt types introduced numerical issues...
# tweak the condition number in the CCA matrix inverses with the rcond parameter to address the numerical degenercies...
# alternative is to use reg
analyse_datasets('plos_one',loader_args=dict(ofs=60,stopband=((0,3),(30,-1))),
                 model='cca',clsfr_args=dict(tau_ms=350,evtlabs=('00','01','10','11'),rank=1,rcond=(1e-6,1e-4)))
# Not as good as re-fe alone, but OK.

In [None]:
dataset_loader, dataset_files, dataroot = get_dataset('lowlands')
print("Got {} datasets".format(len(dataset_files)))


In [None]:
# run the  analysis  -- this may take a while!
analyse_datasets('lowlands',loader_args=dict(ofs=90,stopband=((0,3),(25,-1))),
                 model='cca',clsfr_args=dict(tau_ms=450,evtlabs=('re','fe'),rank=1,reg=.02))
# bp=3-25, rank=1, reg=.02, ofs=60, tau=500 -> .56
# bp=3-25, rank=1, reg=.02, ofs=60, tau=450 -> .56 (@29)
# bp=3-25, rank=1, reg=.02, ofs=80, tau=450 -> .53 (@30)
# bp=3-25, rank=1, reg=.02, ofs=90, tau=450 -> .56 (@28 : 44,23,42,39)

In [None]:
# run the  analysis  -- this may take a while!
analyse_datasets('lowlands',loader_args=dict(ofs=90,stopband=((0,3),(25,-1))),
                 model='cca',clsfr_args=dict(tau_ms=450,evtlabs=('re','fe'),rank=1,reg=.02))
# bp=3-25, rank=1, reg=.02, ofs=90, tau=450 -> .56 (@.28)

In [None]:
dataset_loader, dataset_files, dataroot = get_dataset('p300_prn')
print("Got {} datasets".format(len(dataset_files)))

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='rc_5_flash'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 model='cca',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),rank=3,reg=.02))
# bp=1-12, ofs=32, tau=750, evtlabs=('re','anyre'), rank=1 -> 33 (@10)
# bp=1-12, ofs=32, tau=750, evtlabs=('re','anyre'), rank=1 -> 28 (@08)
# bp=1-12, ofs=32, tau=750, evtlabs=('re','anyre'), rank=1 -> 29 (@09)


In [None]:
debug_test_single_dataset('p300_prn',dataset_args=dict(label='rc_5_flash'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 model='cca',tau_ms=750,evtlabs=('re','anyre'),rank=3,reg=.02)

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='nan_rc_5_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(8,-1)),subtriallen=None),
                 clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),rank=6,reg=.02))
# bp=1-8, tau=750, rank=3, reg=.02 -> 73 (@.07)  
# bp=1-12, tau=750, rank=3, reg=.02 -> 72 (@.07)  
# bp=3-25, tau=750, rank=3, reg=.02 -> 66 (@.16)
# bp=1-25, tau=750, rank=3, reg=.02 -> 69 (@.08)
# bp=1-8, tau=750, rank=1, reg=.02 -> 68 (@.12)  

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='prn_5_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 model='cca',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),rank=3,reg=.02))
# bp=1-12, tau=750, rank=3, reg=.02 -> XX (@.07)  # p-val are too pessimistic

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='rc_10_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 model='cca',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),rank=3,reg=.02))
# bp=1-12, ofs=32, tau=750, evtlabs=('re','anyre'), rank=3 = 81 @06 (19,114,17,140)

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='prn_10_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 model='cca',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),rank=3,reg=.02))

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='prn_15_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 model='cca',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),rank=3,reg=.02))
# bp=1-12, ofs=32, tau=750, evtlabs=('re','anyre'), rank=3 -> 76 @04 (23,46,20,37)

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='prn_5_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 preprocess_args=dict(badChannelThresh=None, badTrialThresh=None, whiten=False),
                 model='lr',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),ignore_unlabelled=True,center=True))

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='rc_5_flash'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 preprocess_args=dict(badChannelThresh=None, badTrialThresh=None, whiten=False),
                 model='lr',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),ignore_unlabelled=True,center=True))

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='rc_5_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 preprocess_args=dict(badChannelThresh=None, badTrialThresh=None, whiten=False),
                 model='lr',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),ignore_unlabelled=True,center=True))

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='rc_10_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 preprocess_args=dict(badChannelThresh=None, badTrialThresh=None, whiten=False),
                 model='lr',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),ignore_unlabelled=True,center=True))

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='prn_10_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 preprocess_args=dict(badChannelThresh=None, badTrialThresh=None, whiten=False),
                 model='lr',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),ignore_unlabelled=True,center=True))

In [None]:
analyse_datasets('p300_prn',dataset_args=dict(label='prn_15_flip'),
                 loader_args=dict(ofs=32,stopband=((0,1),(12,-1)),subtriallen=None),
                 preprocess_args=dict(badChannelThresh=None, badTrialThresh=None, whiten=False),
                 model='lr',clsfr_args=dict(tau_ms=750,evtlabs=('re','anyre'),ignore_unlabelled=True,center=True))

In [None]:
    analyse_datasets("openBMI_ERP",clsfr_args=dict(tau_ms=700,evtlabs=('re','ntre'),rank=5),
                     loader_args=dict(ofs=30,stopband=((0,1),(12,-1))))
# bp=1-12, evtlabs=('re','ntre'), tau=700, rank=5 -> 