In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline

class ReLU(Baseline):
    LINK = "RectifiedLinearUnit"

    def __init__(self, config: Config):
        super(ReLU, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=5))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=1))

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                sigma_L = numpyro.deterministic(site.sigma_L, global_sigma_L * sigma_L_raw)

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                    a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, sigma_b * b_raw)

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, sigma_L * L_raw)

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                    p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    p = numpyro.deterministic("p", sigma_p * p_raw)

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    L[subject, feature0]
                    + jnp.where(
                        intensity <= a[subject, feature0],
                        0,
                        b[subject, feature0] * (intensity - a[subject, feature0])
                    )
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[subject, feature0] + g_2[subject, feature0] * ((1 / (mu + 1)) ** p[subject, feature0])
                )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    dist.Gamma(concentration=mu * beta, rate=beta),
                    obs=response_obs
                )


In [3]:
toml_path = os.path.join("/home/vishu/repos/hbmep-paper/configs/paper/tms/ReLU.toml")


config = Config(toml_path=toml_path)
# config.MCMC_PARAMS["num_warmup"] = 40000
# config.MCMC_PARAMS["num_samples"] = 60000
# config.MCMC_PARAMS["thinning"] = 4

model = ReLU(config=config)


2023-10-17 10:44:53,245 - hbmep.config - INFO - Verifying configuration ...
2023-10-17 10:44:53,246 - hbmep.config - INFO - Success!


2023-10-17 10:44:53,444 - hbmep.model.baseline - INFO - Initialized model with RectifiedLinearUnit link


In [4]:
src = "/home/vishu/data/hbmep-processed/human/tms/data_pkpk_auc.csv"
df = pd.read_csv(src)

df, encoder_dict = model.load(df=df)


2023-10-17 10:44:53,479 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/tms/ReLU
2023-10-17 10:44:53,480 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/tms/ReLU
2023-10-17 10:44:53,482 - hbmep.dataset.core - INFO - Processing data ...
2023-10-17 10:44:53,483 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-17 10:44:53,515 - hbmep.model.baseline - INFO - Running inference with RectifiedLinearUnit ...


  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

In [6]:
mcmc.print_summary(prob=.95)



                         mean       std    median      2.5%     97.5%     n_eff     r_hat
       L_raw[0,0,0]      0.67      0.59      0.50      0.00      1.81     42.07      1.08
       L_raw[0,0,1]      0.88      0.60      0.72      0.04      2.12     66.65      1.03
       L_raw[0,0,2]      0.75      0.57      0.66      0.00      1.83     16.27      1.17
       L_raw[0,0,3]      0.69      0.54      0.56      0.00      1.72     27.03      1.21
       L_raw[0,0,4]      0.79      0.61      0.67      0.01      1.96     24.03      1.12
       L_raw[0,0,5]      0.79      0.52      0.67      0.03      1.78     72.43      1.09
       L_raw[0,1,0]      0.34      0.10      0.33      0.18      0.53     27.88      1.27
       L_raw[0,1,1]      0.65      0.26      0.57      0.25      1.22      4.64      1.99
       L_raw[0,1,2]      0.79      0.27      0.78      0.33      1.25      3.06      2.29
       L_raw[0,1,3]      0.74      0.19      0.70      0.43      1.18     13.78      1.36
       L_

In [7]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-12 14:55:12,629 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec
2023-10-12 14:55:19,963 - hbmep.utils.utils - INFO - func:predict took: 7.33 sec
2023-10-12 14:55:19,985 - hbmep.model.baseline - INFO - Rendering ...
2023-10-12 14:55:29,692 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/ReLU/recruitment_curves.pdf
2023-10-12 14:55:29,693 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 9.73 sec
2023-10-12 14:55:29,733 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-12 14:55:39,718 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/ReLU/posterior_predictive_check.pdf
2023-10-12 14:55:39,719 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 10.03 sec
2023-10-12 14:55:39,719 - hbmep.utils.utils - INFO - func:render_predictive_check took: 10.03 sec


In [8]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")


2023-10-12 14:56:28,138 - __main__ - INFO - Evaluating model ...


In [9]:
score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
2023-10-12 14:57:31,182 - __main__ - INFO - ELPD LOO (Log): 8905.49


In [10]:
score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")

See http://arxiv.org/abs/1507.04544 for details
2023-10-12 14:58:26,507 - __main__ - INFO - ELPD WAIC (Log): 8916.92


In [13]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [12]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)
