In [2]:
%reload_ext autoreload
%autoreload 2

import os
import logging
import multiprocessing
from pathlib import Path
from tqdm import tqdm

import arviz as az
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns

import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import scipy.stats as stats
import numpyro
from numpyro.diagnostics import hpdi

from hbmep.config import Config
# from hbmep_paper.simulator import HierarchicalBayesianModel
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [28]:
import numpyro.distributions as dist
from hbmep.model import Baseline
from hbmep_paper.utils.constants import HBM


class HierarchicalBayesianModel(Baseline):
    LINK = HBM

    def __init__(self, config: Config):
        super(HierarchicalBayesianModel, self).__init__(config=config)

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(150, 50, low=0)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(50))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(10))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    g_1 = numpyro.sample(site.g_1, dist.Exponential(0.01))
                    g_2 = numpyro.sample(site.g_2, dist.Exponential(0.01))

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[feature0, subject]
            + jnp.maximum(
                0,
                -1
                + (H[feature0, subject] + 1)
                / jnp.power(
                    1
                    + (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1)
                    * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
                    1 / v[feature0, subject]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[feature0, subject] + g_2[feature0, subject] * (1 / mu) ** 2
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )

In [29]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/simulation.toml")

config = Config(toml_path=toml_path)

model = HierarchicalBayesianModel(config=config)



2023-09-14 19:50:13,770 - hbmep.config - INFO - Verifying configuration ...
2023-09-14 19:50:13,770 - hbmep.config - INFO - Success!
2023-09-14 19:50:13,771 - hbmep.model.baseline - INFO - Initialized model with hierarchical_bayesian link


In [30]:
# Load data
simulation_params = {
    "n_subject": 10,
    "n_feature0": 1,
    "n_draws": 5,
    "n_repeats": 10
}

df_full, posterior_samples_true = model.simulate(**simulation_params)

obs = np.array(posterior_samples_true[site.obs])


2023-09-14 19:50:13,817 - hbmep.model.baseline - INFO - Simulating data ...


2023-09-14 19:50:17,813 - hbmep.utils.utils - INFO - func:predict took: 3.99 sec
2023-09-14 19:50:17,815 - hbmep.utils.utils - INFO - func:simulate took: 4.00 sec


In [31]:
df = df_full.copy()
df[model.response] = obs[1, 1, ...]

df, encoder_dict = model.load(df=df)
model.plot(df=df, encoder_dict=encoder_dict)

2023-09-14 19:50:17,842 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/simulation/
2023-09-14 19:50:17,843 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/simulation/
2023-09-14 19:50:17,844 - hbmep.dataset.core - INFO - Processing data ...
2023-09-14 19:50:17,845 - hbmep.utils.utils - INFO - func:load took: 0.00 sec
2023-09-14 19:50:17,847 - hbmep.dataset.core - INFO - Plotting dataset ...


2023-09-14 19:50:23,040 - hbmep.dataset.core - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/simulation/dataset.pdf
2023-09-14 19:50:23,041 - hbmep.utils.utils - INFO - func:plot took: 5.20 sec
