In [1]:
%reload_ext autoreload
%autoreload 2

import os
from pathlib import Path
import multiprocessing

import jax
import jax.numpy as jnp
import pandas as pd
import numpy as np
import numpyro
import numpyro.distributions as dist

from hbmep.config import Config
from hbmep.model import Baseline
from hbmep.model.utils import Site as site
from hbmep.utils.constants import RECTIFIED_LOGISTIC

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()


#### Load config

In [2]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/J_RCML_000.toml")

config = Config(toml_path=toml_path)

config.BUILD_DIR = "/home/vishu/repos/hbmep-paper/reports/J_RCML_000/notebook"


2023-07-18 15:35:10,731 - hbmep.config - INFO - Verifying configuration ...
2023-07-18 15:35:10,731 - hbmep.config - INFO - Success!


In [3]:
class RectifiedLogistic(Baseline):
    def __init__(self, config: Config):
        super(RectifiedLogistic, self).__init__(config=config)
        self.link = RECTIFIED_LOGISTIC

        self.mu_a = config.PRIORS[site.mu_a]
        self.sigma_a = config.PRIORS[site.sigma_a]

        self.sigma_b = config.PRIORS[site.sigma_b]

        self.sigma_L = config.PRIORS[site.sigma_L]
        self.sigma_H = config.PRIORS[site.sigma_H]
        self.sigma_v = config.PRIORS[site.sigma_v]

        self.g_1 = config.PRIORS[site.g_1]
        self.g_2 = config.PRIORS[site.g_2]

        self.p = config.PRIORS[site.p]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(150, 20, low=0)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(20))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.5))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(5))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    p = numpyro.sample(site.p, dist.HalfCauchy(10))

                    g_1 = numpyro.sample(
                        site.g_1, dist.HalfCauchy(20)
                    )
                    g_2 = numpyro.sample(
                        site.g_2, dist.HalfCauchy(20)
                    )

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[feature0, subject]
            + jnp.maximum(
                0,
                -1
                + (H[feature0, subject] + 1)
                / jnp.power(
                    1
                    + (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1)
                    * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
                    1 / v[feature0, subject]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[feature0, subject]
            + g_2[feature0, subject] * jnp.power(1 / (mu + 1), p[feature0, subject])
        )

        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(mu * beta, beta).to_event(1),
                obs=response_obs
            )


model = RectifiedLogistic(config=config)

In [4]:
df = pd.read_csv(model.csv_path)

""" Filter """
ind = df.compound_position.isin(["-C5M"])

df = df[ind].copy()
df.reset_index(drop=True, inplace=True)

df, encoder_dict = model.load(df=df)

2023-07-18 15:35:10,926 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/J_RCML_000/notebook
2023-07-18 15:35:10,927 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/J_RCML_000/notebook
2023-07-18 15:35:10,927 - hbmep.dataset.core - INFO - Processing data ...
2023-07-18 15:35:10,928 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
# class RectifiedLogistic(Baseline):
#     def __init__(self, config: Config):
#         super(RectifiedLogistic, self).__init__(config=config)
#         self.link = RECTIFIED_LOGISTIC

#         self.mu_a = config.PRIORS[site.mu_a]
#         self.sigma_a = config.PRIORS[site.sigma_a]

#         self.sigma_b = config.PRIORS[site.sigma_b]

#         self.sigma_L = config.PRIORS[site.sigma_L]
#         self.sigma_H = config.PRIORS[site.sigma_H]
#         self.sigma_v = config.PRIORS[site.sigma_v]

#         self.g_1 = config.PRIORS[site.g_1]
#         self.g_2 = config.PRIORS[site.g_2]

#         self.p = config.PRIORS[site.p]

#     def _model(self, subject, features, intensity, response_obs=None):
#         intensity = intensity.reshape(-1, 1)
#         intensity = np.tile(intensity, (1, self.n_response))

#         feature0 = features[0].reshape(-1,)

