## [**Dirichlet Process Mixture Models in Pyro**](http://pyro.ai/examples/dirichlet_process_mixture.html#Dirichlet-Process-Mixture-Models-in-Pyro)

#####  The prototypical example of Bayesian nonparametrics in practice is the Dirichlet Process Mixture Model (DPMM). A DPMM allows for a practitioner to build a mixture model when the number of distinct clusters in the geometric structure of their data is unknown â€“ in other words, the number of clusters is allowed to grow as more data is observed. This feature makes the DPMM highly useful towards exploratory data analysis, where few facets of the data in question are known; this presentation aims to demonstrate this fact.

$$(G(\Omega_1), ..., G(\Omega_k)) \sim \text{Dir}(\alpha G_0(\Omega_1), ..., \alpha G_0(\Omega_k))$$

Imagine a restaurant with infinite tables (indexed by the positive integers) that accepts customers one at a time. The n th customer chooses their seat according to the following probabilities:

## [The Stick-Breaking Method (Sethuraman, 1994)](http://pyro.ai/examples/dirichlet_process_mixture.html#The-Stick-Breaking-Method-(Sethuraman,-1994))

1. $\beta_i \sim \text{Beta}(1, \alpha)$ for $i \in \mathbb{N}$
2. $\theta_i \sim G_0$ for $i \in \mathbb{N}$
3. $\pi_i(\beta_{1:\infty}) = \beta_i \prod_{j<i} (1-\beta_j)$
4. $z_n \sim \pi(\beta_{1:\infty})$ and then $x_n \sim f(\theta_{z_n})$

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.distributions import constraints

import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam

assert pyro.__version__.startswith('1.8.3')
pyro.set_rng_seed(0)

In [5]:
data = torch.cat(
    (
        MultivariateNormal(-8* torch.ones(2), torch.eye(2)).sample([50]),
        MultivariateNormal(8* torch.ones(2), torch.eye(2)).sample([50]),
        MultivariateNormal(torch.tensor([1.5,2]), torch.eye(2)).sample([50]),
        MultivariateNormal(torch.tensor([-0.5,1]), torch.eye(2)).sample([50])
    )
)

In [7]:
data.size()

torch.Size([200, 2])