In [2]:
import jax
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp

In [None]:
def change_point_generator(t:jax.Array, n_changepoints:int=24):
    alpha_cp = numpyro.sample("alpha_cp", dist.Gamma(2.0, 1.0))
    with numpyro.plate("beta_cp_plate", n_changepoints):
        beta_cp = numpyro.sample("beta_cp", dist.Beta(1.0, alpha_cp))

    one_minus_beta = 1.0 - beta_cp
    cumprod_one_minus_beta = jnp.concatenate([
        jnp.array([1.0]),
        jnp.cumprod(one_minus_beta[:-1])
    ])
    w1 = beta_cp * cumprod_one_minus_beta
    w = jnp.where(w1 > 1e-4, w1, 0.0)
    numpyro.deterministic("changepoint_weights", w)

    with numpyro.plate("changepoint_location_plate", n_changepoints):
        s_raw = numpyro.sample("s_raw", dist.Beta(1.1, 1.1))
    
    t_max = jnp.max(t)
    s = s_raw * t_max
    numpyro.deterministic("changepoint_locations", s)

    return s, w

def trend_with_changepoints_model(t:jax.Array, n_changepoints:int=24):
    s, w = change_point_generator(t, n_changepoints)

    changepoint_scale = numpyro.sample("changepoint_scale", dist.Gamma(2.0,1.0))
    with numpyro.plate("delta_plate", len(s)):
        delta = numpyro.sample("delta", dist.Laplace(0, changepoint_scale))
    
    m = numpyro.sample("m", dist.Normal(0, 1.0))
    k = numpyro.sample("k", dist.Laplace(0, 0.1))

    A = (t[:, None] >= s) * 1.0

    wd = delta * w

    offset_ = m + jnp.dot(A, -s * wd)
    growth_ = k + jnp.dot(A, wd)

    return growth_ + offset_


In [None]:
def trend_model(t:jax.Array, y:jax.Array):
