In [None]:
def sanity_check(V=1, Vp=0, N=20, sigma=1, plot=True, get_gain=False):

    Vh_1_list = np.random.normal(V, sigma, N)
    Vh_2_list = np.random.normal(Vp, sigma, N)

    probs = np.ones(N) # probability that Vh_1 is drawn from V (as opposed to Vp)
    gain = np.ones(N)
    switch = np.zeros(N)

    for i in range(N):

        Vh_1 = Vh_1_list[i]
        Vh_2 = Vh_2_list[i]

        if Vh_1 < Vh_2:
            switch[i] = 1
            Vh_1, Vh_2 = Vh_2, Vh_1

        p1 = np.exp(-0.5 * (((Vh_1 - V) / sigma) ** 2 + ((Vh_2 - Vp) / sigma) ** 2))
        p2 = np.exp(-0.5 * (((Vh_1 - Vp) / sigma) ** 2 + ((Vh_2 - V) / sigma) ** 2))

        denom = p1 + p2

        probs[i] = p1 / denom

        gain[i] = V - (V*p1 + Vp * p2) / denom
    
    if get_gain:
        return gain
    elif plot:
        fig, ax1 = plt.subplots()
        fig.suptitle(f"V={V}, V'={Vp}, sigma={sigma}")

        ax1.plot(gain, label="expected gain from eval", color="r")
        ax1.set_ylabel("Expected gain from eval")
        
        diffs = abs(Vh_1_list-Vh_2_list)
        ax2 = ax1.twinx()
        ax2.plot(diffs, label="diff between vhats", color="b")
        ax2.set_ylabel("Difference between vhats")
        fig.legend()
        plt.show()

In [None]:
sanity_check()

In [None]:
sanity_check(V=1.1, Vp=1)

In [None]:
sanity_check(V=11, Vp=10)

In [None]:
sanity_check(sigma=0.5, N=10)

In [None]:
# average gain as a fn of sigma
num = 20
sigmas = np.linspace(0.1, 10, num=num)
gains = np.zeros(num)
for i, s in enumerate(sigmas):
    num_samples = 1000
    gains[i] = sum(sanity_check(sigma=s, get_gain=True, N=num_samples))/num_samples

plt.plot(sigmas, gains)
plt.xlabel("Sigma")
plt.ylabel("Average gain")
plt.title("Average gain as a function of sigma for V=1, V'=0")
plt.show()

In [None]:
# average gain as a fn of V, fixing Vp at 0
num = 50
Vs = np.linspace(0.1, 10, num=num)
gains = np.zeros(num)
for i, v in enumerate(Vs):
    num_samples = 10000
    gains[i] = sum(sanity_check(V=v, get_gain=True, N=num_samples))/num_samples

plt.plot(Vs, gains)
plt.xlabel("Sigma")
plt.ylabel("Average gain")
plt.title("Average gain as a function of sigma for sigma=1, V'=0")
plt.show()

In [None]:
# average gain as a fn of V, fixing Vp at 0
num = 50
Vs = np.linspace(0.1, 10, num=num)
gains = np.zeros(num)
for i, v in enumerate(Vs):
    num_samples = 10000
    gains[i] = sum(sanity_check(V=v, sigma=2, get_gain=True, N=num_samples))/num_samples

plt.plot(Vs, gains)
plt.xlabel("V")
plt.ylabel("Average gain")
plt.title("Average gain as a function of V for sigma=2, V'=0")
plt.show()

In [None]:
# average gain as a fn of V, fixing Vp at 0
num = 50
Vs = np.linspace(0.1, 10, num=num)
gains = np.zeros(num)
for i, v in enumerate(Vs):
    num_samples = 10000
    gains[i] = sum(sanity_check(V=v, sigma=0.2, get_gain=True, N=num_samples))/num_samples

plt.plot(Vs, gains)
plt.xlabel("V")
plt.ylabel("Average gain")
plt.title("Average gain as a function of V for sigma=0.2, V'=0")
plt.show()