In [None]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [None]:
import numpyro.distributions as dist
from hbmep.model import BaseModel


class MixtureModel(BaseModel):
    NAME = "mixture_model"

    def __init__(self, config: Config):
        super(MixtureModel, self).__init__(config=config)
        self.combination_columns = [self.subject] + self.features

    def fn(self, x, a, b, v, L, ell, H):
        return (
            L
            + jnp.where(
                jnp.less(x, a),
                0.,
                -ell + jnp.true_divide(
                    H + ell,
                    jnp.power(
                        1
                        + jnp.multiply(
                            -1
                            + jnp.power(
                                jnp.true_divide(H + ell, ell),
                                v
                            ),
                            jnp.exp(jnp.multiply(-b, x - a))
                        ),
                        jnp.true_divide(1, v)
                    )
                )
            )
        )

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response):
            """ Global Priors """
            b_scale_global_scale = numpyro.sample("b_scale_global_scale", dist.HalfNormal(100))
            v_scale_global_scale = numpyro.sample("v_scale_global_scale", dist.HalfNormal(100))

            L_scale_global_scale = numpyro.sample("L_scale_global_scale", dist.HalfNormal(1))
            ell_scale_global_scale = numpyro.sample("ell_scale_global_scale", dist.HalfNormal(100))
            H_scale_global_scale = numpyro.sample("H_scale_global_scale", dist.HalfNormal(5))

            g_1_scale_global_scale = numpyro.sample("g_1_scale_global_scale", dist.HalfNormal(100))
            g_2_scale_global_scale = numpyro.sample("g_2_scale_global_scale", dist.HalfNormal(100))

            with numpyro.plate("n_feature0", n_feature0):
                """ Hyper-priors """
                a_mean = numpyro.sample("a_mean", dist.TruncatedNormal(150., 100., low=0))
                a_scale = numpyro.sample("a_scale", dist.HalfNormal(100.))

                b_scale_raw = numpyro.sample("b_scale_raw", dist.HalfNormal(scale=1))
                b_scale = numpyro.deterministic("b_scale", jnp.multiply(b_scale_global_scale, b_scale_raw))

                v_scale_raw = numpyro.sample("v_scale_raw", dist.HalfNormal(scale=1))
                v_scale = numpyro.deterministic("v_scale", jnp.multiply(v_scale_global_scale, v_scale_raw))

                L_scale_raw = numpyro.sample("L_scale_raw", dist.HalfNormal(scale=1))
                L_scale = numpyro.deterministic("L_scale", jnp.multiply(L_scale_global_scale, L_scale_raw))

                ell_scale_raw = numpyro.sample("ell_scale_raw", dist.HalfNormal(scale=1))
                ell_scale = numpyro.deterministic("ell_scale", jnp.multiply(ell_scale_global_scale, ell_scale_raw))

                H_scale_raw = numpyro.sample("H_scale_raw", dist.HalfNormal(scale=1))
                H_scale = numpyro.deterministic("H_scale", jnp.multiply(H_scale_global_scale, H_scale_raw))

                g_1_scale_raw = numpyro.sample("g_1_scale_raw", dist.HalfNormal(scale=1))
                g_1_scale = numpyro.deterministic("g_1_scale", jnp.multiply(g_1_scale_global_scale, g_1_scale_raw))

                g_2_scale_raw = numpyro.sample("g_2_scale_raw", dist.HalfNormal(scale=1))
                g_2_scale = numpyro.deterministic("g_2_scale", jnp.multiply(g_2_scale_global_scale, g_2_scale_raw))

                with numpyro.plate(site.n_subject, n_subject):
                    """ Priors """
                    a = numpyro.sample(
                        site.a, dist.TruncatedNormal(a_mean, a_scale, low=0)
                    )

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, jnp.multiply(b_scale, b_raw))

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, jnp.multiply(v_scale, v_raw))

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, jnp.multiply(L_scale, L_raw))

                    ell_raw = numpyro.sample("ell_raw", dist.HalfNormal(scale=1))
                    ell = numpyro.deterministic("ell", jnp.multiply(ell_scale, ell_raw))

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, jnp.multiply(H_scale, H_raw))

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, jnp.multiply(g_1_scale, g_1_raw))

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, jnp.multiply(g_2_scale, g_2_raw))

        if response_obs is not None:
            """ Outlier Distribution """
            outlier_prob = numpyro.sample("outlier_prob", dist.Uniform(0., .01))
            outlier_scale = numpyro.sample("outlier_scale", dist.HalfNormal(10))

        with numpyro.plate(site.n_response, self.n_response):
            with numpyro.plate(site.n_data, n_data):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    self.fn(
                        x=intensity,
                        a=a[subject, feature0],
                        b=b[subject, feature0],
                        v=v[subject, feature0],
                        L=L[subject, feature0],
                        ell=ell[subject, feature0],
                        H=H[subject, feature0]
                    )
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[subject, feature0] + jnp.true_divide(g_2[subject, feature0], mu)
                )

                if response_obs is not None:
                    q = numpyro.deterministic("q", outlier_prob * jnp.ones((n_data, self.n_response)))
                    bg_scale = numpyro.deterministic("bg_scale", outlier_scale * jnp.ones((n_data, self.n_response)))

                    mixing_distribution = dist.Categorical(
                        probs=jnp.stack([1 - q, q], axis=-1)
                    )
                    component_distributions=[
                        dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                        dist.HalfNormal(scale=bg_scale)
                    ]

                    """ Mixture """
                    Mixture = dist.MixtureGeneral(
                        mixing_distribution=mixing_distribution,
                        component_distributions=component_distributions
                    )

                    """ Observation """
                    numpyro.sample(
                        site.obs,
                        Mixture,
                        obs=response_obs
                    )

                if response_obs is None:
                    """ Observation """
                    numpyro.sample(
                        site.obs,
                        dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                        obs=response_obs
                    )


