In [1]:
import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import norm
import copy

In [2]:
def update_v(var0, var, n=1):
    return(1/(1/var0 + n/var))

def update_m(x, mu0, var0, var, n=1):
    new_v = update_v(var0, var, n)
    m_ = (mu0/var0 + np.sum(x)/var)
    return(new_v * m_)

In [3]:
#plot
np.random.seed(1)

m_prior = 0
v_prior = 1

m_likelihood = 2
v_likelihood = 1.5

m_current = copy.copy(m_prior)
v_current = copy.copy(v_prior)

x_ = np.linspace(-3,7, 1000+1)
norm_prior = norm.pdf(x_, m_prior, np.sqrt(v_prior))

x_likelihood = np.linspace(-3, 7, 1000+1)
norm_likelihood = norm.pdf(x_likelihood, m_likelihood, np.sqrt(v_likelihood))

fig, ax = plt.subplots(1, 2, figsize=(15, 5))

plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
plt.rc('axes', labelsize=16)
legend_size=16
title_size=16
plt.close()

datapoints = []

def prior_func():
    
    #Normal
    ax[0].cla()
    ax[1].cla()

    ax[0].plot(x_, norm_prior, label=r'Prior with $\mu={}$ and $\sigma^2={}$'.format(m_prior, v_prior),
              color='blue')
    ax[0].set_ylim(0, 0.6)
    ax[0].set_xlim(-3, 7)
    ax[0].axvline(m_likelihood, 0, 0.5, ls='--', color='green')
    ax[0].legend(loc='upper left', prop={'size': legend_size})

    ax[1].plot(x_likelihood, norm_likelihood, label=r'Normal with $\mu={}$ and $\sigma^2={}$'.format(m_likelihood,
              v_likelihood), color='blue')
    ax[1].axvline(m_likelihood, 0, 0.7, color='firebrick', label='True mean', lw=2)
    ax[1].set_xlim(-3,7)
    ax[1].set_ylim(0, 0.5)
    ax[1].legend(loc='upper left', prop={'size': legend_size})
    plt.tight_layout()
    
def posterior_func(i):
    
    #show how prior shifts given new data
    global m_current
    global v_current
    
    norm_sample = norm.rvs(m_likelihood, np.sqrt(v_likelihood))
    datapoints.append(norm_sample)
    
    #update parameters
    v_current = update_v(v_prior, v_likelihood, n=i+1)
    m_current = update_m(datapoints, m_prior, v_prior, v_likelihood, n=i+1)
    norm_posterior = norm.pdf(x_, m_current, np.sqrt(v_current))
        
    ax[0].cla()
    ax[1].cla()
    
    ax[0].plot(x_, norm_prior, label=r'Prior with $\mu={}$ and $\sigma^2={}$'.format(m_prior, v_prior),
              color='blue', alpha=0.5)
    ax[0].plot(x_, norm_posterior, label=r'Posterior with $\mu={:.02f}$ and $\sigma^2={:.02f}$'.format(m_current, v_current),
              color='firebrick')
    ax[0].set_ylim(0, 0.6)
    ax[0].set_xlim(-3, 7)
    ax[0].axvline(m_likelihood, 0, 0.5, ls='--', color='green')
    ax[0].legend(loc='upper left', prop={'size': legend_size})

    ax[1].plot(x_likelihood, norm_likelihood, label=r'Normal with $\mu={}$ and $\sigma^2={}$'.format(m_likelihood,
              v_likelihood), color='blue')
    ax[1].axvline(norm_sample, 0, 0.55, color='black', label='Datapoint {}'.format(i+1), lw=5)
    ax[1].axvline(m_likelihood, 0, 0.7, color='firebrick', label='True mean', lw=2)
    if i > 0:
        ax[1].vlines(datapoints, 0, 0.1, color='black', label='previous datapoints')
    ax[1].set_xlim(-3,7)
    ax[1].set_ylim(0, 0.5)
    ax[1].legend(loc='upper left', prop={'size': legend_size})
    plt.tight_layout()

In [4]:
def animate_func(i):
    if i==0:
        return(prior_func())
    else:
        return(posterior_func(i))

In [5]:
from matplotlib import animation
from IPython.display import HTML

total_frames=30

# Animation setup
anim = animation.FuncAnimation(
    fig, func=animate_func, frames=total_frames, interval=1000, blit=False
)
anim.save('Normal_Normal.gif', dpi=300)
datapoints = []
m_current = 0
v_current = 1
HTML(anim.to_jshtml())

<Figure size 432x288 with 0 Axes>