In [None]:
import os
import numpy as np

In [None]:
from paths import rootdir as datadir
from prepare import param_generator

In [None]:
import graphs

In [None]:
n_trials, cue_freqs, rew_freqs, params = param_generator()
exps = {(n_trial, cue_freq, rew_freq): idx for idx, n_trial, cue_freq, rew_freq in params}

In [None]:
def load_records(n_trial, cue_freq, rew_freq):
    """Load the data for a given combination of parameters."""
    idx = exps[(n_trial, cue_freq, rew_freq)]
    filename = 'data_sgstim.{}'.format(idx)
    filepath = os.path.join(datadir, 'data', filename + '.npy')
    return np.squeeze(np.load(filepath))

### Best Choice graph

In [None]:
def graph_best(n_trial, cue_freq, rew_freq, fig=None, lines=None):
    """Display graph of best choice"""
    records = load_records(n_trial, cue_freq, rew_freq)

    P_mean = np.mean(records["best"], axis=0)
    
    if fig is None:
        fig = graphs.figure(y_range=[0.0, 1.2], plot_width=900, plot_height=400, tools="")

    if lines is None:
        lines = {}
        lines['mean'] = fig.line(range(0, 100),  P_mean[n_trial:])
    else:
        lines['mean'].data_source.data['y'] =  P_mean[n_trial:]
        graphs.io.push_notebook()
    fig.title.text = "Cue choice [n_trial={}, A_freq={}, A_rew={}]".format(n_trial, cue_freq, rew_freq)
    
    return fig, lines

def update_best(n_trial, cue_freq, rew_freq, fig=None, lines=None):
    graph_best(n_trial, cue_freq, rew_freq, fig=fig, lines=lines)

In [None]:
idx, n_trial, cue_freq, rew_freq = params[0]
fig, lines = graph_best(n_trial, cue_freq, rew_freq)
graphs.show(fig)

In [None]:
graphs.interact(update_best, fig=graphs.fixed(fig), lines=graphs.fixed(lines),
    n_trial=graphs.select('n_trial', n_trials), cue_freq=graphs.select('A_freq', cue_freqs), 
    rew_freq=graphs.select('A_rew', rew_freqs))

### Chosen Cue Graph

In [None]:
def cue_count(cues):
    """Return to occurences of each cues on the rows of cues"""
    table = {}
    for c in [-1, 0, 1]:
        table[c] = (cues == c).sum(0)/cues.shape[0]
    return table
    
def rew_count(cues, rews):
    table = {}
    for c in [-1, 0, 1]:
        table[c] = ((cues == c)*rews).sum(0)/cues.shape[0]
    return table
    

def graph_cue(n_trial, cue_freq, rew_freq, single_phase=True, fig=None, lines=None):
    """Display graph of best choice"""
    records = load_records(n_trial, cue_freq, rew_freq)

    cues = records["cue"]
    if not single_phase:
        cues = records["cue"][:, n_trial:]
    cue_table = cue_count(cues)
    
    if fig is None:
        fig = graphs.figure(y_range=[-1.3, 1.3],
                            plot_width=900, plot_height=400, tools="")

    x   = list(range(1, cues.shape[1]+1)) + list(range(cues.shape[1], 0, -1))
    y_1 = (              [ 0.5*c_1 for c_1 in cue_table[-1]] + 
           list(reversed([-0.5*c_1 for c_1 in cue_table[-1]])))
    y0  = (              [ 0.5*c_1 for c_1 in cue_table[-1]] + 
           list(reversed([ 0.5*c_1 + c0 for c_1, c0 in zip(cue_table[-1], cue_table[0])])))
    y1  = (              [-0.5*c_1 for c_1 in cue_table[-1]] + 
           list(reversed([-0.5*c_1 - c1 for c_1, c1 in zip(cue_table[-1], cue_table[1])])))
    if lines is None:
        lines = {}
        lines[ 0] = fig.patch(x,  y0, legend='A', fill_color="#fa6900", fill_alpha=0.5, line_color="#fa6900")
        lines[ 1] = fig.patch(x,  y1, legend='B', fill_color="#69d2e7", fill_alpha=0.5, line_color="#69d2e7")
        lines[-1] = fig.patch(x, y_1, legend='no choice', fill_color="#aaaaaa", fill_alpha=0.5, line_color="#aaaaaa")
    else:
        lines[-1].data_source.data['x'] = x
        lines[-1].data_source.data['y'] = y_1
        lines[ 0].data_source.data['x'] = x
        lines[ 0].data_source.data['y'] = y0
        lines[ 1].data_source.data['x'] = x
        lines[ 1].data_source.data['y'] = y1
        graphs.io.push_notebook()
    fig.title.text = "Cue choice [n_trial={}, A_freq={}, A_rew={}]".format(n_trial, cue_freq, rew_freq)
        
        
    return fig, lines

def update_cue(n_trial, cue_freq, rew_freq, single_phase=True, fig=None, lines=None):
    graph_cue(n_trial, cue_freq, rew_freq, single_phase=single_phase, fig=fig, lines=lines)

In [None]:
idx, n_trial, cue_freq, rew_freq = params[0]
fig, lines = graph_cue(n_trial, cue_freq, rew_freq)
graphs.show(fig)

In [None]:
graphs.interact(update_cue, fig=graphs.fixed(fig), lines=graphs.fixed(lines),
    n_trial=graphs.select('number of trials for single phase', n_trials), cue_freq=graphs.select('A freq. during single phase', cue_freqs), 
    rew_freq=graphs.select('A reward', rew_freqs), single_phase=True)