In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Introduction

In this notebook, we are going to walk through modelling situations
that involve some form of "cluster identification",
where the number of clusters isn't exactly known beforehand.

## Chinese Restaurant Process

First customer always chooses the first table.

The $n$th customers afterwards occupy the first _unoccupied_ table
with probability $\frac{\alpha}{n-1+\alpha}$,
and occupies an _already occupied_ table
with probability $\frac{c}{n-1+\alpha}$.

Here:

- $n$ is the index of the customers after the first.
- $c$ is the number of people already sitting at that table.
- $\alpha$ is a parameter of the Chinese Restaurant Process.

## Let's simulate this!

In [None]:
import jax.numpy as np
from jax import jit
from jax.ops import index_update, index



def create_alpha_vector(alpha, table_assignments, n, current_open_table):
    v = np.zeros_like(table_assignments)
    v = index_update(v, index[current_open_table], alpha)
    return v / (n - 1 + alpha)

In [None]:
table_assignments = np.zeros(shape=(10,))
table_assignments = index_update(table_assignments, index[0], 1)

create_alpha_vector(5, table_assignments, 2, 1)

In [None]:

def create_occupied_vector(alpha, table_assignments, n, current_open_table):
    v = table_assignments / (n - 1 + alpha)
    return v

In [None]:
create_occupied_vector(5, table_assignments, 2, 1)

In [None]:
table_assignments = np.zeros(shape=(10,))

table_assignments = index_update(table_assignments, index[0], 1)

np.min(np.where(table_assignments == 0)[0])

In [None]:
from jax.scipy.special import logit
from jax import vmap, lax, jit
from jax.random import categorical, PRNGKey, split

p = np.array([0.1, 0.8, 0.1])
logit_p = np.log(p / (1 -p))
categorical(k, logit_p)

In [None]:
k = PRNGKey(42)

def one_draw(k, p, zeros):
    logits = logit(p)
    idx = categorical(k, logits)
    draw = index_update(zeros, index[idx], 1)
    return draw

def f(carry, x):
    k, p = carry
    # x is our zeros
    draw = one_draw(k, p, x)
    k, _ = split(k)
    return (k, p), draw

def multinomial(k, n, p):
    n_draws = 10000
    a = np.zeros(shape=(n_draws, len(p)))
    (k, p), draws = lax.scan(f, (k, p), a)
    return np.sum(draws, axis=0)

multinomial(k, n=1000, p=np.array([0.3, 0.3, 0.4]))

In [None]:
draws = multinomial(k, n=1000, p=np.array([0.3, 0.3, 0.4]))
draws

In [None]:
import numpy as onp
from tqdm import tqdm

alpha = 3
table_assignments = np.zeros(shape=(alpha * 10,))
table_assignments = index_update(table_assignments, index[0], 1)

current_open_table = np.min(np.where(table_assignments == 0)[0])
n_customers = 1000
for n in tqdm(range(2, n_customers+1)):
    prob_vect = create_alpha_vector(alpha, table_assignments, n, current_open_table) + create_occupied_vector(alpha, table_assignments, n, current_open_table)
    assignment_vect = one_draw(k, prob_vect, np.zeros_like(prob_vect))
    table_assignments = np.squeeze(table_assignments + assignment_vect)
    current_open_table = np.min(np.where(table_assignments == 0)[0])
    k, _ = split(k)
prob_vect

## Stick-breaking process

In [None]:
import jax
import numpy as onp

# Taken from https://stats.stackexchange.com/questions/396315/coding-a-simple-stick-breaking-process-in-python

def stick_breaking(k, num_weights, alpha):
    k, _ = split(k)
    betas = onp.random.beta(1,alpha, size=(num_weights,)) 
    betas[1:] *= onp.cumprod(1 - betas[:-1])
    return betas


stick_breaking(k, num_weights=max_num_classes, alpha=3).sum()


In [None]:
betas.sum()

In [None]:
import matplotlib.pyplot as plt

k, _ = split(k)
max_num_classes = 30

def stick_breaking_jax(k, num_weights, alpha):
    k, _ = split(k)
    betas = jax.random.beta(k, a=1, b=alpha, shape=(num_weights,))
    products = np.cumprod(1 - betas[:-1])
    betas = index_update(betas, index[1:], products * betas[1:])
    return betas


weights = stick_breaking_jax(k, num_weights=max_num_classes, alpha=2)
plt.plot(weights)

The $\alpha$ parameter is proportional to the number of components that we end up using.

## Generate mixture gaussian from weights

In [None]:
from jax.random import categorical
from jax.scipy.special import logit

k, _ = split(k)
n_observations = 300
indices = categorical(k, logit(weights), shape=(n_observations,))
indices

$X \sim N(\mu, \sigma)$ is equivalent to:

