In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class RectifiedLogistic(Baseline):
    LINK = "rectified_logistic"

    def __init__(self, config: Config):
        super(RectifiedLogistic, self).__init__(config=config)
        # self.combination_columns = self.features + [self.subject]

    def fn(self, x, a, b, v, L, l, H):
        return (
            L
            + jnp.where(
                jnp.less(x, a),
                0.,
                -l + jnp.true_divide(
                    H + l,
                    jnp.power(
                        1
                        + jnp.multiply(
                            -1
                            + jnp.power(
                                jnp.true_divide(H + l, l),
                                v
                            ),
                            jnp.exp(jnp.multiply(-b, x - a))
                        ),
                        jnp.true_divide(1, v)
                    )
                )
            )
        )

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
            global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))
            global_sigma_l = numpyro.sample("global_sigma_l", dist.HalfNormal(100))
            global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))
            # global_sigma_g_3 = numpyro.sample("global_sigma_g_3", dist.HalfNormal(100))

            # global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=150))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(100))

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, jnp.multiply(global_sigma_b, sigma_b_raw))

                sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                sigma_v = numpyro.deterministic(site.sigma_v, jnp.multiply(global_sigma_v, sigma_v_raw))

                sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                sigma_L = numpyro.deterministic(site.sigma_L, jnp.multiply(global_sigma_L, sigma_L_raw))

                sigma_l_raw = numpyro.sample("sigma_l_raw", dist.HalfNormal(scale=1))
                sigma_l = numpyro.deterministic("sigma_l", jnp.multiply(global_sigma_l, sigma_l_raw))

                sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                sigma_H = numpyro.deterministic(site.sigma_H, jnp.multiply(global_sigma_H, sigma_H_raw))

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", jnp.multiply(global_sigma_g_1, sigma_g_1_raw))

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", jnp.multiply(global_sigma_g_2, sigma_g_2_raw))

                # sigma_g_3_raw = numpyro.sample("sigma_g_3_raw", dist.HalfNormal(scale=1))
                # sigma_g_3 = numpyro.deterministic("sigma_g_3", jnp.multiply(global_sigma_g_3, sigma_g_3_raw))

                # sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                # sigma_p = numpyro.deterministic("sigma_p", jnp.multiply(global_sigma_p, sigma_p_raw))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=sigma_a, rate=1))
                    a = numpyro.deterministic(site.a, jnp.true_divide(jnp.multiply(a_raw, mu_a), sigma_a))

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, jnp.multiply(sigma_b, b_raw))

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, jnp.multiply(sigma_v, v_raw))

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, jnp.multiply(sigma_L, L_raw))

                    l_raw = numpyro.sample("l_raw", dist.HalfNormal(scale=1))
                    l = numpyro.deterministic("l", jnp.multiply(sigma_l, l_raw))

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, jnp.multiply(sigma_H, H_raw))

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, jnp.multiply(sigma_g_1, g_1_raw))

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, jnp.multiply(sigma_g_2, g_2_raw))

                    # g_3_raw = numpyro.sample("g_3_raw", dist.HalfCauchy(scale=1))
                    # g_3 = numpyro.deterministic("g_3", jnp.multiply(sigma_g_3, g_3_raw))

                    # p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    # p = numpyro.deterministic("p", jnp.multiply(sigma_p, p_raw))

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    self.fn(
                        x=intensity,
                        a=a[feature0, subject],
                        b=b[feature0, subject],
                        v=v[feature0, subject],
                        L=L[feature0, subject],
                        l=l[feature0, subject],
                        H=H[feature0, subject]
                    )
                )
                beta = numpyro.deterministic(
                    site.beta, g_1[feature0, subject] + jnp.true_divide(g_2[feature0, subject], mu)
                )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                    obs=response_obs
                )


In [3]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/rats/J_RCML_000/link-comparison/rectified_logistic.toml"

