# CRP MCMC inference
Adapted from Tamara Broderick's code on https://github.com/tbroderick/bnp_tutorial

Extended to concentration parameter inference using algorithm in West (1992) using Gamma(2,4) prior.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import asyncio
from crp_mcmc import GaussianMixture, CRPGibbs, run_mcmc
%matplotlib widget

## Define Gaussian Mixture to create clustered data

In [2]:
gm = GaussianMixture(1, [0.3, 0.4, 0.2, 0.1]) # define Gaussian mixture with sd=1 and given frequencies
data= gm.rvs(1000) # create 1000 data points

## Run fast numba implementation of the MCMC Gibbs sampler

### $\alpha$ constant

In [3]:
%%time
z, probs, alphas = run_mcmc(data, alpha=0.01, max_iter=1000)

Wall time: 56.9 s


In [4]:
colors = ["blue", "orange", "green", "yellow", "red", "purple", "black"]
fix, ax = plt.subplots(figsize=(12, 5))
for i, dat in enumerate(data):
    ax.scatter(dat[0], dat[1], c=colors[int(z[i])])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### $\alpha$ inferred

In [5]:
%%time
z, probs, alphas = run_mcmc(data, max_iter=1000)

Wall time: 1min 1s


In [6]:
alphas[500:].mean()

0.5995496881493255

In [7]:
colors = ["blue", "orange", "green", "yellow", "red", "purple", "black", "teal"]
fix, ax = plt.subplots(figsize=(12, 5))
for i, dat in enumerate(data):
    ax.scatter(dat[0], dat[1], c=colors[int(z[i])])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## 9 clusters

In [8]:
gm = GaussianMixture(1, [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], mu=np.array([[6, 6], [6, 0], [6, -6], [0, -6], [0, 0], [0, 6], [-6, 6], [-6, 0], [-6, -6]]))
data = gm.rvs(1000) # create 1000 data points

### $\alpha$ constant

In [9]:
%%time
z, probs, alphas = run_mcmc(data, alpha=0.01, max_iter=1000)

Wall time: 1min 28s


In [10]:
colors = ["blue", "orange", "green", "yellow", "red", "purple", "black", "teal", "magenta", "brown", "lime"]
fix, ax = plt.subplots(figsize=(12, 5))
for i, dat in enumerate(data):
    ax.scatter(dat[0], dat[1], c=colors[int(z[i])])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### $\alpha$ inferred with 9 clusters

In [11]:
%%time
z, probs, alphas = run_mcmc(data, max_iter=1000)

Wall time: 1min 45s


In [12]:
colors = ["blue", "orange", "green", "yellow", "red", "purple", "black", "teal", "magenta", "brown", "lime"]
fix, ax = plt.subplots(figsize=(12, 5))
for i, dat in enumerate(data):
    ax.scatter(dat[0], dat[1], c=colors[int(z[i])])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Live animation
Watch the clustering process live

In [13]:
gm = GaussianMixture(1, [0.4, 0.3, 0.2, 0.1])
data= gm.rvs(100)
gb = CRPGibbs(data, 1)

In [14]:
fig = plt.figure()
loop = asyncio.get_event_loop()
loop.create_task(gb.run_live(0.01, 100, True, fig=fig));

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …