In [1]:
%reload_ext autoreload
%autoreload 2

import os
from pathlib import Path
import multiprocessing

import jax
import jax.numpy as jnp
import pandas as pd
import numpy as np
import numpyro
import numpyro.distributions as dist

from hbmep.config import Config
from hbmep.model import Baseline
from hbmep.model.utils import Site as site
from hbmep.utils.constants import RECTIFIED_LOGISTIC

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()


#### Load config

In [2]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/studentT.toml")

config = Config(toml_path=toml_path)

config.BUILD_DIR = "/home/vishu/repos/hbmep-paper/reports/fix-3"


2023-07-21 13:35:02,929 - hbmep.config - INFO - Verifying configuration ...
2023-07-21 13:35:02,929 - hbmep.config - INFO - Success!


In [3]:
class StudentT(Baseline):
    LINK = "studentT"

    def __init__(self, config: Config):
        super(StudentT, self).__init__(config=config)

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(150, 50, low=0)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(50))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.5))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(10))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    nu_1 = numpyro.sample("nu_1", dist.HalfCauchy(1))
                    nu_2 = numpyro.sample("nu_2", dist.HalfCauchy(1))

                    sigma_1 = numpyro.sample("sigma_1", dist.HalfCauchy(1))
                    sigma_2 = numpyro.sample("sigma_2", dist.HalfCauchy(1))

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[feature0, subject]
            + jnp.maximum(
                0,
                -1
                + (H[feature0, subject] + 1)
                / jnp.power(
                    1
                    + (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1)
                    * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
                    1 / v[feature0, subject]
                )
            )
        )
        sigma = numpyro.deterministic(
            "sigma", (1 / sigma_1[feature0, subject]) + (1 / sigma_2[feature0, subject]) * mu
        )
        nu = numpyro.deterministic(
            "nu", 2 + (1 / nu_1[feature0, subject]) + (1 / nu_2[feature0, subject]) * (1 / mu)
        )

        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.StudentT(df=nu, loc=mu, scale=sigma).to_event(1),
                obs=response_obs
            )


model = StudentT(config=config)

In [4]:
# class Laplace(Baseline):
#     LINK = "laplace"

#     def __init__(self, config: Config):
#         super(Laplace, self).__init__(config=config)

#     def _model(self, subject, features, intensity, response_obs=None):
#         intensity = intensity.reshape(-1, 1)
#         intensity = np.tile(intensity, (1, self.n_response))

#         feature0 = features[0].reshape(-1,)

#         n_data = intensity.shape[0]
#         n_subject = np.unique(subject).shape[0]
#         n_feature0 = np.unique(feature0).shape[0]

#         nu_1_scale = numpyro.sample("nu_1_scale", dist.HalfCauchy(5))
#         nu_2_scale = numpyro.sample("nu_2_scale", dist.HalfCauchy(5))
#         sigma_1_scale = numpyro.sample("sigma_1_scale", dist.HalfCauchy(5))
#         sigma_2_scale = numpyro.sample("sigma_2_scale", dist.HalfCauchy(5))

#         with numpyro.plate(site.n_response, self.n_response, dim=-1):
#             with numpyro.plate(site.n_subject, n_subject, dim=-2):
#                 """ Hyper-priors """
#                 mu_a = numpyro.sample(
#                     site.mu_a,
#                     dist.TruncatedNormal(150, 100, low=0)
#                 )
#                 sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(20))

#                 sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.5))

#                 sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
#                 sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
#                 sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(5))

#                 with numpyro.plate("n_feature0", n_feature0, dim=-3):
#                     """ Priors """
#                     a = numpyro.sample(
#                         site.a,
#                         dist.TruncatedNormal(mu_a, sigma_a, low=0)
#                     )
#                     b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

#                     L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
#                     H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
#                     v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

#                     nu_1 = numpyro.sample("nu_1", dist.HalfCauchy(nu_1_scale))
#                     nu_2 = numpyro.sample("nu_2", dist.HalfCauchy(nu_2_scale))

#                     sigma_1 = numpyro.sample("sigma_1", dist.HalfCauchy(sigma_1_scale))
#                     sigma_2 = numpyro.sample("sigma_2", dist.HalfCauchy(sigma_2_scale))

#         """ Model """
#         mu = numpyro.deterministic(
#             site.mu,
#             L[feature0, subject]
#             + jnp.maximum(
#                 0,
#                 -1
#                 + (H[feature0, subject] + 1)
#                 / jnp.power(
#                     1
#                     + (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1)
#                     * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
#                     1 / v[feature0, subject]
#                 )
#             )
#         )
#         sigma = numpyro.deterministic(
#             "sigma", sigma_1[feature0, subject] + sigma_2[feature0, subject] * mu
#         )
#         nu = numpyro.deterministic(
#             "nu", 1 + nu_1[feature0, subject] + nu_2[feature0, subject] * mu
#         )

