In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import os 
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import pandas as pd
from operator import sub, add
import glob
from random import shuffle

In [2]:
home_path = os.path.expanduser('~') #need this to get correct home dir. across operating systems


exp_param_path = home_path + "/Dropbox/loki_1/fmri_experiment/experimental_parameters/"
img_path = home_path + "/Dropbox/loki_1/fmri_experiment/images/symm_greebles/"

    
os.chdir(exp_param_path)
os.getcwd()

'/home/coaxlab/Dropbox/loki_1/fmri_experiment/experimental_parameters'

In [3]:
n_families = 5
n_trials = 600
test_prop = .50
n_test_repetitions = 2

In [11]:
def select_images(img_path=img_path, n_families=n_families, n_trials=n_trials):    
    
    family_indices = np.arange(0,n_families)
    nominal_families = np.arange(1,n_families+1)

    m_family_list = [glob.glob(img_path+'m'+str(family)+'~*v1.tif') for family in nominal_families] #chose view 1
    f_family_list = [glob.glob(img_path+'f'+str(family)+'~*v1.tif') for family in nominal_families]
    
    
    return m_family_list, f_family_list, family_indices, nominal_families



def check_greeble_samples(m_family_list, f_family_list, family_indices, nominal_families): 
    #error checking
    
    m_family_sample_count = [len(m_family_list[family_idx]) for family_idx in family_indices] #n. samples are unequal across gender 
    f_family_sample_count = [len(f_family_list[family_idx]) for family_idx in family_indices] #so count them to figure out test set

    n_greebles_per_family = np.unique(list(map(add, m_family_sample_count, f_family_sample_count)))
    n_unique_greebles_per_family = len(n_greebles_per_family) == 1
    
    
    n_families = len(m_family_list)
    n_samples_per_family = n_trials // n_families #each family will be sampled n times 
    
    n_families_equal = len(m_family_list) == len(f_family_list)
    
    
    print('Number of families equal across sex: ', n_families_equal)
    print('Number of families: ', n_families, 'Number of family repetitions: ', n_samples_per_family)
    print('Number of greebles per family: ', n_greebles_per_family)
    print('Equal number of greebles per family: ', n_unique_greebles_per_family)

    return m_family_sample_count, f_family_sample_count

In [5]:
def select_training_test_samples(m_family_sample_count, f_family_sample_count, test_prop=test_prop,
                                n_test_repetitions=n_test_repetitions):
    
    n_test_per_m_fam = [int(test_prop * m_family_sample_count[family_idx]) for family_idx in family_indices] #n. samples are unequal across gender 
    n_test_per_f_fam = [int(test_prop * f_family_sample_count[family_idx]) for family_idx in family_indices] #so count them to figure out training set

    n_training_per_m_fam = map(sub, m_family_sample_count, n_test_per_m_fam)
    n_training_per_f_fam = map(sub, f_family_sample_count, n_test_per_f_fam)


    n_test_greebles_per_fam = np.unique(map(add, n_test_per_m_fam, n_test_per_f_fam))
    n_test_trials = n_test_greebles_per_fam * n_test_repetitions * n_families 

    training_greebles_m = [list(np.random.choice(m_family_list[family_idx], n_training_per_m_fam[family_idx], replace=False)) for family_idx in family_indices]
    training_greebles_f = [list(np.random.choice(f_family_list[family_idx], n_training_per_f_fam[family_idx], replace=False)) for family_idx in family_indices]

    summed_m_family_list = sum(m_family_list,[]) #summed for testing 
    summed_f_family_list = sum(f_family_list,[])

    summed_training_greebles_m = sum(training_greebles_m,[])#summed for random sampling across families for each trial
    summed_training_greebles_f = sum(training_greebles_f,[])
    

    test_greebles_m = list(np.setdiff1d(summed_m_family_list, summed_training_greebles_m))
    test_greebles_f = list(np.setdiff1d(summed_f_family_list, summed_training_greebles_f))
    test_greeble_vec = test_greebles_f + test_greebles_m
    shuffle(test_greeble_vec) #shuffle so that m/f are randomly placed
    
    training_vec_choice_m = list(np.random.choice(summed_training_greebles_m, n_trials, replace=True)) 
    training_vec_choice_f = list(np.random.choice(summed_training_greebles_f, n_trials, replace=True)) 
    
    
    training_vec_choice = training_vec_choice_m + training_vec_choice_f
    test_vec_choice = list(np.random.choice(test_greeble_vec, n_test_trials, replace=True)) #could instead preesent each one exactly twice. will need to ask.
    
    return(training_vec_choice, test_vec_choice, test_greeble_vec,
training_vec_choice_m, training_vec_choice_f, summed_training_greebles_m, summed_training_greebles_f)

In [6]:
def check_training_test_samples(training_vec_choice, test_vec_choice):
    
    #double check that training and test samples do not overlap 
    train_test_overlap = bool(set(training_vec_choice) & set(test_vec_choice))
    print('Training set does not overlap with the test set: ', (train_test_overlap == 0))
    
    return None

