# Purpose
* playing around with `infer_discrete`
* The example below show itreations based on the discussion on the [Pyro forum](https://forum.pyro.ai/t/mcmc-get-samples-returns-empty-dict/3086)
* If you need sometihng similar look at [`Predictive`](http://num.pyro.ai/en/latest/utilities.html#predictive) and [Example: Bayesian Models of Annotation](http://num.pyro.ai/en/latest/examples/annotation.html) from [Support infer_discrete for Predictive (#1086) ](https://github.com/pyro-ppl/numpyro/commit/003424bb3c57e44b433991cc73ddbb557bf31f3c)

In [1]:
import jax
import jax.numpy as jnp
import numpyro
from numpyro.contrib.funsor import config_enumerate, infer_discrete
import numpyro.distributions as dist
from numpyro.infer.util import Predictive
import pandas as pd
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%load_ext watermark

In [3]:
%watermark -v -m -p jax,numpy,pandas,numpyro

Python implementation: CPython
Python version       : 3.8.11
IPython version      : 7.18.1

jax    : 0.2.19
numpy  : 1.20.3
pandas : 1.3.2
numpyro: 0.7.2

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 4.19.193-1-MANJARO
Machine     : x86_64
Processor   : 
CPU cores   : 4
Architecture: 64bit



In [4]:
%watermark -gb

Git hash: 472176c9c68598205edff7afdce8a21355406f60

Git branch: master



In [5]:
num_samples = 1000
num_warmup = 1000
num_chains = 4

## DiscreteHMCGibbs

In [6]:
key = jax.random.PRNGKey(2)

guess = 0.7


def mystery(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    return murderer, weapon


conditioned_model = numpyro.handlers.condition(mystery, {"weapon": 0.0})

nuts_kernel = NUTS(conditioned_model)

kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)

mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains)
mcmc.run(key, guess)

mcmc.print_summary()

with numpyro.handlers.seed(rng_seed=0):
    samples = []
    for _ in range(1000):
        samples.append(
            tuple(
                [
                    sample.item() if hasattr(sample, "item") else sample
                    for sample in conditioned_model(guess)
                ]
            )
        )

samples = pd.DataFrame(samples, columns=["murderer", "weapon"])

print(pd.crosstab(samples.murderer, samples.weapon, normalize="all"))

sample: 100%|██████████| 2000/2000 [00:03<00:00, 558.63it/s, 1 steps of size 3.40e+38. acc. prob=1.00] 
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6269.45it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6598.84it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 5979.06it/s, 1 steps of size 3.40e+38. acc. prob=1.00]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  murderer      0.34      0.48      0.00      0.00      1.00  12998.62      1.00

weapon      0.0
murderer       
0         0.327
1         0.673


## `infer_discrete`

In [7]:
num_samples = 1000
num_warmup = 1000
num_chains = 4

In [8]:
# caution: `*data` within infer_discrete_model is a global variable
def infer_discrete_model(rng_key, samples):
    conditioned_model = numpyro.handlers.condition(model, data=samples)
    infer_discrete_model = infer_discrete(
        config_enumerate(conditioned_model), rng_key=rng_key
    )
    with numpyro.handlers.trace() as tr:
        infer_discrete_model(*data)

    return {
        name: site["value"]
        for name, site in tr.items()
        if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
    }

In [9]:
def model(guess, weapon):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]), obs=weapon)

nuts_kernel = NUTS(model)

data = (guess, 0.)

# caution: HMC will marginalize all the discrete variables, for `model` results in an empty dict from mcmc.get_samples()
mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains)
mcmc.run(key, *data)

sample: 100%|██████████| 2000/2000 [00:02<00:00, 916.79it/s, 1 steps of size 3.40e+38. acc. prob=1.00] 
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6917.77it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6864.45it/s, 1 steps of size 3.40e+38. acc. prob=1.00]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6259.95it/s, 1 steps of size 3.40e+38. acc. prob=1.00]


In [10]:
posterior_samples = mcmc.get_samples()
posterior_samples

{}

In [11]:
num_samples = 4000

In [12]:
discrete_samples = jax.vmap(infer_discrete_model)(
    jax.random.split(jax.random.PRNGKey(1), num_samples), {}
)

In [13]:
discrete_samples["murderer"].mean(), discrete_samples["murderer"].std()

(DeviceArray(0.353, dtype=float32), DeviceArray(0.47790274, dtype=float32))

## Using Predictive

In [14]:
key = jax.random.PRNGKey(3)

guess = 0.7


def mystery(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    return murderer, weapon


conditioned_model = numpyro.handlers.condition(mystery, {"weapon": 0.0})

predictive = Predictive(conditioned_model, num_samples=1000, infer_discrete=True)
samples = predictive(key, guess)
samples["murderer"].mean(), samples["murderer"].std()

(DeviceArray(0.356, dtype=float32), DeviceArray(0.47881523, dtype=float32))