#         with numpyro.plate(site.data, n_data):
#             return numpyro.sample(
#                 site.obs,
#                 dist.AsymmetricLaplace(loc=mu - (1 - jnp.power(nu, 2)) / (nu * sigma), scale=sigma, asymmetry=nu).to_event(1),
#                 obs=response_obs
#             )


# model = Laplace(config=config)

In [5]:
# class RectifiedLogistic(Baseline):
#     LINK = RECTIFIED_LOGISTIC

#     def __init__(self, config: Config):
#         super(RectifiedLogistic, self).__init__(config=config)

#     def _model(self, subject, features, intensity, response_obs=None):
#         intensity = intensity.reshape(-1, 1)
#         intensity = np.tile(intensity, (1, self.n_response))

#         feature0 = features[0].reshape(-1,)

#         n_data = intensity.shape[0]
#         n_subject = np.unique(subject).shape[0]
#         n_feature0 = np.unique(feature0).shape[0]

#         with numpyro.plate(site.n_response, self.n_response, dim=-1):
#             with numpyro.plate(site.n_subject, n_subject, dim=-2):
#                 """ Hyper-priors """
#                 mu_a = numpyro.sample(
#                     site.mu_a,
#                     dist.TruncatedNormal(150, 50, low=0)
#                 )
#                 sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(50))

#                 sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))

#                 sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
#                 sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
#                 sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(10))

#                 with numpyro.plate("n_feature0", n_feature0, dim=-3):
#                     """ Priors """
#                     a = numpyro.sample(
#                         site.a,
#                         dist.TruncatedNormal(mu_a, sigma_a, low=0)
#                     )
#                     b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

#                     L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
#                     H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
#                     v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

#                     g_1 = numpyro.sample("g_1", dist.Exponential(1 / 100))
#                     g_2 = numpyro.sample("g_2", dist.Exponential(1 / 100))

#         """ Model """
#         mu = numpyro.deterministic(
#             site.mu,
#             L[feature0, subject]
#             + jnp.maximum(
#                 0,
#                 -1
#                 + (H[feature0, subject] + 1)
#                 / jnp.power(
#                     1
#                     + (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1)
#                     * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
#                     1 / v[feature0, subject]
#                 )
#             )
#         )
#         beta = numpyro.deterministic(
#             site.beta,
#             g_1[feature0, subject] + g_2[feature0, subject] * (1 / jnp.log(mu + 1))
#         )

#         with numpyro.plate(site.data, n_data):
#             return numpyro.sample(
#                 site.obs,
#                 dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
#                 obs=response_obs
#             )


# model = RectifiedLogistic(config=config)

In [6]:
df = pd.read_csv(model.csv_path)

sub = [("amap06", "-C8L")]
ind = df[model.combination_columns].apply(tuple, axis=1).isin(sub)
df = df[ind].copy()
df.reset_index(inplace=True, drop=True)

df, encoder_dict = model.load(df=df)


2023-07-21 13:35:03,329 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/fix-3
2023-07-21 13:35:03,329 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/fix-3
2023-07-21 13:35:03,330 - hbmep.dataset.core - INFO - Processing data ...
2023-07-21 13:35:03,331 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [7]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-07-21 13:35:03,382 - hbmep.model.baseline - INFO - Running inference with studentT ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-21 13:35:49,449 - hbmep.utils.utils - INFO - func:run_inference took: 46.07 sec


In [8]:
mcmc.print_summary(prob=.95)



                    mean       std    median      2.5%     97.5%     n_eff     r_hat
      H[0,0,0]      4.22      0.13      4.22      3.97      4.47  11309.34      1.00
      H[0,0,1]      1.47      0.02      1.46      1.44      1.50  14290.18      1.00
      L[0,0,0]      0.04      0.01      0.04      0.02      0.07  13071.88      1.00
      L[0,0,1]      0.02      0.00      0.02      0.02      0.03  21529.32      1.00
      a[0,0,0]    160.69      3.07    160.69    155.21    166.46  13555.34      1.00
      a[0,0,1]    166.93      4.97    164.54    162.95    180.73     21.20      1.22
      b[0,0,0]      0.02      0.00      0.02      0.02      0.03   7390.69      1.00
      b[0,0,1]      0.10      0.05      0.08      0.04      0.18   1630.67      1.01
   nu_1[0,0,0]     16.96    473.63      2.21      0.00     26.59   8625.86      1.00
   nu_1[0,0,1]     61.75   2202.60      5.60      0.35     77.50  14921.11      1.00
   nu_2[0,0,0]    171.77  18199.88      5.70      0.01    112.96

In [9]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-21 13:35:49,626 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-07-21 13:35:52,786 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/fix-3/recruitment_curves.pdf
2023-07-21 13:35:52,786 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 3.16 sec


In [10]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-21 13:35:52,828 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-07-21 13:35:56,313 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/fix-3/posterior_predictive_check.pdf
2023-07-21 13:35:56,318 - hbmep.utils.utils - INFO - func:render_predictive_check took: 3.49 sec
