In [1]:
%reload_ext autoreload
%autoreload 2

import os
import pickle
import logging
import multiprocessing
from pathlib import Path

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

import arviz as az
import numpyro

from hbmep.config import Config
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)

In [2]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class SaturatedReLU(Baseline):
    LINK = "saturated_relu"

    def __init__(self, config: Config):
        super(SaturatedReLU, self).__init__(config=config)
        # self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
            # global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))
            global_g_shape = numpyro.sample("global_g_shape", dist.HalfNormal(100))

            global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(.1))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=10))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=10 / 150))

                # sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))
                sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                g_shape_raw = numpyro.sample("g_shape_raw", dist.HalfNormal(scale=1))
                g_shape = numpyro.deterministic("g_shape", global_g_shape * g_shape_raw)
                # sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                # sigma_v = numpyro.deterministic(site.sigma_v, global_sigma_v * sigma_v_raw)

                # sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                sigma_L = numpyro.deterministic(site.sigma_L, global_sigma_L * sigma_L_raw)

                # sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                # sigma_H = numpyro.deterministic(site.sigma_H, global_sigma_H * sigma_H_raw)

                sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                    a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                    # b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))
                    b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    b = numpyro.deterministic(site.b, sigma_b * b_raw)

                    g = numpyro.sample("g", dist.Beta(1, g_shape))
                    # v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    # v = numpyro.deterministic(site.v, sigma_v * v_raw)

                    # L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    L = numpyro.deterministic(site.L, sigma_L * L_raw)

                    # g_1 = numpyro.sample(
                    #     site.g_1, dist.HalfCauchy(2.5)
                    # )
                    # g_2 = numpyro.sample(
                    #     site.g_2, dist.HalfCauchy(2.5)
                    # )

                    g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                    g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                    p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    p = numpyro.deterministic("p", sigma_p * p_raw)

                    # p = numpyro.sample("p", dist.HalfNormal(scale=5))


        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    L[feature0, subject] - \
                    jnp.log(jnp.maximum(
                        g[feature0, subject],
                        jnp.exp(-jax.nn.relu(
                            b[feature0, subject] * (intensity - a[feature0, subject])
                        ))
                    ))
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[feature0, subject] + g_2[feature0, subject] * jnp.power(1 / (mu + 1), p[feature0, subject])
                )
                # beta = numpyro.deterministic(
                #     site.beta,
                #     g_1[feature0, subject] + g_2[feature0, subject] * jnp.power(1 / (mu + 1), 1)
                # )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    dist.Gamma(concentration=mu * beta, rate=beta),
                    obs=response_obs
                )


In [5]:
import numpyro.distributions as dist
from hbmep.model import Baseline


