In [1]:
%reload_ext autoreload
%autoreload 2

import os
from pathlib import Path
import multiprocessing

import numpy as np
import jax
import numpyro

from hbmep.config import Config
from hbmep.model import Model
from hbmep.model.utils import Site as site

from hbmep_paper.model import HierarchicalBayesian, MaximumLikelihood
from hbmep_paper.utils import simulate, run_experiment

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()


In [2]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/experiments.toml")

config = Config(toml_path=toml_path)
config.BUILD_DIR = "/home/vishu/repos/hbmep-paper/reports/hb-cp/"


2023-08-01 13:07:12,183 - hbmep.config - INFO - Verifying configuration ...
2023-08-01 13:07:12,184 - hbmep.config - INFO - Success!


In [3]:
import jax.numpy as jnp
import numpyro.distributions as dist

from hbmep.model import Baseline
from hbmep.utils.constants import RECTIFIED_LOGISTIC


class RectifiedLogistic(Baseline):
    LINK = RECTIFIED_LOGISTIC

    def __init__(self, config: Config):
        super(RectifiedLogistic, self).__init__(config=config)

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(150, 50, low=0)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(50))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(10))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    g_1 = numpyro.sample(site.g_1, dist.Exponential(0.01))
                    g_2 = numpyro.sample(site.g_2, dist.Exponential(0.01))

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    L[feature0, subject]
                    + jnp.maximum(
                        0,
                        -1
                        + (H[feature0, subject] + 1)
                        / jnp.power(
                            1
                            + (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1)
                            * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
                            1 / v[feature0, subject]
                        )
                    )
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[feature0, subject] + g_2[feature0, subject] * (1 / mu)
                )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )


In [4]:
model = RectifiedLogistic(config=config)

# Load data
simulation_params = {
    "n_subject": 3,
    "n_feature0": 10,
    "n_repeats": 10
}
df, posterior_samples_true = simulate(model=model, **simulation_params)
obs = np.array(posterior_samples_true[site.obs])

ind = df.compound_position.isin([0, 1, 2, 3, 4, 5, 6, 7, 8])
df = df[ind].copy()
df.reset_index(drop=True, inplace=True)
obs = obs[:, ind, :]

df, encoder_dict = model.load(df=df)

df[model.response] = obs[-1, ...]

print(df.shape)

2023-08-01 13:07:12,370 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link
2023-08-01 13:07:12,371 - hbmep_paper.utils.utils - INFO - Simulating data ...


2023-08-01 13:07:15,556 - hbmep.utils.utils - INFO - func:predict took: 3.18 sec
2023-08-01 13:07:15,557 - hbmep.utils.utils - INFO - func:simulate took: 3.19 sec
2023-08-01 13:07:15,558 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/hb-cp/
2023-08-01 13:07:15,559 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/hb-cp/
2023-08-01 13:07:15,559 - hbmep.dataset.core - INFO - Processing data ...
2023-08-01 13:07:15,561 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


(2430, 4)


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)

2023-08-01 13:07:15,618 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-08-01 13:43:13,332 - hbmep.utils.utils - INFO - func:run_inference took: 35 min and 57.71 sec


In [7]:
mcmc.print_summary(prob=.95)


                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      1.46      1.06      1.19      0.16      3.54  15859.33      1.00
  H[0,1,0]      1.72      0.05      1.72      1.62      1.83  19323.58      1.00
  H[0,2,0]      2.39      3.38      1.10      0.00      9.18   7360.66      1.00
  H[1,0,0]      0.79      0.02      0.79      0.75      0.84  26480.45      1.00
  H[1,1,0]      1.05      0.02      1.05      1.01      1.08  24409.16      1.00
  H[1,2,0]      2.34      3.35      1.06      0.00      9.02   6798.68      1.00
  H[2,0,0]      0.33      0.01      0.33      0.31      0.34  24222.87      1.00
  H[2,1,0]      3.55      1.63      3.03      2.33      6.51   3989.68      1.00
  H[2,2,0]      2.63      3.41      1.39      0.02      9.31   6866.62      1.00
  H[3,0,0]      0.34      0.01      0.34      0.33      0.36  15436.37      1.00
  H[3,1,0]      1.12      0.02      1.12      1.08      1.15  20063.45      1.00
  H[3,2,0]      2.40      3

In [None]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-08-01 12:20:06,025 - hbmep.model.baseline - INFO - Generating predictions ...


2023-08-01 12:20:22,094 - hbmep.utils.utils - INFO - func:predict took: 16.06 sec
2023-08-01 12:20:22,153 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-08-01 12:20:27,405 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/hb-cp/recruitment_curves.pdf
2023-08-01 12:20:27,405 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 21.38 sec


In [None]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-08-01 12:20:27,429 - hbmep.model.baseline - INFO - Generating predictions ...
2023-08-01 12:20:43,380 - hbmep.utils.utils - INFO - func:predict took: 15.95 sec
2023-08-01 12:20:43,480 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-08-01 12:20:53,592 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/hb-cp/posterior_predictive_check.pdf
2023-08-01 12:20:53,594 - hbmep.utils.utils - INFO - func:render_predictive_check took: 26.16 sec


In [None]:
combinations = model._make_combinations(df=df, columns=model.combination_columns)
combinations = [[slice(None)] + list(c) for c in combinations]
combinations = [tuple(c[::-1]) for c in combinations]

In [None]:
posterior_samples_true[site.a].shape

(10, 10, 3, 1)

In [None]:
a_true = np.array(posterior_samples_true[site.a]) 

a_true.shape

(10, 10, 3, 1)

In [None]:
a_true = a_true[9, :, 2, 0]

In [None]:
a_true.shape

(10,)

In [None]:
a_true = np.array(posterior_samples_true[site.a])       # n_repeats x ... x n_muscles
a_true = a_true[9, ...]      # ... x n_muscles
a_true = np.array([a_true[c] for c in combinations])    # n_combinations x n_muscles
a_true = a_true.reshape(-1, )

a = np.array(posterior_samples[site.a])     # n_posterior_samples x ... x n_muscles
a = a.mean(axis=0)      # ... x n_muscles
a = np.array([a[c] for c in combinations])      # n_combinations x n_muscles
a = a.reshape(-1, )

In [None]:
a_true

array([275.42469838, 340.28460268, 262.262484  , 323.92206672,
       344.05620336, 172.30172065, 195.74205704, 251.44240172,
       271.29628746, 344.83168449])

In [None]:
a

array([330.58424049, 332.24919439, 287.82485427, 312.69409475,
       340.68593192, 163.95536421, 195.46873313, 252.2635708 ,
       265.39433827, 334.82226092])