$$ \hat{X} \sim N(0, 1) $$
$$ X = \sigma\hat{X} + \mu$$

In [None]:
from jax.random import normal

mus = np.linspace(0, 350, num=max_num_classes)
sigmas = np.ones(shape=(max_num_classes)) * 2
mus[indices] + sigmas[indices] * normal(k, shape=(n_observations,))

In [None]:
def dp_mixture_gaussian(k, alpha, max_num_classes, num_observations, mus, sigmas):
    weights = stick_breaking_jax(k, num_weights=max_num_classes, alpha=alpha)
    indices = categorical(k, logit(weights), shape=(n_observations,))
    return mus[indices] + sigmas[indices] * normal(k, shape=(n_observations,))

mus = np.linspace(0, 350, num=max_num_classes)
sigmas = np.ones(shape=(max_num_classes))

draws = dp_mixture_gaussian(k, alpha=0.7, max_num_classes=45, num_observations=100, mus=mus, sigmas=sigmas)

In [None]:
plt.hist(draws)

In [None]:
def ecdf_scatter(data):
    x, y = np.sort(data), np.arange(1, len(data)+1) / len(data)
    plt.scatter(x, y)
    plt.show()
    
ecdf_scatter(draws)

## Generate Multiple MvNormals

In [None]:
from jax import random as npr

k, _ = split(k)
draws = npr.multivariate_normal(k, mean=np.array([1, 3]), cov=np.array([[1, 0.8], [0.8, 1]]), shape=(30,))
draws.shape

In [None]:
from sklearn.datasets import make_spd_matrix
max_num_classes = 45
num_states = 2
means = np.linspace(0, 1000, max_num_classes * num_states).reshape(max_num_classes, num_states)
cov = np.stack([make_spd_matrix(num_states) for i in range(max_num_classes)])

In [None]:
alpha = 1
n_observations = 300
k, _ = split(k)
weights = stick_breaking_jax(k, num_weights=max_num_classes, alpha=alpha)
indices = categorical(k, logit(weights), shape=(n_observations,))
indices

In [None]:
from functools import partial

def generate_mvnorm_func(means, covs):
    def mvnorm(key, idx):
        k, _ = split(key)
        return npr.multivariate_normal(k, mean=means[idx], cov=cov[idx])
    return mvnorm
k = PRNGKey(42)

ks = []
for i in range(n_observations):
    k, _ = split(k)
    ks.append(k)
ks = np.vstack(ks)
print(ks.shape)

mvnorm = generate_mvnorm_func(means, cov)
draws = vmap(mvnorm)(ks, indices)


Generate some data now.

In [None]:
from jax.random import multivariate_normal
mus = np.array(
    [
        [-10, 3],
        [5, 5],
        [10, 1],
    ]
)
# plt.scatter(mus[:, 0], mus[:, 1])

covs = np.stack([make_spd_matrix(n_dim=2) for i in range(3)])

k, _ = split(k)

ks = []
for i in range(150):
    k, _ = split(k)
    ks.append(k)
ks = np.vstack(ks)

indices = np.array([0] * 50 + [1] * 50 + [2] * 50)

mvnorm = generate_mvnorm_func(mus, covs)
draws = vmap(mvnorm)(ks, indices)
draws.shape

In [None]:
plt.scatter(draws[:, 0], draws[:, 1])

In [None]:
import pymc3 as pm

In [None]:
with pm.Model() as model:
    # Mus should be Guassian priors of shape (max_num_components, 2)
    mus = pm.Normal("mus", mu=0, sigma=5, shape=(max_num_components, 2))
    covs = pm.LKJCholeskyCov("covs", )
    
    # MvNormal component distributions that we can index into.
    comp_dists = pm.MvNormal.dist(mu=mus, cov=covs, shape=(max_num_components, 2))

In [None]:
import numpy as onp
mu = onp.zeros(3)
true_cov = onp.array([[1.0, 0.5, 0.1],
                     [0.5, 2.0, 0.2],
                     [0.1, 0.2, 1.0]])
data = onp.random.multivariate_normal(mu, true_cov, 10)

with pm.Model() as model:
    sd_dist = pm.HalfCauchy.dist(beta=2.5, shape=3)
    chol_packed = pm.LKJCholeskyCov('chol_packed',
        n=3, eta=2, sd_dist=sd_dist)
    chol = pm.expand_packed_triangular(3, chol_packed)
    vals = pm.MvNormal.dist(mu=mu, chol=chol, observed=data)
    
    

In [None]:
with model:
    trace = pm.sample(2000)

In [None]:
import arviz as az


az.plot_trace(trace)

In [None]:
with model

In [None]:
def stick_breaking(beta):
    portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])

    return beta * portion_remaining