In [1]:
import pandas as pd
import numpy as np
import numpy.matlib
import matplotlib.pyplot as plt
import pdb
import scipy
from scipy.optimize import minimize, fmin
from scipy.stats import multivariate_normal
from tqdm.notebook import tqdm

In [9]:
""" 
Obtaining data from a given expt
"""
csv_test = pd.read_csv('../auditory_categorization_noContext/important_things_not_included_in_assets/allTrials.csv')
csv_data = pd.read_csv('auditory_categorization_prolific_online_data/human_auditory_categorization_91686_2022-01-19_20h56.51_f8da6647-2fb7-444d-9e8c-e52b3a37d4b2/6062c088821c76a49374e453_categorization_task_2021-05-17_23h24.18.467.csv');
                       

In [10]:
n_tones = 3
n_trials = csv_data.shape[0]-47

"""
Get tones and values of keys pressed
"""
test_columns = list(csv_test.columns)
test_tones_name = test_columns.index('Name')
test_tones_col_idx = test_columns.index('Tones')
test_tones_cat_col_idx = test_columns.index('Tonekind')

df_names = (csv_test.iloc[:,test_tones_name]).values
df_tones = (csv_test.iloc[:,test_tones_col_idx]).values
df_tone_cat = (csv_test.iloc[:,test_tones_cat_col_idx]).values

tones_array_orig = np.zeros((n_trials,n_tones))
tones_array_idxs_keep = []

tones_cat_array_orig = np.zeros((n_trials,n_tones))
tones_cat_array_idxs_keep = []

for i_wav in range(n_trials):
    if isinstance(csv_data['Name'][i_wav+46],str):
        tones_array_orig[i_wav,:] = np.array(df_tones[np.where(csv_data['Name'][i_wav+46]\
                                                          ==df_names)[0]][0][1:-1].split(',')).astype(float)  
        tones_array_idxs_keep += [i_wav]

        tones_cat_array_orig[i_wav,:] = np.array(df_tone_cat[np.where(csv_data['Name'][i_wav+46]\
                                                          ==df_names)[0]][0][1:-1].split(',')).astype(float)  
        tones_cat_array_idxs_keep += [i_wav]


df_tones = np.copy(tones_array_orig[tones_array_idxs_keep,:])
df_tone_cat = np.copy(tones_cat_array_orig[tones_cat_array_idxs_keep,:])
df_corrans = np.copy(csv_data['corrAns'][46:csv_data.shape[0]])[tones_array_idxs_keep]
df_keys = np.copy(csv_data['test_resp.keys'][46:csv_data.shape[0]])[tones_array_idxs_keep]

In [11]:
"""
Find no response cases in the expt
"""
no_response = np.intersect1d(np.where(df_keys!='h')[0],np.where(df_keys!='l')[0])
print("Did not respond to: ",no_response)

"""
Convert keys ['l','h'] to [0,1] and calculate accuracies
"""
corrans_num_orig = np.zeros_like(df_corrans)
corrans_num_orig[df_corrans == 'h'] = 1

keys_num_orig = np.zeros_like(df_keys)
keys_num_orig[df_keys == 'h'] = 1

corrans_num = corrans_num_orig[:600]
keys_num = keys_num_orig[:600]
tones_array = df_tones[:600]
print("Got correct: ", np.sum(keys_num==corrans_num)/len(tones_array))
print("Got high correct: ", np.sum((keys_num)*(corrans_num))/np.sum(corrans_num))
print("Got low correct: ", np.sum((1-keys_num)*(1-corrans_num))/np.sum(1-corrans_num))

Did not respond to:  [  0 150]
Got correct:  0.6883333333333334
Got high correct:  0.7408637873754153
Got low correct:  0.6354515050167224


In [12]:
allTrial_tones = np.repeat(tones_array,1,axis = 0)
allTrial_behaviour = np.reshape(keys_num,np.prod(keys_num.shape)) 
# this has been changed to check how values change with observer responses

expt_tones = np.arange(90,3000,1) #array of possible true tones
exptFreqSeqArray = np.arange(np.log10(expt_tones[0]), np.log10(expt_tones[-1]), np.log10(1003/1000)*40)
log_freq_seq_array = np.arange(0.6,4.7,0.1)
log_freq_percept = np.arange(0.6,4.7,0.1) # array of possible perceptual tones
expt_freq_seq_mid = np.median(exptFreqSeqArray)
low_dist = [expt_freq_seq_mid - 0.15,0.1]
high_dist = [expt_freq_seq_mid + 0.15,0.1]

