In [1]:
import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import beta, bernoulli, binom
import copy

In [2]:
#plot
np.random.seed(1)

a_prior = 1
b_prior = 1

n_binom = 3
p_binom = 0.35
x_binom = np.arange(0, n_binom+1)
binom_pmf = binom.pmf(x_binom, n_binom, p_binom)

a_current = copy.copy(a_prior)
b_current = copy.copy(b_prior)

x_ = np.linspace(0,1, 1000+1)
beta_prior = beta.pdf(x_, a_prior, b_prior)

fig, ax = plt.subplots(1, 2, figsize=(15, 5))

plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
plt.rc('axes', labelsize=16)
legend_size=16
title_size=16
plt.close()


def prior_func():
    
    #Beta
    ax[0].cla()
    ax[1].cla()

    ax[0].plot(x_, beta_prior, label=r'Prior with $\alpha={}$ and $\beta={}$'.format(a_prior, b_prior),
              color='blue')
    ax[0].set_ylim(0, 5)
    ax[0].set_xlim(-0.1, 1.1)
    ax[0].legend(loc='upper left', prop={'size': legend_size})
    ax[0].axvline(p_binom, 0, 0.5, ls='--', color='green')

    # Binomial
    ax[1].set_xlim(-0.1, 3.1)
    ax[1].set_ylim(0, 1.1)
    ax[1].plot(x_binom, binom_pmf, '.', label=r'Binomial with $n= {}$ and $p={}$'.format(n_binom, p_binom),
          marker='o', markersize=10, color='blue')
    ax[1].set_ylim(0, binom_pmf.max() + 0.15)
    ax[1].vlines(x_binom, 0, binom_pmf, colors='b', lw=3, alpha=0.5)
    ax[1].legend(loc='upper right', prop={'size': legend_size})

    plt.tight_layout()
    
def posterior_func(i):
    
    #show how prior shifts given new data
    global a_current
    global b_current
    
    #draw a sample, i.e. flip the coin
    binom_sample = binom.rvs(n_binom, p_binom)
    
    #update parameters
    a_current += binom_sample 
    b_current += n_binom - binom_sample
        
    beta_current = beta.pdf(x_, a_current, b_current)
    
    
    ax[0].cla()
    ax[1].cla()
    
    #Beta
    ax[0].plot(x_, beta_prior, label=r'Prior with $\alpha={}$ and $\beta={}$'.format(a_prior, b_prior),
              color='blue', alpha=0.5)
    ax[0].plot(x_, beta_current, label=r'Posterior with $\alpha={}$ and $\beta={}$'.format(a_current, b_current),
             alpha=1, color='firebrick')
    ax[0].set_ylim(0, 5)
    ax[0].set_xlim(-0.1, 1.1)
    ax[0].legend(loc='upper left', prop={'size': legend_size})
    ax[0].axvline(p_binom, 0, 0.5, ls='--', color='green')

    # Bernoulli
    ax[1].set_xlim(-0.1, 3.1)
    ax[1].set_ylim(0, 1.1)
    ax[1].plot(x_binom, binom_pmf, '.', label=r'Binomial with $n= {}$ and $p={}$'.format(n_binom, p_binom),
              marker='o', markersize=10, color='blue')
    ax[1].axvline(binom_sample, 0, binom.pmf(binom_sample, n_binom, p_binom)*1.5,
                  color='black', lw=4, label='Datapoint {}'.format(i+1))
    ax[1].set_ylim(0, binom_pmf.max() + 0.15)
    ax[1].vlines(x_binom, 0, binom_pmf, colors='b', lw=3, alpha=0.5)
    ax[1].legend(loc='upper right', prop={'size': legend_size})

    plt.rc('xtick', labelsize=16)
    plt.rc('ytick', labelsize=16)
    plt.rc('axes', labelsize=16)
    plt.tight_layout()
    

In [3]:
def animate_func(i):
    if i==0:
        return(prior_func())
    else:
        return(posterior_func(i))

In [4]:
from matplotlib import animation
from IPython.display import HTML

total_frames=30

# Animation setup
anim = animation.FuncAnimation(
    fig, func=animate_func, frames=total_frames, interval=1000, blit=False
)
anim.save('Beta_Binomial.gif', dpi=300)
a_current = 1
b_current = 1
HTML(anim.to_jshtml())

<Figure size 432x288 with 0 Axes>