In [1]:
import jax
import jax.numpy as jnp
import numpyro

In [2]:
def model():
    _n = numpyro.sample('_n', numpyro.distributions.DiscreteUniform(low = 0, high = 1))
    n = numpyro.deterministic('n', 1 + _n)
    x = numpyro.sample('x', numpyro.distributions.Uniform(), sample_shape = (n,))

In [3]:
# inner_kernel = numpyro.infer.NUTS(model)
inner_kernel = numpyro.infer.HMC(model)

kernel = numpyro.infer.MixedHMC(inner_kernel)
# kernel = numpyro.infer.DiscreteHMCGibbs(inner_kernel)

mcmc = numpyro.infer.MCMC(kernel, num_warmup = 100, num_samples = 100)
mcmc.run(jax.random.PRNGKey(0))

sample: 100%|██████████| 200/200 [00:01<00:00, 164.72it/s, 9 steps of size 7.85e-01. acc. prob=0.94]


In [4]:
mcmc.get_samples()['x'].shape

(100, 1)

In [5]:
def model(max_dim):
    _n = numpyro.sample(
        '_n', numpyro.distributions.DiscreteUniform(low = 0, high = 1),
    )
    n = numpyro.deterministic('n', 1 + _n)
    
    _x = numpyro.sample(
        '_x', numpyro.distributions.Uniform(), sample_shape = (max_dim,),
    )
    x = numpyro.deterministic(
        'x', jnp.where(jnp.arange(max_dim) < n, _x, jnp.nan),
    )

In [6]:
inner_kernel = numpyro.infer.NUTS(model)
# inner_kernel = numpyro.infer.HMC(model)

# kernel = numpyro.infer.MixedHMC(inner_kernel)
kernel = numpyro.infer.DiscreteHMCGibbs(inner_kernel)

mcmc = numpyro.infer.MCMC(kernel, num_warmup = 100, num_samples = 100)
mcmc.run(jax.random.PRNGKey(0), max_dim = 2)

sample: 100%|██████████| 200/200 [00:01<00:00, 128.72it/s, 7 steps of size 5.35e-01. acc. prob=0.95]


In [7]:
mcmc.get_samples()

{'_n': Array([1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1,
        0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1,
        1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0], dtype=int32),
 '_x': Array([[5.32667823e-02, 5.42984962e-01],
        [7.29522705e-02, 4.84765112e-01],
        [5.31151116e-01, 2.56597310e-01],
        [6.52572393e-01, 4.54058528e-01],
        [7.51536191e-01, 5.04196107e-01],
        [1.11931421e-01, 3.61368567e-01],
        [1.90121442e-01, 2.84495354e-01],
        [5.26375696e-02, 6.91787839e-01],
        [3.24021094e-02, 7.22972274e-01],
        [5.18872067e-02, 9.81668055e-01],
        [8.45413685e-01, 4.82256353e-01],
        [5.64300478e-01, 4.21395868e-01],
        [5.64300478e-01, 4.21395868e-01],
        [4.29402500e-01, 5.91388226e-01],
        [2.20936283e-01, 2.89939493e-01],
   