In [7]:
def parse_images(training_vec_choice_f, training_vec_choice_m, test_greeble_vec, img_path=img_path):

    #separate file from path 
    f_training_greebles_parsed = pd.DataFrame([training_vec_choice_f[trial].split(img_path)[1] for trial in range(len(training_vec_choice_f))])
    m_training_greebles_parsed = pd.DataFrame([training_vec_choice_m[trial].split(img_path)[1] for trial in range(len(training_vec_choice_m))])
    test_greebles_parsed = pd.DataFrame([test_greeble_vec[trial].split(img_path)[1] for trial in range(len(test_greeble_vec))])

    return f_training_greebles_parsed, m_training_greebles_parsed, test_greebles_parsed

In [8]:
def print_greeble_training_set(m_training_greebles_parsed, f_training_greebles_parsed, exp_param_path=exp_param_path): 

    training_images_df = pd.DataFrame(np.column_stack([m_training_greebles_parsed, f_training_greebles_parsed]),
                               columns = ['m_image','f_image'])
    training_images_df.to_csv(exp_param_path + 'training_greeble_images.csv', header=True, index_label=True, index=False)
    
    return None

In [9]:
def print_greeble_test_set(test_greebles_parsed, exp_param_path=exp_param_path):      
    
    test_images_df = pd.DataFrame(np.array(test_greebles_parsed), columns = ['image'])
    test_images_df.to_csv(exp_param_path + 'test_greeble_images.csv', header=True, index_label=True, index=False)

    return None

In [12]:
m_family_list, f_family_list, family_indices, nominal_families = select_images()
m_family_sample_count, f_family_sample_count = check_greeble_samples(m_family_list, f_family_list, family_indices, nominal_families)

Number of families equal across sex:  True
Number of families:  5 Number of family repetitions:  120
Number of greebles per family:  [16]
Equal number of greebles per family:  True


In [65]:
(training_vec_choice, test_vec_choice, test_greeble_vec,
training_vec_choice_m, training_vec_choice_f, summed_training_greebles_m, summed_training_greebles_f) = select_training_test_samples(m_family_sample_count, f_family_sample_count)

In [66]:
check_training_test_samples(training_vec_choice, test_vec_choice)
f_training_greebles_parsed, m_training_greebles_parsed, test_greebles_parsed = parse_images(training_vec_choice_f, training_vec_choice_m, test_greeble_vec)

('Training set does not overlap with the test set: ', True)


In [67]:
print_greeble_training_set(m_training_greebles_parsed, f_training_greebles_parsed)

In [68]:
print_greeble_test_set(test_greebles_parsed)

In [None]:
def define_prob_vectors(n_trials = 800, n_targets = 2, hc_p=.65,
                        mc_p=.75, lc_p=.85, print_pooled_c_prob=1):

    lc_binary_t0, lc_binary_t1 = np.zeros((n_trials)), np.zeros((n_trials))
    mc_binary_t0, mc_binary_t1 = np.zeros((n_trials)), np.zeros((n_trials))
    hc_binary_t0, hc_binary_t1 = np.zeros((n_trials)), np.zeros((n_trials))

    mc_binary_t1 = np.zeros((n_trials))
    for t in np.arange(0,n_trials):
        test_dist = np.random.uniform()
        if test_dist < mc_p: 
            mc_binary_t1[t]=1
        else: 
            mc_binary_t1[t]=0

    mc_binary_t0[mc_binary_t1 == 0] = 1
    mc_binary_t0 = mc_binary_t0 > 0 
    mc_binary_t1 = mc_binary_t1 > 0 


    lc_binary_t1 = np.zeros((n_trials))
    for t in np.arange(0,n_trials):
        test_dist = np.random.uniform(0, 1, 1)
        if test_dist < lc_p: 
            lc_binary_t1[t]=1
        else: 
            lc_binary_t1[t]=0

    lc_binary_t0[lc_binary_t1 == 0] = 1
    lc_binary_t0 = lc_binary_t0 > 0 
    lc_binary_t1 = lc_binary_t1 > 0 

    hc_binary_t1 = np.zeros((n_trials))
    for t in np.arange(0,n_trials):
        test_dist = np.random.uniform(0, 1, 1)
        if test_dist < hc_p: 
            hc_binary_t1[t]=1
        else: 
            hc_binary_t1[t]=0
    hc_binary_t0[hc_binary_t1 == 0] = 1
    hc_binary_t0 = hc_binary_t0 > 0 
    hc_binary_t1 = hc_binary_t1 > 0 

    print('lc_p(rewarding_target) :', np.sum(lc_binary_t1)/n_trials)
    print('lc_p(unrewarding_target) :', np.sum(lc_binary_t0)/n_trials)

    print('mc_p(rewarding_target) :', np.sum(mc_binary_t1)/n_trials)
    print('hc_p(rewarding_target) :', np.sum(hc_binary_t1)/n_trials)
    t_range = np.arange(1,n_trials+1)
    hc_cumulative_p = np.cumsum(hc_binary_t1)/t_range
    mc_cumulative_p = np.cumsum(mc_binary_t1)/t_range
    lc_cumulative_p = np.cumsum(lc_binary_t1)/t_range
    if print_pooled_c_prob == 1: 
        plt.rcParams['font.size'] = 18
        plt.figure()
        plt.plot(hc_cumulative_p, 'r.', label='high conflict')
        plt.plot(mc_cumulative_p, 'b.', label='moderate conflict')
        plt.plot(lc_cumulative_p, 'g.', label='low conflict')
        plt.legend()
        plt.ylabel('cumulative p(max_reward)')
        plt.xlabel('t')
        plt.ylim([0.5,1.01])
    
    return(lc_binary_t0, lc_binary_t1, mc_binary_t0,mc_binary_t1, hc_binary_t0, hc_binary_t1, hc_cumulative_p,mc_cumulative_p,lc_cumulative_p,n_trials)

