In [1]:
%reload_ext autoreload
%autoreload 2

import os
from pathlib import Path
import multiprocessing

import jax
import numpyro

from hbmep.config import Config
from hbmep.model import Model
from hbmep.model.utils import Site as site

from hbmep_paper.model import Simulator
from hbmep_paper.utils import simulate

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

In [2]:
root_path = Path(os.getcwd()).parent.parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/experiments.toml")

config = Config(toml_path=toml_path)
config.BUILD_DIR = "/home/vishu/repos/hbmep-paper/reports/experiments-debug/gen/"


2023-08-03 14:10:51,809 - hbmep.config - INFO - Verifying configuration ...
2023-08-03 14:10:51,810 - hbmep.config - INFO - Success!


In [3]:
simulator = Simulator(config=config)

simulation_params = {
    "n_subject": 5,
    "n_feature0": 15,
    "n_repeats": 100,
    "downsample_rate": 1
}
df, posterior_samples_true = simulate(model=simulator, **simulation_params)

2023-08-03 14:10:51,844 - hbmep.model.baseline - INFO - Initialized model with hierarchical_bayesian link
2023-08-03 14:10:51,845 - hbmep_paper.utils.utils - INFO - Simulating data ...


2023-08-03 14:10:57,094 - hbmep.utils.utils - INFO - func:predict took: 5.24 sec
2023-08-03 14:10:57,108 - hbmep.utils.utils - INFO - func:simulate took: 5.26 sec


In [4]:
import numpy as np
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from hbmep.config import Config
from hbmep.model import Baseline
from hbmep.model.utils import Site as site

from hbmep_paper.utils.constants import HBM



class HierarchicalBayesian(Baseline):
    LINK = HBM

    def __init__(self, config: Config):
        super(HierarchicalBayesian, self).__init__(config=config)

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(150, 50, low=0)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(50))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(10))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    g_1 = numpyro.sample(site.g_1, dist.Exponential(0.01))
                    g_2 = numpyro.sample(site.g_2, dist.Exponential(0.01))

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[feature0, subject]
            + jnp.maximum(
                0,
                -1
                + (H[feature0, subject] + 1)
                / jnp.power(
                    1
                    + (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1)
                    * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
                    1 / v[feature0, subject]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[feature0, subject] + g_2[feature0, subject] * (1 / (mu + 1))
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )


model = HierarchicalBayesian(config=config)
df, encoder_dict = model.load(df=df)

df[model.response] = posterior_samples_true[site.obs][0, ...]

2023-08-03 14:10:57,130 - hbmep.model.baseline - INFO - Initialized model with hierarchical_bayesian link
2023-08-03 14:10:57,130 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/experiments-debug/gen/
2023-08-03 14:10:57,131 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/experiments-debug/gen/
2023-08-03 14:10:57,131 - hbmep.dataset.core - INFO - Processing data ...
2023-08-03 14:10:57,133 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
df[model.intensity].min()

4.0

In [6]:
ind = (df.compound_position == 2) & (df.participant == 0)

t = df[ind].reset_index().copy()

In [9]:
1 / 0.000032

31250.0

In [7]:
df[model.response].min()

biceps_auc    0.000032
dtype: float64

In [12]:
df[model.response].min()

biceps_auc    0.000016
dtype: float64

In [11]:
df.min()

participant          0.000000
compound_position    0.000000
pulse_amplitude      0.000000
biceps_auc           0.000016
dtype: float64

In [10]:
t.min()

index                180.000000
participant            0.000000
compound_position      2.000000
pulse_amplitude        0.000000
biceps_auc             0.000544
dtype: float64

In [5]:
model.plot(df=df, encoder_dict=encoder_dict)

2023-08-03 13:55:34,989 - hbmep.dataset.core - INFO - Plotting dataset ...


2023-08-03 13:55:55,005 - hbmep.dataset.core - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/experiments-debug/3/dataset.pdf
2023-08-03 13:55:55,006 - hbmep.utils.utils - INFO - func:plot took: 20.02 sec


In [6]:
mcmc, posterior_samples = model.run_inference(df=df)

2023-08-03 13:55:55,023 - hbmep.model.baseline - INFO - Running inference with hierarchical_bayesian ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
mcmc.print_summary(prob=.95)


                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]    112.72     79.32     97.83      4.30    241.61      4.83      1.39
  H[0,1,0]      1.40      0.37      1.21      1.10      2.25      6.18      1.48
  H[1,0,0]     57.84     49.20     41.17      0.30    156.82     75.57      1.07
  H[1,1,0]     58.65     61.10     42.98      0.01    177.44      5.23      1.31
  H[2,0,0]      5.28      1.17      4.73      3.95      7.59      2.68      2.04
  H[2,1,0]     45.11     45.23     32.64      0.15    137.81     16.48      1.11
  L[0,0,0]      0.01      0.00      0.01      0.01      0.01      7.12      1.25
  L[0,1,0]      0.01      0.00      0.01      0.01      0.01     21.07      1.13
  L[1,0,0]      0.07      0.00      0.07      0.06      0.07     30.08      1.11
  L[1,1,0]      0.01      0.00      0.01      0.01      0.01     27.69      1.12
  L[2,0,0]      0.00      0.00      0.00      0.00      0.00    190.83      1.02
  L[2,1,0]      0.00      0

In [None]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-08-02 09:49:09,277 - hbmep.model.baseline - INFO - Generating predictions ...
2023-08-02 09:49:14,397 - hbmep.utils.utils - INFO - func:predict took: 5.12 sec
2023-08-02 09:49:14,422 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-08-02 09:49:16,106 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/mle/recruitment_curves.pdf
2023-08-02 09:49:16,107 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 6.83 sec


In [None]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-08-02 09:49:16,146 - hbmep.model.baseline - INFO - Generating predictions ...
2023-08-02 09:49:21,272 - hbmep.utils.utils - INFO - func:predict took: 5.12 sec
2023-08-02 09:49:21,312 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-08-02 09:49:24,535 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/mle/posterior_predictive_check.pdf
2023-08-02 09:49:24,536 - hbmep.utils.utils - INFO - func:render_predictive_check took: 8.39 sec
