In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class RectifiedLogistic(Baseline):
    LINK = "rectified_logistic"

    def __init__(self, config: Config):
        super(RectifiedLogistic, self).__init__(config=config)
        # self.combination_columns = self.features + [self.subject]

    def fn(self, x, a, b, v, L, l, H):
        return (
            L
            + jnp.where(
                jnp.less(x, a),
                0.,
                -l + jnp.true_divide(
                    H + l,
                    jnp.power(
                        1
                        + jnp.multiply(
                            -1
                            + jnp.power(
                                jnp.true_divide(H + l, l),
                                v
                            ),
                            jnp.exp(-b * (x - a))
                        ),
                        1 / v
                    )
                )
            )
        )

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
            global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))
            global_sigma_l = numpyro.sample("global_sigma_l", dist.HalfNormal(100))
            global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            # global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=150))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(100))

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, jnp.multiply(global_sigma_b, sigma_b_raw))

                sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                sigma_v = numpyro.deterministic(site.sigma_v, jnp.multiply(global_sigma_v, sigma_v_raw))

                sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                sigma_L = numpyro.deterministic(site.sigma_L, jnp.multiply(global_sigma_L, sigma_L_raw))

                sigma_l_raw = numpyro.sample("sigma_l_raw", dist.HalfNormal(scale=1))
                sigma_l = numpyro.deterministic("sigma_l", jnp.multiply(global_sigma_l, sigma_l_raw))

                sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                sigma_H = numpyro.deterministic(site.sigma_H, jnp.multiply(global_sigma_H, sigma_H_raw))

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", jnp.multiply(global_sigma_g_1, sigma_g_1_raw))

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", jnp.multiply(global_sigma_g_2, sigma_g_2_raw))

                # sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                # sigma_p = numpyro.deterministic("sigma_p", jnp.multiply(global_sigma_p, sigma_p_raw))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=sigma_a, rate=1))
                    a = numpyro.deterministic(site.a, jnp.true_divide(jnp.multiply(a_raw, mu_a), sigma_a))

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, jnp.multiply(sigma_b, b_raw))

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, jnp.multiply(sigma_v, v_raw))

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, jnp.multiply(sigma_L, L_raw))

                    l_raw = numpyro.sample("l_raw", dist.HalfNormal(scale=1))
                    l = numpyro.deterministic("l", jnp.multiply(sigma_l, l_raw))

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, jnp.multiply(sigma_H, H_raw))

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, jnp.multiply(sigma_g_1, g_1_raw))

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, jnp.multiply(sigma_g_2, g_2_raw))

                    # p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    # p = numpyro.deterministic("p", jnp.multiply(sigma_p, p_raw))

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    self.fn(
                        x=intensity,
                        a=a[feature0, subject],
                        b=b[feature0, subject],
                        v=v[feature0, subject],
                        L=L[feature0, subject],
                        l=l[feature0, subject],
                        H=H[feature0, subject]
                    )
                )
                beta = numpyro.deterministic(
                    site.beta, jnp.true_divide(g_2[feature0, subject], g_1[feature0, subject] + mu)
                )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                    obs=response_obs
                )


In [3]:
# import numpyro.distributions as dist
# from hbmep.model import Baseline
# from tensorflow_probability.substrates.jax import distributions as tfd


# class RectifiedLogistic(Baseline):
#     LINK = "rectified_logistic"

#     def __init__(self, config: Config):
#         super(RectifiedLogistic, self).__init__(config=config)
#         # self.combination_columns = self.features + [self.subject]

#     def fn(self, x, a, b, v, L, l, H):
#         return (
#             L
#             + jnp.where(
#                 jnp.less(x, a),
#                 0.,
#                 -l + jnp.true_divide(
#                     H + l,
#                     jnp.power(
#                         1
#                         + jnp.multiply(
#                             -1
#                             + jnp.power(
#                                 jnp.true_divide(H + l, l),
#                                 v
#                             ),
#                             jnp.exp(-b * (x - a))
#                         ),
#                         1 / v
#                     )
#                 )
#             )
#         )

#     def _model(self, subject, features, intensity, response_obs=None):
#         if response_obs is not None: response_obs = np.log(response_obs)

#         subject, n_subject = subject
#         features, n_features = features
#         intensity, n_data = intensity

#         intensity = intensity.reshape(-1, 1)
#         intensity = np.tile(intensity, (1, self.n_response))

#         feature0 = features[0].reshape(-1,)
#         n_feature0 = n_features[0]

#         with numpyro.plate(site.n_response, self.n_response, dim=-1):
#             global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
#             global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))

#             L_location_global_location = numpyro.sample("L_location_global_location", dist.Normal(-5, 10))
#             L_location_global_scale = numpyro.sample("L_location_global_scale", dist.HalfNormal(10))
#             L_scale_global_scale = numpyro.sample("L_scale_global_scale", dist.HalfNormal(10))

