In [1]:
import numpy as np
import matplotlib.pyplot as plt

import copy
from scipy.stats import dirichlet, multinomial

In [2]:
# plotting devices
from scipy.special import gammaln, xlogy
import matplotlib.tri as tri

corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])

refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=4)


midpoints = [(corners[(i + 1) % 3] + corners[(i + 2) % 3]) / 2.0 \
             for i in range(3)]
def xy2bc(xy, tol=1.e-3):
    '''Converts 2D Cartesian coordinates to barycentric.'''
    s = [(corners[i] - midpoints[i]).dot(xy - midpoints[i]) / 0.75 \
         for i in range(3)]
    return np.clip(s, tol, 1.0 - tol)

class Dirichlet(object):
    def __init__(self, alpha):
        self._alpha = np.array(alpha)
        self._coef = np.sum(gammaln(alpha)) - gammaln(np.sum(alpha))    
        
    def pdf(self, x):
        '''Returns pdf value for `x`.'''
        return(np.exp(- self._coef+ np.sum((xlogy(self._alpha-1, x.T)).T, 0)))
    
def draw_pdf_contours(dist, nlevels=200, subdiv=5, **kwargs):

    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]
    xys = [xy2bc(xy) for xy in zip(trimesh.x, trimesh.y)]
    #print(len(xys))

    plt.tricontourf(trimesh, pvals, nlevels, **kwargs)
    plt.axis('equal')
    plt.xlim(0, 1)
    plt.ylim(0, 0.75**0.5)
    plt.axis('off')
    plt.show();
    
#draw_pdf_contours(Dirichlet([2, 2, 2]))

In [3]:
#plot
np.random.seed(1)

multi_p = [0.6, 0.25, 0.15] 

a_prior = [2,2,2]
a_current = copy.copy(a_prior)

x_multi = np.arange(0,3)

fig, ax = plt.subplots(1, 2, figsize=(15, 5))

plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
plt.rc('axes', labelsize=16)
legend_size=16
title_size=16
plt.close()

def prior_func():
    
    #Dirichlet
    ax[0].cla()
    ax[1].cla()
    
    nlevels=200
    subdiv=5
    dist = Dirichlet(a_prior)
    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]
    xys = [xy2bc(xy) for xy in zip(trimesh.x, trimesh.y)]
    #print(len(xys))

    ax[0].tricontourf(trimesh, pvals, nlevels)
    ax[0].axis('equal')
    ax[0].set_xlim(0, 1)
    ax[0].set_ylim(0, 0.75**0.5)
    ax[0].axis('off')
    #ax[0].legend(loc='upper right', prop={'size': legend_size})
    ax[0].text(0.75, 0.75, r'$\alpha={}$'.format(a_prior), fontsize=20)

    ax[1].bar(x_multi, multi_p, label=r'Multinomial with $p={}$'.format(multi_p),
              color='blue')
    ax[1].set_ylim(0, 3.1)
    ax[1].legend(loc='upper right', prop={'size': legend_size})

    plt.tight_layout()
    
def posterior_func(i):
    
    global a_current
    
    multi_sample = multinomial.rvs(3, multi_p)
    a_current += multi_sample
    a_current = list(a_current)
    
    ax[0].cla()
    ax[1].cla()
    
    #Dirichlet
    nlevels=200
    subdiv=5
    dist = Dirichlet(a_current)
    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]
    xys = [xy2bc(xy) for xy in zip(trimesh.x, trimesh.y)]
    #print(len(xys))

    ax[0].tricontourf(trimesh, pvals, nlevels)
    ax[0].axis('equal')
    ax[0].set_xlim(0, 1)
    ax[0].set_ylim(0, 0.75**0.5)
    ax[0].axis('off')
    #ax[0].legend(loc='upper right', prop={'size': legend_size})
    ax[0].text(0.75, 0.75, r'$\alpha={}$'.format(a_current), fontsize=20)

    ax[1].bar(x_multi, multi_p, label=r'Multinomial with $p={}$'.format(multi_p),
              color='blue')
    ax[1].vlines(x_multi, 0, multi_sample, color='black', label='Datapoint {}'.format(i+1),
                lw=5)
    ax[1].set_ylim(0, 3.1)
    ax[1].legend(loc='upper right', prop={'size': legend_size})
    plt.tight_layout()

In [4]:
def animate_func(i):
    if i==0:
        return(prior_func())
    else:
        return(posterior_func(i))

In [5]:
from matplotlib import animation
from IPython.display import HTML

total_frames=20

# Animation setup
anim = animation.FuncAnimation(
    fig, func=animate_func, frames=total_frames, interval=1000, blit=False
)
anim.save('Dirichlet_Multinomial.gif', dpi=300)
a_current = [2,2,2]
HTML(anim.to_jshtml())

<Figure size 432x288 with 0 Axes>