In [1]:
import jax
import jax.numpy as jnp
import numpyro

In [2]:
def run_sampler(kernel, *args, **kwargs):
    mcmc = numpyro.infer.MCMC(kernel, num_warmup = 100, num_samples = 100)
    mcmc.run(jax.random.PRNGKey(0), *args, **kwargs)
    mcmc.print_summary();
    return mcmc.get_samples()

## model with discrete latent variable

In [3]:
def model():
    x = numpyro.sample('x', numpyro.distributions.DiscreteUniform())

#### NUTS

In [4]:
kernel = numpyro.infer.NUTS(model)

In [5]:
run_sampler(kernel)

  mcmc.run(jax.random.PRNGKey(0), *args, **kwargs)
sample: 100%|██████████| 200/200 [00:00<00:00, 217.74it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


ValueError: max() arg is an empty sequence

#### HMC

In [6]:
kernel = numpyro.infer.HMC(model)

In [7]:
run_sampler(kernel)

  mcmc.run(jax.random.PRNGKey(0), *args, **kwargs)
sample: 100%|██████████| 200/200 [00:00<00:00, 247.72it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


ValueError: max() arg is an empty sequence

#### MixedHMC

In [8]:
inner_kernel = numpyro.infer.HMC(model)
kernel = numpyro.infer.MixedHMC(inner_kernel)

In [9]:
run_sampler(kernel)

sample: 100%|██████████| 200/200 [00:01<00:00, 106.60it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         x      0.42      0.50      0.00      0.00      1.00     98.50      0.99






{'x': Array([0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0,
        0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1,
        1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0,
        0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0], dtype=int32)}

#### DiscreteHMCGibbs - NUTS

In [10]:
inner_kernel = numpyro.infer.NUTS(model)
kernel = numpyro.infer.DiscreteHMCGibbs(inner_kernel)

In [11]:
run_sampler(kernel)

sample: 100%|██████████| 200/200 [00:01<00:00, 132.10it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         x      0.52      0.50      1.00      0.00      1.00     89.64      0.99






{'x': Array([1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1,
        0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1,
        1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0], dtype=int32)}

#### DiscreteHMCGibbs - HMC

In [12]:
inner_kernel = numpyro.infer.HMC(model)
kernel = numpyro.infer.DiscreteHMCGibbs(inner_kernel)

In [13]:
run_sampler(kernel)

sample: 100%|██████████| 200/200 [00:01<00:00, 112.92it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         x      0.52      0.50      1.00      0.00      1.00     89.64      0.99






{'x': Array([1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1,
        0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1,
        1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0], dtype=int32)}

## data selection via discrete latent variable

In [14]:
data = jax.random.uniform(jax.random.PRNGKey(1), (10,))
data

Array([0.7551559 , 0.3129729 , 0.12388372, 0.548188  , 0.4223112 ,
       0.30576992, 0.82008433, 0.95633745, 0.3566252 , 0.55691683],      dtype=float32)

In [15]:
def model(data):
    i = numpyro.sample('i', numpyro.distributions.DiscreteUniform(low = 0, high = data.size - 1))
    x = data[i]
    # below does nothing to the posterior, just an example of how sampled indices can be used downstream
    numpyro.factor('log_prob', numpyro.distributions.Uniform().log_prob(x))

#### NUTS

In [16]:
kernel = numpyro.infer.NUTS(model)

In [17]:
run_sampler(kernel, data)

  mcmc.run(jax.random.PRNGKey(0), *args, **kwargs)
sample: 100%|██████████| 200/200 [00:00<00:00, 219.64it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


ValueError: max() arg is an empty sequence

#### HMC

In [18]:
kernel = numpyro.infer.HMC(model)

In [19]:
run_sampler(kernel, data)

  mcmc.run(jax.random.PRNGKey(0), *args, **kwargs)
sample: 100%|██████████| 200/200 [00:00<00:00, 214.85it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


ValueError: max() arg is an empty sequence

#### MixedHMC

In [20]:
inner_kernel = numpyro.infer.HMC(model)
kernel = numpyro.infer.MixedHMC(inner_kernel)

In [21]:
run_sampler(kernel, data)

sample: 100%|██████████| 200/200 [00:01<00:00, 108.33it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         i      4.43      2.68      4.00      0.00      8.00    115.96      1.01






{'i': Array([2, 7, 4, 5, 4, 7, 0, 6, 5, 8, 4, 5, 8, 4, 9, 4, 1, 5, 5, 0, 3, 6,
        9, 2, 4, 6, 0, 1, 3, 6, 8, 7, 2, 8, 9, 2, 7, 5, 2, 8, 7, 2, 8, 8,
        6, 7, 4, 3, 2, 3, 8, 1, 7, 0, 6, 3, 1, 7, 9, 6, 2, 2, 7, 3, 0, 4,
        6, 3, 0, 6, 5, 8, 1, 2, 2, 1, 5, 7, 6, 3, 2, 4, 2, 4, 8, 0, 2, 1,
        9, 8, 2, 3, 5, 3, 1, 8, 2, 8, 5, 4], dtype=int32)}

#### DiscreteHMCGibbs - NUTS

In [22]:
inner_kernel = numpyro.infer.NUTS(model)
kernel = numpyro.infer.DiscreteHMCGibbs(inner_kernel)

In [23]:
run_sampler(kernel, data)

sample: 100%|██████████| 200/200 [00:01<00:00, 127.30it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         i      4.35      2.74      5.00      1.00      9.00     37.27      0.99






{'i': Array([5, 0, 8, 3, 3, 1, 0, 5, 0, 3, 3, 5, 2, 5, 2, 0, 3, 1, 7, 3, 6, 1,
        6, 5, 6, 9, 2, 7, 2, 1, 8, 7, 9, 3, 5, 5, 2, 5, 9, 2, 4, 5, 5, 1,
        4, 4, 5, 8, 9, 9, 0, 5, 6, 7, 2, 7, 6, 9, 6, 3, 9, 8, 6, 6, 9, 5,
        2, 5, 7, 0, 7, 8, 6, 2, 1, 1, 4, 2, 3, 3, 9, 9, 2, 3, 5, 1, 6, 5,
        7, 2, 5, 5, 3, 1, 1, 5, 1, 0, 7, 0], dtype=int32)}

#### DiscreteHMCGibbs - HMC

In [24]:
inner_kernel = numpyro.infer.HMC(model)
kernel = numpyro.infer.DiscreteHMCGibbs(inner_kernel)

In [25]:
run_sampler(kernel, data)

sample: 100%|██████████| 200/200 [00:01<00:00, 108.74it/s, 1 steps of size 4.65e+18. acc. prob=1.00]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         i      4.35      2.74      5.00      1.00      9.00     37.27      0.99






{'i': Array([5, 0, 8, 3, 3, 1, 0, 5, 0, 3, 3, 5, 2, 5, 2, 0, 3, 1, 7, 3, 6, 1,
        6, 5, 6, 9, 2, 7, 2, 1, 8, 7, 9, 3, 5, 5, 2, 5, 9, 2, 4, 5, 5, 1,
        4, 4, 5, 8, 9, 9, 0, 5, 6, 7, 2, 7, 6, 9, 6, 3, 9, 8, 6, 6, 9, 5,
        2, 5, 7, 0, 7, 8, 6, 2, 1, 1, 4, 2, 3, 3, 9, 9, 2, 3, 5, 1, 6, 5,
        7, 2, 5, 5, 3, 1, 1, 5, 1, 0, 7, 0], dtype=int32)}