In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)

In [2]:
import numpyro.distributions as dist
from hbmep.model import BaseModel


class MixtureModel(BaseModel):
    LINK = "mixture_model"

    def __init__(self, config: Config):
        super(MixtureModel, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]

    def fn(self, x, a, b, v, L, l, H):
        return (
            L
            + jnp.where(
                jnp.less(x, a),
                0.,
                -l + jnp.true_divide(
                    H + l,
                    jnp.power(
                        1
                        + jnp.multiply(
                            -1
                            + jnp.power(
                                jnp.true_divide(H + l, l),
                                v
                            ),
                            jnp.exp(jnp.multiply(-b, x - a))
                        ),
                        jnp.true_divide(1, v)
                    )
                )
            )
        )

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            """ Global Priors """
            a_mean_global_scale = numpyro.sample("a_mean_global_scale", dist.HalfNormal(100))
            a_shape_global_scale = numpyro.sample("a_shape_global_scale", dist.HalfNormal(100))

            b_scale_global_scale = numpyro.sample("b_scale_global_scale", dist.HalfNormal(100))
            v_scale_global_scale = numpyro.sample("v_scale_global_scale", dist.HalfNormal(100))

            L_scale_global_scale = numpyro.sample("L_scale_global_scale", dist.HalfNormal(10))
            l_scale_global_scale = numpyro.sample("l_scale_global_scale", dist.HalfNormal(100))
            H_scale_global_scale = numpyro.sample("H_scale_global_scale", dist.HalfNormal(10))

            g_1_scale_global_scale = numpyro.sample("g_1_scale_global_scale", dist.HalfNormal(100))
            g_2_scale_global_scale = numpyro.sample("g_2_scale_global_scale", dist.HalfNormal(100))

            # p_scale_global_scale = numpyro.sample("p_scale_global_scale", dist.HalfNormal(100))

            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                # a_mean = numpyro.sample("a_mean", dist.HalfNormal(scale=100))
                # a_scale = numpyro.sample("a_scale", dist.HalfNormal(scale=100))

                # a_mean = numpyro.sample("a_mean", dist.HalfNormal(scale=50))
                # a_shape = numpyro.sample("a_shape", dist.HalfNormal(scale=100))

                a_mean_raw = numpyro.sample("a_mean_raw", dist.HalfNormal(scale=1))
                a_mean = numpyro.deterministic("a_mean", jnp.multiply(a_mean_global_scale, a_mean_raw))

                a_shape_raw = numpyro.sample("a_shape_raw", dist.HalfNormal(scale=1))
                a_shape = numpyro.deterministic("a_shape", jnp.multiply(a_shape_global_scale, a_shape_raw))

                b_scale_raw = numpyro.sample("b_scale_raw", dist.HalfNormal(scale=1))
                b_scale = numpyro.deterministic("b_scale", jnp.multiply(b_scale_global_scale, b_scale_raw))

                v_scale_raw = numpyro.sample("v_scale_raw", dist.HalfNormal(scale=1))
                v_scale = numpyro.deterministic("v_scale", jnp.multiply(v_scale_global_scale, v_scale_raw))

                L_scale_raw = numpyro.sample("L_scale_raw", dist.HalfNormal(scale=1))
                L_scale = numpyro.deterministic("L_scale", jnp.multiply(L_scale_global_scale, L_scale_raw))

                l_scale_raw = numpyro.sample("l_scale_raw", dist.HalfNormal(scale=1))
                l_scale = numpyro.deterministic("sigma_l", jnp.multiply(l_scale_global_scale, l_scale_raw))

                H_scale_raw = numpyro.sample("H_scale_raw", dist.HalfNormal(scale=1))
                H_scale = numpyro.deterministic("H_scale", jnp.multiply(H_scale_global_scale, H_scale_raw))

                g_1_scale_raw = numpyro.sample("g_1_scale_raw", dist.HalfNormal(scale=1))
                g_1_scale = numpyro.deterministic("g_1_scale", jnp.multiply(g_1_scale_global_scale, g_1_scale_raw))

                g_2_scale_raw = numpyro.sample("g_2_scale_raw", dist.HalfNormal(scale=1))
                g_2_scale = numpyro.deterministic("g_2_scale", jnp.multiply(g_2_scale_global_scale, g_2_scale_raw))

                # p_scale_raw = numpyro.sample("p_scale_raw", dist.HalfNormal(scale=1))
                # p_scale = numpyro.deterministic("p_scale", p_scale_global_scale * p_scale_raw)

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=a_shape, rate=1))
                    a = numpyro.deterministic(site.a, jnp.true_divide(jnp.multiply(a_raw, a_mean), a_shape))
                    # a = numpyro.sample(site.a, dist.TruncatedNormal(loc=a_mean, scale=a_scale, low=0))

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, jnp.multiply(b_scale, b_raw))

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, jnp.multiply(v_scale, v_raw))

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, jnp.multiply(L_scale, L_raw))

                    l_raw = numpyro.sample("l_raw", dist.HalfNormal(scale=1))
                    l = numpyro.deterministic("l", jnp.multiply(l_scale, l_raw))

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, jnp.multiply(H_scale, H_raw))

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, jnp.multiply(g_1_scale, g_1_raw))

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, jnp.multiply(g_2_scale, g_2_raw))

                    # p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    # p = numpyro.deterministic("p", p_scale * p_raw)

        """ Outlier Distribution """
        outlier_dist_shape = numpyro.sample("outlier_dist_shape", dist.HalfNormal(5))
        outlier_dist_rate = numpyro.sample("outlier_dist_rate", dist.HalfNormal(1))

        """ Mixture """
        if response_obs is not None:
            outlier_prob = numpyro.sample("outlier_prob", dist.Uniform(0., .2))
        else: # Turn off mixture when predicting
            outlier_prob = numpyro.deterministic("outlier_prob", 0.)

        mixing_distribution = dist.Categorical(probs=jnp.array([1 - outlier_prob, outlier_prob]))

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    self.fn(
                        x=intensity,
                        a=a[subject, feature0],
                        b=b[subject, feature0],
                        v=v[subject, feature0],
                        L=L[subject, feature0],
                        l=l[subject, feature0],
                        H=H[subject, feature0]
                    )
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[subject, feature0] + jnp.true_divide(g_2[subject, feature0], mu)
                )
                # beta = numpyro.deterministic(
                #     site.beta,
                #     g_1[subject, feature0] + g_2[subject, feature0] * jnp.power(1 / (mu + 1), p[subject, feature0])
                # )
                # beta = numpyro.deterministic(
                #     site.beta,
                #     g_1[subject, feature0]
                # )

                """ Mixture """
                component_distributions = [
                    dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                    dist.Gamma(concentration=outlier_dist_shape, rate=outlier_dist_rate)
                ]
                # component_distributions = [
                #     dist.Gamma(concentration=beta, rate=jnp.true_divide(beta, mu)),
                #     dist.Gamma(concentration=outlier_dist_shape, rate=outlier_dist_rate)
                # ]
                Mixture = dist.MixtureGeneral(
                    mixing_distribution=mixing_distribution,
                    component_distributions=component_distributions
                )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    Mixture,
                    obs=response_obs
                )