class SaturatedReLU(Baseline):
    LINK = "saturated_relu"

    def __init__(self, config: Config):
        super(SaturatedReLU, self).__init__(config=config)
        # self.combination_columns = self.features + [self.subject]

    def _model(self, subject, features, intensity, response_obs=None):
        subject, n_subject = subject
        features, n_features = features
        intensity, n_data = intensity

        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)
        n_feature0 = n_features[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            # global_sigma_g_1 = numpyro.sample("global_sigma_g_1", dist.HalfNormal(100))
            # global_sigma_g_2 = numpyro.sample("global_sigma_g_2", dist.HalfNormal(100))

            # global_sigma_b = numpyro.sample("global_sigma_b", dist.HalfNormal(100))
            # global_sigma_v = numpyro.sample("global_sigma_v", dist.HalfNormal(100))
            # global_g_shape = numpyro.sample("global_g_shape", dist.HalfNormal(100))

            # global_sigma_L = numpyro.sample("global_sigma_L", dist.HalfNormal(1))

            global_sigma_p = numpyro.sample("global_sigma_p", dist.HalfNormal(100))

            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(site.mu_a, dist.HalfNormal(scale=10))
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(scale=10 / 150))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))
                # sigma_b_raw = numpyro.sample("sigma_b_raw", dist.HalfNormal(scale=1))
                # sigma_b = numpyro.deterministic(site.sigma_b, global_sigma_b * sigma_b_raw)

                g_shape = numpyro.sample("g_shape", dist.HalfNormal(5.0))
                # g_shape_raw = numpyro.sample("g_shape_raw", dist.HalfNormal(scale=1))
                # g_shape = numpyro.deterministic("g_shape", global_g_shape * g_shape_raw)
                # sigma_v_raw = numpyro.sample("sigma_v_raw", dist.HalfNormal(scale=1))
                # sigma_v = numpyro.deterministic(site.sigma_v, global_sigma_v * sigma_v_raw)

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                # sigma_L_raw = numpyro.sample("sigma_L_raw", dist.HalfNormal(scale=1))
                # sigma_L = numpyro.deterministic(site.sigma_L, global_sigma_L * sigma_L_raw)

                # sigma_H_raw = numpyro.sample("sigma_H_raw", dist.HalfNormal(scale=1))
                # sigma_H = numpyro.deterministic(site.sigma_H, global_sigma_H * sigma_H_raw)

                # sigma_g_1_raw = numpyro.sample("sigma_g_1_raw", dist.HalfNormal(scale=1))
                # sigma_g_1 = numpyro.deterministic("sigma_g_1", global_sigma_g_1 * sigma_g_1_raw)

                # sigma_g_2_raw = numpyro.sample("sigma_g_2_raw", dist.HalfNormal(scale=1))
                # sigma_g_2 = numpyro.deterministic("sigma_g_2", global_sigma_g_2 * sigma_g_2_raw)

                sigma_p_raw = numpyro.sample("sigma_p_raw", dist.HalfNormal(scale=1))
                sigma_p = numpyro.deterministic("sigma_p", global_sigma_p * sigma_p_raw)

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a_raw = numpyro.sample("a_raw", dist.Gamma(concentration=mu_a, rate=1))
                    a = numpyro.deterministic(site.a, (1 / sigma_a) * a_raw)

                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))
                    # b_raw = numpyro.sample("b_raw", dist.HalfNormal(scale=1))
                    # b = numpyro.deterministic(site.b, sigma_b * b_raw)

                    g = numpyro.sample("g", dist.Beta(1, g_shape))
                    # v_raw = numpyro.sample("v_raw", dist.HalfNormal(scale=1))
                    # v = numpyro.deterministic(site.v, sigma_v * v_raw)

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    # L_raw = numpyro.sample("L_raw", dist.HalfNormal(scale=1))
                    # L = numpyro.deterministic(site.L, sigma_L * L_raw)

                    g_1 = numpyro.sample(
                        site.g_1, dist.HalfCauchy(2.5)
                    )
                    g_2 = numpyro.sample(
                        site.g_2, dist.HalfCauchy(2.5)
                    )

                    # g_1_raw = numpyro.sample("g_1_raw", dist.HalfCauchy(scale=1))
                    # g_1 = numpyro.deterministic(site.g_1, sigma_g_1 * g_1_raw)

                    # g_2_raw = numpyro.sample("g_2_raw", dist.HalfCauchy(scale=1))
                    # g_2 = numpyro.deterministic(site.g_2, sigma_g_2 * g_2_raw)

                    p_raw = numpyro.sample("p_raw", dist.HalfNormal(scale=1))
                    p = numpyro.deterministic("p", sigma_p * p_raw)

                    # p = numpyro.sample("p", dist.HalfNormal(scale=5))

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.data, n_data, dim=-2):
                """ Model """
                mu = numpyro.deterministic(
                    site.mu,
                    L[feature0, subject] - \
                    jnp.log(jnp.maximum(
                        g[feature0, subject],
                        jnp.exp(-jax.nn.relu(
                            b[feature0, subject] * (intensity - a[feature0, subject])
                        ))
                    ))
                )
                beta = numpyro.deterministic(
                    site.beta,
                    g_1[feature0, subject] + g_2[feature0, subject] * jnp.power(1 / (mu + 1), p[feature0, subject])
                )
                # beta = numpyro.deterministic(
                #     site.beta,
                #     g_1[feature0, subject] + g_2[feature0, subject] * jnp.power(1 / (mu + 1), 1)
                # )

                """ Observation """
                numpyro.sample(
                    site.obs,
                    dist.Gamma(concentration=mu * beta, rate=beta),
                    obs=response_obs
                )

In [3]:
toml_path = "/home/vishu/repos/hbmep-paper/configs/paper/rats/J_RCML_000/link-comparison/saturated_relu.toml"

config = Config(toml_path=toml_path)
config.BUILD_DIR = os.path.join(config.BUILD_DIR, "previous")
# config.MCMC_PARAMS["num_warmup"] = 4000
# config.MCMC_PARAMS["num_samples"] = 6000
# config.MCMC_PARAMS["thinning"] = 4

model = SaturatedReLU(config=config)


2023-10-27 01:30:10,141 - hbmep.config - INFO - Verifying configuration ...
2023-10-27 01:30:10,142 - hbmep.config - INFO - Success!
2023-10-27 01:30:10,156 - hbmep.model.baseline - INFO - Initialized model with saturated_relu link