idxs_with_response = np.delete(np.arange(len(allTrial_tones)),no_response)
allTrial_tones_responded = allTrial_tones[idxs_with_response,:]
allTrial_behaviour_responded = allTrial_behaviour[idxs_with_response]
allTrial_tones_cat_responded = df_tone_cat[idxs_with_response]
iAllGaussLow = 0
iAllGaussHigh = 0
allGaussLow = np.zeros((len(allTrial_behaviour_responded),5))
allGaussHigh = np.zeros((len(allTrial_behaviour_responded),5))
    
for i_trial in range(len(allTrial_tones_responded)):
    if sum(allTrial_tones_cat_responded[i_trial]==1)==3:
        allGaussLow[iAllGaussLow,:3] = allTrial_tones_responded[i_trial,:]
        allGaussLow[iAllGaussLow,3] = allTrial_behaviour_responded[i_trial]
        allGaussLow[iAllGaussLow,4] = int(i_trial)
        iAllGaussLow += 1
    #if (sum(np.log10(allTrial_tones_responded[i_trial])<expt_freq_seq_mid)==3 
    #    & sum(np.log10(allTrial_tones_responded[i_trial])>(low_dist[0]-2*low_dist[1]))==3):
    #    allGaussLow[iAllGaussLow,:3] = allTrial_tones_responded[i_trial,:]
    #    allGaussLow[iAllGaussLow,3] = allTrial_behaviour_responded[i_trial]
    #    allGaussLow[iAllGaussLow,4] = int(i_trial)
    #    print(i_trial, 
    #      allTrial_tones_cat_responded[i_trial], 
    #      np.log10(allGaussLow[iAllGaussLow,:3]),
    #      allGaussLow[iAllGaussLow,3])
    #    iAllGaussLow += 1
    elif sum(allTrial_tones_cat_responded[i_trial]==2)==3:
        allGaussHigh[iAllGaussHigh,:3] = allTrial_tones_responded[i_trial,:]
        allGaussHigh[iAllGaussHigh,3] = allTrial_behaviour_responded[i_trial]    
        allGaussHigh[iAllGaussHigh,4] = int(i_trial)
        iAllGaussHigh += 1
    #elif (sum(np.log10(allTrial_tones_responded[i_trial])>expt_freq_seq_mid)==3 
    #    & sum(np.log10(allTrial_tones_responded[i_trial])<(high_dist[0]+2*high_dist[1]))==3): 
    #    allGaussHigh[iAllGaussHigh,:3] = allTrial_tones_responded[i_trial,:]
    #    allGaussHigh[iAllGaussHigh,3] = allTrial_behaviour_responded[i_trial]    
    #    allGaussHigh[iAllGaussHigh,4] = int(i_trial)        
    #    iAllGaussHigh += 1
AllGaussian = np.concatenate((allGaussLow[:iAllGaussLow], allGaussHigh[:iAllGaussHigh]),axis=0)    
trial_tones = AllGaussian[:,:3]
trial_behaviour = AllGaussian[:,3]

unique_tones = np.unique(trial_tones)
tone1_prob_behaviour = np.zeros((len(unique_tones)))
tone2_prob_behaviour = np.zeros((len(unique_tones)))
tone3_prob_behaviour = np.zeros((len(unique_tones)))

for i_tone in range(len(unique_tones)):
    tone1_prob_behaviour[i_tone] = np.mean(trial_behaviour[trial_tones[:,0]\
                                                       ==unique_tones[i_tone]])
    tone2_prob_behaviour[i_tone] = np.mean(trial_behaviour[trial_tones[:,1]\
                                                       ==unique_tones[i_tone]])
    tone3_prob_behaviour[i_tone] = np.mean(trial_behaviour[trial_tones[:,2]\
                                                       ==unique_tones[i_tone]])
    
influence = np.nanmean([tone1_prob_behaviour,tone2_prob_behaviour,
                        tone3_prob_behaviour],axis=0)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [13]:
def gaussian(x, mean, sigma):
    return np.exp(-(x-mean)**2/(2*sigma**2))