In [3]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/tms/config.toml"

config = Config(toml_path=toml_path)
config.BUILD_DIR = os.path.join(config.BUILD_DIR, "proc-2023-11-15", "model-comparison/mixture-model", "global-prior-on-mean")
# config.RESPONSE = ['PKPK_ADM', 'PKPK_APB', 'PKPK_Biceps', 'PKPK_ECR', 'PKPK_FCR', 'PKPK_Triceps']
# config.RESPONSE = ['PKPK_ECR']
# config.RESPONSE = ['PKPK_ADM']

config.MCMC_PARAMS["num_warmup"] = 10000
config.MCMC_PARAMS["num_samples"] = 1000

model = MixtureModel(config=config)


2023-11-15 09:25:33,925 - hbmep.config - INFO - Verifying configuration ...
2023-11-15 09:25:33,925 - hbmep.config - INFO - Success!
2023-11-15 09:25:33,940 - hbmep.model.baseline - INFO - Initialized base_model


In [4]:
src = "/home/vishu/data/hbmep-processed/human/tms/proc-2023-11-15.csv"
df = pd.read_csv(src)

# subset = ["SCS04", "SCA07", "SCA01"]
# ind = df[model.subject].isin(subset)
# df = df[ind].reset_index(drop=True).copy()

# ind = df[model.intensity] > 0
# df = df[ind].reset_index(drop=True).copy()

