In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
    Created on Thu Jul 4 2024
    
    @author: Yaning
"""

%matplotlib qt
import matplotlib.animation as animation
from matplotlib.widgets import Slider, Button
import matplotlib as mpl
from matplotlib import pyplot as plt
import scipy.interpolate as inter
import numpy as np
import scipy.stats as stats
import torch

In [3]:
# function for getting new y data from updated params
def acquire_ydata(params):
    global prior
    global likelihood
    global multiply
    global estimation
    global percentage
    global deviation
    global time_percep
    global prior_decre
    # for the first graph
    a = params[0]
    prior_sigma = params[1]
    ll_value = params[2]
    es_sigma = params[3]
    delay = params[4]
    b = params[5]
    prior = stats.norm.pdf(x, a/delay*10, prior_sigma)
    likelihood = stats.norm.pdf(x, ll_value, 0.1*delay*es_sigma)
    t = 1/(1+b*np.exp(-delay*0.1))

    # also for the other graphs
    mean = (a/delay*10*t**2*es_sigma**2 + ll_value*prior_sigma**2)/(
       t**2*es_sigma**2 + prior_sigma**2)
    sigma = ((t**2*es_sigma**2*prior_sigma**2)/
             (t**2*es_sigma**2 + prior_sigma**2))**0.5
    
    # this is for the first graph
    multiply = stats.norm.pdf(x, mean, sigma)
    
    # time perception and prior mean
    time_percep = 1/(1+b*np.exp(-delays*0.1))
    prior_decre = a/delays*10

    # other graphs y axis values
    estimation = []
    percentage = []
    for i in delays:
      # time perception for estimated value
      t_sub = 1/(1+b*np.exp(-i*0.1))
      # ll estimated value
      estimation_value = (a/i*10*t_sub**2*es_sigma**2 + ll_value*prior_sigma**2)/(
       t_sub**2*es_sigma**2 + prior_sigma**2)
      estimation.append(estimation_value)
      # choosing LL percentage
      percentage.append((np.exp(estimation_value)/(np.exp(estimation_value) + np.exp(ss_value))))


# the values for the initial plot
params_names = ['μ$_{prior}$(a)', 'σ$_{prior}$', 'μ$_{es}$', 'σ$_{es}$', 'delay', 'TP(b)']
# same sequence as the params_names
params = [1, 3, 50, 4, 10, 5]
delays = np.linspace(1, 122, 1000)
ss_value = 20
x = np.linspace(-20, 80, 1000)

# plot the initial params
acquire_ydata(params)



# set the plot structure
fig = plt.figure(constrained_layout=False, figsize=(15,15))
gs = fig.add_gridspec(nrows=6, ncols=4, left=0.05, right=0.95, wspace=0.35, hspace=0.55)
# ax1, ax2 are on the left
ax1 = fig.add_subplot(gs[:3, :2], ylim = [0,0.5], xlabel = 'reward value (€)', ylabel = 'pdf')
ax2 = fig.add_subplot(gs[3:, :2], ylim = [0,55], xlabel = 'delay (days)', ylabel = 'LL estimation value (€)')

# ax3, ax4 are on the right
ax3 = fig.add_subplot(gs[:2, 3:], ylim = [0,1.1], xlabel = 'delay (days)', ylabel = 'choose LL percentage (%)')
ax4 = fig.add_subplot(gs[2:4, 3:], ylim = [0,10], xlabel = 'delay (days)', ylabel = 'prior mean value (€)')
ax5 = fig.add_subplot(gs[4:, 3:], ylim = [0,1.1], xlabel = 'delay (days)', ylabel = 'normalised subjective time perception')

plt.show()


def update(val):
    # update curve
    global params
    global prior
    global likelihood
    global multiply
    global estimation
    global percentage
    global deviation
    for i in range(len(params)):
      params[i] = sliders[i].val 
    acquire_ydata(params)
    m1.set_ydata(prior)
    m2.set_ydata(likelihood)
    m3.set_ydata(multiply)
    m4.set_ydata(estimation)
    m5.set_ydata(percentage)
    m6.set_ydata(prior_decre)
    m7.set_ydata(time_percep)
    # redraw canvas while idle
    fig.canvas.draw_idle()



def reset(event):
    global params
    global prior
    global likelihood
    global multiply
    global estimation
    global percentage
    global prior_decre
    global time_percep
    #reset the values
    for i in np.arange(len(params)):
      sliders[i].reset()
    acquire_ydata(params)
    # redraw canvas while idle
    fig.canvas.draw_idle()


m1, = ax1.plot (x, prior, label = 'prior')
m2, = ax1.plot (x, likelihood, label = 'likelihood')
m3, = ax1.plot (x, multiply, label = 'posterior')
m4, = ax2.plot(delays, estimation)
m5, = ax3.plot(delays, percentage)
m6, = ax4.plot(delays, prior_decre)
m7, = ax5.plot(delays, time_percep)

ax1.axvline(x = ss_value, linewidth = 1, label='ss_value', linestyle='--')
ax2.axhline(y = ss_value, linewidth = 1, label='ss_value', linestyle='--')
ax1.legend(loc=2,prop={'size':10})

sliders = []

for i in range(len(params)):

    axamp = plt.axes([0.55, 0.6-(i*0.05), 0.12, 0.02])
    # Slider
    min = 0
    max = 10
    # sigmas
    if i == 1 or i == 3:
       min = 0.1
    # ll_value
    elif i == 2:
       max = 60
    elif i == 4:
       min = 1
       max = 122
    elif i == 5:
       max = 20
    s = Slider(axamp, params_names[i], min, max, valinit=params[i])
    sliders.append(s)

    
for i in range(len(params)):
    #samp.on_changed(update_slider)
    sliders[i].on_changed(update)

axres = plt.axes([0.55, 0.6-((len(params))*0.05), 0.12, 0.02])
bres = Button(axres, 'Reset')
bres.on_clicked(reset)


plt.show()