In [1]:
import numpy as np
import pandas as pd
import pymc3 as pm
import arviz as az

from aesara import tensor as at
from matplotlib import pyplot as plt

In [2]:
old_faithful_df = pd.read_csv(pm.get_data("old_faithful.csv"))
waiting_times = old_faithful_df["waiting"]
waiting_times = ((waiting_times - waiting_times.mean())/waiting_times.std()).values

In [3]:
K = 20

In [None]:
def stick_breaking(betas):
    '''
    betas is a K-vector of iid draws from a Beta distribution
    '''
    sticks = at.concatenate(
        [
            [1],
            (1 - betas[:-1])
        ]
    )
    
    return at.mul(betas, at.cumprod(sticks))

In [None]:
with pm.Model() as model:
    alpha = pm.Gamma(name="alpha", alpha=1, beta=1)
    v = pm.Beta(name="v", alpha=1, beta=alpha, shape=(K,)) # beta=alpha kinda confusing here
    
    w = pm.Deterministic(name="w", var=stick_breaking(v))
    mu = pm.Normal(name="mu", mu=0, sigma=5)
    sigma = pm.InverseGamma(name="sigma", alpha=1, beta=1, shape=(K,))
    obs = pm.NormalMixture(name="theta", w=w, mu=mu, tau=1/sigma, observed=waiting_times)

In [None]:
SAMPLES = 20000
BURN = 10000

with model:
    step = pm.Metropolis(target_accept=0.9)
    trace = pm.sample(
        SAMPLES, 
        step=step, 
        tune=BURN, 
        chains=4, 
        random_seed=123, 
        return_inferencedata=True, 
        init="advi",
    )

In [None]:
az.plot_trace(trace)

In [35]:
with pm.Model() as model:
    w = pm.Dirichlet(name="w", a=[1]*30)
    trace = pm.sample(10000)

  trace = pm.sample(10000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [w]


Sampling 2 chains for 1_000 tune and 10_000 draw iterations (2_000 + 20_000 draws total) took 67 seconds.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
There were 3 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 25% for some parameters.


In [42]:
trace.get_values("w")

array([[0.00173831, 0.08497921, 0.05517845, ..., 0.00403032, 0.03469542,
        0.01529772],
       [0.00525818, 0.05597316, 0.04727625, ..., 0.05223651, 0.06529829,
        0.02224548],
       [0.00804903, 0.00866127, 0.06025815, ..., 0.03034349, 0.00897736,
        0.01864374],
       ...,
       [0.00262708, 0.04504959, 0.00019336, ..., 0.07143613, 0.03380029,
        0.04570397],
       [0.13230428, 0.01677348, 0.09757534, ..., 0.01527133, 0.00908755,
        0.02381493],
       [0.00071193, 0.03728508, 0.00309891, ..., 0.06866954, 0.0317229 ,
        0.02248025]])