In [None]:
import os
import numpy as np

In [None]:
from paths import rootdir as datadir
from prepare import param_generator

In [None]:
import graphs

In [None]:
n_trials, cue_freqs, rew_freqs, params = param_generator()
exps = {(n_trial, cue_freq, rew_freq): idx for idx, n_trial, cue_freq, rew_freq in params}

In [None]:
def load_records(n_trial, cue_freq, rew_freq):
    """Load the data for a given combination of parameters."""
    idx = exps[(n_trial, cue_freq, rew_freq)]
    filename = 'data_sgstim.{}'.format(idx)
    filepath = os.path.join(datadir, 'data', filename + '.npy')
    return np.squeeze(np.load(filepath))
    
def graph_experiment(n_trial, cue_freq, rew_freq, fig=None, lines=None):
    """Display the """
    records = load_records(n_trial, cue_freq, rew_freq)

    P_mean = np.mean(records["best"], axis=0)
    P_std  = np.std(records["best"], axis=0)
    # RT_mean = np.mean(records["RT"]*1000, axis=0)
    # RT_std = np.std(records["RT"]*1000, axis=
    
    if fig is None:
        fig = graphs.figure(y_range=[0.0, 1.2],
                            title=("Performance [n_trial={}, A_freq={}, "+
                            "A_rew={}]").format(n_trial, cue_freq, rew_freq),
                            plot_width=900, plot_height=400, tools="")

    if lines is None:
        line_down = fig.line(range(0, 100), (P_mean-P_std)[n_trial:], line_dash='dashed')
        line_mean = fig.line(range(0, 100),  P_mean[n_trial:])
        line_up   = fig.line(range(0, 100), (P_mean+P_std)[n_trial:], line_dash='dashed')
        return fig, (line_down, line_mean, line_up)
    
    else:
        line_down, line_mean, line_up = lines
        line_down.data_source.data['y'] = (P_mean-P_std)[n_trial:]
        line_mean.data_source.data['y'] =  P_mean[n_trial:]
        line_up.data_source.data['y']   = (P_mean+P_std)[n_trial:]
        graphs.io.push_notebook()

In [None]:
idx, n_trial, cue_freq, rew_freq = params[0]
fig, lines = graph_experiment(n_trial, cue_freq, rew_freq)
graphs.show(fig)

In [None]:
graphs.interact(graph_experiment, 
    n_trial=graphs.SelectionSlider(description='n_trial',  options=list(n_trials)), 
    cue_freq=graphs.SelectionSlider(description='A_freq', options=list(cue_freqs)), 
    rew_freq=graphs.SelectionSlider(description='A_rew', options=list(rew_freqs)),
    fig=graphs.fixed(fig), lines=graphs.fixed(lines))