In [8]:
import pymc as pm
import numpy as np
import pandas as pd

import arviz as az

import matplotlib.pyplot as plt

from scipy.stats import norm

from pymc.distributions import NoDistribution
from pymc.distributions.distribution import Discrete, NoDistribution
from pymc.model import modelcontext
from pymc import Mixture

from pymc_experimental import dp

from aesara.tensor.random.basic import RandomVariable

In [2]:
faithful = pd.read_csv("/Users/larryshamalama/Downloads/faithful.tsv", sep="\t")
data = faithful["waiting"].values
data = (data - data.mean())/data.std()

In [3]:
class DirichletProcessMixture:
    
    def __new__(cls, name, G0, alpha, K, **kwargs):
        
        return _dirichlet_process_mixture(
            name=name, G0=G0, alpha=alpha, K=K, **kwargs,
        )
    
    @classmethod
    def dist(cls, name, G0, alpha, K, **kwargs):
        return _dirichlet_process_mixture(
            name=None, G0=G0, alpha=alpha, K=K, **kwargs,
        )


def _dirichlet_process_mixture(*, name, G0, alpha, K, **kwargs):
    sbw = pm.StickBreakingWeights("sbw", alpha, K)
    
    if name is not None:
        return Mixture(name, sbw, G0, **kwargs)
    else:
        return Mixture.dist(sbw, G0, **kwargs)

In [4]:
K = 11

with pm.Model() as model:
    alpha = pm.Gamma("alpha", 0.5, 0.5)
    mu = pm.Normal(name="mu", sigma=5., shape=(K+1,))
    sigma = pm.Gamma("sigma", 0.5, 0.5, shape=(K+1,))
    G0 = pm.Normal.dist(mu, sigma=sigma)
    dpm = DirichletProcessMixture(name="dpm", G0=G0, alpha=alpha, K=K, observed=data)
    
    trace = pm.sample(chains=1, draws=10000, tune=5000, target_accept=0.95)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [alpha, mu, sigma, sbw]


Sampling 1 chain for 5_000 tune and 10_000 draw iterations (5_000 + 10_000 draws total) took 3186 seconds.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.


In [5]:
x_plot = np.linspace(-3, 3, num=1001)[..., None]

def plot_density_draw(x):
    pass

In [32]:
trace.to_dict()["posterior"]["alpha"].std()

3.6958526294053764e-07

In [31]:
trace.to_dict()["posterior"]["mu"].std(axis=1)

array([[2.81109998e-07, 3.08383721e-07, 1.55565394e-11, 2.84426470e-07,
        2.29628556e-07, 2.84734742e-07, 2.91951755e-07, 2.03092807e-07,
        3.53129726e-07, 2.99230899e-07, 4.81625276e-07, 3.84175740e-07]])

In [33]:
trace.to_dict()["posterior"]["sigma"].std(axis=1)

array([[2.97571435e-07, 9.67111384e-08, 8.93882248e-18, 2.81197033e-08,
        1.26392294e-07, 4.32155678e-08, 1.25917042e-07, 6.40499659e-08,
        1.18677854e-07, 3.27073419e-07, 1.99484262e-07, 1.61897729e-07]])

In [35]:
trace.to_dict()["posterior"]["sbw"].std(axis=1)

array([[6.16854841e-08, 8.06791411e-08, 2.86831773e-08, 2.18036939e-10,
        1.31955586e-10, 1.19398764e-10, 3.38888050e-11, 4.18155248e-11,
        6.35481532e-15, 1.68854923e-15, 1.08520871e-16, 6.19462697e-17]])