In [1]:
%reload_ext autoreload
%autoreload 2

import os
from pathlib import Path
import multiprocessing

import jax
import pandas as pd
import numpy as np
import numpyro

from hbmep.config import Config
from hbmep.model import Model

from hbmep_paper.utils import simulate

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()


#### Load config

In [2]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/experiments.toml")


In [3]:
# Initialize and validate configuration
config = Config(toml_path=toml_path)


2023-07-31 11:44:52,757 - hbmep.config - INFO - Verifying configuration ...
2023-07-31 11:44:52,757 - hbmep.config - INFO - Success!


In [4]:
# Initialize model
model = Model(config=config)

2023-07-31 11:44:52,792 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


In [5]:
# Load data
df = simulate(model=model.model)

df, encoder_dict = model.load(df=df)


2023-07-31 11:44:52,902 - hbmep_paper.utils.utils - INFO - Simulating data ...


2023-07-31 11:44:55,147 - hbmep.utils.utils - INFO - func:predict took: 2.24 sec
2023-07-31 11:44:55,161 - hbmep.utils.utils - INFO - func:simulate took: 2.36 sec
2023-07-31 11:44:55,161 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/experiments/
2023-07-31 11:44:55,162 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/experiments/
2023-07-31 11:44:55,163 - hbmep.dataset.core - INFO - Processing data ...
2023-07-31 11:44:55,163 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [6]:
# Plot dataset
model.plot(df=df, encoder_dict=encoder_dict)

2023-07-31 11:44:55,175 - hbmep.dataset.core - INFO - Plotting dataset ...
2023-07-31 11:45:00,009 - hbmep.dataset.core - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/experiments/dataset.pdf
2023-07-31 11:45:00,010 - hbmep.utils.utils - INFO - func:plot took: 4.84 sec


In [7]:
# Run inference
mcmc, posterior_samples = model.run_inference(df=df)


2023-07-31 11:45:52,339 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-31 12:19:33,416 - hbmep.utils.utils - INFO - func:run_inference took: 33 min and 41.08 sec


In [8]:
mcmc.print_summary(prob=.95)



                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      0.56      0.01      0.56      0.55      0.57  48308.11      1.00
  H[0,1,0]      2.41      2.61      1.55      0.14      7.45  13755.01      1.00
  H[0,2,0]      1.30      1.52      0.84      0.06      3.99  11392.67      1.00
  H[1,0,0]      2.13      0.03      2.13      2.08      2.19  31473.29      1.00
  H[1,1,0]      3.01      2.63      2.19      0.50      8.05  12425.46      1.00
  H[1,2,0]      1.12      0.17      1.08      0.87      1.47   5813.68      1.00
  H[2,0,0]      1.91      0.02      1.91      1.88      1.94  33127.95      1.00
  H[2,1,0]      2.73      2.63      1.91      0.21      7.80  13956.79      1.00
  H[2,2,0]      1.27      1.33      0.86      0.32      3.40   9083.17      1.00
  H[3,0,0]      0.72      0.01      0.72      0.69      0.74  40837.43      1.00
  H[3,1,0]      2.38      2.66      1.55      0.05      7.38  14626.83      1.00
  H[3,2,0]      1.13      1

In [9]:
# Recruitment curves
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-31 12:19:34,098 - hbmep.model.baseline - INFO - Generating predictions ...


2023-07-31 12:20:00,136 - hbmep.utils.utils - INFO - func:predict took: 26.03 sec
2023-07-31 12:20:00,233 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-07-31 12:20:08,897 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/experiments/recruitment_curves.pdf
2023-07-31 12:20:08,898 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 34.80 sec


In [10]:
# Posterior Predictive Check
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-31 12:20:09,047 - hbmep.model.baseline - INFO - Generating predictions ...
2023-07-31 12:20:35,106 - hbmep.utils.utils - INFO - func:predict took: 26.05 sec
2023-07-31 12:20:35,281 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-07-31 12:20:51,602 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/experiments/posterior_predictive_check.pdf
2023-07-31 12:20:51,604 - hbmep.utils.utils - INFO - func:render_predictive_check took: 42.56 sec