def Tones1dgrid(latentTones, sigma):    
    
    input_array_0 = np.expand_dims(gaussian(log_freq_percept, latentTones[0], sigma), axis = 1);
    s0 = 1/np.sum(input_array_0); 
    input_array_0 *= s0; 
    return input_array_0

def Tones3dgrid(latentTones, sigma):    
    
    input_array_0 = np.expand_dims(gaussian(log_freq_percept, latentTones[0], sigma), axis = 1)
    input_array_1 = np.expand_dims(gaussian(log_freq_percept, latentTones[1], sigma), axis = 1)
    input_array_2 = np.expand_dims(gaussian(log_freq_percept, latentTones[2], sigma), axis = 1)
    s0 = 1/np.sum(input_array_0); 
    s1 = 1/np.sum(input_array_1); 
    s2 = 1/np.sum(input_array_2);
    input_array_0 *= s0; 
    input_array_1 *= s1; 
    input_array_2 *= s2; 
    
    input_array_mat = np.expand_dims(input_array_0@input_array_1.T,axis=2)@(input_array_2.T) #p(T1,T2..|H)   
                                     
    return input_array_mat

def posterior_array(freq_input, n_tones, p_low, log_prior):
    """
    Arguments: 
    freq_input - range of all possible frequencies (percepts?)
    p_back - prob of background
    p_low - prob of low condition
    log_prior - list of prior parameters
    """
    
    log_prior_low_mean = log_prior[0]; log_prior_low_sigma = log_prior[2];
    log_prior_high_mean = log_prior[1]; log_prior_high_sigma = log_prior[2];
    prior_low = gaussian(x=freq_input, mean=log_prior_low_mean, sigma=log_prior_low_sigma)
    prior_high = gaussian(x=freq_input, mean=log_prior_high_mean, sigma=log_prior_high_sigma)
    prior_dist_mixed_high = prior_high 
    #mixture model with p(T|B) = 1/no. of possible freqs
    prior_dist_mixed_high /= prior_dist_mixed_high.sum() #normalizing
    prior_dist_mixed_high = np.expand_dims(prior_dist_mixed_high, axis = 1)
    prior_dist_mixed_low = prior_low 
    #mixture model with p(T|B) = 1/no. of possible freqs
    prior_dist_mixed_low /= prior_dist_mixed_low.sum() #normalizing
    prior_dist_mixed_low = np.expand_dims(prior_dist_mixed_low, axis = 1)
        
    if n_tones == 3:
        prior_tones_low = np.expand_dims(prior_dist_mixed_low@np.transpose\
                                         (prior_dist_mixed_low),axis=2)@np.transpose(prior_dist_mixed_low) \
        #p(T1,T2..|L) 
        
        prior_tones_high = np.expand_dims(prior_dist_mixed_high@np.transpose\
                                          (prior_dist_mixed_high),axis=2)@np.transpose(prior_dist_mixed_high) \
        #p(T1,T2..|H) 

    elif n_tones == 1:
        prior_tones_low = prior_dist_mixed_low
        prior_tones_high = prior_dist_mixed_high
        
    normalizer = (1-p_low)*prior_tones_high + p_low*prior_tones_low #p(H)*p(T1,T2..|H) + p(L)*p(T1,T2..|L)
    posterior = prior_tones_high*(1-p_low)/normalizer
    # posterior /= np.sum(posterior)
    
    return prior_dist_mixed_high, prior_dist_mixed_low, prior_tones_high, prior_tones_low, normalizer, posterior

In [16]:
# define mle function
def errorFn(modelProbHighGivenPercept, exptInfluenceFunction):
    return np.squeeze(modelProbHighGivenPercept)-np.squeeze(exptInfluenceFunction)
    
