In [1]:
%reload_ext autoreload
%autoreload 2

import os
import logging
import multiprocessing
from pathlib import Path
from tqdm import tqdm

import arviz as az
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns

import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import scipy.stats as stats
import numpyro
from numpyro.diagnostics import hpdi

from hbmep.config import Config
# from hbmep_paper.simulator import HierarchicalBayesianModel
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)

In [None]:
import numpyro.distributions as dist
from hbmep.model import Baseline
from hbmep_paper.utils.constants import HBM


class HierarchicalBayesianModel(Baseline):
    LINK = HBM

    def __init__(self, config: Config):
        super(HierarchicalBayesianModel, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        global_sigma_a = numpyro.sample("global_sigma_a", dist.HalfNormal(50))

        global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
        global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))

        global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))
        global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

        global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(10))
        global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(10))

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(20, 50, low=0)
                )

                sigma_a_raw = numpyro.sample("sigma_a_raw", dist.HalfNormal(scale=1))
                sigma_a = numpyro.deterministic(site.sigma_a, global_sigma_a * sigma_a_raw)

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                sigma_v = numpyro.deterministic(site.sigma_v, global_sigma_v * sigma_v_raw)

                sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                sigma_L = numpyro.deterministic(site.sigma_L, global_sigma_L * sigma_L_raw)

                sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                sigma_H = numpyro.deterministic(site.sigma_H, global_sigma_H * sigma_H_raw)

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.Exponential(rate=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", (1 / global_sigma_g_1) * sigma_g_1_raw)

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.Exponential(rate=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", (1 / global_sigma_g_2) * sigma_g_2_raw)

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, sigma_b * b_raw)

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, sigma_v * v_raw)

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, sigma_L * L_raw)

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, sigma_H * H_raw)

                    g_1_raw = numpyro.sample("g_1_raw", dist.Exponential(rate=1))
                    g_1 = numpyro.deterministic(site.g_1, (1 / sigma_g_1) * g_1_raw)

                    g_2_raw = numpyro.sample("g_2_raw", dist.Exponential(rate=1))
                    g_2 = numpyro.deterministic(site.g_2, (1 / sigma_g_2) * g_2_raw)

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[subject, feature0]
            + jnp.maximum(
                0,
                -1
                + (H[subject, feature0] + 1)
                / jnp.power(
                    1
                    + (jnp.power(1 + H[subject, feature0], v[subject, feature0]) - 1)
                    * jnp.exp(-b[subject, feature0] * (intensity - a[subject, feature0])),
                    1 / v[subject, feature0]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[subject, feature0] + g_2[subject, feature0] * (1 / mu)
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )

In [3]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/paper/tms/mixed-effects.toml")

config = Config(toml_path=toml_path)

model = HierarchicalBayesianModel(config=config)


2023-10-02 12:02:50,936 - hbmep.config - INFO - Verifying configuration ...
2023-10-02 12:02:50,936 - hbmep.config - INFO - Success!
2023-10-02 12:02:50,950 - hbmep.model.baseline - INFO - Initialized model with hierarchical_bayesian link


In [4]:
src = "/home/vishu/data/hbmep-processed/human/tms/data.csv"
df = pd.read_csv(src)

# subset = ["SCA01"]
# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

df[model.features[0]] = 0

df[model.response] = df[model.response] * 1000

df, encoder_dict = model.load(df=df)


2023-10-02 12:02:51,066 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/human/tms/fit/mixed-effects
2023-10-02 12:02:51,067 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/human/tms/fit/mixed-effects
2023-10-02 12:02:51,067 - hbmep.dataset.core - INFO - Processing data ...
2023-10-02 12:02:51,069 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-02 12:02:51,341 - hbmep.model.baseline - INFO - Running inference with hierarchical_bayesian ...


  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

2023-10-02 12:03:22,148 - hbmep.utils.utils - INFO - func:run_inference took: 30.81 sec


In [6]:
mcmc.print_summary(prob=.95)


                                 mean       std    median      2.5%     97.5%     n_eff     r_hat
          H_baseline[0,0,0]   1126.20     93.40   1121.66    945.48   1306.01    627.00      1.00
          H_baseline[1,0,0]   1358.10   1228.92   1052.25      0.00   3678.53    665.77      1.01
          H_baseline[2,0,0]   1208.59     87.39   1205.76   1043.81   1387.03    988.68      1.00
          H_baseline[3,0,0]    570.45    100.82    556.69    385.41    772.62    678.44      1.00
          H_baseline[4,0,0]   3085.92    314.38   3060.13   2541.06   3745.68    423.65      1.01
          H_baseline[5,0,0]   4564.03    559.74   4526.56   3478.27   5689.11    882.79      1.01
          H_baseline[6,0,0]    435.20    113.57    411.14    269.05    654.40    747.46      1.00
          H_baseline[7,0,0]    424.92    207.69    378.03    213.65    771.11    209.46      1.02
          L_baseline[0,0,0]      6.84      0.48      6.84      5.80      7.73    887.11      1.00
          L_baselin

In [7]:
prediction_df = model.make_prediction_dataset(df=df)

posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)

model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-02 12:03:27,873 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.00 sec


2023-10-02 12:03:29,464 - hbmep.utils.utils - INFO - func:predict took: 1.59 sec
2023-10-02 12:03:29,470 - hbmep.model.baseline - INFO - Rendering ...
2023-10-02 12:03:30,897 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/human/tms/fit/mixed-effects/recruitment_curves.pdf
2023-10-02 12:03:30,898 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 1.43 sec
2023-10-02 12:03:30,906 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-02 12:03:32,432 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/human/tms/fit/mixed-effects/posterior_predictive_check.pdf
2023-10-02 12:03:32,433 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 1.53 sec
2023-10-02 12:03:32,433 - hbmep.utils.utils - INFO - func:render_predictive_check took: 1.54 sec


In [8]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)

In [9]:
import pickle

with open(dest, "rb") as g:
    model_, mcmc_, posterior_samples_ = pickle.load(g)
