In [1]:
import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import beta, bernoulli, binom
import copy

In [2]:
#plot
np.random.seed(1)

a_prior = 1
b_prior = 1

bernoulli_p = 0.7
bernoulli_pmf = bernoulli.pmf([0,1], bernoulli_p)

a_current = copy.copy(a_prior)
b_current = copy.copy(b_prior)

x_ = np.linspace(0,1, 1000+1)
beta_prior = beta.pdf(x_, a_prior, b_prior)

fig, ax = plt.subplots(1, 2, figsize=(15, 5))
plt.close()

plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
plt.rc('axes', labelsize=16)
legend_size=16
title_size=16
plt.close()

def prior_func():
    
    #Beta
    ax[0].cla()
    ax[1].cla()

    ax[0].plot(x_, beta_prior, label=r'Prior with $\alpha={}$ and $\beta={}$'.format(a_prior, b_prior),
              color='blue')
    ax[0].set_ylim(0, 5)
    ax[0].set_xlim(-0.1, 1.1)
    ax[0].legend(loc='upper left', prop={'size': legend_size})
    ax[0].axvline(bernoulli_p, 0, 0.5, ls='--', color='green')

    # Bernoulli
    ax[1].set_xlim(-0.1, 1.1)
    ax[1].set_ylim(0, 1.1)
    ax[1].vlines([0,1], 0, bernoulli_pmf, colors='b', lw=5, alpha=0.5,
                label=r'Bernoulli with $p={}$'.format(bernoulli_p))
    ax[1].legend(loc='upper left', prop={'size': legend_size})

    plt.tight_layout()

    
def posterior_func(i):
    
    #show how prior shifts given new data
    global a_current
    global b_current
    
    #draw a sample, i.e. flip the coin
    bernoulli_sample = bernoulli.rvs(bernoulli_p)
    
    #update parameters
    if bernoulli_sample == 1:
        a_current+=1 
    else:
        b_current+=1
        
    beta_current = beta.pdf(x_, a_current, b_current)
    
    
    ax[0].cla()
    ax[1].cla()
    
    #Beta
    ax[0].plot(x_, beta_prior, label=r'Prior with $\alpha={}$ and $\beta={}$'.format(a_prior, b_prior),
              color='blue', alpha=0.5)
    ax[0].plot(x_, beta_current, label=r'Posterior with $\alpha={}$ and $\beta={}$'.format(a_current, b_current),
             alpha=1, color='firebrick')
    ax[0].set_ylim(0, 5)
    ax[0].set_xlim(-0.1, 1.1)
    ax[0].legend(loc='upper left', prop={'size': legend_size})
    ax[0].axvline(bernoulli_p, 0, 0.5, ls='--', color='green')

    # Bernoulli
    ax[1].set_xlim(-0.1, 1.1)
    ax[1].set_ylim(0, 1.1)
    ax[1].vlines([0,1], 0, bernoulli_pmf, colors='b', lw=5, alpha=0.5,
                label=r'Bernoulli with $p={}$'.format(bernoulli_p))
    ax[1].axvline(bernoulli_sample, 0, 10/11, color='black', label='Datapoint {}'.format(i+1))
    ax[1].legend(loc='upper left', prop={'size': legend_size})

    plt.rc('xtick', labelsize=16)
    plt.rc('ytick', labelsize=16)
    plt.rc('axes', labelsize=16)
    plt.tight_layout()
    
    #plt.savefig('beta_{}_posterior_bernoulli.png'.format(i+1))
    #plt.show();

In [3]:
def animate_func(i):
    if i==0:
        return(prior_func())
    else:
        return(posterior_func(i))

In [4]:
from matplotlib import animation
from IPython.display import HTML

total_frames=30

# Animation setup
anim = animation.FuncAnimation(
    fig, func=animate_func, frames=total_frames, interval=1000, blit=False
)
anim.save('Beta_Bernoulli.gif', dpi=300)
a_current = 1
b_current = 1
HTML(anim.to_jshtml())

<Figure size 432x288 with 0 Axes>