In [None]:
def assign_reward_values(lc_binary_t0, lc_binary_t1, mc_binary_t0, mc_binary_t1, hc_binary_t0, hc_binary_t1,n_trials, r_mu=3, r_std=1, n_trials_expected = 600):

    
    #reward values 
    mu_rewards = np.repeat(r_mu, n_trials_expected)
    std_rewards = np.repeat(r_std, n_trials_expected)

    hc_rewards = np.random.normal(loc=r_mu, scale=r_std, size=n_trials)
    mc_rewards = np.random.normal(loc=r_mu, scale=r_std, size=n_trials)
    lc_rewards = np.random.normal(loc=r_mu, scale=r_std, size=n_trials)
    

    hc_rewards_t0 = np.zeros(hc_binary_t0.shape)
    mc_rewards_t0 = np.zeros(mc_binary_t0.shape)
    lc_rewards_t0 = np.zeros(lc_binary_t0.shape)

    hc_rewards_t1 = np.zeros_like(hc_rewards_t0)
    mc_rewards_t1 = np.zeros_like(mc_rewards_t0)
    lc_rewards_t1 = np.zeros_like(lc_rewards_t0)

    hc_rewards_t0[hc_binary_t0] = hc_rewards[hc_binary_t0]
    hc_rewards_t1[~hc_binary_t0] = hc_rewards[~hc_binary_t0]

    mc_rewards_t0[mc_binary_t0] = mc_rewards[mc_binary_t0]
    mc_rewards_t1[~mc_binary_t0] = mc_rewards[~mc_binary_t0]

    lc_rewards_t0[lc_binary_t0] = lc_rewards[lc_binary_t0]
    lc_rewards_t1[~lc_binary_t0] = lc_rewards[~lc_binary_t0]
    
    return(hc_rewards, mc_rewards, lc_rewards, 
           hc_rewards_t0, hc_rewards_t1, mc_rewards_t0, mc_rewards_t1, 
           lc_rewards_t0, lc_rewards_t1,mu_rewards,std_rewards)

In [None]:
def assign_changepoint_indices(n_trials, lv_lambda=35, mv_lambda=25, hv_lambda=15):
    
    #volatility 
    lv_size, mv_size, hv_size = int(n_trials/lv_lambda),int(n_trials/mv_lambda),  int(n_trials/hv_lambda)
    #find change point indices and slice to fit n_trials 
    lv_lam = np.cumsum(np.random.poisson(lam=lv_lambda,size=lv_size))
    mv_lam = np.cumsum(np.random.poisson(lam=mv_lambda,size=mv_size))
    hv_lam = np.cumsum(np.random.poisson(lam=hv_lambda,size=hv_size))

    hv_lam = hv_lam[hv_lam < n_trials]
    mv_lam = mv_lam[mv_lam < n_trials]
    lv_lam = lv_lam[lv_lam < n_trials]
    
    return(hv_lam, mv_lam, lv_lam)