def MLE(params):
    log_prior_low_mean, log_prior_high_mean, log_prior_sigma, sigma_sensory, prob_low = \
    params[0], params[1], params[2], params[3], params[4] # inputs are guesses at our parameters  
    
    _,_,LikelihoodLatentTonegivenHigh,LikelihoodLatentTonegivenLow,_,_ = \
    posterior_array(log_freq_seq_array, n_tones=1, p_low=prob_low,\
                    log_prior=[log_prior_low_mean,log_prior_high_mean,log_prior_sigma])

    LikelihoodPerceptgivenHigh = np.zeros((len(log_freq_percept),1))
    LikelihoodPerceptgivenLow = np.zeros((len(log_freq_percept),1))
    
    for itrue1 in range(len(log_freq_seq_array)):
        probPerceptgivenLatentTones = Tones1dgrid([log_freq_seq_array[itrue1]],sigma=sigma_sensory)                                                           
        LikelihoodPerceptgivenHigh \
        += probPerceptgivenLatentTones * LikelihoodLatentTonegivenHigh[itrue1]
        LikelihoodPerceptgivenLow \
        += probPerceptgivenLatentTones * LikelihoodLatentTonegivenLow[itrue1]
    probHighgivenPercept = LikelihoodPerceptgivenHigh*(1-prob_low)/\
    (LikelihoodPerceptgivenHigh*(1-prob_low) + LikelihoodPerceptgivenLow*(prob_low))
    
    probability_high = np.zeros((len(unique_tones),1))
    for i_tone in range(len(unique_tones)):
        input_array_mat = Tones1dgrid([np.log10(unique_tones[i_tone])],sigma=sigma_sensory)
        probability_high[i_tone] = np.sum(np.multiply(probHighgivenPercept>0.5,input_array_mat))
    
    #plt.plot(probability_high)
    #plt.plot(influence)
    RMSerror = errorFn(probability_high, influence)
    #plt.plot(np.abs(RMSerror))
    #plt.show()
    #print(sigma_sensory, np.sum(np.abs(RMSerror)))
           
    return(np.sum(RMSerror**2))

In [17]:
"""
New optimization algorithm: uses scipy.optimize.fmin. 
Crude grid initially and then find minimum using the function.
"""
guess_VaryingSigma = np.asarray([0.02,1])
nll_VaryingSigma = np.zeros(len(guess_VaryingSigma))
thetas_VaryingSigma = np.zeros((len(guess_VaryingSigma),4))
for s in range(len(guess_VaryingSigma)):
    guess_low_mean = np.arange(2.1,2.71,0.15); guess_high_mean = np.arange(2.7,3.31,0.15); 
    guess_sensory_sigma = np.arange(0.05,0.4,0.01); guess_p_low = np.arange(0.4,0.61,0.05)

    # Constraining guesses of means of low and high distributions based on observed behaviour in figure shown above. 

    neg_ll_array = np.zeros((len(guess_low_mean), len(guess_high_mean),
                             len(guess_sensory_sigma), len(guess_p_low)))
    for lm in tqdm(range(len(guess_low_mean))):
        for hm in tqdm(range(len(guess_high_mean)), leave=False, desc="High mean"):
            for ss in range(len(guess_sensory_sigma)):
                for pl in range(len(guess_p_low)):
                    params = [guess_low_mean[lm], guess_high_mean[hm], guess_VaryingSigma[s], \
                              guess_sensory_sigma[ss], guess_p_low[pl]]
                    # print(lm, hm, pb)
                    neg_ll_array[lm,hm,ss,pl] = MLE(params) 

    """
    Means and p_back corresponding to the least negative log likelihood value
    """
    idxs = np.where(neg_ll_array == np.amin(neg_ll_array)) 
    best_thetas = np.array([guess_low_mean[idxs[0]], guess_high_mean[idxs[1]], 
                            guess_sensory_sigma[idxs[2]], guess_p_low[idxs[3]]])
    
    print(guess_VaryingSigma[s], neg_ll_array[idxs], best_thetas)
    #nll_VaryingSigma[s] = neg_ll_array[idxs]
    #thetas_VaryingSigma[s,:] = best_thetas

  0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

  posterior = prior_tones_high*(1-p_low)/normalizer


High mean:   0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

  probHighgivenPercept = LikelihoodPerceptgivenHigh*(1-prob_low)/\