In [None]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/rats/J_RCML_000/config.toml"
config = Config(toml_path=toml_path)
config.BUILD_DIR = os.path.join(config.BUILD_DIR, "model-comparison", "mixture-model")
config.MCMC_PARAMS["num_warmup"] = 10000
config.MCMC_PARAMS["num_samples"] = 1000

model = MixtureModel(config=config)


2023-11-30 11:48:26,207 - hbmep.model.baseline - INFO - Initialized mixture_model


In [None]:
src = "/home/vishu/data/hbmep-processed/J_RCML_000/data.csv"
df = pd.read_csv(src)

# mat = "/home/vishu/data/hbmep-processed/J_RCML_000/mat.npy"
# mat = np.load(mat)

# subset = ['C5L-C6L', 'C6L-C7L', 'C7L-C8L']
# subset = ['-C5L', '-C6L', '-C7L', '-C8L']

# ind = df[model.features[0]].isin(subset)
# df = df[ind].reset_index(drop=True).copy()
# mat = mat[ind, :]

# model.plot(df=df, mep_matrix=mat, destination_path="/home/vishu/temp.pdf")

# subset = ["amap01", "amap02", "amap03", "amap04"]
subset = ["amap01"]
ind = df[model.subject].isin(subset)
df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-11-30 11:48:26,659 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/model-comparison/mixture-model
2023-11-30 11:48:26,660 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/model-comparison/mixture-model
2023-11-30 11:48:26,663 - hbmep.dataset.core - INFO - Processing data ...
2023-11-30 11:48:26,665 - hbmep.utils.utils - INFO - func:load took: 0.01 sec


In [None]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-11-30 11:48:26,807 - hbmep.model.baseline - INFO - Running inference with mixture_model ...


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

2023-11-30 12:13:58,805 - hbmep.utils.utils - INFO - func:run_inference took: 25 min and 32.00 sec


In [6]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-11-30 12:24:30,514 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec
2023-11-30 12:24:46,014 - hbmep.utils.utils - INFO - func:predict took: 15.50 sec
2023-11-30 12:24:46,015 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-11-30 12:25:11,620 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/model-comparison/mixture-model/recruitment_curves.pdf
2023-11-30 12:25:11,621 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 25.61 sec
2023-11-30 12:25:11,621 - hbmep.model.baseline - INFO - Rendering posterior predictive checks ...
2023-11-30 12:25:37,986 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/model-comparison/mixture-model/posterior_predictive_check.pdf
2023-11-30 12:25:37,987 - hbmep.utils.utils - INFO - func:predictive_checks_renderer took: 26.37 sec
2023-11-30 12:25:37,987 - hbmep.utils.utils - INFO - func:render_pr

In [None]:
mcmc.print_summary(prob=.95)



                               mean       std    median      2.5%     97.5%     n_eff     r_hat
             H_raw[0,0,0]      1.15      0.57      1.04      0.33      2.39     27.18      1.15
             H_raw[0,0,1]      1.17      0.62      0.97      0.32      2.47     11.95      1.43
             H_raw[0,0,2]      1.26      0.63      1.18      0.33      2.29      4.01      1.70
             H_raw[0,0,3]      1.10      0.39      0.98      0.53      2.00     23.18      1.14
             H_raw[0,0,4]      0.81      0.46      0.68      0.18      1.74     28.03      1.16
             H_raw[0,0,5]      0.96      0.44      0.86      0.32      1.87     18.65      1.25
             H_raw[0,1,0]      0.90      0.47      0.79      0.21      1.82      7.26      1.29
             H_raw[0,1,1]      0.72      0.48      0.67      0.11      1.62      4.39      1.53
             H_raw[0,1,2]      1.34      0.54      1.29      0.29      2.31     32.67      1.16
             H_raw[0,1,3]      1.12    

In [None]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")


2023-10-25 15:17:33,019 - __main__ - INFO - Evaluating model ...
2023-10-25 15:17:39,651 - __main__ - INFO - ELPD LOO (Log): 6366.52
See http://arxiv.org/abs/1507.04544 for details
2023-10-25 15:17:40,070 - __main__ - INFO - ELPD WAIC (Log): 6433.86


In [None]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [None]:
dest = os.path.join(model.build_dir, "numpyro_data.nc")
az.to_netcdf(numpyro_data, dest)


'/home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/rectified_logistic/numpyro_data.nc'

In [None]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)


In [None]:
len(posterior_samples.keys())

46