In [2]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class Logistic5(Baseline):
    LINK = "Logistic5"

    def __init__(self, config: Config):
        super(Logistic5, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
            global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))
            global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=10))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=.5))

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                sigma_v = numpyro.deterministic(site.sigma_v, global_sigma_v * sigma_v_raw)

                sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                sigma_L = numpyro.deterministic(site.sigma_L, global_sigma_L * sigma_L_raw)

                sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                sigma_H = numpyro.deterministic(site.sigma_H, global_sigma_H * sigma_H_raw)

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                    a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, sigma_b * b_raw)

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, sigma_v * v_raw)

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, sigma_L * L_raw)

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, sigma_H * H_raw)

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                    p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    p = numpyro.deterministic("p", sigma_p * p_raw)

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[subject, feature0]
            + (
                H[subject, feature0]
                / jnp.power(
                    1 + v[subject, feature0] * jnp.exp(-b[subject, feature0] * (intensity - a[subject, feature0])),
                    1 / v[subject, feature0]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[subject, feature0] + g_2[subject, feature0] * ((1 / (mu + 1)) ** p[subject, feature0])
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )


In [3]:
root_path = Path(os.getcwd()).parent.parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/paper/rats/J_RCML_000/Logistic5.toml")

config = Config(toml_path=toml_path)
# config.MCMC_PARAMS["num_warmup"] = 4000
# config.MCMC_PARAMS["num_samples"] = 6000
# config.MCMC_PARAMS["thinning"] = 4

model = Logistic5(config=config)


2023-10-13 15:12:36,175 - hbmep.config - INFO - Verifying configuration ...
2023-10-13 15:12:36,175 - hbmep.config - INFO - Success!
2023-10-13 15:12:36,189 - hbmep.model.baseline - INFO - Initialized model with Logistic5 link


In [4]:
src = "/home/vishu/data/hbmep-processed/J_RCML_000/data.csv"
df = pd.read_csv(src)

model.mep_matrix_path = None
subset = ["amap01", "amap02"]
ind = df[model.subject].isin(subset)
df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-10-13 15:12:36,271 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/Logistic5
2023-10-13 15:12:36,272 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/Logistic5
2023-10-13 15:12:36,273 - hbmep.dataset.core - INFO - Processing data ...
2023-10-13 15:12:36,274 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
model.plot(df=df, encoder_dict=encoder_dict)

2023-10-13 15:12:36,347 - hbmep.model.baseline - INFO - Rendering ...
2023-10-13 15:12:46,864 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/Logistic5/dataset.pdf
2023-10-13 15:12:46,864 - hbmep.utils.utils - INFO - func:plot took: 10.52 sec


In [7]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-13 15:14:38,171 - hbmep.model.baseline - INFO - Running inference with Logistic5 ...


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

2023-10-13 15:22:21,019 - hbmep.utils.utils - INFO - func:run_inference took: 7 min and 42.85 sec


In [9]:
mcmc.print_summary(prob=.95)



                         mean       std    median      2.5%     97.5%     n_eff     r_hat
       H_raw[0,0,0]      1.56      1.10      1.69      0.02      3.23      2.03      8.83
       H_raw[0,0,1]      0.65      0.27      0.66      0.28      1.14      3.98      2.27
       H_raw[0,0,2]      0.70      0.22      0.64      0.35      1.07      3.95      2.04
       H_raw[0,0,3]      1.11      0.36      1.13      0.51      1.69      2.52      3.14
       H_raw[0,0,4]      1.40      0.41      1.40      0.75      2.07      2.42      3.09
       H_raw[0,0,5]      1.14      0.22      1.06      0.77      1.53      2.80      2.40
       H_raw[0,1,0]      0.93      0.92      0.56      0.01      2.51      3.34      2.33
       H_raw[0,1,1]      0.84      0.42      0.69      0.24      1.69      2.91      3.34
       H_raw[0,1,2]      0.60      0.25      0.57      0.23      1.06      3.51      2.04
       H_raw[0,1,3]      0.81      0.48      0.66      0.32      1.75      2.94      2.10
       H_

In [10]:
model.combination_columns

['compound_position', 'compound_charge_params', 'participant']

In [10]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-13 15:22:46,259 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.00 sec
2023-10-13 15:23:14,522 - hbmep.utils.utils - INFO - func:predict took: 28.26 sec
2023-10-13 15:23:14,601 - hbmep.model.baseline - INFO - Rendering ...
2023-10-13 15:23:57,457 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/Logistic5/recruitment_curves.pdf
2023-10-13 15:23:57,458 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 42.94 sec
2023-10-13 15:23:57,615 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-13 15:24:47,060 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/Logistic5/posterior_predictive_check.pdf
2023-10-13 15:24:47,062 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 49.60 sec
2023-10-13 15:24:47,062 - hbmep.utils.utils - INFO - func:render_predictive_check took: 49.60 sec


In [7]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")


2023-10-13 12:23:11,551 - __main__ - INFO - Evaluating model ...


In [8]:
score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
  return umr_sum(a, axis, dtype, out, keepdims, initial, where)
2023-10-13 12:23:16,771 - __main__ - INFO - ELPD LOO (Log): 4581.48


In [9]:
score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")

See http://arxiv.org/abs/1507.04544 for details
2023-10-13 12:23:16,850 - __main__ - INFO - ELPD WAIC (Log): 4610.55


In [13]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [12]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)
