In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)

In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class RectifiedLogistic(Baseline):
    LINK = "rectified_logistic"

    def __init__(self, config: Config):
        super(RectifiedLogistic, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def fn(self, x, a, b, v, L, l, H):
        return (
            L
            + jnp.where(
                jnp.less(x, a),
                0.,
                -l + jnp.true_divide(
                    H + l,
                    jnp.power(
                        1
                        + jnp.multiply(
                            -1
                            + jnp.power(
                                jnp.true_divide(H + l, l),
                                v
                            ),
                            jnp.exp(-b * (x - a))
                        ),
                        1 / v
                    )
                )
            )
        )

    def fn_prime(self, x, *args):
        gradient = jax.grad(self.fn, argnums=0)
        for _ in range(len(x.shape)):
            gradient = jnp.vmap(gradient)
        return gradient(x, *args)

    def _model(self, subject, features, intensity, response_obs=None):
        if response_obs is not None: response_obs = jnp.log(response_obs)

        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]


        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(10))
            global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(10))

            L_location_global_location = numpyro.sample("L_location_global_location", dist.Normal(-5, 10))
            L_location_global_scale = numpyro.sample("L_location_global_scale", dist.HalfNormal(10))
            L_scale_global_scale = numpyro.sample("L_scale_global_scale", dist.HalfNormal(10))

            l_scale_global_scale = numpyro.sample("l_scale_global_scale", dist.HalfNormal(10))

            H_location_global_location = numpyro.sample("H_location_global_location", dist.Normal(0, 10))
            H_location_global_scale = numpyro.sample("H_location_global_scale", dist.HalfNormal(10))
            H_scale_global_scale = numpyro.sample("H_scale_global_scale", dist.HalfNormal(10))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(10))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(10))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=5))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=1 / 10))

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                sigma_v = numpyro.deterministic(site.sigma_v, global_sigma_v * sigma_v_raw)

                L_location_base = numpyro.sample("L_location_base", dist.Normal(loc=0, scale=1))
                L_location = numpyro.deterministic(
                    "L_location",
                    L_location_global_location + jnp.multiply(L_location_global_scale, L_location_base)
                )
                L_scale_base = numpyro.sample("L_scale_base", dist.HalfNormal(scale=1))
                L_scale = numpyro.deterministic(
                    "L_scale",
                    jnp.multiply(L_scale_global_scale, L_scale_base)
                )

                l_scale_base = numpyro.sample("l_scale_base", dist.HalfNormal(scale=1))
                l_scale = numpyro.deterministic(
                    "l_scale",
                    jnp.multiply(l_scale_global_scale, l_scale_base)
                )

                H_location_base = numpyro.sample("H_location_base", dist.Normal(loc=0, scale=1))
                H_location = numpyro.deterministic(
                    "H_location",
                    H_location_global_location + jnp.multiply(H_location_global_scale, H_location_base)
                )
                H_scale_base = numpyro.sample("H_scale_base", dist.HalfNormal(scale=1))
                H_scale = numpyro.deterministic(
                    "H_scale",
                    jnp.multiply(H_scale_global_scale, H_scale_base)
                )

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                    a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, sigma_b * b_raw)

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, sigma_v * v_raw)

                    L_base = numpyro.sample("L_base", dist.Normal(loc=0, scale=1))
                    L = numpyro.deterministic(
                        site.L,
                        L_location + jnp.multiply(L_scale, L_base)
                    )

                    l_base = numpyro.sample("l_base", dist.HalfNormal(scale=1))
                    l = numpyro.deterministic(
                        "l",
                        jnp.multiply(l_scale, l_base)
                    )

                    H_base = numpyro.sample("H_base", dist.Normal(loc=0, scale=1))
                    H = numpyro.deterministic(
                        site.H,
                        H_location + jnp.multiply(H_scale, H_base)
                    )

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfNormal(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfNormal(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                    p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    p = numpyro.deterministic("p", sigma_p * p_raw)

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                loc = numpyro.deterministic(
                    "loc",
                    self.fn(
                        x=intensity,
                        a=a[subject, feature0],
                        b=b[subject, feature0],
                        v=v[subject, feature0],
                        L=L[subject, feature0],
                        l=l[subject, feature0],
                        H=H[subject, feature0]
                    )
                )
                mu = numpyro.deterministic(
                    site.mu,
                    jnp.exp(loc)
                )
                scale = numpyro.deterministic(
                    "scale",
                    g_1[subject, feature0] + jnp.multiply(g_2[subject, feature0], jnp.power(jnp.exp(mu), p[subject, feature0]))
                )

                """ Observation """
                response = numpyro.sample(
                    "response_obs",
                    dist.Normal(loc=loc, scale=scale),
                    obs=response_obs
                )
                numpyro.deterministic(
                    site.obs,
                    jnp.exp(response)
                )



In [3]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/tms/link-comparison/proc-2023-10-27/uninjured-link-comparison/rectified_logistic.toml"
config = Config(toml_path=toml_path)
config.MCMC_PARAMS["num_warmup"] = 5000
config.MCMC_PARAMS["num_samples"] = 1000

model = RectifiedLogistic(config=config)


2023-11-06 14:16:59,747 - hbmep.config - INFO - Verifying configuration ...
2023-11-06 14:16:59,747 - hbmep.config - INFO - Success!
2023-11-06 14:16:59,763 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


In [5]:
src = "/home/vishu/data/hbmep-processed/human/tms/data_pkpk_auc_proc-2023-10-27.csv"
df = pd.read_csv(src)

# ind = df[model.features[0]].isin(["Uninjured"])
# df = df[ind].reset_index(drop=True).copy()

# subset = ["SCA01"]
subset = ["SCA01", "SCA02", "SCA03"]
ind = df[model.subject].isin(subset)
df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)
# df[model.response] = np.log(df[model.response])


2023-11-06 14:16:59,974 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/tms/link-comparison/proc-2023-10-27/uninjured-link-comparison/rectified_logistic
2023-11-06 14:16:59,974 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/tms/link-comparison/proc-2023-10-27/uninjured-link-comparison/rectified_logistic
2023-11-06 14:16:59,975 - hbmep.dataset.core - INFO - Processing data ...
2023-11-06 14:16:59,976 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [4]:
# model.plot(df=df, encoder_dict=encoder_dict)


In [6]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-11-06 14:17:00,503 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

2023-11-06 14:19:13,071 - hbmep.utils.utils - INFO - func:run_inference took: 2 min and 12.57 sec


In [10]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-11-06 14:52:33,892 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec


2023-11-06 14:52:34,269 - hbmep.utils.utils - INFO - func:predict took: 0.38 sec
2023-11-06 14:52:34,272 - hbmep.model.baseline - INFO - Rendering ...
2023-11-06 14:52:34,758 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/link-comparison/proc-2023-10-27/uninjured-link-comparison/rectified_logistic/recruitment_curves.pdf
2023-11-06 14:52:34,758 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 0.49 sec
2023-11-06 14:52:34,761 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-11-06 14:52:35,308 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/link-comparison/proc-2023-10-27/uninjured-link-comparison/rectified_logistic/posterior_predictive_check.pdf
2023-11-06 14:52:35,308 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 0.55 sec
2023-11-06 14:52:35,308 - hbmep.utils.utils - INFO - func:render_predictive_check took: 0.55 sec


In [9]:
mcmc.print_summary(prob=.95)



                                   mean       std    median      2.5%     97.5%     n_eff     r_hat
                H_base[0,0,0]      0.67      0.81      0.66     -1.11      1.63      4.99      1.45
                H_base[1,0,0]     -0.58      1.08     -0.81     -2.86      1.22      3.07      1.82
                H_base[2,0,0]     -0.25      0.67     -0.27     -1.74      0.58      4.43      1.50
         H_location_base[0,0]      0.26      0.48      0.41     -1.08      0.95     34.71      1.16
H_location_global_location[0]      1.20      5.85      1.00    -12.84     14.65     84.90      1.03
   H_location_global_scale[0]     20.13     28.83      6.31      0.26     87.36      6.00      1.35
            H_scale_base[0,0]      0.93      0.78      0.74      0.01      2.81      7.36      1.77
      H_scale_global_scale[0]     17.71     31.76      3.63      0.37     94.14     10.59      1.23
                L_base[0,0,0]      0.13      0.55      0.14     -0.77      1.17     20.14      1.21

In [9]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")


2023-11-06 14:51:35,654 - __main__ - INFO - Evaluating model ...
2023-11-06 14:51:35,956 - __main__ - INFO - ELPD LOO (Log): -113.63
See http://arxiv.org/abs/1507.04544 for details
2023-11-06 14:51:35,964 - __main__ - INFO - ELPD WAIC (Log): -113.11


In [None]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [None]:
dest = os.path.join(model.build_dir, "numpyro_data.nc")
az.to_netcdf(numpyro_data, dest)


'/home/vishu/repos/hbmep-paper/reports/paper/tms/link-comparison/rectified_logistic/numpyro_data.nc'