In [None]:
def write_changepoints(lc_rewards_t0, mc_rewards_t0, hc_rewards_t0, 
                       lc_rewards_t1, mc_rewards_t1, hc_rewards_t1,
                       hv_lam, mv_lam, lv_lam, n_trials,
                       lv_lambda=30, mv_lambda=20, hv_lambda=10):

    lc_reward_arr = np.transpose(np.array((lc_rewards_t0, lc_rewards_t1)))
    mc_reward_arr = np.transpose(np.array((mc_rewards_t0, mc_rewards_t1)))
    hc_reward_arr = np.transpose(np.array((hc_rewards_t0, hc_rewards_t1)))
    
    #also write generative ps 
    lc_rewards_flipped = np.array([lc_reward_arr[n,::-1] if np.sum(n>=mv_lam)%2 else lc_reward_arr[n,:] for n in range(len(lc_reward_arr))])
    

    mc_rewards_flipped = np.array([mc_reward_arr[n,::-1] if np.sum(n>=mv_lam)%2 else mc_reward_arr[n,:] for n in range(len(mc_reward_arr))])

    mc_rewards_flipped_lv = np.array([mc_reward_arr[n,::-1] if np.sum(n>=lv_lam)%2 else mc_reward_arr[n,:] for n in range(len(mc_reward_arr))])

    mc_rewards_flipped_hv = np.array([mc_reward_arr[n,::-1] if np.sum(n>=hv_lam)%2 else mc_reward_arr[n,:] for n in range(len(mc_reward_arr))])

    hc_rewards_flipped = np.array([hc_reward_arr[n,::-1] if np.sum(n>=mv_lam)%2 else hc_reward_arr[n,:] for n in range(len(hc_reward_arr))])
    
    #mark trials with cp indicator
    lv_cp_vec = np.zeros((n_trials), dtype=bool)
    mv_cp_vec = np.zeros((n_trials), dtype=bool)
    hv_cp_vec = np.zeros((n_trials), dtype=bool)

    lv_cp_vec[lv_lam] = 1
    mv_cp_vec[mv_lam] = 1
    hv_cp_vec[hv_lam] = 1
    
    cp_lv_epoch_idx = list(np.where(lv_cp_vec == 1)[0])
    cp_lv_epoch_idx.insert(0,0)
    cp_lv_epoch_idx.append(n_trials)
    cp_hv_epoch_idx = list(np.where(hv_cp_vec == 1)[0])
    cp_hv_epoch_idx.insert(0,0)
    cp_hv_epoch_idx.append(n_trials)
    cp_mv_epoch_idx = list(np.where(mv_cp_vec == 1)[0])
    cp_mv_epoch_idx.insert(0,0)
    cp_mv_epoch_idx.append(n_trials)

    cp_lv_epoch_len = np.diff(cp_lv_epoch_idx)
    cp_mv_epoch_len = np.diff(cp_mv_epoch_idx)
    cp_hv_epoch_len = np.diff(cp_hv_epoch_idx)
    
    return(lc_rewards_flipped, mc_rewards_flipped,mc_rewards_flipped_lv, mc_rewards_flipped_hv,hc_rewards_flipped, cp_lv_epoch_len, cp_mv_epoch_len, cp_hv_epoch_len, cp_lv_epoch_idx, cp_hv_epoch_idx,cp_mv_epoch_idx, lv_cp_vec, mv_cp_vec, hv_cp_vec)

In [None]:
def write_obs_changepoints(lc_rewards_flipped, mc_rewards_flipped,mc_rewards_flipped_lv, mc_rewards_flipped_hv,hc_rewards_flipped, n_trials): 
#observed reward-identity changes (not "real" changepoints)
#conflict conditions
    lc_obs_cp_vec = np.zeros((n_trials))
    mc_obs_cp_vec = np.zeros((n_trials))
    mc_obs_cp_vec_lv = np.zeros((n_trials))
    mc_obs_cp_vec_hv = np.zeros((n_trials))
    hc_obs_cp_vec = np.zeros((n_trials))

    lc_rewards_flipped_vec = lc_rewards_flipped != 0 
    mc_rewards_flipped_vec = mc_rewards_flipped != 0 
    mc_rewards_flipped_vec_lv = mc_rewards_flipped_lv != 0 
    mc_rewards_flipped_vec_hv = mc_rewards_flipped_hv != 0 
    hc_rewards_flipped_vec = hc_rewards_flipped != 0 

    lc_obs_cp_idx = np.where(lc_rewards_flipped_vec[:-1] != lc_rewards_flipped_vec[1:])[0]
    mc_obs_cp_idx = np.where(mc_rewards_flipped_vec[:-1] != mc_rewards_flipped_vec[1:])[0]
    mc_obs_cp_idx_lv = np.where(mc_rewards_flipped_vec_lv[:-1] != mc_rewards_flipped_vec_lv[1:])[0]
    mc_obs_cp_idx_hv = np.where(mc_rewards_flipped_vec_hv[:-1] != mc_rewards_flipped_vec_hv[1:])[0]
    hc_obs_cp_idx = np.where(hc_rewards_flipped_vec[:-1] != hc_rewards_flipped_vec[1:])[0]

    lc_obs_cp_vec[lc_obs_cp_idx+1] = 1
    mc_obs_cp_vec[mc_obs_cp_idx+1] = 1
    mc_obs_cp_vec_lv[mc_obs_cp_idx_lv+1] = 1
    mc_obs_cp_vec_hv[mc_obs_cp_idx_hv+1] = 1
    hc_obs_cp_vec[hc_obs_cp_idx+1] = 1
    
    return(lc_obs_cp_vec,mc_obs_cp_vec,mc_obs_cp_vec_lv,
           mc_obs_cp_vec_hv,hc_obs_cp_vec)

