# Purpose
* playing around with `infer_discrete`

In [1]:
import jax
import jax.numpy as jnp
import numpyro
from numpyro.contrib.funsor import config_enumerate, infer_discrete
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs

In [2]:
num_samples = 1000
num_warmup = 1000
num_chains = 4

## DiscreteHMCGibbs

In [3]:
key = jax.random.PRNGKey(2)

guess = 0.7


def mystery(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    return murderer, weapon


conditioned_model = numpyro.handlers.condition(mystery, {"weapon": 0.0})

nuts_kernel = NUTS(conditioned_model)

kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)

mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains)
mcmc.run(key, guess)

mcmc.print_summary()

with numpyro.handlers.seed(rng_seed=0):
    samples = []
    for _ in range(1000):
        samples.append(
            tuple(
                [
                    sample.item() if hasattr(sample, "item") else sample
                    for sample in conditioned_model(guess)
                ]
            )
        )

samples = pd.DataFrame(samples, columns=["murderer", "weapon"])

print(pd.crosstab(samples.murderer, samples.weapon, normalize="all"))

sample: 100%|██████████| 2000/2000 [00:03<00:00, 619.70it/s, 1 steps of size 3.40e+38. acc. prob=1.00] 
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6618.43it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6594.09it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6769.77it/s, 1 steps of size 3.40e+38. acc. prob=1.00]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  murderer      0.34      0.48      0.00      0.00      1.00  12998.62      1.00

weapon      0.0
murderer       
0         0.327
1         0.673


## `infer_discrete`

In [4]:
num_samples = 1000
num_warmup = 1000
num_chains = 4

In [5]:
# caution: `*data` within infer_discrete_model is a global variable
def infer_discrete_model(rng_key, samples):
    conditioned_model = numpyro.handlers.condition(model, data=samples)
    infer_discrete_model = infer_discrete(
        config_enumerate(conditioned_model), rng_key=rng_key
    )
    with numpyro.handlers.trace() as tr:
        infer_discrete_model(*data)

    return {
        name: site["value"]
        for name, site in tr.items()
        if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
    }

In [6]:
def model(guess, weapon):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]), obs=weapon)

nuts_kernel = NUTS(model)

data = (guess, 0.)

# caution: HMC will marginalize all the discrete variables, for `model` results in an empty dict from mcmc.get_samples()
mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains)
mcmc.run(key, *data)

sample: 100%|██████████| 2000/2000 [00:01<00:00, 1000.61it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6671.17it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7082.04it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6653.30it/s, 1 steps of size 3.40e+38. acc. prob=1.00]


In [7]:
num_samples = 4000

In [8]:
posterior_samples = mcmc.get_samples()
discrete_samples = jax.vmap(infer_discrete_model)(
    jax.random.split(jax.random.PRNGKey(1), num_samples), posterior_samples
)

In [9]:
posterior_samples

{}

In [10]:
discrete_samples["murderer"].mean(), discrete_samples["murderer"].std()

(DeviceArray(0.353, dtype=float32), DeviceArray(0.47790274, dtype=float32))