In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class Logistic4(Baseline):
    LINK = "Logistic4"

    def __init__(self, config: Config):
        super(Logistic4, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        feature1 = features[1].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]
        n_feature1 = np.unique(feature1).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))
            global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                with numpyro.plate("n_feature1", n_feature1, dim=-3):
                    """ Hyper-priors """
                    mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=5))
                    sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=1))

                    sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                    sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                    sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                    sigma_L = numpyro.deterministic(site.sigma_L, global_sigma_L * sigma_L_raw)

                    sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                    sigma_H = numpyro.deterministic(site.sigma_H, global_sigma_H * sigma_H_raw)

                    sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                    sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                    sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                    sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                    sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                    sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                    with numpyro.plate(site.n_subject, n_subject, dim=-4):
                        """ Priors """
                        a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                        a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                        b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                        b = numpyro.deterministic(site.b, sigma_b * b_raw)

                        L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                        L = numpyro.deterministic(site.L, sigma_L * L_raw)

                        H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                        H = numpyro.deterministic(site.H, sigma_H * H_raw)

                        g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                        g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                        g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                        g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                        p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                        p = numpyro.deterministic("p", sigma_p * p_raw)

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[subject, feature1, feature0]
            + (
                H[subject, feature1, feature0]
                / (1 + jnp.exp(-b[subject, feature1, feature0] * (intensity - a[subject, feature1, feature0])))
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[subject, feature1, feature0] + g_2[subject, feature1, feature0] * ((1 / (mu + 1)) ** p[subject, feature1, feature0])
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )


In [3]:
root_path = Path(os.getcwd()).parent.parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/paper/rats/SHIE/Logistic4.toml")

config = Config(toml_path=toml_path)
config.MCMC_PARAMS["num_warmup"] = 4000
config.MCMC_PARAMS["num_samples"] = 6000
config.MCMC_PARAMS["thinning"] = 4

model = Logistic4(config=config)


2023-10-13 13:37:34,245 - hbmep.config - INFO - Verifying configuration ...
2023-10-13 13:37:34,245 - hbmep.config - INFO - Success!
2023-10-13 13:37:34,259 - hbmep.model.baseline - INFO - Initialized model with Logistic4 link


In [4]:
src = "/home/vishu/data/hbmep-processed/L_SHIE/data.csv"
df = pd.read_csv(src)

model.mep_matrix_path = None
# subset = ["amap01"]
# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-10-13 13:37:34,343 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/rats/SHIE/Logistic4
2023-10-13 13:37:34,344 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/rats/SHIE/Logistic4
2023-10-13 13:37:34,345 - hbmep.dataset.core - INFO - Processing data ...
2023-10-13 13:37:34,346 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-13 13:37:34,425 - hbmep.model.baseline - INFO - Running inference with Logistic4 ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
mcmc.print_summary(prob=.95)



                          mean       std    median      2.5%     97.5%     n_eff     r_hat
      H_raw[0,0,0,0]      0.99      0.55      0.87      0.17      2.04     83.57      1.05
      H_raw[0,0,0,1]      1.27      0.54      1.20      0.28      2.23    147.55      1.03
      H_raw[0,0,0,2]      1.06      0.55      0.92      0.24      2.10     90.36      1.05
      H_raw[0,0,0,3]      0.81      0.48      0.73      0.03      1.71     36.46      1.10
      H_raw[0,0,0,4]      0.98      0.52      0.87      0.20      2.11     56.58      1.06
      H_raw[0,0,0,5]      1.08      0.53      1.01      0.21      2.08    134.50      1.04
      H_raw[0,0,1,0]      0.76      0.56      0.62      0.05      1.94    171.37      1.01
      H_raw[0,0,1,1]      0.43      0.40      0.30      0.02      1.24    137.99      1.02
      H_raw[0,0,1,2]      0.28      0.35      0.15      0.00      1.00     84.62      1.06
      H_raw[0,0,1,3]      0.82      0.51      0.68      0.02      1.83     98.19      1.0

In [None]:
model.mep_matrix_path = None

prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-13 13:23:49,027 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec


2023-10-13 13:23:56,555 - hbmep.utils.utils - INFO - func:predict took: 7.53 sec
2023-10-13 13:23:56,585 - hbmep.model.baseline - INFO - Rendering ...
2023-10-13 13:24:06,942 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/SHIE/Logistic4/recruitment_curves.pdf
2023-10-13 13:24:06,943 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 10.38 sec
2023-10-13 13:24:07,032 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-13 13:24:18,178 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/SHIE/Logistic4/posterior_predictive_check.pdf
2023-10-13 13:24:18,179 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 11.24 sec
2023-10-13 13:24:18,180 - hbmep.utils.utils - INFO - func:render_predictive_check took: 11.24 sec


In [None]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")


2023-10-13 13:24:19,865 - __main__ - INFO - Evaluating model ...


In [None]:
score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

2023-10-13 13:24:22,085 - __main__ - INFO - ELPD LOO (Log): 4655.62


In [None]:
score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")

See http://arxiv.org/abs/1507.04544 for details
2023-10-13 13:24:22,144 - __main__ - INFO - ELPD WAIC (Log): 4675.73


In [None]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [12]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)