In [None]:
def calc_empirical_cprob(cp_lv_epoch_len,cp_mv_epoch_len, cp_hv_epoch_len, cp_mv_epoch_idx, cp_lv_epoch_idx, cp_hv_epoch_idx,
                        lc_rewards_flipped, mc_rewards_flipped_hv,mc_rewards_flipped_lv, hc_rewards_flipped, n_trials, print_epoch_cprob=1, hc_p=.65,
                        mc_p=.75, lc_p=.85):

    #calculate cumulative prob. for each epoch 
    #need to set reshape value to actual cps 
    hc_test = []
    lc_test = []
    hv_test = []
    lv_test = []

    for epoch in range(len(cp_mv_epoch_len)):
        hc_test.append(np.cumsum((hc_rewards_flipped[cp_mv_epoch_idx[epoch]:
    cp_mv_epoch_idx[epoch+1],0] > 0))
                  /np.arange(1,cp_mv_epoch_len[epoch]+1))

    for epoch in range(len(cp_mv_epoch_len)):
        lc_test.append(np.cumsum((lc_rewards_flipped[cp_mv_epoch_idx[epoch]:
    cp_mv_epoch_idx[epoch+1],0] > 0))
                  /np.arange(1,cp_mv_epoch_len[epoch]+1))


    for epoch in range(len(cp_lv_epoch_len)):
        lv_test.append(np.cumsum((mc_rewards_flipped_lv[cp_lv_epoch_idx[epoch]:
    cp_lv_epoch_idx[epoch+1],0] > 0))
                  /np.arange(1,cp_lv_epoch_len[epoch]+1))

    for epoch in range(len(cp_hv_epoch_len)):
        hv_test.append(np.cumsum((mc_rewards_flipped_hv[cp_hv_epoch_idx[epoch]:
    cp_hv_epoch_idx[epoch+1],0] > 0))
                  /np.arange(1,cp_hv_epoch_len[epoch]+1))

    peak_p_reward_hc = [epoch[-1] for epoch in hc_test]
    peak_p_reward_lc = [epoch[-1] for epoch in lc_test]
    peak_p_reward_hv = [epoch[-1] for epoch in hv_test]
    peak_p_reward_lv = [epoch[-1] for epoch in lv_test]


    print(np.mean(peak_p_reward_hc[::2]), np.mean(peak_p_reward_lc[::2]), 
    np.mean(peak_p_reward_lv[::2]), np.mean(peak_p_reward_hv[::2]))

    lc_cprob_epoch = np.hstack(lc_test).flatten()
    hc_cprob_epoch = np.hstack(hc_test).flatten()
    hv_cprob_epoch = np.hstack(hv_test).flatten()
    lv_cprob_epoch = np.hstack(lv_test).flatten()

 
    if print_epoch_cprob == 1: 
        plt.rcParams['figure.figsize'] = (10,10)
        plt.rcParams['font.size'] = 18
        plt.figure()
        plt.subplot(2,1,1)
        plt.plot(lc_cprob_epoch, 'b.-')
        plt.axhline(lc_p, alpha = 0.5, color='k')
        plt.axhline(1-lc_p, alpha = 0.5, color='k')
        plt.title('low conflict')
        plt.xlabel('t')
        plt.ylabel('cumulative p(r|correct) per epoch')
        plt.show()
        plt.subplot(2,1,2)
        plt.plot(hc_cprob_epoch, 'r.-')
        plt.title('high conflict')
        plt.xlabel('t')
        plt.ylabel('cumulative p(r|correct) per epoch')
        plt.axhline(hc_p, alpha = 0.5, color='k')
        plt.axhline(1-hc_p, alpha = 0.5, color='k')
        plt.show()
        
#     if animate_plot == 1: 
#         x = np.arange(0,n_trials)

#         # target 0 and 1 p(r)
#         t1t2_reward_fig = plt.figure(figsize=(20, 10))
#         plt.plot()
#         plt.xlim(0, n_trials)
#         plt.ylim(-.05, 1.05)
#         graph, = plt.plot([], [], 'k-')
#         graph2, = plt.plot([], [], 'ro')

#         plt.title("cumulative p(reward)",fontsize = 40)
#         plt.xlabel("trial", fontsize = 30)
#         plt.ylabel("cumulative p(reward|correct)", fontsize = 30)
#         plt.tick_params(axis='both', which='major', labelsize=20)

#         filename='hv'
    

#         def animate(i):
# #             graph.set_data(x[:i+1], muRewardDelta_vec[:i+1])
#             graph2.set_data(x[:i+1], hv_cprob_epoch[:i+1])
#             return graph

#         anim = FuncAnimation(t1t2_reward_fig, animate, frames=n_trials, interval=10)
#     #     plt.show()
#         HTML(anim.to_html5_video())

#         anim.save('sample_tc_' + filename +'.mp4',extra_args=['-vcodec', 'libx264'])
        
        return(lc_cprob_epoch, hc_cprob_epoch,hv_cprob_epoch,lv_cprob_epoch)

