In [1]:
%reload_ext autoreload
%autoreload 2

import os
from pathlib import Path
import multiprocessing

import numpy as np
import jax
import numpyro

from hbmep.config import Config
from hbmep.model import Model
from hbmep.model.utils import Site as site

from hbmep_paper.model import HierarchicalBayesian, MaximumLikelihood
from hbmep_paper.utils import simulate, run_experiment

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()


In [2]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/experiments.toml")

config = Config(toml_path=toml_path)
config.BUILD_DIR = "/home/vishu/repos/hbmep-paper/reports/hb/"


2023-08-01 11:12:00,051 - hbmep.config - INFO - Verifying configuration ...
2023-08-01 11:12:00,051 - hbmep.config - INFO - Success!


In [3]:
# Initialize model
model = Model(config=config)

2023-08-01 11:12:00,112 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


In [4]:
# Load data
simulation_params = {
    "n_subject": 3,
    "n_feature0": 10,
    "n_repeats": 10
}
df, posterior_samples_true = simulate(model=model.model, **simulation_params)

obs = np.array(posterior_samples_true[site.obs])

# ind = df[["participant", "compound_position"]].apply(tuple, axis=1).isin([(0, 0)])
ind = df["participant"].isin([1])

df = df[ind].copy()
df.reset_index(drop=True, inplace=True)

obs = obs[:, ind, :]

2023-08-01 11:12:00,153 - hbmep_paper.utils.utils - INFO - Simulating data ...


2023-08-01 11:12:03,823 - hbmep.utils.utils - INFO - func:predict took: 3.67 sec
2023-08-01 11:12:03,823 - hbmep.utils.utils - INFO - func:simulate took: 3.67 sec


In [5]:
df, encoder_dict = model.load(df=df)

df[model.model.response] = obs[-1, ...]

2023-08-01 11:12:03,876 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/hb/
2023-08-01 11:12:03,876 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/hb/
2023-08-01 11:12:03,876 - hbmep.dataset.core - INFO - Processing data ...
2023-08-01 11:12:03,878 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [6]:
model.plot(df=df, encoder_dict=encoder_dict)

2023-08-01 11:12:03,950 - hbmep.dataset.core - INFO - Plotting dataset ...


2023-08-01 11:12:06,831 - hbmep.dataset.core - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/hb/dataset.pdf
2023-08-01 11:12:06,831 - hbmep.utils.utils - INFO - func:plot took: 2.88 sec


In [7]:
mcmc, posterior_samples = model.run_inference(df=df)

2023-08-01 11:12:06,883 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-08-01 11:29:20,851 - hbmep.utils.utils - INFO - func:run_inference took: 17 min and 13.97 sec


In [9]:
mcmc.print_summary(prob=.95)


                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      1.73      0.05      1.72      1.62      1.84  22464.84      1.00
  H[1,0,0]      1.05      0.02      1.05      1.01      1.08  28465.17      1.00
  H[2,0,0]      3.50      1.36      3.05      2.36      6.19   4193.27      1.00
  H[3,0,0]      1.12      0.02      1.12      1.08      1.15  23343.08      1.00
  H[4,0,0]      1.57      0.05      1.57      1.48      1.68  13004.32      1.00
  H[5,0,0]      3.84      0.05      3.84      3.74      3.94  27868.07      1.00
  H[6,0,0]      4.01      2.55      3.33      0.96      8.96  16587.49      1.00
  H[7,0,0]      3.85      0.04      3.85      3.78      3.92  24332.47      1.00
  H[8,0,0]      8.83      0.16      8.81      8.52      9.15  12538.41      1.00
  H[9,0,0]      1.88      0.02      1.88      1.85      1.92  24618.94      1.00
  L[0,0,0]      0.00      0.00      0.00      0.00      0.00  31003.95      1.00
  L[1,0,0]      0.05      0

In [10]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-08-01 11:30:01,349 - hbmep.model.baseline - INFO - Generating predictions ...


2023-08-01 11:30:15,783 - hbmep.utils.utils - INFO - func:predict took: 14.43 sec
2023-08-01 11:30:15,856 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-08-01 11:30:23,055 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/hb/recruitment_curves.pdf
2023-08-01 11:30:23,055 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 21.71 sec


In [11]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-08-01 11:30:23,128 - hbmep.model.baseline - INFO - Generating predictions ...


2023-08-01 11:30:37,665 - hbmep.utils.utils - INFO - func:predict took: 14.53 sec
2023-08-01 11:30:37,770 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-08-01 11:30:46,702 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/hb/posterior_predictive_check.pdf
2023-08-01 11:30:46,704 - hbmep.utils.utils - INFO - func:render_predictive_check took: 23.58 sec


(10, 10, 3, 1)

In [25]:
combinations = model.model._make_combinations(df=df, columns=model.model.combination_columns)
combinations = [[slice(None)] + list(c) for c in combinations]
combinations = [tuple(c[::-1]) for c in combinations]

a_true = np.array(posterior_samples_true[site.a])
a_true = a_true[9, :, 1, 0]

a = np.array(posterior_samples[site.a])     # n_posterior_samples x ... x n_muscles
a = a.mean(axis=0)      # ... x n_muscles
a = np.array([a[c] for c in combinations])      # n_combinations x n_muscles
a = a.reshape(-1, )

In [26]:
a_true

array([237.73228229, 232.88597116, 229.95561155, 236.35789639,
       231.92548811, 240.19545285, 232.17914188, 240.53928484,
       231.7539535 , 230.77401365])

In [27]:
a

array([236.72332566, 232.76817896, 229.92469533, 237.42854246,
       232.36914194, 240.28812928, 231.95209277, 240.73974195,
       231.66133812, 230.91878115])