In [1]:
%reload_ext autoreload
%autoreload 2

import os
import logging
import multiprocessing
from pathlib import Path

import numpyro
from hbmep.config import MepConfig
from hbmep.dataset import MepDataset
from hbmep.models import Model

numpyro.set_platform("cpu")
cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()

FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO)
logger = logging.getLogger(__name__)


#### Load config


In [2]:
# Path to toml file
root_path = Path(os.getcwd()).parent.absolute()
toml_path = os.path.join(root_path, "config.toml")
logger.info(f"Toml path - {toml_path}")

# Load config and validate
config = MepConfig(toml_path=toml_path)


2023-07-11 09:13:36,178 - __main__ - INFO - Toml path - /home/vishu/repos/hbmep/config.toml
2023-07-11 09:13:36,178 - hbmep.config - INFO - Verifying configuration ...
2023-07-11 09:13:36,179 - hbmep.config - INFO - Success!


#### Load data and preprocess

In [3]:
# Initialize dataset
data = MepDataset(config=config)

# Preprocess data
df, encoder_dict = data.build()


2023-07-11 09:13:36,197 - hbmep.dataset.core - INFO - Initialized /home/vishu/repos/hbmep/reports/test_run_01 for storing artefacts
2023-07-11 09:13:36,198 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep/reports/test_run_01
2023-07-11 09:13:36,198 - hbmep.dataset.core - INFO - Reading data from /home/vishu/data/mock.csv ...
2023-07-11 09:13:36,200 - hbmep.dataset.core - INFO - Processing data ...
2023-07-11 09:13:36,202 - hbmep.utils.utils - INFO - func:build took: 0.00 sec


#### Visualize dataset

In [4]:
data.plot(df=df, encoder_dict=encoder_dict)


2023-07-11 09:13:37,115 - hbmep.dataset.core - INFO - Saved to /home/vishu/repos/hbmep/reports/test_run_01/dataset.pdf
2023-07-11 09:13:37,115 - hbmep.utils.utils - INFO - func:plot took: 0.90 sec


#### Initialize model

In [5]:
model = Model(config=config)


2023-07-11 09:14:59,486 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-07-11 09:14:59,486 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-07-11 09:14:59,486 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
2023-07-11 09:14:59,487 - jax._src.xla_bridge - INFO - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.


Prior predictive check: We can draw from the model to see if it correctly specifies our prior knowledge

In [6]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict)


2023-07-11 09:15:11,452 - hbmep.models.baseline - INFO - Saved to /home/vishu/repos/hbmep/reports/test_run_01/prior_predictive_check.pdf
2023-07-11 09:15:11,452 - hbmep.utils.utils - INFO - func:render_predictive_check took: 3.93 sec


#### Run MCMC inference

In [7]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-07-11 09:15:53,416 - hbmep.models.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-11 09:20:27,584 - hbmep.utils.utils - INFO - func:run_inference took: 4 min and 34.17 sec


#### Diagnostics

In [8]:
mcmc.print_summary(prob=.95)



                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      6.01      0.10      6.01      5.82      6.21  27135.74      1.00
  H[0,0,1]      3.78      0.06      3.78      3.66      3.89  35528.31      1.00
  H[0,1,0]     15.38     24.24      6.37      0.00     61.71  15822.76      1.00
  H[0,1,1]      2.51      7.24      0.85      0.00      9.28  13582.87      1.00
  H[1,0,0]      4.28      0.19      4.27      3.93      4.68  16206.16      1.00
  H[1,0,1]      1.49      0.04      1.49      1.42      1.56  27128.11      1.00
  H[1,1,0]     15.59     24.64      6.36      0.00     63.18  15866.57      1.00
  H[1,1,1]      2.44      6.42      0.85      0.00      9.22  12944.19      1.00
  H[2,0,0]     11.04     15.44      6.23      0.00     37.90  17663.81      1.00
  H[2,0,1]      7.37     11.79      3.72      0.00     26.94  16387.25      1.00
  H[2,1,0]      0.35      0.02      0.35      0.31      0.39  18217.75      1.00
  H[2,1,1]      0.56      0

#### Plot recruitment curves

In [9]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-11 09:20:39,288 - hbmep.models.baseline - INFO - Saved to /home/vishu/repos/hbmep/reports/test_run_01/recruitment_curves.pdf
2023-07-11 09:20:39,288 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 10.58 sec


#### Posterior Predictive Check

We can now supply the posterior samples to `render_predictive_check` method to inspect how well our model is able to explain the data

In [10]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-11 09:20:51,900 - hbmep.models.baseline - INFO - Saved to /home/vishu/repos/hbmep/reports/test_run_01/posterior_predictive_check.pdf
2023-07-11 09:20:51,902 - hbmep.utils.utils - INFO - func:render_predictive_check took: 12.54 sec