In [None]:
def calc_expected_val(n_trials):
    #performance criterion for 80/20 p. in frank's task is .65 choosing high val., .5 for 60/40 
    #values from Frank, Worack, & Curran 2005 

    frank_lc_crit_highVal = .65
    frank_lc_crit_lowVal = 1 - frank_lc_crit_highVal
    lc_p_highVal = .85 
    lc_p_lowVal = 1-lc_p_highVal 

    frank_hc_crit_highVal = .5
    frank_hc_crit_lowVal = 1 - frank_hc_crit_highVal
    hc_p_highVal = .65
    hc_p_lowVal = 1-hc_p_highVal 

    pts_per_trial = 3
    total_cost = -1*n_trials 


    hv_target_pts_lc = frank_lc_crit_highVal*lc_p_highVal*(n_trials*pts_per_trial)
    lv_target_pts_lc = frank_lc_crit_lowVal*lc_p_lowVal*(n_trials*pts_per_trial)

    hv_target_pts_hc = frank_hc_crit_highVal*hc_p_highVal*(n_trials*pts_per_trial)
    lv_target_pts_hc = frank_hc_crit_lowVal*hc_p_lowVal*(n_trials*pts_per_trial)

    expected_val_lc = hv_target_pts_lc + lv_target_pts_lc + total_cost
    expected_val_hc = hv_target_pts_hc + lv_target_pts_hc + total_cost

    print('expected_val low conflict ', expected_val_lc,'\nexpected_val high conflict ', expected_val_hc)
    return(None)

In [None]:
def slice_trials(hc_rewards_flipped, hc_cumulative_p, hc_cprob_epoch, hc_obs_cp_vec, lc_rewards_flipped, lc_cumulative_p, lc_cprob_epoch, mv_cp_vec, lc_obs_cp_vec, mc_rewards_flipped_lv, mc_cumulative_p,lv_cprob_epoch, lv_cp_vec, mc_obs_cp_vec_lv, mc_rewards_flipped_hv,hv_cprob_epoch, hv_cp_vec, mc_obs_cp_vec_hv, mu_rewards, std_rewards, n_input_trials=300): 
    
    cp_list = [mv_cp_vec,lv_cp_vec,hv_cp_vec] 
    
    next_cp = []
    for l in cp_list: 
        #INDEX OF NEXT CP AFTER EXACTLY n/2 trials
        next_cp_idx = np.asarray(np.argwhere(l))
        next_cp.append(next_cp_idx[next_cp_idx > n_input_trials][0])  
        
        print(next_cp)
    
    mv_slice,lv_slice,hv_slice = next_cp
           

    hc_rewards_flipped = np.vstack((hc_rewards_flipped[:n_input_trials,:], hc_rewards_flipped[mv_slice:mv_slice+n_input_trials, :]))
    hc_cumulative_p = np.hstack((hc_cumulative_p[:n_input_trials], hc_cumulative_p[mv_slice:mv_slice+n_input_trials]))
    hc_cprob_epoch = np.hstack((hc_cprob_epoch[:n_input_trials], hc_cprob_epoch[mv_slice:mv_slice+n_input_trials]))
    hc_obs_cp_vec = np.hstack((hc_obs_cp_vec[:n_input_trials], hc_obs_cp_vec[mv_slice:mv_slice+n_input_trials]))
    
    mv_cp_vec = np.hstack((mv_cp_vec[:n_input_trials], mv_cp_vec[mv_slice:mv_slice+n_input_trials]))
    
    lv_cp_vec = np.hstack((lv_cp_vec[:n_input_trials], lv_cp_vec[lv_slice:lv_slice+n_input_trials]))
   
    hv_cp_vec = np.hstack((hv_cp_vec[:n_input_trials], hv_cp_vec[hv_slice:hv_slice+n_input_trials]))
   


    lc_rewards_flipped = np.vstack((lc_rewards_flipped[:n_input_trials,:], lc_rewards_flipped[mv_slice:mv_slice+n_input_trials,:]))
    lc_cumulative_p = np.hstack((lc_cumulative_p[:n_input_trials], lc_cumulative_p[mv_slice:mv_slice+n_input_trials]))
    lc_cprob_epoch = np.hstack((lc_cprob_epoch[:n_input_trials], lc_cprob_epoch[mv_slice:mv_slice+n_input_trials]))
    lc_obs_cp_vec = np.hstack((lc_obs_cp_vec[:n_input_trials], lc_obs_cp_vec[mv_slice:mv_slice+n_input_trials]))



    mc_rewards_flipped_lv = np.vstack((mc_rewards_flipped_lv[:n_input_trials,:], mc_rewards_flipped_lv[lv_slice:lv_slice+n_input_trials,:]))
    mc_cumulative_p_lv = np.hstack((mc_cumulative_p[:n_input_trials], mc_cumulative_p[lv_slice:lv_slice+n_input_trials]))
    lv_cprob_epoch = np.hstack((lv_cprob_epoch[:n_input_trials], lv_cprob_epoch[lv_slice:lv_slice+n_input_trials]))
    mc_obs_cp_vec_lv = np.hstack((mc_obs_cp_vec_lv[:n_input_trials], mc_obs_cp_vec_lv[lv_slice:lv_slice+n_input_trials]))


    mc_rewards_flipped_hv = np.vstack((mc_rewards_flipped_hv[:n_input_trials,:], mc_rewards_flipped_hv[hv_slice:hv_slice+n_input_trials,:]))
    mc_cumulative_p_hv = np.hstack((mc_cumulative_p[:n_input_trials], mc_cumulative_p[hv_slice:hv_slice+n_input_trials]))
    hv_cprob_epoch = np.hstack((hv_cprob_epoch[:n_input_trials], hv_cprob_epoch[hv_slice:hv_slice+n_input_trials]))
    mc_obs_cp_vec_hv = np.hstack((mc_obs_cp_vec_hv[:n_input_trials], mc_obs_cp_vec_hv[hv_slice:hv_slice+n_input_trials]))


        
    return (hc_rewards_flipped, hc_cumulative_p, hc_cprob_epoch, hc_obs_cp_vec, lc_rewards_flipped, lc_cumulative_p, lc_cprob_epoch, mv_cp_vec, lc_obs_cp_vec, mc_rewards_flipped_lv, mc_cumulative_p_lv,mc_cumulative_p_hv,lv_cprob_epoch, lv_cp_vec, mc_obs_cp_vec_lv, mc_rewards_flipped_hv,hv_cprob_epoch, hv_cp_vec, mc_obs_cp_vec_hv, mu_rewards, std_rewards) 

