In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)

In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


In [3]:
rng_key = jax.random.PRNGKey(0)

# Create data
N = 100
mu = 10
mu00, mu11 = dist.Normal(mu, 1).sample(rng_key, (2,))
data00 = dist.Normal(mu00, 1).sample(rng_key, (N,))
data11 = dist.Normal(mu11, 1).sample(rng_key, (N,))

df = \
    pd.DataFrame([[0, 0]] * N + [[1, 2]] * N, columns=["subject", "is_diabetic"])
df["response"] = jnp.concatenate([data00, data11])

subject = df["subject"].values
is_diabetic = df["is_diabetic"].values
response = df["response"].values

n_unique_subject = np.max(subject) + 1
n_unique_is_diabetic = np.max(is_diabetic) + 1

print("shapes:", subject.shape, is_diabetic.shape, response.shape)
print("n_unique:", n_unique_subject, n_unique_is_diabetic)



shapes: (200,) (200,) (200,)
n_unique: 2 3


In [4]:
def hiearchical_model_with_unused_parameters(subject, n_unique_subject, is_diabetic, n_unique_is_diabetic, response=None):
    global_mu_estimate = numpyro.sample("global_mu_estimate", dist.Normal(0, 100))
    with numpyro.plate("n_unique_is_diabetic", n_unique_is_diabetic, dim=-1):
        with numpyro.plate("n_unique_subject", n_unique_subject, dim=-2):
            mu_estimate = numpyro.sample("mu_estimate", dist.Normal(global_mu_estimate, 10))

    with numpyro.plate("data", len(subject)):
        numpyro.sample("obs", dist.Normal(mu_estimate[subject, is_diabetic], 1), obs=response)


def hiearchical_model_for_sanity_check(subject, n_unique_subject, response=None):
    global_mu_estimate = numpyro.sample("global_mu_estimate", dist.Normal(0, 100))
    with numpyro.plate("n_unique_subject", n_unique_subject, dim=-1):
        mu_estimate = numpyro.sample("mu_estimate", dist.Normal(global_mu_estimate, 10))

    with numpyro.plate("data", len(subject)):
        numpyro.sample("obs", dist.Normal(mu_estimate[subject], 1), obs=response)


In [5]:
from numpyro.infer import MCMC, NUTS

num_warmup, num_samples = 500, 2000

# Run NUTS
sampler = NUTS(hiearchical_model_with_unused_parameters)
mcmc = MCMC(sampler, num_warmup=num_warmup, num_samples=num_samples, num_chains=4)
mcmc.run(rng_key, subject=subject, n_unique_subject=n_unique_subject, is_diabetic=is_diabetic, n_unique_is_diabetic=n_unique_is_diabetic, response=response)
posterior_samples = mcmc.get_samples()
mcmc.print_summary()


  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]


                          mean       std    median      5.0%     95.0%     n_eff     r_hat
  global_mu_estimate     10.24      6.90     10.14     -1.43     21.19   3426.38      1.00
    mu_estimate[0,0]     11.74      0.10     11.74     11.56     11.89  10828.78      1.00
    mu_estimate[0,1]     10.18     12.03     10.07     -9.58     29.65   4882.29      1.00
    mu_estimate[0,2]     10.20     12.27     10.29     -8.65     31.68   5005.39      1.00
    mu_estimate[1,0]     10.15     12.27     10.24    -10.79     29.36   5080.25      1.00
    mu_estimate[1,1]     10.23     12.03     10.14     -8.29     30.99   4424.68      1.00
    mu_estimate[1,2]      9.17      0.10      9.17      9.00      9.33  10201.05      1.00

Number of divergences: 0


In [6]:
num_warmup, num_samples = 500, 2000

# Run NUTS
sampler = NUTS(hiearchical_model_for_sanity_check)
mcmc = MCMC(sampler, num_warmup=num_warmup, num_samples=num_samples, num_chains=4)
mcmc.run(rng_key, subject=subject, n_unique_subject=n_unique_subject, response=response)
posterior_samples = mcmc.get_samples()
mcmc.print_summary()

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]


                          mean       std    median      5.0%     95.0%     n_eff     r_hat
  global_mu_estimate     10.32      7.05     10.37     -1.21     21.43   8386.26      1.00
      mu_estimate[0]     11.74      0.10     11.74     11.58     11.91   7317.59      1.00
      mu_estimate[1]      9.17      0.10      9.17      9.00      9.33   8295.34      1.00

Number of divergences: 0


In [7]:
mu00, mu11

(Array(11.81608667, dtype=float64), Array(9.24511516, dtype=float64))