In [1]:
%reload_ext autoreload
%autoreload 2

import os
import logging
import multiprocessing
from pathlib import Path
from tqdm import tqdm

import arviz as az
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns

import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import scipy.stats as stats
import numpyro
from numpyro.diagnostics import hpdi

from hbmep.config import Config
# from hbmep_paper.simulator import HierarchicalBayesianModel
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)

In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline
from hbmep_paper.utils.constants import HBM


class HierarchicalBayesianModel(Baseline):
    LINK = HBM

    def __init__(self, config: Config):
        super(HierarchicalBayesianModel, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(50, 50, low=0, high=100)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(50))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.5))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(10))

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0, high=100)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    g_1 = numpyro.sample(site.g_1, dist.Exponential(0.01))
                    g_2 = numpyro.sample(site.g_2, dist.Exponential(0.01))

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[subject, feature0]
            + jnp.maximum(
                0,
                -1
                + (H[subject, feature0] + 1)
                / jnp.power(
                    1
                    + (jnp.power(1 + H[subject, feature0], v[subject, feature0]) - 1)
                    * jnp.exp(-b[subject, feature0] * (intensity - a[subject, feature0])),
                    1 / v[subject, feature0]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[subject, feature0] + g_2[subject, feature0] * (1 / mu) ** 2
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )

In [3]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/human/tms/hbm-chains.toml")

config = Config(toml_path=toml_path)

model = HierarchicalBayesianModel(config=config)


2023-09-20 10:07:36,907 - hbmep.config - INFO - Verifying configuration ...
2023-09-20 10:07:36,908 - hbmep.config - INFO - Success!
2023-09-20 10:07:36,921 - hbmep.model.baseline - INFO - Initialized model with hierarchical_bayesian link


In [4]:
src = "/home/vishu/data/hbmep-processed/human/tms/data.csv"
df = pd.read_csv(src)

df, encoder_dict = model.load(df=df)


2023-09-20 10:07:36,963 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/human/tms/hbm-chains
2023-09-20 10:07:36,964 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/human/tms/hbm-chains
2023-09-20 10:07:36,965 - hbmep.dataset.core - INFO - Processing data ...
2023-09-20 10:07:36,966 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [6]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-09-20 10:07:37,044 - hbmep.model.baseline - INFO - Running inference with hierarchical_bayesian ...


  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

2023-09-20 10:14:50,574 - hbmep.utils.utils - INFO - func:run_inference took: 7 min and 13.53 sec


In [7]:
mcmc.print_summary(prob=.95)


                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      5.43      5.01      4.14      0.00     14.88    295.87      1.01
  H[0,0,1]      4.75      4.48      3.50      0.01     13.44    265.19      1.01
  H[0,0,2]      4.02      3.84      2.87      0.02     11.59    334.43      1.00
  H[0,0,3]      3.19      3.30      2.10      0.01     10.16    244.16      1.00
  H[0,0,4]      2.24      2.84      1.32      0.00      6.95    136.01      1.01
  H[0,0,5]      2.43      2.93      1.46      0.00      8.17    163.84      1.01
  H[0,1,0]      9.74      2.68      9.13      6.05     15.27    174.37      1.02
  H[0,1,1]      6.62      0.65      6.56      5.39      7.87    190.44      1.03
  H[0,1,2]      2.16      3.11      0.93      0.00      8.44     31.88      1.08
  H[0,1,3]      0.66      0.05      0.66      0.57      0.76    227.49      1.02
  H[0,1,4]      2.02      2.60      1.20      0.17      6.42    130.92      1.02
  H[0,1,5]      1.43      2

In [56]:
prediction_df = model.make_prediction_dataset(df=df)
prediction_df.shape

2023-09-20 11:47:17,750 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.00 sec


(600, 3)

In [57]:
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

2023-09-20 11:47:24,602 - hbmep.utils.utils - INFO - func:predict took: 5.58 sec


In [58]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-09-20 11:47:24,754 - hbmep.model.baseline - INFO - Rendering ...


2023-09-20 11:47:31,066 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/human/tms/hbm-chains/recruitment_curves.pdf
2023-09-20 11:47:31,067 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 6.33 sec


In [59]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-09-20 11:47:31,175 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...


2023-09-20 11:47:38,722 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/human/tms/hbm-chains/posterior_predictive_check.pdf
2023-09-20 11:47:38,724 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 7.58 sec
2023-09-20 11:47:38,724 - hbmep.utils.utils - INFO - func:render_predictive_check took: 7.58 sec


In [71]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)

In [72]:
import pickle

with open(dest, "rb") as g:
    model_, mcmc_, posterior_samples_ = pickle.load(g)
