In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class MixedEffects(Baseline):
    LINK = "mixed_effects"

    def __init__(self, config: Config):
        super(MixedEffects, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_baseline = 1
        n_feature0 = 1
        # n_delta = 1

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b_baseline = numpyro.sample("global_sigma_b_baseline", dist.HalfNormal(100))
            global_sigma_v_baseline = numpyro.sample("global_sigma_v_baseline", dist.HalfNormal(100))

            global_sigma_L_baseline = numpyro.sample("global_sigma_L_baseline", dist.HalfNormal(1))
            global_sigma_l_baseline = numpyro.sample("global_sigma_l_baseline", dist.HalfNormal(100))
            global_sigma_H_baseline = numpyro.sample("global_sigma_H_baseline", dist.HalfNormal(5))

            global_sigma_g_1_baseline = numpyro.sample("global_sigma_g_1_baseline", dist.HalfNormal(100))
            global_sigma_g_2_baseline = numpyro.sample("global_sigma_g_2_baseline", dist.HalfNormal(100))

            global_sigma_p_baseline = numpyro.sample("global_sigma_p_baseline", dist.HalfNormal(100))

            with numpyro.plate("n_baseline", n_baseline, dim=-2):
                """ Hyper-priors """
                mu_a_baseline = numpyro.sample("mu_a_baseline", dist.HalfNormal(scale=5))
                sigma_a_baseline = numpyro.sample("sigma_a_baseline", dist.HalfNormal(scale=1 / 10))

                sigma_b_raw_baseline = numpyro.sample("sigma_b_raw_baseline", dist.HalfNormal(scale=1))
                sigma_b_baseline = numpyro.deterministic("sigma_b_baseline", global_sigma_b_baseline * sigma_b_raw_baseline)

                sigma_v_raw_baseline = numpyro.sample("sigma_v_raw_baseline", dist.HalfNormal(scale=1))
                sigma_v_baseline = numpyro.deterministic("sigma_v_baseline", global_sigma_v_baseline * sigma_v_raw_baseline)

                sigma_L_raw_baseline = numpyro.sample("sigma_L_raw_baseline", dist.HalfNormal(scale=1))
                sigma_L_baseline = numpyro.deterministic("sigma_L_baseline", global_sigma_L_baseline * sigma_L_raw_baseline)

                sigma_l_raw_baseline = numpyro.sample("sigma_l_raw_baseline", dist.HalfNormal(scale=1))
                sigma_l_baseline = numpyro.deterministic("sigma_l_baseline", global_sigma_l_baseline * sigma_l_raw_baseline)

                sigma_H_raw_baseline = numpyro.sample("sigma_H_raw_baseline", dist.HalfNormal(scale=1))
                sigma_H_baseline = numpyro.deterministic("sigma_H_baseline", global_sigma_H_baseline * sigma_H_raw_baseline)

                sigma_g_1_raw_baseline = numpyro.sample("sigma_g_1_raw_baseline", dist.HalfNormal(scale=1))
                sigma_g_1_baseline = numpyro.deterministic("sigma_g_1_baseline", global_sigma_g_1_baseline * sigma_g_1_raw_baseline)

                sigma_g_2_raw_baseline = numpyro.sample("sigma_g_2_raw_baseline", dist.HalfNormal(scale=1))
                sigma_g_2_baseline = numpyro.deterministic("sigma_g_2_baseline", global_sigma_g_2_baseline * sigma_g_2_raw_baseline)

                sigma_p_raw_baseline = numpyro.sample("sigma_p_raw_baseline", dist.HalfNormal(scale=1))
                sigma_p_baseline = numpyro.deterministic("sigma_p_baseline", global_sigma_p_baseline * sigma_p_raw_baseline)

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a_raw_baseline = numpyro.sample("a_raw_baseline", dist.Gamma(concentration=mu_a_baseline, rate=1))
                    a_baseline = numpyro.deterministic("a_baseline", (1 / sigma_a_baseline) * a_raw_baseline)

                    b_raw_baseline = numpyro.sample("b_raw_baseline", dist.HalfNormal(scale=1))
                    b_baseline = numpyro.deterministic("b_baseline", sigma_b_baseline * b_raw_baseline)

                    v_raw_baseline = numpyro.sample("v_raw_baseline", dist.HalfNormal(scale=1))
                    v_baseline = numpyro.deterministic("v_baseline", sigma_v_baseline * v_raw_baseline)

                    L_raw_baseline = numpyro.sample("L_raw_baseline", dist.HalfNormal(scale=1))
                    L_baseline = numpyro.deterministic("L_baseline", sigma_L_baseline * L_raw_baseline)

                    l_raw_baseline = numpyro.sample("l_raw_baseline", dist.HalfNormal(scale=1))
                    l_baseline = numpyro.deterministic("l_baseline", sigma_l_baseline * l_raw_baseline)

                    H_raw_baseline = numpyro.sample("H_raw_baseline", dist.HalfNormal(scale=1))
                    H_baseline = numpyro.deterministic("H_baseline", sigma_H_baseline * H_raw_baseline)

                    g_1_raw_baseline = numpyro.sample("g_1_raw_baseline", dist.HalfCauchy(scale=1))
                    g_1_baseline = numpyro.deterministic("g_1_baseline", sigma_g_1_baseline * g_1_raw_baseline)

                    g_2_raw_baseline = numpyro.sample("g_2_raw_baseline", dist.HalfCauchy(scale=1))
                    g_2_baseline = numpyro.deterministic("g_2_baseline", sigma_g_2_baseline * g_2_raw_baseline)

                    p_raw_baseline = numpyro.sample("p_raw_baseline", dist.HalfNormal(scale=1))
                    p_baseline = numpyro.deterministic("p_baseline", sigma_p_baseline * p_raw_baseline)

        # """ Delta """
        # with numpyro.plate(site.n_response, self.n_response, dim=-1):
        #     with numpyro.plate("n_delta", n_delta, dim=-2):
        #         mu_a_delta = numpyro.sample("mu_a_delta", dist.Normal(0, 100))
        #         sigma_a_delta = numpyro.sample("sigma_a_delta", dist.HalfNormal(100))

        #         with numpyro.plate(site.n_subject, n_subject, dim=-3):
        #             a_delta = numpyro.sample("a_delta", dist.Normal(mu_a_delta, sigma_a_delta))

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Deterministic """
                    # a = numpyro.deterministic(
                    #     site.a,
                    #     jnp.concatenate([a_baseline, a_baseline], axis=1)
                    # )

                    # b = numpyro.deterministic(
                    #     site.b,
                    #     jnp.concatenate([b_baseline, b_baseline], axis=1)
                    # )
                    # v = numpyro.deterministic(
                    #     site.v,
                    #     jnp.concatenate([v_baseline, v_baseline], axis=1)
                    # )

                    # L = numpyro.deterministic(
                    #     site.L,
                    #     jnp.concatenate([L_baseline, L_baseline], axis=1)
                    # )
                    # l = numpyro.deterministic(
                    #     "l",
                    #     jnp.concatenate([l_baseline, l_baseline], axis=1)
                    # )

                    # H = numpyro.deterministic(
                    #     site.H,
                    #     jnp.concatenate([H_baseline, H_baseline], axis=1)
                    # )

                    # g_1 = numpyro.deterministic(
                    #     site.g_1,
                    #     jnp.concatenate([g_1_baseline, g_1_baseline], axis=1)
                    # )
                    # g_2 = numpyro.deterministic(
                    #     site.g_2,
                    #     jnp.concatenate([g_2_baseline, g_2_baseline], axis=1)
                    # )

                    # p = numpyro.deterministic(
                    #     "p",
                    #     jnp.concatenate([p_baseline, p_baseline], axis=1)
                    # )

                    a = numpyro.deterministic(site.a, a_baseline)

                    b = numpyro.deterministic(site.b, b_baseline)
                    v = numpyro.deterministic(site.v, v_baseline)

                    L = numpyro.deterministic(site.L, L_baseline)
                    l = numpyro.deterministic("l", l_baseline)

                    H = numpyro.deterministic(site.H, H_baseline)

                    g_1 = numpyro.deterministic(site.g_1, g_1_baseline)
                    g_2 = numpyro.deterministic(site.g_2, g_2_baseline)

                    p = numpyro.deterministic("p", p_baseline)

        # """ Penalty """
        # penalty = (jnp.fabs(a_baseline + a_delta) - (a_baseline + a_delta))
        # numpyro.factor("a_penalty", -penalty)

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    L[subject, feature0]
                    + jnp.where(
                        intensity <= a[subject, feature0],
                        0,
                        -l[subject, feature0]
                        + (
                            (H[subject, feature0] + l[subject, feature0])
                            / jnp.power(
                                1
                                + (
                                    (
                                        -1
                                        + jnp.power(
                                            (H[subject, feature0] + l[subject, feature0]) / l[subject, feature0],
                                            v[subject, feature0]
                                        )
                                    )
                                    * jnp.exp(-b[subject, feature0] * (intensity - a[subject, feature0]))
                                ),
                                1 / v[subject, feature0]
                            )
                        )
                    )
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[subject, feature0] + g_2[subject, feature0] * jnp.power(1 / (mu + 1), p[subject, feature0])
                )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    dist.Gamma(concentration=mu * beta, rate=beta),
                    obs=response_obs
                )


In [3]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/tms/mixed-effects/mixed_effects.toml"
config = Config(toml_path=toml_path)

model = MixedEffects(config=config)


2023-10-27 16:18:33,330 - hbmep.config - INFO - Verifying configuration ...
2023-10-27 16:18:33,330 - hbmep.config - INFO - Success!
2023-10-27 16:18:33,346 - hbmep.model.baseline - INFO - Initialized model with mixed_effects link


In [4]:
src = "/home/vishu/data/hbmep-processed/human/tms/data_pkpk_auc.csv"
df = pd.read_csv(src)

# subset = ["SCA01"]
# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

df[model.features[0]] = 0
df, encoder_dict = model.load(df=df)


2023-10-27 16:18:33,431 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/out/hbmep-paper/paper/tms/mixed-effects
2023-10-27 16:18:33,432 - hbmep.dataset.core - INFO - Copied config to /home/vishu/out/hbmep-paper/paper/tms/mixed-effects
2023-10-27 16:18:33,433 - hbmep.dataset.core - INFO - Processing data ...
2023-10-27 16:18:33,434 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-27 16:18:33,530 - hbmep.model.baseline - INFO - Running inference with mixed_effects ...


  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

2023-10-27 16:27:30,211 - hbmep.utils.utils - INFO - func:run_inference took: 8 min and 56.68 sec


In [6]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-27 16:27:30,349 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec
2023-10-27 16:27:32,087 - hbmep.utils.utils - INFO - func:predict took: 1.74 sec
2023-10-27 16:27:32,093 - hbmep.model.baseline - INFO - Rendering ...
2023-10-27 16:27:33,599 - hbmep.model.baseline - INFO - Saved to /home/vishu/out/hbmep-paper/paper/tms/mixed-effects/recruitment_curves.pdf
2023-10-27 16:27:33,600 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 1.51 sec
2023-10-27 16:27:33,609 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-27 16:27:35,228 - hbmep.model.baseline - INFO - Saved to /home/vishu/out/hbmep-paper/paper/tms/mixed-effects/posterior_predictive_check.pdf
2023-10-27 16:27:35,229 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 1.63 sec
2023-10-27 16:27:35,229 - hbmep.utils.utils - INFO - func:render_predictive_check took: 1.63 sec


In [7]:
mcmc.print_summary(prob=.95)


                                  mean       std    median      2.5%     97.5%     n_eff     r_hat
       H_raw_baseline[0,0,0]      0.31      0.12      0.29      0.11      0.58     45.55      1.08
       H_raw_baseline[1,0,0]      0.85      0.60      0.73      0.00      2.04    220.02      1.01
       H_raw_baseline[2,0,0]      0.34      0.14      0.32      0.12      0.63     45.96      1.07
       H_raw_baseline[3,0,0]      0.86      0.59      0.73      0.07      2.07    107.25      1.04
       H_raw_baseline[4,0,0]      1.12      0.49      1.03      0.32      2.15     57.44      1.04
       H_raw_baseline[5,0,0]      1.39      0.52      1.31      0.56      2.39     52.27      1.07
       H_raw_baseline[6,0,0]      0.75      0.53      0.63      0.05      1.76    119.64      1.02
       H_raw_baseline[7,0,0]      0.90      0.62      0.78      0.04      2.14    116.79      1.03
       L_raw_baseline[0,0,0]      0.48      0.17      0.46      0.19      0.83     14.45      1.21
       L_

In [8]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")

2023-10-27 16:27:36,416 - __main__ - INFO - Evaluating model ...
2023-10-27 16:27:37,082 - __main__ - INFO - ELPD LOO (Log): 672.76
See http://arxiv.org/abs/1507.04544 for details
2023-10-27 16:27:37,111 - __main__ - INFO - ELPD WAIC (Log): 674.68


In [11]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)
dest

'/home/vishu/out/hbmep-paper/paper/tms/mixed-effects/inference.pkl'

In [12]:
dest = os.path.join(model.build_dir, "numpyro_data.nc")
az.to_netcdf(numpyro_data, dest)


'/home/vishu/out/hbmep-paper/paper/tms/mixed-effects/numpyro_data.nc'