In [None]:
def print_trial_structure(hc_rewards_flipped, hc_cumulative_p, hc_cprob_epoch, hc_obs_cp_vec, lc_rewards_flipped, lc_cumulative_p, lc_cprob_epoch, mv_cp_vec, lc_obs_cp_vec, mc_rewards_flipped_lv, mc_cumulative_p_lv,mc_cumulative_p_hv, lv_cprob_epoch, lv_cp_vec, mc_obs_cp_vec_lv, mc_rewards_flipped_hv,hv_cprob_epoch, hv_cp_vec, mc_obs_cp_vec_hv, mu_rewards, std_rewards, filenames = ['test_highC', 'test_lowC', 'test_lowV', 'test_highV']):

    #print 
    filename = filenames[0]
    taskParameters = np.array((hc_rewards_flipped[:,0], hc_rewards_flipped[:,1], hc_cumulative_p, hc_cprob_epoch, 1-hc_cprob_epoch, mv_cp_vec, hc_obs_cp_vec, mu_rewards, std_rewards))
    taskParameters = np.matrix.transpose(taskParameters)
    header = ("r_t1, r_t2, c_prob, c_prob_epoch_t0, c_prob_epoch_t1, cp, obs_cp, mu_rewards, std_rewards") 
    np.savetxt(filename + '.csv',taskParameters, header = header, delimiter=',', comments = '', fmt='%f')

    filename = filenames[1]
    taskParameters = np.array((lc_rewards_flipped[:,0],  lc_rewards_flipped[:,1], lc_cumulative_p, lc_cprob_epoch, 1-lc_cprob_epoch, mv_cp_vec, lc_obs_cp_vec, mu_rewards, std_rewards))
    taskParameters = np.matrix.transpose(taskParameters)
    header = ("r_t1, r_t2, c_prob, c_prob_epoch_t0, c_prob_epoch_t1,cp, obs_cp, mu_rewards, std_rewards") 
    np.savetxt(filename + '.csv',taskParameters, header = header, delimiter=',', comments = '', fmt='%f')

    filename = filenames[2]
    taskParameters = np.array(( mc_rewards_flipped_lv[:,0], mc_rewards_flipped_lv[:,1], mc_cumulative_p_lv,lv_cprob_epoch, 1-lv_cprob_epoch,lv_cp_vec, mc_obs_cp_vec_lv, mu_rewards, std_rewards))
    taskParameters = np.matrix.transpose(taskParameters)
    header = ("r_t1, r_t2, c_prob, c_prob_epoch_t0, c_prob_epoch_t1, cp, obs_cp,mu_rewards, std_rewards") 
    np.savetxt(filename + '.csv',taskParameters, header = header, delimiter=',', comments = '', fmt='%f')

    filename = filenames[3]
    taskParameters = np.array(( mc_rewards_flipped_hv[:,0], mc_rewards_flipped_hv[:,1], mc_cumulative_p_hv,hv_cprob_epoch,1-hv_cprob_epoch, hv_cp_vec, mc_obs_cp_vec_hv, mu_rewards, std_rewards))
    taskParameters = np.matrix.transpose(taskParameters)
    header = ("r_t1, r_t2, c_prob, c_prob_epoch_t0, c_prob_epoch_t1, cp, obs_cp, mu_rewards, std_rewards") 
    np.savetxt(filename + '.csv',taskParameters, header = header, delimiter=',', comments = '', fmt='%f')

