In [None]:
import numpy as np
import pandas as pd
import pymc3 as pm
import arviz as az

from aesara import tensor as at
from matplotlib import pyplot as plt

In [None]:
old_faithful_df = pd.read_csv(pm.get_data("old_faithful.csv"))
waiting_times = old_faithful_df["waiting"]
waiting_times = ((waiting_times - waiting_times.mean())/waiting_times.std()).values

In [None]:
K = 20

In [None]:
def stick_breaking(betas):
    '''
    betas is a K-vector of iid draws from a Beta distribution
    '''
    sticks = at.concatenate(
        [
            [1],
            (1 - betas[:-1])
        ]
    )
    
    return at.mul(betas, at.cumprod(sticks))

In [None]:
with pm.Model() as model:
    alpha = pm.Gamma(name="alpha", alpha=1, beta=1)
    v = pm.Beta(name="v", alpha=1, beta=alpha, shape=(K,)) # beta=alpha kinda confusing here
    
    w = pm.Deterministic(name="w", var=stick_breaking(v))
    mu = pm.Normal(name="mu", mu=0, sigma=5)
    sigma = pm.InverseGamma(name="sigma", alpha=1, beta=1, shape=(K,))
    obs = pm.NormalMixture(name="theta", w=w, mu=mu, tau=1/sigma, observed=waiting_times)

In [None]:
SAMPLES = 20000
BURN = 10000

with model:
    step = pm.Metropolis(target_accept=0.9)
    trace = pm.sample(
        SAMPLES, 
        step=step, 
        tune=BURN, 
        chains=4, 
        random_seed=123, 
        return_inferencedata=True, 
        init="advi",
    )

In [None]:
az.plot_trace(trace)