config = Config(toml_path=toml_path)
config.BUILD_DIR = os.path.join(config.BUILD_DIR, "constant_3")
config.MCMC_PARAMS["num_warmup"] = 5000
config.MCMC_PARAMS["num_samples"] = 1000
# config.MCMC_PARAMS["thinning"] = 4

model = RectifiedLogistic(config=config)


2023-11-08 16:21:42,036 - hbmep.config - INFO - Verifying configuration ...
2023-11-08 16:21:42,036 - hbmep.config - INFO - Success!
2023-11-08 16:21:42,057 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


In [4]:
src = "/home/vishu/data/hbmep-processed/J_RCML_000/data.csv"
df = pd.read_csv(src)

# # # model.mep_matrix_path = None
# subset = ["amap01", "amap02", "amap03", "amap04"]
# subset = [("amap01", "-C5L"), ("amap01", "-C6L"), ("amap01", "-C7L")]
# subset = [("amap01", "-C5L"), ("amap01", "-C6L")]
# ind = df[model.combination_columns].apply(tuple, axis=1).isin(subset)

# subset = ["amap01"]
# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-11-08 16:21:42,103 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/constant_3
2023-11-08 16:21:42,104 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/constant_3
2023-11-08 16:21:42,105 - hbmep.dataset.core - INFO - Processing data ...
2023-11-08 16:21:42,108 - hbmep.utils.utils - INFO - func:load took: 0.01 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-11-08 16:21:42,126 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

In [7]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-11-09 09:16:51,824 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec


2023-11-09 09:17:14,099 - hbmep.utils.utils - INFO - func:predict took: 22.27 sec
2023-11-09 09:17:14,162 - hbmep.model.baseline - INFO - Rendering ...
2023-11-09 09:17:52,355 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/constant_3/recruitment_curves.pdf
2023-11-09 09:17:52,356 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 38.26 sec
2023-11-09 09:17:52,478 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-11-09 09:18:31,538 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/constant_3/posterior_predictive_check.pdf
2023-11-09 09:18:31,540 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 39.18 sec
2023-11-09 09:18:31,540 - hbmep.utils.utils - INFO - func:render_predictive_check took: 39.18 sec


In [6]:
mcmc.print_summary(prob=.95)



                         mean       std    median      2.5%     97.5%     n_eff     r_hat
       H_raw[0,0,0]      1.79      0.54      1.74      0.83      2.87    174.82      1.03
       H_raw[0,1,0]      0.93      0.27      0.91      0.42      1.42     95.40      1.03
       H_raw[0,2,0]      1.21      0.33      1.19      0.61      1.92     38.24      1.11
       H_raw[0,3,0]      1.06      0.43      0.98      0.41      1.95    141.11      1.03
       H_raw[0,4,0]      1.22      0.22      1.22      0.75      1.65     44.08      1.18
       H_raw[0,5,0]      1.66      0.50      1.58      0.80      2.68    271.83      1.00
       H_raw[0,6,0]      1.06      0.34      1.01      0.50      1.75    193.02      1.01
       H_raw[0,7,0]      1.22      0.44      1.14      0.52      2.17     99.39      1.04
       H_raw[1,0,0]      1.36      0.50      1.29      0.52      2.31    279.63      1.00
       H_raw[1,1,0]      1.01      0.54      0.90      0.21      2.10    324.52      1.01
       H_

In [8]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")



2023-11-09 09:22:51,384 - __main__ - INFO - Evaluating model ...
2023-11-09 09:22:58,340 - __main__ - INFO - ELPD LOO (Log): 6128.19
See http://arxiv.org/abs/1507.04544 for details
2023-11-09 09:22:58,879 - __main__ - INFO - ELPD WAIC (Log): 6180.08


In [11]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [12]:
dest = os.path.join(model.build_dir, "numpyro_data.nc")
az.to_netcdf(numpyro_data, dest)


'/home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/numpyro_data.nc'

In [None]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)


In [16]:
len(posterior_samples.keys())

46