#         n_data = intensity.shape[0]
#         n_subject = np.unique(subject).shape[0]
#         n_feature0 = np.unique(feature0).shape[0]

#         with numpyro.plate(site.n_response, self.n_response, dim=-1):
#             with numpyro.plate(site.n_subject, n_subject, dim=-2):
#                 """ Hyper-priors """
#                 mu_a = numpyro.sample(
#                     site.mu_a,
#                     dist.TruncatedNormal(150, 20, low=0)
#                 )
#                 sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(20))

#                 sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(.1))

#                 sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(.05))
#                 sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
#                 sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(5))

#                 with numpyro.plate("n_feature0", n_feature0, dim=-3):
#                     """ Priors """
#                     a = numpyro.sample(
#                         site.a,
#                         dist.TruncatedNormal(mu_a, sigma_a, low=0)
#                     )
#                     b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

#                     L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
#                     H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
#                     v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

#                     g_1 = numpyro.sample(
#                         site.g_1, dist.HalfCauchy(20)
#                     )
#                     g_2 = numpyro.sample(
#                         site.g_2, dist.HalfCauchy(20)
#                     )

#                     p = numpyro.sample(site.p, dist.HalfNormal(10))

#         """ Model """
#         mu = numpyro.deterministic(
#             site.mu,
#             L[feature0, subject] + \
#             jnp.maximum(
#                 0,
#                 -1 + \
#                 (H[feature0, subject] + 1) / \
#                 jnp.power(
#                     1 + \
#                     (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1) * \
#                     jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
#                     1 / v[feature0, subject]
#                 )
#             )
#         )
#         beta = numpyro.deterministic(
#             site.beta,
#             g_1[feature0, subject] + \
#             g_2[feature0, subject] * jnp.power(1 / mu, p[feature0, subject])
#         )

#         with numpyro.plate(site.data, n_data):
#             return numpyro.sample(
#                 site.obs,
#                 dist.Gamma(mu * beta, beta).to_event(1),
#                 obs=response_obs
#             )


In [6]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-07-18 15:35:11,100 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-18 15:54:02,457 - hbmep.utils.utils - INFO - func:run_inference took: 18 min and 51.36 sec


In [7]:
mcmc.print_summary(prob=.95)



                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      3.02      2.96      1.95      0.78      9.00   5907.67      1.00
  H[0,0,1]      4.15      4.05      2.80      0.26     12.20  10644.48      1.00
  H[0,0,2]      1.27      1.69      0.76      0.45      3.84   4154.69      1.00
  H[0,1,0]      5.77      4.25      4.47      1.04     14.29  10581.94      1.00
  H[0,1,1]      0.11      0.32      0.09      0.06      0.14    498.78      1.01
  H[0,1,2]      3.57      3.92      2.21      0.13     11.33  11053.62      1.00
  H[0,2,0]      3.60      0.22      3.59      3.20      4.04   6761.58      1.00
  H[0,2,1]      4.62      4.07      3.30      0.68     12.72   8709.71      1.00
  H[0,2,2]      0.49      0.48      0.45      0.22      0.70   1647.59      1.00
  H[0,3,0]      0.90      0.06      0.90      0.80      1.01   7626.00      1.00
  H[0,3,1]      0.75      0.18      0.73      0.57      0.95   1403.32      1.00
  H[0,3,2]      2.19      1

In [8]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-18 15:54:03,214 - hbmep.model.baseline - INFO - Rendering recruitment curves ...


2023-07-18 15:54:30,308 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/J_RCML_000/notebook/recruitment_curves.pdf
2023-07-18 15:54:30,309 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 27.10 sec


In [9]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-18 15:54:30,623 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...


2023-07-18 15:55:03,452 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/J_RCML_000/notebook/posterior_predictive_check.pdf
2023-07-18 15:55:03,455 - hbmep.utils.utils - INFO - func:render_predictive_check took: 32.83 sec
