In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)


In [2]:
import numpyro.distributions as dist
from hbmep.model import BaseModel


class MixtureModel(BaseModel):
    NAME = "mixture_model"

    def __init__(self, config: Config):
        super(MixtureModel, self).__init__(config=config)
        self.combination_columns = [self.subject] + self.features

    def fn(self, x, a, b, v, L, ell, H):
        return (
            L
            + jnp.where(
                jnp.less(x, a),
                0.,
                -ell + jnp.true_divide(
                    H + ell,
                    jnp.power(
                        1
                        + jnp.multiply(
                            -1
                            + jnp.power(
                                jnp.true_divide(H + ell, ell),
                                v
                            ),
                            jnp.exp(jnp.multiply(-b, x - a))
                        ),
                        jnp.true_divide(1, v)
                    )
                )
            )
        )

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response):
            """ Global Priors """
            b_scale_global_scale = numpyro.sample("b_scale_global_scale", dist.HalfNormal(5))
            v_scale_global_scale = numpyro.sample("v_scale_global_scale", dist.HalfNormal(5))

            L_scale_global_scale = numpyro.sample("L_scale_global_scale", dist.HalfNormal(.5))
            ell_scale_global_scale = numpyro.sample("ell_scale_global_scale", dist.HalfNormal(10))
            H_scale_global_scale = numpyro.sample("H_scale_global_scale", dist.HalfNormal(5))

            g_1_scale_global_scale = numpyro.sample("g_1_scale_global_scale", dist.HalfNormal(5))
            g_2_scale_global_scale = numpyro.sample("g_2_scale_global_scale", dist.HalfNormal(5))

            with numpyro.plate("n_feature0", n_feature0):
                """ Hyper-priors """
                a_mean = numpyro.sample("a_mean", dist.TruncatedNormal(50., 20., low=0))
                a_scale = numpyro.sample("a_scale", dist.HalfNormal(30.))

                b_scale_raw = numpyro.sample("b_scale_raw", dist.HalfNormal(scale=1))
                b_scale = numpyro.deterministic("b_scale", jnp.multiply(b_scale_global_scale, b_scale_raw))

                v_scale_raw = numpyro.sample("v_scale_raw", dist.HalfNormal(scale=1))
                v_scale = numpyro.deterministic("v_scale", jnp.multiply(v_scale_global_scale, v_scale_raw))

                L_scale_raw = numpyro.sample("L_scale_raw", dist.HalfNormal(scale=1))
                L_scale = numpyro.deterministic("L_scale", jnp.multiply(L_scale_global_scale, L_scale_raw))

                ell_scale_raw = numpyro.sample("ell_scale_raw", dist.HalfNormal(scale=1))
                ell_scale = numpyro.deterministic("ell_scale", jnp.multiply(ell_scale_global_scale, ell_scale_raw))

                H_scale_raw = numpyro.sample("H_scale_raw", dist.HalfNormal(scale=1))
                H_scale = numpyro.deterministic("H_scale", jnp.multiply(H_scale_global_scale, H_scale_raw))

                g_1_scale_raw = numpyro.sample("g_1_scale_raw", dist.HalfNormal(scale=1))
                g_1_scale = numpyro.deterministic("g_1_scale", jnp.multiply(g_1_scale_global_scale, g_1_scale_raw))

                g_2_scale_raw = numpyro.sample("g_2_scale_raw", dist.HalfNormal(scale=1))
                g_2_scale = numpyro.deterministic("g_2_scale", jnp.multiply(g_2_scale_global_scale, g_2_scale_raw))

                with numpyro.plate(site.n_subject, n_subject):
                    """ Priors """
                    a = numpyro.sample(
                        "a", dist.TruncatedNormal(a_mean, a_scale, low=0)
                    )

                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, jnp.multiply(b_scale, b_raw))

                    v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    v = numpyro.deterministic(site.v, jnp.multiply(v_scale, v_raw))

                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, jnp.multiply(L_scale, L_raw))

                    ell_raw = numpyro.sample("ell_raw", dist.HalfNormal(scale=1))
                    ell = numpyro.deterministic("ell", jnp.multiply(ell_scale, ell_raw))

                    H_raw = numpyro.sample("H_raw", dist.HalfNormal(scale=1))
                    H = numpyro.deterministic(site.H, jnp.multiply(H_scale, H_raw))

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, jnp.multiply(g_1_scale, g_1_raw))

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, jnp.multiply(g_2_scale, g_2_raw))

        if response_obs is not None:
            """ Outlier Distribution """
            outlier_prob = numpyro.sample("outlier_prob", dist.Uniform(0., .01))
            outlier_scale = numpyro.sample("outlier_scale", dist.HalfNormal(10))

        with numpyro.plate(site.n_response, self.n_response):
            with numpyro.plate(site.n_data, n_data):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    self.fn(
                        x=intensity,
                        a=a[subject, feature0],
                        b=b[subject, feature0],
                        v=v[subject, feature0],
                        L=L[subject, feature0],
                        ell=ell[subject, feature0],
                        H=H[subject, feature0]
                    )
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[subject, feature0] + jnp.true_divide(g_2[subject, feature0], mu)
                )

                if response_obs is not None:
                    q = numpyro.deterministic("q", outlier_prob * jnp.ones((n_data, self.n_response)))
                    bg_scale = numpyro.deterministic("bg_scale", outlier_scale * jnp.ones((n_data, self.n_response)))

                    mixing_distribution = dist.Categorical(
                        probs=jnp.stack([1 - q, q], axis=-1)
                    )
                    component_distributions=[
                        dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                        dist.HalfNormal(scale=bg_scale)
                    ]

                    """ Mixture """
                    Mixture = dist.MixtureGeneral(
                        mixing_distribution=mixing_distribution,
                        component_distributions=component_distributions
                    )

                    """ Observation """
                    numpyro.sample(
                        site.obs,
                        Mixture,
                        obs=response_obs
                    )

                if response_obs is None:
                    """ Observation """
                    numpyro.sample(
                        site.obs,
                        dist.Gamma(concentration=jnp.multiply(mu, beta), rate=beta),
                        obs=response_obs
                    )


