In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class Logistic3(Baseline):
    LINK = "Logistic3"

    def __init__(self, config: Config):
        super(Logistic3, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))

            global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=5))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=1))

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                sigma_H = numpyro.deterministic(site.sigma_H, global_sigma_H * sigma_H_raw)

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                    a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, sigma_b * b_raw)

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, sigma_H * H_raw)

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                    p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    p = numpyro.deterministic("p", sigma_p * p_raw)

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            H[subject, feature0]
            / (1 + jnp.exp(-b[subject, feature0] * (intensity - a[subject, feature0])))
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[subject, feature0] + g_2[subject, feature0] * ((1 / (mu + 1)) ** p[subject, feature0])
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )


In [3]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/paper/tms/Logistic3.toml")

config = Config(toml_path=toml_path)
# config.MCMC_PARAMS["num_warmup"] = 40000
# config.MCMC_PARAMS["num_samples"] = 60000
# config.MCMC_PARAMS["thinning"] = 4

model = Logistic3(config=config)


2023-10-12 16:06:36,682 - hbmep.config - INFO - Verifying configuration ...
2023-10-12 16:06:36,683 - hbmep.config - INFO - Success!
2023-10-12 16:06:36,713 - hbmep.model.baseline - INFO - Initialized model with Logistic3 link


In [4]:
src = "/home/vishu/data/hbmep-processed/human/tms/data_pkpk_auc.csv"
df = pd.read_csv(src)

# subset = ["SCA01"]
# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-10-12 16:06:36,775 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/tms/Logistic3
2023-10-12 16:06:36,775 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/tms/Logistic3
2023-10-12 16:06:36,776 - hbmep.dataset.core - INFO - Processing data ...
2023-10-12 16:06:36,777 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-12 16:06:37,022 - hbmep.model.baseline - INFO - Running inference with Logistic3 ...


  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

2023-10-12 16:21:06,299 - hbmep.utils.utils - INFO - func:run_inference took: 14 min and 29.28 sec


In [6]:
mcmc.print_summary(prob=.95)



                         mean       std    median      2.5%     97.5%     n_eff     r_hat
       H_raw[0,0,0]      0.79      0.60      0.67      0.00      1.95   1526.44      1.00
       H_raw[0,0,1]      0.80      0.61      0.68      0.00      1.96   1563.10      1.00
       H_raw[0,0,2]      0.78      0.59      0.68      0.00      1.92   1099.43      1.00
       H_raw[0,0,3]      0.80      0.60      0.67      0.00      1.95   1032.31      1.00
       H_raw[0,0,4]      0.79      0.60      0.69      0.00      1.96   1181.82      1.00
       H_raw[0,0,5]      0.79      0.61      0.66      0.00      2.00   1317.94      1.00
       H_raw[0,1,0]      0.80      0.24      0.78      0.35      1.28    196.80      1.01
       H_raw[0,1,1]      0.52      0.15      0.51      0.25      0.82    177.85      1.02
       H_raw[0,1,2]      0.14      0.13      0.11      0.01      0.34    339.53      1.00
       H_raw[0,1,3]      0.74      0.24      0.73      0.27      1.17    150.07      1.02
       H_

In [7]:
model.mep_matrix_path = None

prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-12 16:21:06,672 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec
2023-10-12 16:21:13,918 - hbmep.utils.utils - INFO - func:predict took: 7.25 sec
2023-10-12 16:21:13,938 - hbmep.model.baseline - INFO - Rendering ...
2023-10-12 16:21:22,072 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/Logistic3/recruitment_curves.pdf
2023-10-12 16:21:22,073 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 8.15 sec
2023-10-12 16:21:22,113 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-12 16:21:31,696 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/Logistic3/posterior_predictive_check.pdf
2023-10-12 16:21:31,697 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 9.62 sec
2023-10-12 16:21:31,698 - hbmep.utils.utils - INFO - func:render_predictive_check took: 9.62 sec


In [8]:
((posterior_predictive[site.mu] * posterior_predictive[site.beta]) == 0).sum()

Array(0, dtype=int64)

In [9]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")


ValueError: Gamma distribution got invalid concentration parameter.

In [None]:
score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

2023-10-12 16:05:54,556 - __main__ - INFO - ELPD LOO (Log): 1257.97


In [None]:
score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")

See http://arxiv.org/abs/1507.04544 for details
2023-10-12 16:05:54,665 - __main__ - INFO - ELPD WAIC (Log): 1258.94


In [None]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [None]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)
