In [14]:
import numpy as np
import pymc3 as pm

import aesara.tensor as at
from aesara.tensor.random.op import RandomVariable

from scipy import stats as st

import aesara

import arviz as az

%config InlineBackend.figure_format = "retina"
%matplotlib inline

In [2]:
K = 50; N = 50
M = 2; mu = 2.
rng = np.random.RandomState(seed=34)
y = rng.normal(loc=mu, scale=2., size=[N,])

In [3]:
def stick_breaking(betas):
    """
    betas ~ Beta(1, alpha)
    """
    sticks = at.concatenate(
        [[1]],
        at.cumprod(1 - betas[:-1])
    )
    
    product = at.mul(betas, sticks)
    
    return at.concatenate(
        [
            product,
            [1 - at.sum(product)],
        ]
    )

In [61]:
with pm.Model() as model:
    
    concentration = pm.InverseGamma("concentration", alpha=0.1, beta=0.1)
    
    betas = pm.Beta("betas", 1., concentration, shape=(K,))
    weights = pm.Deterministic("weights", stick_breaking(betas))
    
    G0 = pm.Normal(name="G0", mu=0., sigma=3., shape=(K+1,))
    idx = pm.Bernoulli(name="post-mixture-idx", p=at.mul(concentration, at.inv(concentration + len(y))))
    

#     y_atoms = pm.Deterministic(
#         name="post-atoms", 
#         var=at.as_tensor_variable(y)[cat_idx],
#     )
    
#     Gn = pm.Deterministic(
#         name="Gn",
#         var=G0[idx],
#     )
    
    prior = pm.sample_prior_predictive(samples=1000,)

In [66]:
prior["betas"].mean()

0.1774097049859114

In [None]:
with pm.Model():
    norm_obj = pm.Normal("norm", mu=0., sigma=1.)
    norm = pm.Normal

In [None]:
norm.rv_op.rng_fn(rng, size=[4, 2, 3])

In [None]:
with pm.Model() as model:
    idx = pm.Categorical(name="cat", p=np.array([0.1, 0.2, 0.3, 0.4]))
    
    y_tensor = at.as_tensor_variable(y)
    y_atoms = pm.Deterministic(name="y_atoms", var=y_tensor[idx])
    
    prior = pm.sample_prior_predictive(samples=1000)

In [None]:
np.unique(prior["y_atoms"])

In [None]:
class DirichletProcess:
    
    def __init__(self, M, rng, K):
        self.M = M
        self.K = K
        
        self.samples = create_dp_samples(M, rng, K)
        self.rng = rng
    
    def run(self):
        
        atoms, counts = np.unique(self.samples, return_counts=True)
        
        # stick-glueing assumes decreasing weights
        index = np.argsort(counts)
        counts = counts[index][::-1]
        atoms = atoms[index][::-1]
        
        weights = np.array(counts)/counts.sum() # empirical weights
        
        recovered_betas = stick_glueing(weights)
        
        with pm.Model(rng_seeder=rng) as model:
            α = pm.Uniform("α", 0., 10.)
            
            # K - 1 because the Kth weight is 1 - weights.sum()
            β = pm.Beta("β", 1., α, observed=recovered_betas)
            µ = pm.Normal("µ", mu=0., sigma=5.)
            
            G0 = pm.Normal("G0", mu=µ, sigma=3., observed=atoms)
            
            self.posterior = pm.sample(draws=2000, chains=1)

In [None]:
dp = DirichletProcess(M, rng, K)
dp.run()

In [None]:
_ = pm.plot_trace(dp.posterior)

In [None]:
dp.posterior.to_dict()["posterior"]["α"].mean()

In [None]:
dp.samples