In [3]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/tms/config.toml"
config = Config(toml_path=toml_path)
config.BUILD_DIR = os.path.join(config.BUILD_DIR, "model-comparison", "testing", "mixture-model")
config.MCMC_PARAMS["num_warmup"] = 4000
config.MCMC_PARAMS["num_samples"] = 1000

model = MixtureModel(config=config)


2023-12-01 12:06:51,429 - hbmep.model.baseline - INFO - Initialized mixture_model


In [4]:
src = "/home/vishu/data/hbmep-processed/human/tms/proc_2023-11-28.csv"
df = pd.read_csv(src)

subset = ["SCA01", "SCA02", "SCS01"]
ind = df[model.subject].isin(subset)
df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-12-01 12:06:56,077 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model
2023-12-01 12:06:56,078 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model
2023-12-01 12:06:56,081 - hbmep.dataset.core - INFO - Processing data ...
2023-12-01 12:06:56,082 - hbmep.utils.utils - INFO - func:load took: 0.01 sec


In [5]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-12-01 12:06:59,465 - hbmep.model.baseline - INFO - Running inference with mixture_model ...


  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

In [7]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

orderby = lambda x: (x[1], x[0])
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive, orderby=orderby)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive, orderby=orderby)


2023-11-30 10:58:25,496 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.00 sec
2023-11-30 10:58:26,959 - hbmep.utils.utils - INFO - func:predict took: 1.46 sec
2023-11-30 10:58:26,960 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-11-30 10:58:28,118 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model/recruitment_curves.pdf
2023-11-30 10:58:28,118 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 1.16 sec
2023-11-30 10:58:28,118 - hbmep.model.baseline - INFO - Rendering posterior predictive checks ...
2023-11-30 10:58:29,381 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model/posterior_predictive_check.pdf
2023-11-30 10:58:29,382 - hbmep.utils.utils - INFO - func:predictive_checks_renderer took: 1.26 sec
2023-11-30 10:58:29,382 - hbmep.utils.utils - INFO - func:render_predictive_ch

In [8]:
_posterior_samples = posterior_samples.copy()
_posterior_samples["outlier_prob"] = _posterior_samples["outlier_prob"] * 0

prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=_posterior_samples)

orderby = lambda x: (x[1], x[0])
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=_posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive, orderby=orderby)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive, orderby=orderby)

2023-12-01 12:21:09,176 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.00 sec
2023-12-01 12:21:13,073 - hbmep.utils.utils - INFO - func:predict took: 3.90 sec
2023-12-01 12:21:13,075 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-12-01 12:21:18,616 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model/recruitment_curves.pdf
2023-12-01 12:21:18,617 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 5.54 sec
2023-12-01 12:21:18,618 - hbmep.model.baseline - INFO - Rendering posterior predictive checks ...
2023-12-01 12:21:23,632 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model/posterior_predictive_check.pdf
2023-12-01 12:21:23,633 - hbmep.utils.utils - INFO - func:predictive_checks_renderer took: 5.02 sec
2023-12-01 12:21:23,633 - hbmep.utils.utils - INFO - func:render_predictive_ch

In [17]:
orderby = lambda x: (x[1], x[0])
combinations = model._make_combinations(df=df, columns=model.combination_columns, orderby=orderby)
combinations


[(1, 1), (0, 1), (2, 0)]

In [11]:
combinations = model._make_combinations(df=df, columns=model.combination_columns)
combinations


[(0, 1), (1, 1), (2, 0)]

In [10]:
model.combination_columns


['participant', 'participant_condition']

In [18]:
orderby = lambda x: (x[1], -x[0])
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=_posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive, orderby=orderby)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive, orderby=orderby)


2023-12-01 12:24:40,836 - hbmep.model.baseline - INFO - Rendering recruitment curves ...


2023-12-01 12:24:45,901 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model/recruitment_curves.pdf
2023-12-01 12:24:45,902 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 5.07 sec
2023-12-01 12:24:45,902 - hbmep.model.baseline - INFO - Rendering posterior predictive checks ...
2023-12-01 12:24:50,935 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model/posterior_predictive_check.pdf
2023-12-01 12:24:50,937 - hbmep.utils.utils - INFO - func:predictive_checks_renderer took: 5.03 sec
2023-12-01 12:24:50,937 - hbmep.utils.utils - INFO - func:render_predictive_check took: 5.03 sec


In [7]:
_posterior_samples = posterior_samples.copy()
_posterior_samples["outlier_prob"] = _posterior_samples["outlier_prob"] * 0

prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=_posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=_posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-12-01 12:19:24,086 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec


2023-12-01 12:19:27,997 - hbmep.utils.utils - INFO - func:predict took: 3.91 sec
2023-12-01 12:19:27,997 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-12-01 12:19:32,341 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model/recruitment_curves.pdf
2023-12-01 12:19:32,342 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 4.34 sec
2023-12-01 12:19:32,342 - hbmep.model.baseline - INFO - Rendering posterior predictive checks ...
2023-12-01 12:19:36,764 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/tms/model-comparison/testing/mixture-model/posterior_predictive_check.pdf
2023-12-01 12:19:36,765 - hbmep.utils.utils - INFO - func:predictive_checks_renderer took: 4.42 sec
2023-12-01 12:19:36,765 - hbmep.utils.utils - INFO - func:render_predictive_check took: 4.42 sec


In [8]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data, var_name=site.obs)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data, var_name=site.obs)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")


2023-11-30 10:58:57,190 - __main__ - INFO - Evaluating model ...
2023-11-30 10:58:58,662 - __main__ - INFO - ELPD LOO (Log): 1520.44
See http://arxiv.org/abs/1507.04544 for details
2023-11-30 10:58:58,686 - __main__ - INFO - ELPD WAIC (Log): 1499.86


In [7]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data, var_name=site.obs)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data, var_name=site.obs)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")


2023-11-30 10:55:07,526 - __main__ - INFO - Evaluating model ...
2023-11-30 10:55:09,042 - __main__ - INFO - ELPD LOO (Log): 1520.44
See http://arxiv.org/abs/1507.04544 for details
2023-11-30 10:55:09,067 - __main__ - INFO - ELPD WAIC (Log): 1499.86
