In [1]:
%reload_ext autoreload
%autoreload 2

import os
import logging
import multiprocessing

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

import arviz as az
import jax
import jax.numpy as jnp
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class RectifiedLogistic(Baseline):
    LINK = "rectified_logistic"

    def __init__(self, config: Config):
        super(RectifiedLogistic, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def fn(self, x, a, b, v, L, l, H):
        return (
            L
            + jnp.where(
                jnp.less(x, a),
                0.,
                -l + jnp.true_divide(
                    H + l,
                    jnp.power(
                        1
                        + jnp.multiply(
                            -1
                            + jnp.power(
                                jnp.true_divide(H + l, l),
                                v
                            ),
                            jnp.exp(jnp.multiply(-b, x - a))
                        ),
                        jnp.true_divide(1, v)
                    )
                )
            )
        )

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            """ Global Priors """
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
            global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))
            global_sigma_l = numpyro.sample("global_sigma_l", dist.HalfNormal(100))
            global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                a_mean =  numpyro.sample("a_mean", dist.HalfNormal(scale=50))
                a_shape = numpyro.sample("a_shape", dist.HalfNormal(scale=100))

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, jnp.multiply(global_sigma_b, sigma_b_raw))

                sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                sigma_v = numpyro.deterministic(site.sigma_v, jnp.multiply(global_sigma_v, sigma_v_raw))

                sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                sigma_L = numpyro.deterministic(site.sigma_L, jnp.multiply(global_sigma_L, sigma_L_raw))

                sigma_l_raw = numpyro.sample("sigma_l_raw", dist.HalfNormal(scale=1))
                sigma_l = numpyro.deterministic("sigma_l", jnp.multiply(global_sigma_l, sigma_l_raw))

                sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                sigma_H = numpyro.deterministic(site.sigma_H, jnp.multiply(global_sigma_H, sigma_H_raw))

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", jnp.multiply(global_sigma_g_1, sigma_g_1_raw))

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", jnp.multiply(global_sigma_g_2, sigma_g_2_raw))

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=a_shape, rate=1))
                    a = numpyro.deterministic(site.a, jnp.true_divide(jnp.multiply(a_raw, a_mean), a_shape))

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, jnp.multiply(sigma_b, b_raw))

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, jnp.multiply(sigma_v, v_raw))

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, jnp.multiply(sigma_L, L_raw))

                    l_raw = numpyro.sample("l_raw", dist.HalfNormal(scale=1))
                    l = numpyro.deterministic("l", jnp.multiply(sigma_l, l_raw))

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, jnp.multiply(sigma_H, H_raw))

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, jnp.multiply(sigma_g_1, g_1_raw))

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, jnp.multiply(sigma_g_2, g_2_raw))

        # """ Outlier Distribution """
        # outlier_dist_shape = numpyro.sample("outlier_dist_loc", dist.HalfNormal(5))
        # outlier_dist_rate = numpyro.sample("outlier_dist_scale", dist.HalfNormal(1))

        # """ Mixture """
        # if response_obs is not None:
        #     outlier_prob = numpyro.sample("outlier_prob", dist.Uniform(0.0, 0.2))
        # else: # Turn off mixture when predicting
        #     outlier_prob = numpyro.deterministic("outlier_prob", 0.)

        # mixing_distribution = dist.Categorical(probs=jnp.array([1 - outlier_prob, outlier_prob]))

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    self.fn(
                        x=intensity,
                        a=a[subject, feature0],
                        b=b[subject, feature0],
                        v=v[subject, feature0],
                        L=L[subject, feature0],
                        l=l[subject, feature0],
                        H=H[subject, feature0]
                    )
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[subject, feature0] + jnp.true_divide(g_2[subject, feature0], mu)
                )

                # """ Mixture """
                # component_distributions = [
                #     dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                #     dist.Gamma(concentration=outlier_dist_shape, rate=outlier_dist_rate)
                # ]
                # Mixture = dist.MixtureGeneral(
                #     mixing_distribution=mixing_distribution,
                #     component_distributions=component_distributions
                # )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                    obs=response_obs
                )


In [13]:
def f(a, **kwargs):
    print(kwargs["b"])

f(a=1, b=2)

2


In [None]:
posterior_samples

In [3]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/tms/injury-threshold-comparison/fcr.toml"

config = Config(toml_path=toml_path)
config.BUILD_DIR = os.path.join(config.BUILD_DIR, "non_mixture")

config.MCMC_PARAMS["num_warmup"] = 500
config.MCMC_PARAMS["num_samples"] = 1000

model = RectifiedLogistic(config=config)


2023-11-09 13:59:03,394 - hbmep.config - INFO - Verifying configuration ...
2023-11-09 13:59:03,395 - hbmep.config - INFO - Success!


2023-11-09 13:59:03,418 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


In [4]:
src = "/home/vishu/data/hbmep-processed/human/tms/data_pkpk_auc_proc-2023-10-27.csv"
df = pd.read_csv(src)

df[model.features[0]] = \
    df[model.features[0]].replace({
        "Uninjured": "01_Uninjured",
        "SCI": "02_SCI"
    })

# ind = df[model.features[0]].isin(["Uninjured"])
# df = df[ind].reset_index(drop=True).copy()

# subset = ["SCA02"]
# # subset = ["SCA01", "SCA02", "SCA03", "SCS01", "SCS02", "SCS03"]
# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)
# encoder_dict[model.features[0]].inverse_transform([0, 1])


2023-11-09 13:59:03,487 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/tms/injury-threshold-comparison/fcr/non_mixture
2023-11-09 13:59:03,487 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/tms/injury-threshold-comparison/fcr/non_mixture
2023-11-09 13:59:03,488 - hbmep.dataset.core - INFO - Processing data ...
2023-11-09 13:59:03,489 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-11-09 13:59:03,579 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

2023-11-09 14:03:07,732 - hbmep.utils.utils - INFO - func:run_inference took: 4 min and 4.15 sec


In [8]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-11-09 14:07:56,270 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.00 sec
2023-11-09 14:07:59,106 - hbmep.utils.utils - INFO - func:predict took: 2.84 sec
2023-11-09 14:07:59,117 - hbmep.model.baseline - INFO - Rendering ...
2023-11-09 14:08:01,898 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/injury-threshold-comparison/fcr/non_mixture/recruitment_curves.pdf
2023-11-09 14:08:01,899 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 2.79 sec
2023-11-09 14:08:01,913 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-11-09 14:08:04,925 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/injury-threshold-comparison/fcr/non_mixture/posterior_predictive_check.pdf
2023-11-09 14:08:04,926 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 3.03 sec
2023-11-09 14:08:04,926 - hbmep.utils.utils - INFO - func:render_predictive_check took: 3.

In [9]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")


2023-11-09 14:08:06,123 - __main__ - INFO - Evaluating model ...
2023-11-09 14:08:07,519 - __main__ - INFO - ELPD LOO (Log): 3387.78
See http://arxiv.org/abs/1507.04544 for details
2023-11-09 14:08:07,578 - __main__ - INFO - ELPD WAIC (Log): 3392.15