0.02 [0.18798702 0.18798702 0.18798702 0.18798702 0.18798702 0.18798702
 0.18798702 0.18798702 0.18798702 0.18798702 0.18798702 0.18798702
 0.18798702 0.18798702 0.18798702 0.18798702 0.18798702 0.18798702
 0.18798702 0.18798702 0.18798702 0.18798702 0.18798702 0.18798702
 0.18798702 0.18798702 0.18798702] [[2.1  2.1  2.1  2.1  2.1  2.1  2.1  2.25 2.25 2.25 2.25 2.25 2.25 2.25
  2.25 2.4  2.4  2.4  2.4  2.4  2.4  2.4  2.55 2.55 2.55 2.55 2.7 ]
 [3.15 3.15 3.15 3.15 3.3  3.3  3.3  2.85 3.   3.   3.   3.   3.15 3.15
  3.15 2.7  2.7  2.85 2.85 3.   3.   3.   2.7  2.85 2.85 3.   2.85]
 [0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29
  0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29]
 [0.45 0.5  0.55 0.6  0.4  0.45 0.5  0.6  0.45 0.5  0.55 0.6  0.4  0.45
  0.5  0.55 0.6  0.5  0.55 0.4  0.45 0.5  0.5  0.45 0.5  0.4  0.45]]


  0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

High mean:   0%|          | 0/5 [00:00<?, ?it/s]

1.0 [0.18798702 0.18798702 0.18798702 0.18798702 0.18798702 0.18798702
 0.18798702 0.18798702] [[2.1  2.1  2.25 2.25 2.4  2.4  2.55 2.55]
 [3.15 3.3  3.   3.15 2.85 3.   2.7  2.85]
 [0.29 0.29 0.29 0.29 0.29 0.29 0.29 0.29]
 [0.5  0.5  0.5  0.5  0.5  0.5  0.5  0.5 ]]


In [None]:
# define mle function
def MLE_fmin(params):
    log_prior_low_mean, log_prior_high_mean, log_prior_sigma, sigma_sensory, prob_low = \
    params[0], params[1], params[2], params[3], params[4] # inputs are guesses at our parameters  
    
    _,_,LikelihoodLatentTonegivenHigh,LikelihoodLatentTonegivenLow,_,_ = \
    posterior_array(log_freq_seq_array, n_tones=len(trial_tones[0]), p_low=prob_low,\
                    log_prior=[log_prior_low_mean,log_prior_high_mean,log_prior_sigma])

    LikelihoodPerceptgivenHigh = np.zeros((len(log_freq_percept),len(log_freq_percept),len(log_freq_percept)))
    LikelihoodPerceptgivenLow = np.zeros((len(log_freq_percept),len(log_freq_percept),len(log_freq_percept)))
    
    for itrue1 in range(len(log_freq_seq_array)):
        for itrue2 in range(len(log_freq_seq_array)):            
            for itrue3 in range(len(log_freq_seq_array)):
                probPerceptgivenLatentTones = Tones3dgrid([log_freq_seq_array[itrue1],\
                                                           log_freq_seq_array[itrue2],\
                                                           log_freq_seq_array[itrue3]],sigma=sigma_sensory)                                                           
                LikelihoodPerceptgivenHigh \
                += probPerceptgivenLatentTones * LikelihoodLatentTonegivenHigh[itrue1,itrue2,itrue3]
                LikelihoodPerceptgivenLow \
                += probPerceptgivenLatentTones * LikelihoodLatentTonegivenLow[itrue1,itrue2,itrue3]
    probHighgivenPercept = LikelihoodPerceptgivenHigh*(1-prob_low)/\
    (LikelihoodPerceptgivenHigh*(1-prob_low) + LikelihoodPerceptgivenLow*(prob_low))
        
    neg_ll = 0; 
    probability_high = np.zeros((len(trial_tones),1))
    for i_trial in range(len(trial_tones)):
        input_array_mat = Tones3dgrid(np.array([np.log10(trial_tones[i_trial][0]),\
                                               np.log10(trial_tones[i_trial][1]),
                                               np.log10(trial_tones[i_trial][2])]),sigma=sigma_sensory)
        probability_high0 = np.sum(np.multiply(probHighgivenPercept>0.5,input_array_mat))
        probability_high[i_trial] = np.sum(np.multiply(probHighgivenPercept>0.5,input_array_mat))
            
        if trial_behaviour[i_trial]:
            if np.isnan(np.log(probability_high0 + 0.0000001)) \
            or np.isinf(np.log(probability_high0 + 0.0000001)) \
            or np.isnan(np.log(1-probability_high0 + 0.0000001)) \
            or np.isinf(np.log(1-probability_high0 + 0.0000001)):
                pdb.set_trace()
            neg_ll += -np.log(probability_high0 + 0.0000001) # if high dist is chosen by observer
        else:
            neg_ll += -np.log(1 - probability_high0 + 0.0000001) # if low dist is chosen by observer
    print(params, neg_ll)        
    return(neg_ll)