df[model.features[0]] = df[model.features[0]].replace({
    "Uninjured": "01_Uninjured",
    "SCI": "02_SCI"
})

df, encoder_dict = model.load(df=df)


2023-11-15 09:25:34,006 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/tms/proc-2023-11-15/model-comparison/mixture-model/global-prior-on-mean
2023-11-15 09:25:34,006 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/tms/proc-2023-11-15/model-comparison/mixture-model/global-prior-on-mean
2023-11-15 09:25:34,007 - hbmep.dataset.core - INFO - Processing data ...
2023-11-15 09:25:34,008 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-11-15 09:25:34,104 - hbmep.model.baseline - INFO - Running inference with base_model ...


  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

2023-11-15 15:46:44,936 - hbmep.utils.utils - INFO - func:run_inference took: 6 hr and 21 min


In [None]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-11-15 09:14:58,076 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.00 sec


2023-11-15 09:15:31,894 - hbmep.utils.utils - INFO - func:predict took: 33.82 sec
2023-11-15 09:15:31,894 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-11-15 09:15:49,737 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/proc-2023-11-15/model-comparison/mixture-model/global-prior-on-mean/recruitment_curves.pdf
2023-11-15 09:15:49,737 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 17.84 sec
2023-11-15 09:15:49,738 - hbmep.model.baseline - INFO - Rendering posterior predictive checks ...
2023-11-15 09:16:10,878 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/proc-2023-11-15/model-comparison/mixture-model/global-prior-on-mean/posterior_predictive_check.pdf
2023-11-15 09:16:10,879 - hbmep.utils.utils - INFO - func:predictive_checks_renderer took: 21.14 sec
2023-11-15 09:16:10,879 - hbmep.utils.utils - INFO - func:render_predictive_check took: 21.14 sec


In [None]:
mcmc.print_summary(prob=.95)



                               mean       std    median      2.5%     97.5%     n_eff     r_hat
             H_raw[0,0,0]      0.88      0.37      0.82      0.34      1.71      2.83      3.22
             H_raw[0,0,1]      0.29      0.03      0.28      0.25      0.35      3.03      2.84
             H_raw[0,0,2]      0.94      0.40      0.82      0.26      1.88     13.19      1.49
             H_raw[0,0,3]      0.43      0.08      0.40      0.33      0.53      2.11      5.20
             H_raw[0,0,4]      0.82      0.48      0.89      0.15      1.61      2.50      2.47
             H_raw[0,0,5]      0.78      0.42      0.84      0.24      1.53      4.84      1.95
             H_raw[0,1,0]      1.21      0.74      1.23      0.07      2.25      2.28      3.43
             H_raw[0,1,1]      1.00      0.53      0.92      0.14      1.82     19.40      1.19
             H_raw[0,1,2]      0.52      0.44      0.56      0.00      1.46      4.91      1.71
             H_raw[0,1,3]      0.42    

In [11]:
import scipy.stats as stats

a = posterior_samples["a"]
a = np.array(a)
a_map = a.mean(axis=0)[..., 0]

stats.ttest_ind(a=a_map[:10, 1], b=a_map[10:, 0], alternative="greater")

TtestResult(statistic=0.6538370230861886, pvalue=0.26190594190573013, df=14.0)

In [8]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")


2023-11-14 16:36:02,514 - __main__ - INFO - Evaluating model ...
  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
2023-11-14 16:36:04,824 - __main__ - INFO - ELPD LOO (Log): 2839.17
See http://arxiv.org/abs/1507.04544 for details
2023-11-14 16:36:04,890 - __main__ - INFO - ELPD WAIC (Log): 2843.34


In [None]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [None]:
dest = os.path.join(model.build_dir, "numpyro_data.nc")
az.to_netcdf(numpyro_data, dest)


'/home/vishu/repos/hbmep-paper/reports/paper/tms/link-comparison/rectified_logistic/numpyro_data.nc'