In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [None]:
psarr = np.linspace(0.01, 0.45, 20)
prarr = np.linspace(0.55, 1, 25)

In [None]:
ps, pr = np.meshgrid(psarr, prarr)

In [None]:
def root_finder(ps, pr):
    a = 2 * pr - 1
    b = -(2 * pr - ps - 1)
    c = (pr -1) * ps
    root1 = (-b + np.sqrt(b**2 - 4 * a * c)) / 2 / a
    root2 = (-b - np.sqrt(b**2 - 4 * a * c)) / 2 / a
    return root1, root2

In [None]:
root1, root2 = root_finder(ps, pr)

In [None]:
plt.imshow(root1.T, origin='lower')
plt.colorbar()
plt.xlabel('pr')
plt.ylabel('ps')

In [None]:
def update_pt(ps, pr, pt):
    num = ps * (1-pr) * (1-pt) + (1-ps) * pr * pt
    den = pr * pt + (1-pr) * (1-pt)
    return num / den

def simulate_pt_sequence(ps, pr, ptinit, ntrials):
    '''
    Simulate updating of pt
    '''
    pt = ptinit
    ptlst = [pt]
    for i in range(ntrials):        
        pt = update_pt(ps, pr, pt)
        ptlst.append(pt)
    return np.array(ptlst)
        
    
    
def first_switch_id(ps, pr, ptinit, ntrials):
    '''
    Determine the index where pt first crosses 0.5
    ptinit has to be < 0.5
    '''
    assert ptinit < 0.5
    ptlst = simulate_pt_sequence(ps, pr, ptinit, ntrials)
    idx = np.where(ptlst > 0.5)[0]
    if len(idx) == 0:
        return -1 
    else:
        return idx[0]
    
    

In [None]:
a = np.linspace(0, 1, 100)
b= np.where(a > 2)[0]

In [None]:
len(b)

In [None]:
colors = sns.palettes.color_palette('Blues', 10)
idxlst = []
for i,p in enumerate(np.linspace(0.01, 0.45, 10)):
    sim1 = simulate_pt_sequence(p, 1-p, 0.1, 5)
    idx = first_switch_id(p, 1-p, 0.1, 5)
    idxlst.append(idx)
#     sim2 = simulate_pt_sequence(0.1, 0.9, 0.1, 20)
    print(sim1[1])
    plt.plot(sim1, color=colors[i])

plt.hlines(0.5, 0, 5, linestyles='--')

In [None]:
prewlst = np.linspace(0.55, 0.99, 10)
pswitchlst = np.linspace(0.01, 0.45, 15)
idxarr = np.zeros((len(prewlst), len(pswitchlst)))
for idr, prew in enumerate(prewlst):
    for ids, psw in enumerate(pswitchlst):
        root1, _ = root_finder(psw, prew)
        idx = first_switch_id(psw, prew, 1-root1, 10)
        idxarr[idr, ids] = idx

In [None]:
plt.imshow(idxarr, origin='lower', extent=[min(pswitchlst), max(pswitchlst), min(prewlst), max(prewlst)])
plt.colorbar()
plt.xlabel('psw')
plt.ylabel('prew')

In [None]:
idxarr