#             global_sigma_l = numpyro.sample("global_sigma_l", dist.HalfNormal(100))

#             H_location_global_location = numpyro.sample("H_location_global_location", dist.Normal(-5, 10))
#             H_location_global_scale = numpyro.sample("H_location_global_scale", dist.HalfNormal(10))
#             H_scale_global_scale = numpyro.sample("H_scale_global_scale", dist.HalfNormal(10))

#             global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
#             global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

#             # global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

#             with numpyro.plate(site.n_subject, n_subject, dim=-2):
#                 """ Hyper-priors """
#                 mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=150))
#                 sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(100))

#                 sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
#                 sigma_b = numpyro.deterministic(site.sigma_b, jnp.multiply(global_sigma_b, sigma_b_raw))

#                 sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
#                 sigma_v = numpyro.deterministic(site.sigma_v, jnp.multiply(global_sigma_v, sigma_v_raw))

#                 L_location_raw = numpyro.sample("L_location_raw", dist.Normal(loc=0, scale=1))
#                 L_location = numpyro.deterministic("L_location", L_location_global_location + jnp.multiply(L_location_global_scale, L_location_raw))
#                 L_scale_raw = numpyro.sample("L_scale_raw", dist.HalfNormal(scale=1))
#                 L_scale = numpyro.deterministic("L_scale", jnp.multiply(L_scale_global_scale, L_scale_raw))

#                 sigma_l_raw = numpyro.sample("sigma_l_raw", dist.HalfNormal(scale=1))
#                 sigma_l = numpyro.deterministic("sigma_l", jnp.multiply(global_sigma_l, sigma_l_raw))

#                 H_location_raw = numpyro.sample("H_location_raw", dist.Normal(loc=0, scale=1))
#                 H_location = numpyro.deterministic("H_location", H_location_global_location + jnp.multiply(H_location_global_scale, H_location_raw))
#                 H_scale_raw = numpyro.sample("H_scale_raw", dist.HalfNormal(scale=1))
#                 H_scale = numpyro.deterministic("H_scale", jnp.multiply(H_scale_global_scale, H_scale_raw))

#                 sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
#                 sigma_g_1 = numpyro.deterministic("sigma_g_1", jnp.multiply(global_sigma_g_1, sigma_g_1_raw))

#                 sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
#                 sigma_g_2 = numpyro.deterministic("sigma_g_2", jnp.multiply(global_sigma_g_2, sigma_g_2_raw))

#                 # sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
#                 # sigma_p = numpyro.deterministic("sigma_p", jnp.multiply(global_sigma_p, sigma_p_raw))

#                 with numpyro.plate("n_feature0", n_feature0, dim=-3):
#                     """ Priors """
#                     a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=sigma_a, rate=1))
#                     a = numpyro.deterministic(site.a, jnp.true_divide(jnp.multiply(a_raw, mu_a), sigma_a))

#                     b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
#                     b = numpyro.deterministic(site.b, jnp.multiply(sigma_b, b_raw))

#                     v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
#                     v = numpyro.deterministic(site.v, jnp.multiply(sigma_v, v_raw))

#                     L_raw = numpyro.sample("L_raw", dist.Normal(loc=0, scale=1))
#                     L = numpyro.deterministic(site.L, L_location + jnp.multiply(L_scale, L_raw))

#                     l_raw = numpyro.sample("l_raw", dist.HalfNormal(scale=1))
#                     l = numpyro.deterministic("l", jnp.multiply(sigma_l, l_raw))

#                     H_raw = numpyro.sample("H_raw", dist.Normal(loc=0, scale=1))
#                     H = numpyro.deterministic(site.H, H_location + jnp.multiply(H_scale, H_raw))

#                     g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
#                     g_1 = numpyro.deterministic(site.g_1, jnp.multiply(sigma_g_1, g_1_raw))

#                     g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
#                     g_2 = numpyro.deterministic(site.g_2, jnp.multiply(sigma_g_2, g_2_raw))

#                     # p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
#                     # p = numpyro.deterministic("p", jnp.multiply(sigma_p, p_raw))

#         with numpyro.plate(site.n_response, self.n_response, dim=-1):
#             with numpyro.plate(site.data, n_data, dim=-2):
#                 """ Model """
#                 _mu = numpyro.deterministic(
#                     "_mu",
#                     self.fn(
#                         x=intensity,
#                         a=a[feature0, subject],
#                         b=b[feature0, subject],
#                         v=v[feature0, subject],
#                         L=L[feature0, subject],
#                         l=l[feature0, subject],
#                         H=H[feature0, subject]
#                     )
#                 )
#                 mu = numpyro.deterministic(
#                     site.mu,
#                     jnp.exp(_mu)
#                 )
#                 # c = numpyro.deterministic(
#                 #     "c", g_1[subject, feature0] + jnp.true_divide(
#                 #         g_2[subject, feature0],
#                 #         jnp.power(mu, 1 + p[subject, feature0])
#                 #     )
#                 # )
#                 beta = numpyro.deterministic(
#                     site.beta, g_1[feature0, subject] + jnp.true_divide(g_2[feature0, subject], mu)
#                 )