In [None]:
subjects = np.arange(0,1)

# subjects = np.arange(0,1)
min_epoch_length = 10 

In [None]:

for s_idx in subjects: 
    cp = [0,1]
    filenames = ['hc', 'lc', 'lv', 'hv']
    filenames=[fn+'_'+str(s_idx) for fn in filenames]
    failed=True 
    while failed:    
        (lc_binary_t0, lc_binary_t1, mc_binary_t0,mc_binary_t1, hc_binary_t0, hc_binary_t1, hc_cumulative_p, mc_cumulative_p, lc_cumulative_p, n_trials) = define_prob_vectors()

        (hc_rewards, mc_rewards, lc_rewards, 
                   hc_rewards_t0, hc_rewards_t1, mc_rewards_t0, mc_rewards_t1, 
                   lc_rewards_t0, lc_rewards_t1, mu_rewards, std_rewards) = assign_reward_values(lc_binary_t0, lc_binary_t1, mc_binary_t0, mc_binary_t1, hc_binary_t0, hc_binary_t1,n_trials)

        (hv_lam, mv_lam, lv_lam) = assign_changepoint_indices(n_trials)


        (lc_rewards_flipped, mc_rewards_flipped,mc_rewards_flipped_lv, mc_rewards_flipped_hv,hc_rewards_flipped, cp_lv_epoch_len, cp_mv_epoch_len, cp_hv_epoch_len, cp_lv_epoch_idx, cp_hv_epoch_idx,cp_mv_epoch_idx, lv_cp_vec, mv_cp_vec, hv_cp_vec) = write_changepoints(lc_rewards_t0, mc_rewards_t0, hc_rewards_t0, 
        lc_rewards_t1, mc_rewards_t1, hc_rewards_t1,
        hv_lam, mv_lam, lv_lam,n_trials)


        (lc_obs_cp_vec,mc_obs_cp_vec,mc_obs_cp_vec_lv,
        mc_obs_cp_vec_hv,hc_obs_cp_vec) = write_obs_changepoints(lc_rewards_flipped, mc_rewards_flipped,mc_rewards_flipped_lv, mc_rewards_flipped_hv,hc_rewards_flipped,n_trials)

        (lc_cprob_epoch, hc_cprob_epoch,hv_cprob_epoch,lv_cprob_epoch) = calc_empirical_cprob(cp_lv_epoch_len,cp_mv_epoch_len, cp_hv_epoch_len, cp_mv_epoch_idx, cp_lv_epoch_idx, cp_hv_epoch_idx,
                                lc_rewards_flipped, mc_rewards_flipped_hv,mc_rewards_flipped_lv, hc_rewards_flipped,n_trials)


        (hc_rewards_flipped, hc_cumulative_p, hc_cprob_epoch, hc_obs_cp_vec, lc_rewards_flipped, lc_cumulative_p, lc_cprob_epoch, mv_cp_vec, lc_obs_cp_vec, mc_rewards_flipped_lv, mc_cumulative_p_lv,mc_cumulative_p_hv,lv_cprob_epoch, lv_cp_vec, mc_obs_cp_vec_lv, mc_rewards_flipped_hv,hv_cprob_epoch, hv_cp_vec, mc_obs_cp_vec_hv, mu_rewards, std_rewards)  = slice_trials(hc_rewards_flipped, hc_cumulative_p, hc_cprob_epoch, hc_obs_cp_vec, lc_rewards_flipped, lc_cumulative_p, lc_cprob_epoch, mv_cp_vec, lc_obs_cp_vec, mc_rewards_flipped_lv, mc_cumulative_p,lv_cprob_epoch, lv_cp_vec, mc_obs_cp_vec_lv, mc_rewards_flipped_hv,hv_cprob_epoch, hv_cp_vec, mc_obs_cp_vec_hv, mu_rewards, std_rewards, n_input_trials=300)
        
        failed=False
        for cp in [mv_cp_vec,lv_cp_vec,hv_cp_vec]:
            test=cp
            np.insert(test,0,0)
            np.append(test,600)
            failed |= np.sum(np.diff(np.argwhere(test),axis=0)< min_epoch_length) > 1
#             failed |= np.min(np.diff(np.argwhere(test),axis=0))< min_epoch_length
            print(np.sum(np.diff(np.argwhere(test),axis=0)< min_epoch_length))
        
    print_trial_structure(hc_rewards_flipped, hc_cumulative_p, hc_cprob_epoch, hc_obs_cp_vec, lc_rewards_flipped, lc_cumulative_p, lc_cprob_epoch, mv_cp_vec, lc_obs_cp_vec, mc_rewards_flipped_lv, mc_cumulative_p_lv,mc_cumulative_p_hv, lv_cprob_epoch, lv_cp_vec, mc_obs_cp_vec_lv, mc_rewards_flipped_hv,hv_cprob_epoch, hv_cp_vec, mc_obs_cp_vec_hv, mu_rewards, std_rewards, filenames)