In [1]:
%reload_ext autoreload
%autoreload 2

import os
import logging
import multiprocessing
from pathlib import Path
from tqdm import tqdm

import arviz as az
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns

import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import scipy.stats as stats
import numpyro
from numpyro.diagnostics import hpdi

from hbmep.config import Config
from hbmep.distributions import GeneralizedExtremeValue as GEV
# from hbmep_paper.simulator import HierarchicalBayesianModel
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline
from hbmep_paper.utils.constants import HBM


class HierarchicalBayesianModel(Baseline):
    LINK = HBM

    def __init__(self, config: Config):
        super(HierarchicalBayesianModel, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        feature1 = features[1].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]
        n_feature1 = np.unique(feature1).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
            global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))
            global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))
            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                with numpyro.plate("n_feature1", n_feature0, dim=-3):
                    """ Hyper-priors """
                    mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=5))
                    sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=1))

                    sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                    sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                    sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                    sigma_v = numpyro.deterministic(site.sigma_v, global_sigma_v * sigma_v_raw)

                    sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                    sigma_L = numpyro.deterministic(site.sigma_L, global_sigma_L * sigma_L_raw)

                    sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                    sigma_H = numpyro.deterministic(site.sigma_H, global_sigma_H * sigma_H_raw)

                    sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                    sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                    sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                    sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                    sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                    sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                    with numpyro.plate(site.n_subject, n_subject, dim=-4):
                        """ Priors """
                        a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                        a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                        b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                        b = numpyro.deterministic(site.b, sigma_b * b_raw)

                        v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                        v = numpyro.deterministic(site.v, sigma_v * v_raw)

                        L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                        L = numpyro.deterministic(site.L, sigma_L * L_raw)

                        H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                        H = numpyro.deterministic(site.H, sigma_H * H_raw)

                        g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                        g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                        g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                        g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                        p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                        p = numpyro.deterministic("p", sigma_p * p_raw)

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[subject, feature0, feature1]
            + (
                H[subject, feature0, feature1]
                / jnp.power(
                    1 + v[subject, feature0, feature1] * jnp.exp(-b[subject, feature0, feature1] * (intensity - a[subject, feature0, feature1])),
                    1 / v[subject, feature0, feature1]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[subject, feature0, feature1] + g_2[subject, feature0, feature1] * ((1 / mu) ** p[subject, feature0, feature1])
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )


In [3]:
root_path = Path(os.getcwd()).parent.parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "/home/vishu/repos/hbmep-paper/configs/paper/rats/SHIE/generalized-logistic.toml")

config = Config(toml_path=toml_path)
# config.MCMC_PARAMS["num_warmup"] = 8000
# config.MCMC_PARAMS["num_samples"] = 4000

model = HierarchicalBayesianModel(config=config)


2023-10-11 17:11:17,863 - hbmep.config - INFO - Verifying configuration ...
2023-10-11 17:11:17,863 - hbmep.config - INFO - Success!
2023-10-11 17:11:17,885 - hbmep.model.baseline - INFO - Initialized model with hierarchical_bayesian link


In [4]:
src = "/home/vishu/data/hbmep-processed/L_SHIE/data.csv"
df = pd.read_csv(src)

model.mep_matrix_path = None
subset = ["amap01"]
ind = df[model.subject].isin(subset)
df = df[ind].reset_index(drop=True).copy()

# ind = df.pulse_amplitude.isin([0])
# df = df[~ind].reset_index(drop=True).copy()


df, encoder_dict = model.load(df=df)


2023-10-11 17:11:17,940 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/rats/SHIE/generalized-sigmoid
2023-10-11 17:11:17,940 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/rats/SHIE/generalized-sigmoid
2023-10-11 17:11:17,941 - hbmep.dataset.core - INFO - Processing data ...
2023-10-11 17:11:17,942 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-11 17:11:18,031 - hbmep.model.baseline - INFO - Running inference with hierarchical_bayesian ...


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

2023-10-11 17:13:56,121 - hbmep.utils.utils - INFO - func:run_inference took: 2 min and 38.09 sec


In [6]:
mcmc.print_summary(prob=.95)



                          mean       std    median      2.5%     97.5%     n_eff     r_hat
      H_raw[0,0,0,0]      1.03      0.46      1.07      0.25      1.88      3.49      1.71
      H_raw[0,0,0,1]      0.98      0.50      0.90      0.22      1.93      5.04      1.43
      H_raw[0,0,0,2]      1.69      0.71      1.51      0.68      3.11      4.84      1.58
      H_raw[0,0,0,3]      0.93      0.45      0.80      0.31      1.85      4.14      1.96
      H_raw[0,0,0,4]      0.95      0.37      0.87      0.40      1.58      5.36      1.40
      H_raw[0,0,0,5]      0.70      0.42      0.54      0.16      1.43      3.44      1.80
      H_raw[0,0,1,0]      1.10      0.98      0.64      0.24      3.25      2.19      3.95
      H_raw[0,0,1,1]      0.85      0.45      0.84      0.11      1.58      3.85      1.76
      H_raw[0,0,1,2]      1.17      0.50      1.19      0.24      2.06      3.65      2.03
      H_raw[0,0,1,3]      1.05      0.48      1.14      0.23      1.77      5.49      1.7

In [8]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-12 11:24:06,320 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec


2023-10-12 11:24:16,032 - hbmep.utils.utils - INFO - func:predict took: 9.71 sec
2023-10-12 11:24:16,071 - hbmep.model.baseline - INFO - Rendering ...
2023-10-12 11:24:27,904 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/SHIE/generalized-sigmoid/recruitment_curves.pdf
2023-10-12 11:24:27,904 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 11.86 sec
2023-10-12 11:24:27,945 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-12 11:24:39,041 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/SHIE/generalized-sigmoid/posterior_predictive_check.pdf
2023-10-12 11:24:39,042 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 11.14 sec
2023-10-12 11:24:39,042 - hbmep.utils.utils - INFO - func:render_predictive_check took: 11.14 sec
