In [2]:
%matplotlib qt
import matplotlib.animation as animation
from matplotlib.widgets import Slider, Button
import matplotlib as mpl
from matplotlib import pyplot as plt
import scipy.interpolate as inter
import numpy as np
import scipy.stats as stats
import torch

# function for getting new y data from updated params
def acquire_ydata(params):
    global prior
    global likelihood
    global multiply
    global estimation
    global percentage
    global deviation
    prior = stats.norm.pdf(x, params[0], params[1])
    likelihood = stats.norm.pdf(x, params[2], 0.1*params[4]*params[3])
    mean = (params[0]*0.1*params[4]*params[3]**2 + params[2]*params[1]**2)/(
       0.1*params[4]*params[3]**2 + params[1]**2)
    variance = (params[1]**2*0.1*params[4]*params[3]**2)/(
       params[1]**2 + 0.1*params[4]*params[3]**2)
    sigma = variance**0.5
    multiply = stats.norm.pdf(x, mean, sigma)
    estimation = []
    percentage = []
    deviation = []
    for i in delays:
      estimation_value = (params[0]*0.1*i*params[3]**2 + params[2]*params[1]**2)/(
         0.1*i*params[3]**2 + params[1]**2)
      estimation.append(estimation_value)
      deviation_value = ((params[1]**2*0.1*i*params[3]**2)/(
         params[1]**2 + 0.1*i*params[3]**2))**0.5
      deviation.append(deviation_value)
      percentage.append((np.exp(estimation_value)/(np.exp(estimation_value) + np.exp(ss_value))))


# the values for the initial plot
params_names = ['mean_u', 'std_u', 'mean_es', 'std_es', 'delay']
# same sequence as the params_names
params = [0, 3, 50, 4, 10]
delays = np.linspace(1, 122, 1000)
ss_value = 20
x = np.linspace(-20, 80, 1000)

# plot the initial params
acquire_ydata(params)



# set the plot structure
fig = plt.figure(constrained_layout=False, figsize=(20,10))
gs = fig.add_gridspec(nrows=4, ncols=5, left=0.05, right=0.95, wspace=0.35, hspace=0.35)
# ax1, ax2 are on the left
ax1 = fig.add_subplot(gs[:2, :2], ylim = [0,0.5], xlabel = 'reward value', ylabel = 'pdf')
ax2 = fig.add_subplot(gs[2:, :2], ylim = [0,55], xlabel = 'delay', ylabel = 'estimation value')

# ax3, ax4 are on the right
ax3 = fig.add_subplot(gs[:2, 3:], ylim = [0,1.1], xlabel = 'delay', ylabel = 'choose LL percentage')
ax4 = fig.add_subplot(gs[2:, 3:], ylim = [0,9], xlabel = 'delay', ylabel = 'posterior standard deviation')

plt.show()


def update(val):
    # update curve
    global params
    global prior
    global likelihood
    global multiply
    global estimation
    global percentage
    global deviation
    for i in range(len(params)):
      params[i] = sliders[i].val 
    acquire_ydata(params)
    m1.set_ydata(prior)
    m2.set_ydata(likelihood)
    m3.set_ydata(multiply)
    m4.set_ydata(estimation)
    m5.set_ydata(percentage)
    m6.set_ydata(deviation)
    # redraw canvas while idle
    fig.canvas.draw_idle()



def reset(event):
    global params
    global prior
    global likelihood
    global multiply
    global estimation
    global percentage
    global deviation
    #reset the values
    for i in np.arange(len(params)):
      sliders[i].reset()
    acquire_ydata(params)
    # redraw canvas while idle
    fig.canvas.draw_idle()


m1, = ax1.plot (x, prior, label = 'prior')
m2, = ax1.plot (x, likelihood, label = 'likelihood')
m3, = ax1.plot (x, multiply, label = 'posterior')
m4, = ax2.plot(delays, estimation)
m5, = ax3.plot(delays, percentage)
m6, = ax4.plot(delays, deviation)

ax1.axvline(x = ss_value, linewidth = 1, label='ss_value')
ax2.axhline(y = ss_value, linewidth = 1, label='ss_value')
ax1.legend(loc=2,prop={'size':10})

sliders = []

for i in range(len(params)):

    axamp = plt.axes([0.44, 0.6-(i*0.05), 0.12, 0.02])
    # Slider
    min = 0
    max = 10
    if i == 1 or i == 3:
       min = 0.1
    elif i == 2 or i == 0:
       max = 50
    elif i == 4:
       min = 1
       max = 122

    s = Slider(axamp, params_names[i], min, max, valinit=params[i])
    sliders.append(s)

    
for i in range(len(params)):
    #samp.on_changed(update_slider)
    sliders[i].on_changed(update)

axres = plt.axes([0.44, 0.6-((len(params))*0.05), 0.12, 0.02])
bres = Button(axres, 'Reset')
bres.on_clicked(reset)


plt.show()