"""
Optimization using neadler mead method and a simplex algorithm
"""
minimum_nll = scipy.optimize.fmin(MLE_fmin, [2.4,3,0.02,0.25,0.5], maxiter=10000, maxfun=10000, 
                                  xtol=0.01, ftol=0.01)

print(minimum_nll)

In [None]:
#Influence plots the way we currently understand them (11-17-2020)
unique_tones = np.unique(trial_tones)
tone1_prob_behaviour = np.zeros((len(unique_tones)))
tone2_prob_behaviour = np.zeros((len(unique_tones)))
tone3_prob_behaviour = np.zeros((len(unique_tones)))

for i_tone in range(len(unique_tones)):
    tone1_prob_behaviour[i_tone] = np.mean(trial_behaviour[trial_tones[:,0]\
                                                       ==unique_tones[i_tone]])
    tone2_prob_behaviour[i_tone] = np.mean(trial_behaviour[trial_tones[:,1]\
                                                       ==unique_tones[i_tone]])
    tone3_prob_behaviour[i_tone] = np.mean(trial_behaviour[trial_tones[:,2]\
                                                       ==unique_tones[i_tone]])
influence1, = plt.plot(np.log10(unique_tones), tone1_prob_behaviour, label = 'Influence of Tone 1')
influence2, = plt.plot(np.log10(unique_tones), tone2_prob_behaviour, label = 'Influence of Tone 2')
influence3, = plt.plot(np.log10(unique_tones), tone3_prob_behaviour, label = 'Influence of Tone 3')
influence, = plt.plot(np.log10(unique_tones), np.nanmean([tone1_prob_behaviour,tone2_prob_behaviour,
                                                          tone3_prob_behaviour],axis=0),
                       'k', label = 'Average Influence')

_, probability_high = MLE(minimum_nll)

tone1_prob_behaviour = np.zeros((len(unique_tones)))
tone2_prob_behaviour = np.zeros((len(unique_tones)))
tone3_prob_behaviour = np.zeros((len(unique_tones)))

for i_tone in range(len(unique_tones)):
    tone1_prob_behaviour[i_tone] = np.mean(probability_high[trial_tones[:,0]\
                                                       ==unique_tones[i_tone]])
    tone2_prob_behaviour[i_tone] = np.mean(probability_high[trial_tones[:,2]\
                                                       ==unique_tones[i_tone]])
    tone3_prob_behaviour[i_tone] = np.mean(probability_high[trial_tones[:,2]\
                                                      ==unique_tones[i_tone]])
mnll_influence, = plt.plot(np.log10(unique_tones),
                           (tone1_prob_behaviour+tone2_prob_behaviour+tone3_prob_behaviour)/3,'k.',
                          label = 'p(B_H|T) given fmin parameters')

"""
 _, probability_high = MLE(best_thetas)

tone1_prob_behaviour = np.zeros((len(unique_tones)))
tone2_prob_behaviour = np.zeros((len(unique_tones)))
tone3_prob_behaviour = np.zeros((len(unique_tones)))

for i_tone in range(len(unique_tones)):
    tone1_prob_behaviour[i_tone] = np.mean(probability_high[trial_tones[:,0]\
                                                       ==unique_tones[i_tone]])
    tone2_prob_behaviour[i_tone] = np.mean(probability_high[trial_tones[:,1]\
                                                       ==unique_tones[i_tone]])
    tone3_prob_behaviour[i_tone] = np.mean(probability_high[trial_tones[:,2]\
                                                       ==unique_tones[i_tone]])
grid_influence, = plt.plot(np.log10(unique_tones),
                           (tone1_prob_behaviour+tone2_prob_behaviour+tone3_prob_behaviour)/3,'k--',
                          label = 'p(B_H|T) given grid parameters')
"""
plt.legend(handles=[influence, mnll_influence])

#plt.xlim([1.9,3.6])
plt.ylim([-0.2,1.1])
plt.xlabel('log10(Tones)')
plt.ylabel('p(B_H|T)')
# plt.savefig('figures/experimenter=mark_categorization_task_2020-12-28_18h57.17.409_plow_additional_parameter.png')