In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [3]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class MixtureModel(Baseline):
    LINK = "mixture_model"

    def __init__(self, config: Config):
        super(MixtureModel, self).__init__(config=config)
        # self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
            global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))
            global_sigma_l = numpyro.sample("global_sigma_l", dist.HalfNormal(100))
            global_sigma_H = numpyro.sample("global_sigma_H", dist.HalfNormal(5))

            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=10))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=10 / 150))

                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                sigma_v = numpyro.deterministic(site.sigma_v, global_sigma_v * sigma_v_raw)

                sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                sigma_L = numpyro.deterministic(site.sigma_L, global_sigma_L * sigma_L_raw)

                sigma_l_raw = numpyro.sample("sigma_l_raw", dist.HalfNormal(scale=1))
                sigma_l = numpyro.deterministic("sigma_l", global_sigma_l * sigma_l_raw)

                sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                sigma_H = numpyro.deterministic(site.sigma_H, global_sigma_H * sigma_H_raw)

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                    a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, sigma_b * b_raw)

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, sigma_v * v_raw)

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, sigma_L * L_raw)

                    l_raw = numpyro.sample("l_raw", dist.HalfNormal(scale=1))
                    l = numpyro.deterministic("l", sigma_l * l_raw)

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, sigma_H * H_raw)

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                    p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    p = numpyro.deterministic("p", sigma_p * p_raw)

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    L[feature0, subject]
                    + jnp.where(
                        intensity <= a[feature0, subject],
                        0,
                        -l[feature0, subject]
                        + (
                            (H[feature0, subject] + l[feature0, subject])
                            / jnp.power(
                                1
                                + (
                                    (
                                        -1
                                        + jnp.power(
                                            (H[feature0, subject] + l[feature0, subject]) / l[feature0, subject],
                                            v[feature0, subject]
                                        )
                                    )
                                    * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject]))
                                ),
                                1 / v[feature0, subject]
                            )
                        )
                    )
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[feature0, subject] + g_2[feature0, subject] * jnp.power(1 / (mu + 1), p[feature0, subject])
                )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    dist.Gamma(concentration=mu * beta, rate=beta),
                    obs=response_obs
                )


In [4]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/rats/J_RCML_000/model-comparison.toml"

config = Config(toml_path=toml_path)
config.BUILD_DIR = os.path.join(config.BUILD_DIR, "mixture_model")

model = RectifiedLogistic(config=config)


2023-10-25 12:20:01,636 - hbmep.config - INFO - Verifying configuration ...
2023-10-25 12:20:01,636 - hbmep.config - INFO - Success!
2023-10-25 12:20:01,653 - hbmep.model.baseline - INFO - Initialized model with rectified_logistic link


In [5]:
src = "/home/vishu/data/hbmep-processed/J_RCML_000/data.csv"
df = pd.read_csv(src)

# model.mep_matrix_path = None
# subset = ["amap01", "amap02", "amap03", "amap04"]
# subset = ["amap01", "ampa0"]

# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-10-25 12:20:01,772 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic
2023-10-25 12:20:01,772 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic
2023-10-25 12:20:01,773 - hbmep.dataset.core - INFO - Processing data ...
2023-10-25 12:20:01,775 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [6]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-25 12:20:01,827 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

2023-10-25 15:15:49,124 - hbmep.utils.utils - INFO - func:run_inference took: 2 hr and 55 min


In [7]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-25 15:15:49,219 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec


2023-10-25 15:16:12,134 - hbmep.utils.utils - INFO - func:predict took: 22.91 sec
2023-10-25 15:16:12,430 - hbmep.model.baseline - INFO - Rendering ...
2023-10-25 15:16:48,301 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/recruitment_curves.pdf
2023-10-25 15:16:48,306 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 36.17 sec
2023-10-25 15:16:48,802 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-25 15:17:29,483 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/posterior_predictive_check.pdf
2023-10-25 15:17:29,489 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 41.18 sec
2023-10-25 15:17:29,490 - hbmep.utils.utils - INFO - func:render_predictive_check took: 41.18 sec


In [10]:
mcmc.print_summary(prob=.95)
    


                         mean       std    median      2.5%     97.5%     n_eff     r_hat
       H_raw[0,0,0]      1.72      0.53      1.69      0.82      2.82    231.95      1.01
       H_raw[0,1,0]      0.95      0.30      0.92      0.42      1.51    142.62      1.02
       H_raw[0,2,0]      1.26      0.35      1.26      0.59      1.90     42.56      1.05
       H_raw[0,3,0]      0.86      0.32      0.80      0.31      1.50    123.11      1.04
       H_raw[0,4,0]      1.15      0.22      1.15      0.74      1.58     27.84      1.05
       H_raw[0,5,0]      1.73      0.51      1.67      0.83      2.80    206.82      1.03
       H_raw[0,6,0]      1.14      0.39      1.07      0.48      1.88    124.11      1.04
       H_raw[0,7,0]      1.29      0.45      1.21      0.56      2.23    176.53      1.03
       H_raw[1,0,0]      1.34      0.52      1.27      0.50      2.37    228.35      1.01
       H_raw[1,1,0]      1.20      0.58      1.10      0.27      2.34    304.66      1.01
       H_

In [9]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")


2023-10-25 15:17:33,019 - __main__ - INFO - Evaluating model ...
2023-10-25 15:17:39,651 - __main__ - INFO - ELPD LOO (Log): 6366.52
See http://arxiv.org/abs/1507.04544 for details
2023-10-25 15:17:40,070 - __main__ - INFO - ELPD WAIC (Log): 6433.86


In [11]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [12]:
dest = os.path.join(model.build_dir, "numpyro_data.nc")
az.to_netcdf(numpyro_data, dest)


'/home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/numpyro_data.nc'

In [None]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)


In [16]:
len(posterior_samples.keys())

46