In [4]:
src = "/home/vishu/data/hbmep-processed/J_RCML_000/data.csv"
df = pd.read_csv(src)

# model.mep_matrix_path = None
subset = ["amap01"]
ind = df[model.subject].isin(subset)

# subset = [("-C5L", "amap01"), ("-C6L", "amap01")]
# subset = [("-C6L", "amap01")]
# subset = [("-C7L", "amap01")]
# subset = [("amap01", "-C6L"), ("amap01", "-C7L")]
# ind = df[model.combination_columns].apply(tuple, axis=1).isin(subset)

df = df[ind].reset_index(drop=True).copy()

df, encoder_dict = model.load(df=df)


2023-10-27 01:30:10,227 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/saturated_relu/previous
2023-10-27 01:30:10,227 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/saturated_relu/previous
2023-10-27 01:30:10,228 - hbmep.dataset.core - INFO - Processing data ...
2023-10-27 01:30:10,229 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [6]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-10-27 01:30:10,293 - hbmep.model.baseline - INFO - Running inference with saturated_relu ...


  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

2023-10-27 01:40:04,268 - hbmep.utils.utils - INFO - func:run_inference took: 9 min and 53.97 sec


In [8]:
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)

model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)


2023-10-27 01:52:47,967 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.01 sec


2023-10-27 01:52:50,805 - hbmep.utils.utils - INFO - func:predict took: 2.84 sec
2023-10-27 01:52:50,814 - hbmep.model.baseline - INFO - Rendering ...
2023-10-27 01:52:54,123 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/saturated_relu/previous/recruitment_curves.pdf
2023-10-27 01:52:54,123 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 3.32 sec
2023-10-27 01:52:54,138 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-10-27 01:52:57,941 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/paper/rats/J_RCML_000/link-comparison/saturated_relu/previous/posterior_predictive_check.pdf
2023-10-27 01:52:57,942 - hbmep.utils.utils - INFO - func:_render_predictive_check took: 3.82 sec
2023-10-27 01:52:57,942 - hbmep.utils.utils - INFO - func:render_predictive_check took: 3.82 sec


In [7]:
mcmc.print_summary(prob=.95)



                         mean       std    median      2.5%     97.5%     n_eff     r_hat
       L_raw[0,0,0]      1.09      0.29      1.07      0.56      1.71      7.82      1.21
       L_raw[1,0,0]      0.79      0.34      0.85      0.26      1.27      2.63      2.07
       L_raw[2,0,0]      0.72      0.21      0.69      0.45      1.14      6.53      1.24
       L_raw[3,0,0]      0.81      0.27      0.83      0.38      1.25      3.28      1.65
       L_raw[4,0,0]      0.90      0.33      0.96      0.42      1.39      2.78      1.92
       L_raw[5,0,0]      0.76      0.19      0.74      0.36      1.12      6.56      1.32
       L_raw[6,0,0]      0.78      0.35      0.86      0.24      1.27      2.56      2.15
       L_raw[7,0,0]      0.50      0.33      0.53      0.04      1.03      2.92      1.77
       L_raw[8,0,0]      0.94      0.16      0.92      0.62      1.27    170.65      1.03
       L_raw[9,0,0]      0.87      0.25      0.83      0.55      1.39      6.29      1.25
      L_r

In [12]:
numpyro_data = az.from_numpyro(mcmc)

""" Model evaluation """
logger.info("Evaluating model ...")

score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")

score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")

2023-10-26 23:47:11,657 - __main__ - INFO - Evaluating model ...
  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
2023-10-26 23:47:14,768 - __main__ - INFO - ELPD LOO (Log): 1078.16
See http://arxiv.org/abs/1507.04544 for details
2023-10-26 23:47:14,811 - __main__ - INFO - ELPD WAIC (Log): 1061.09


In [None]:
import pickle

dest = os.path.join(model.build_dir, "inference.pkl")
with open(dest, "wb") as f:
    pickle.dump((model, mcmc, posterior_samples), f)


In [None]:
dest = os.path.join(model.build_dir, "numpyro_data.nc")
az.to_netcdf(numpyro_data, dest)


BlockingIOError: [Errno 11] Unable to create file (unable to lock file, errno = 11, error message = 'Resource temporarily unavailable')

In [None]:
# import pickle

# with open(dest, "rb") as g:
#     model_, mcmc_, posterior_samples_ = pickle.load(g)
