In [1]:
%reload_ext autoreload
%autoreload 2

import os
from pathlib import Path
import multiprocessing

import numpy as np
import jax
import numpyro

from hbmep.config import Config
from hbmep.model import Model
from hbmep.model import RectifiedLogistic
from hbmep.model.utils import Site as site

from hbmep_paper.utils import simulate, run_experiment

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()


#### Load config

In [2]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/experiments.toml")


In [3]:
# Initialize and validate configuration
config = Config(toml_path=toml_path)

config.BUILD_DIR = "/home/vishu/repos/hbmep-paper/reports/experiments-2/"


2023-08-01 13:55:12,269 - hbmep.config - INFO - Verifying configuration ...
2023-08-01 13:55:12,269 - hbmep.config - INFO - Success!


In [4]:
# Initialize model
model = Model(config=config)

2023-08-01 13:55:12,335 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


In [5]:
# Load data
simulation_params = {
    "n_subject": 3,
    "n_feature0": 10,
    "n_repeats": 10
}
df, posterior_samples_true = simulate(model=model.model, **simulation_params)

obs = np.array(posterior_samples_true[site.obs])

# ind = df[["participant", "compound_position"]].apply(tuple, axis=1).isin([(0, 0)])
# df = df[ind].copy()
# df.reset_index(drop=True, inplace=True)
# obs = obs[:, ind, :]

2023-08-01 13:55:12,400 - hbmep_paper.utils.utils - INFO - Simulating data ...
2023-08-01 13:55:14,540 - hbmep.utils.utils - INFO - func:predict took: 2.14 sec
2023-08-01 13:55:14,540 - hbmep.utils.utils - INFO - func:simulate took: 2.14 sec


In [6]:
df, encoder_dict = model.load(df=df)

df[model.model.response] = obs[-1, ...]

2023-08-01 13:55:14,632 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/experiments-2/
2023-08-01 13:55:14,633 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/experiments-2/
2023-08-01 13:55:14,633 - hbmep.dataset.core - INFO - Processing data ...
2023-08-01 13:55:14,635 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [7]:
model.plot(df=df, encoder_dict=encoder_dict)

2023-08-01 13:55:14,836 - hbmep.dataset.core - INFO - Plotting dataset ...
2023-08-01 13:55:19,693 - hbmep.dataset.core - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/experiments-2/dataset.pdf
2023-08-01 13:55:19,694 - hbmep.utils.utils - INFO - func:plot took: 4.86 sec


In [8]:
mcmc, posterior_samples = model.run_inference(df=df)

2023-08-01 13:55:19,763 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-08-01 14:35:00,887 - hbmep.utils.utils - INFO - func:run_inference took: 39 min and 41.12 sec


In [10]:
mcmc.print_summary(prob=0.95)


                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      1.37      0.96      1.14      0.18      3.24  19147.66      1.00
  H[0,1,0]      1.73      0.05      1.72      1.62      1.84  19871.45      1.00
  H[0,2,0]      2.39      3.38      1.11      0.00      9.30   7114.01      1.00
  H[1,0,0]      0.79      0.02      0.79      0.75      0.84  26267.49      1.00
  H[1,1,0]      1.05      0.02      1.05      1.01      1.08  25583.88      1.00
  H[1,2,0]      2.36      3.34      1.06      0.00      9.06   6755.63      1.00
  H[2,0,0]      0.33      0.01      0.33      0.31      0.34  25415.15      1.00
  H[2,1,0]      3.49      1.36      3.05      2.37      6.11   4601.65      1.00
  H[2,2,0]      2.64      3.37      1.37      0.03      9.43   6298.28      1.00
  H[3,0,0]      0.34      0.01      0.34      0.33      0.36  14814.42      1.00
  H[3,1,0]      1.12      0.02      1.12      1.08      1.15  19636.46      1.00
  H[3,2,0]      2.39      3

In [11]:
import arviz as az

az.__version__

'0.15.1'

In [15]:
numpyro_data = az.from_numpyro(mcmc)

In [19]:
posterior_samples["a"].shape

(24000, 10, 3, 1)

In [None]:
models = [RectifiedLogistic]

run_experiment(
    config=config, models=models, df=df, posterior_samples_true=posterior_samples_true
)

2023-07-31 15:19:56,754 - hbmep.model.baseline - INFO - Initialized model with baseline link
2023-07-31 15:19:56,755 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/experiments-2/
2023-07-31 15:19:56,755 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/experiments-2/
2023-07-31 15:19:56,755 - hbmep.dataset.core - INFO - Processing data ...
2023-07-31 15:19:56,756 - hbmep.utils.utils - INFO - func:load took: 0.00 sec
2023-07-31 15:19:56,758 - hbmep_paper.utils.utils - INFO - 


2023-07-31 15:19:56,758 - hbmep_paper.utils.utils - INFO -  Experiment: 1/2, Model: 1/1 (rectified_logistic) 
2023-07-31 15:19:56,758 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


2023-07-31 15:19:56,773 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-31 15:29:11,901 - hbmep.utils.utils - INFO - func:run_inference took: 9 min and 15.13 sec
2023-07-31 15:29:11,904 - hbmep_paper.utils.utils - INFO - Experiment: 1/2, Model: 1/1 (rectified_logistic), MSE: 0.5332164992237368, MAE: 0.493328597916945
2023-07-31 15:29:11,905 - hbmep_paper.utils.utils - INFO - Saving artefacts to /home/vishu/repos/hbmep-paper/reports/experiments-2/experiment_0/rectified_logistic
2023-07-31 15:29:11,905 - hbmep.model.baseline - INFO - Generating predictions ...
2023-07-31 15:29:17,100 - hbmep.utils.utils - INFO - func:predict took: 5.19 sec
2023-07-31 15:29:17,121 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-07-31 15:29:18,890 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/experiments-2/experiment_0/rectified_logistic/recruitment_curves.pdf
2023-07-31 15:29:18,891 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 6.99 sec
2023-07-31 15:29:18,891 - hbmep.model.baseline - INFO - Sa

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-31 15:36:30,399 - hbmep.utils.utils - INFO - func:run_inference took: 6 min and 55.15 sec
2023-07-31 15:36:30,401 - hbmep_paper.utils.utils - INFO - Experiment: 2/2, Model: 1/1 (rectified_logistic), MSE: 0.6190056204100881, MAE: 0.6837121217823968
2023-07-31 15:36:30,402 - hbmep_paper.utils.utils - INFO - Saving artefacts to /home/vishu/repos/hbmep-paper/reports/experiments-2/experiment_1/rectified_logistic
2023-07-31 15:36:30,402 - hbmep.model.baseline - INFO - Generating predictions ...
2023-07-31 15:36:35,531 - hbmep.utils.utils - INFO - func:predict took: 5.12 sec
2023-07-31 15:36:35,552 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-07-31 15:36:37,179 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/experiments-2/experiment_1/rectified_logistic/recruitment_curves.pdf
2023-07-31 15:36:37,179 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 6.78 sec
2023-07-31 15:36:37,179 - hbmep.model.baseline - INFO - S