In [1]:
%reload_ext autoreload
%autoreload 2

import os
import json
from pathlib import Path
import multiprocessing

import jax
import jax.numpy as jnp
import pandas as pd
import numpy as np
import numpyro
import numpyro.distributions as dist

from hbmep.config import Config
from hbmep.model import Baseline
from hbmep.model.utils import Site as site
from hbmep.utils.constants import RECTIFIED_LOGISTIC

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()


In [2]:
# src = "/home/vishu/data/hbmep-processed/J_RCML_000/ground"

# df = pd.read_csv(os.path.join(src, "data.csv"))
# mat = np.load(os.path.join(src, "mat.npy"))
# auc_window = np.load(os.path.join(src, "auc_window.npy"))

# f = open(os.path.join(src, "muscles_map.json"))
# muscles_map = json.load(f)

# muscles_map

In [3]:
# """ Filter """
# sub = [1, 2]
# mat = mat[..., np.array(sub) - 1]

# muscles_map = {u: v for u, v in muscles_map.items() if int(u.split("_")[-1]) in sub}

# """ Save """
# dst = "/home/vishu/data/hbmep-processed/J_RCML_000/ground/1_2"

# df.to_csv(os.path.join(dst, "data.csv"), index=False)

# np.save(os.path.join(dst, "mat.npy"), mat)
# np.save(os.path.join(dst, "auc_window.npy"), np.array(auc_window))



In [4]:
# f = open(os.path.join(dst, "muscles_map.json"), "w")
# f.write(json.dumps(muscles_map))
# f.close;

In [5]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/ground_ADM_Biceps.toml")

config = Config(toml_path=toml_path)
config.BUILD_DIR = "/home/vishu/repos/hbmep-paper/reports/J_RCML_000/ground/1_2/nb"


2023-07-26 09:16:33,697 - hbmep.config - INFO - Verifying configuration ...
2023-07-26 09:16:33,698 - hbmep.config - INFO - Success!


In [6]:
class RectifiedLogistic(Baseline):
    LINK = RECTIFIED_LOGISTIC

    def __init__(self, config: Config):
        super(RectifiedLogistic, self).__init__(config=config)

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(150, 50, low=0)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(50))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(10))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    g_1 = numpyro.sample("g_1", dist.Exponential(1 / 100))
                    g_2 = numpyro.sample("g_2", dist.Exponential(1 / 100))

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[feature0, subject]
            + jnp.maximum(
                0,
                -1
                + (H[feature0, subject] + 1)
                / jnp.power(
                    1
                    + (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1)
                    * jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
                    1 / v[feature0, subject]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[feature0, subject] + g_2[feature0, subject] * (1 / mu)
        )

        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )


model = RectifiedLogistic(config=config)

In [7]:
# df = pd.read_csv(model.csv_path)

# """ Filter """
# sub = ["amap01"]
# # sub = ["amap01", "amap02"]

# sub = ["amap01", "amap02", "amap03", "amap04"]
# # sub += ["amap05", "amap06", "amap07", "amap08"]

# ind = df.participant.isin(sub)

# df = df[ind].copy()
# df.reset_index(drop=True, inplace=True)
# print(df.shape)

# df, encoder_dict = model.load(df=df)


In [8]:
df, encoder_dict = model.load()

2023-07-26 09:16:33,757 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/J_RCML_000/ground/1_2/nb
2023-07-26 09:16:33,757 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/J_RCML_000/ground/1_2/nb
2023-07-26 09:16:33,758 - hbmep.dataset.core - INFO - Reading data from /home/vishu/data/hbmep-processed/J_RCML_000/ground/1_2/data.csv ...
2023-07-26 09:16:33,765 - hbmep.dataset.core - INFO - Processing data ...
2023-07-26 09:16:33,766 - hbmep.utils.utils - INFO - func:load took: 0.01 sec


In [9]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-07-26 09:16:33,775 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-26 11:57:38,639 - hbmep.utils.utils - INFO - func:run_inference took: 2 hr and 41 min


In [10]:
mcmc.print_summary(prob=.95)



                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      0.90      0.71      0.76      0.56      1.53    753.91      1.00
  H[0,0,1]      3.25      1.27      2.96      1.60      5.55   1682.45      1.00
  H[0,1,0]      3.03      3.88      1.61      0.05     10.80   1423.94      1.00
  H[0,1,1]      4.76      0.51      4.67      3.89      5.80   1684.52      1.00
  H[0,2,0]      2.41      0.08      2.40      2.25      2.57   1746.00      1.00
  H[0,2,1]      3.03      0.06      3.03      2.92      3.16   2494.95      1.00
  H[0,3,0]      2.77      2.49      2.04      0.19      7.60   2479.75      1.00
  H[0,3,1]      5.32      1.64      4.96      3.05      8.45   1072.96      1.00
  H[0,4,0]      1.71      1.79      1.10      0.24      5.16   1644.46      1.00
  H[0,4,1]      2.67      0.18      2.65      2.38      3.04    434.82      1.01
  H[0,5,0]      0.09      0.01      0.09      0.07      0.12   2211.97      1.00
  H[0,5,1]     13.70      5

In [11]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-26 11:58:54,491 - hbmep.model.baseline - INFO - Rendering recruitment curves ...


2023-07-26 12:01:37,344 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/J_RCML_000/ground/1_2/nb/recruitment_curves.pdf
2023-07-26 12:01:37,344 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 2 min and 42.86 sec


In [12]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-26 12:01:37,957 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...


2023-07-26 12:04:44,124 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/J_RCML_000/ground/1_2/nb/posterior_predictive_check.pdf
2023-07-26 12:04:44,126 - hbmep.utils.utils - INFO - func:render_predictive_check took: 3 min and 6.17 sec


In [13]:
model.save(mcmc=mcmc)

2023-07-26 12:04:44,274 - hbmep.model.baseline - INFO - Saving inference data ...


2023-07-26 12:05:47,232 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/J_RCML_000/ground/1_2/nb/mcmc.nc
2023-07-26 12:05:47,232 - hbmep.model.baseline - INFO - Rendering convergence diagnostics ...
2023-07-26 12:07:48,046 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/J_RCML_000/ground/1_2/nb/diagnostics.csv
2023-07-26 12:07:48,046 - hbmep.model.baseline - INFO - Evaluating model ...
  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
  return umr_sum(a, axis, dtype, out, keepdims, initial, where)
2023-07-26 12:08:03,860 - hbmep.model.baseline - INFO - ELPD LOO (Log): 8433.27
See http://arxiv.org/abs/1507.04544 for details
2023-07-26 12:08:04,827 - hbmep.model.baseline - INFO - ELPD WAIC (Log): 8519.66
2023-07-26 12:08:04,836 - hbmep.utils.utils - INFO - func:save took: 3 min and 20.56 sec