#                 """ Observation """
#                 # numpyro.sample(
#                 #     site.obs,
#                 #     dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
#                 #     obs=response_obs
#                 # )
#                 log_obs = numpyro.sample(
#                     "log_obs",
#                     tfd.ExpGamma(concentration=jnp.multiply(mu, beta), rate=beta),
#                     obs=response_obs
#                 )
#                 numpyro.deterministic(
#                     site.obs,
#                     jnp.exp(log_obs)
#                 )


In [4]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/rats/J_RCML_000/link-comparison/rectified_logistic.toml"

config = Config(toml_path=toml_path)
config.BUILD_DIR = os.path.join(config.BUILD_DIR, "constant")
config.MCMC_PARAMS["num_warmup"] = 5000
config.MCMC_PARAMS["num_samples"] = 1000
# config.MCMC_PARAMS["thinning"] = 4

model = RectifiedLogistic(config=config)


2023-11-08 12:48:00,325 - hbmep.config - INFO - Verifying configuration ...
2023-11-08 12:48:00,326 - hbmep.config - INFO - Success!
2023-11-08 12:48:00,339 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


In [5]:
src = "/home/vishu/data/hbmep-processed/J_RCML_000/data.csv"
df = pd.read_csv(src)

# # # model.mep_matrix_path = None
# subset = ["amap01", "amap02", "amap03", "amap04"]
# subset = [("amap01", "-C5L"), ("amap01", "-C6L"), ("amap01", "-C7L")]
# # subset = [("amap01", "-C5L"), ("amap01", "-C6L")]
# ind = df[model.combination_columns].apply(tuple, axis=1).isin(subset)

# subset = ["amap01"]
# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-11-08 12:48:00,451 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/constant
2023-11-08 12:48:00,451 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/constant
2023-11-08 12:48:00,452 - hbmep.dataset.core - INFO - Processing data ...
2023-11-08 12:48:00,454 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [6]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-11-08 12:48:00,497 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

2023-11-08 14:46:46,426 - hbmep.utils.utils - INFO - func:run_inference took: 1 hr and 58 min


In [11]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-11-08 15:12:34,360 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec


2023-11-08 15:12:56,869 - hbmep.utils.utils - INFO - func:predict took: 22.51 sec
2023-11-08 15:12:56,950 - hbmep.model.baseline - INFO - Rendering ...
2023-11-08 15:13:37,018 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/constant/recruitment_curves.pdf
2023-11-08 15:13:37,019 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 40.13 sec
2023-11-08 15:13:37,140 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-11-08 15:14:17,655 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/constant/posterior_predictive_check.pdf
2023-11-08 15:14:17,656 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 40.64 sec
2023-11-08 15:14:17,657 - hbmep.utils.utils - INFO - func:render_predictive_check took: 40.64 sec


In [10]:
mcmc.print_summary(prob=.95)



                         mean       std    median      2.5%     97.5%     n_eff     r_hat
       H_raw[0,0,0]      1.68      0.50      1.62      0.81      2.69    169.80      1.03
       H_raw[0,1,0]      0.96      0.28      0.94      0.45      1.53    147.82      1.02
       H_raw[0,2,0]      1.26      0.33      1.24      0.58      1.86     66.77      1.07
       H_raw[0,3,0]      1.05      0.41      0.97      0.39      1.87    210.85      1.03
       H_raw[0,4,0]      1.20      0.19      1.19      0.82      1.61     71.73      1.08
       H_raw[0,5,0]      1.62      0.52      1.53      0.73      2.63    268.57      1.01
       H_raw[0,6,0]      1.16      0.34      1.11      0.55      1.86    172.84      1.02
       H_raw[0,7,0]      1.22      0.43      1.14      0.49      2.06    127.21      1.03
       H_raw[1,0,0]      1.40      0.52      1.33      0.56      2.41    275.96      1.01
       H_raw[1,1,0]      0.97      0.53      0.86      0.19      2.04    468.17      1.01
       H_

In [9]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")



2023-11-08 15:06:59,610 - __main__ - INFO - Evaluating model ...
2023-11-08 15:07:06,423 - __main__ - INFO - ELPD LOO (Log): 5999.53
See http://arxiv.org/abs/1507.04544 for details
2023-11-08 15:07:06,865 - __main__ - INFO - ELPD WAIC (Log): 6053.96


In [11]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [12]:
dest = os.path.join(model.build_dir, "numpyro_data.nc")
az.to_netcdf(numpyro_data, dest)


'/home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/numpyro_data.nc'

In [None]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)


In [16]:
len(posterior_samples.keys())

46