In [1]:
%reload_ext autoreload
%autoreload 2

import os
import logging
from pathlib import Path
import multiprocessing

import jax.numpy as jnp
import pandas as pd
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.distributions.mixtures import MixtureGeneral

from hbmep.config import Config
from hbmep.dataset import Dataset
from hbmep.model import Baseline
from hbmep.model.utils import Site as site
from hbmep.utils.constants import RECTIFIED_LOGISTIC
# from hbmep_paper.models.rats.utils import load_data

numpyro.set_platform("cpu")
cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()

FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO)
logger = logging.getLogger(__name__)


#### Load config

In [2]:
root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/config.toml")

config = Config(toml_path=toml_path)

2023-07-14 15:01:21,854 - hbmep.config - INFO - Verifying configuration ...
2023-07-14 15:01:21,854 - hbmep.config - INFO - Success!


#### Load data and preprocess

In [3]:
dir = "/home/vishu/data/hbmep-processed/J_RCML_000"

# auc_window_path = os.path.join(dir, "auc_window.npy")
# time_path = os.path.join(dir, "time.npy")

# mat_path = os.path.join(dir, "mat.npy")
# mat = np.load(mat_path)

# auc_window = np.load(auc_window_path)
# time = np.load(time_path)

df_path = os.path.join(dir, "data.csv")
df = pd.read_csv(df_path)

data = Dataset(config)

""" Filter """
columns = data.columns
sub = [
    # ("amap06", "-C7L"),
    # ("amap06", "-C8L"),
    # ("amap08", "C6M-C7L"),
    ("amap08", "C7M-C7L")
]
ind = df[columns].apply(tuple, axis=1).isin(sub)

df = df[ind].copy()
df.reset_index(drop=True, inplace=True)

# mat = mat[ind, ...]

df, encoder_dict = data.build(df=df)


2023-07-14 15:01:22,137 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/reports/paper/run01
2023-07-14 15:01:22,137 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/reports/paper/run01
2023-07-14 15:01:22,137 - hbmep.dataset.core - INFO - Processing data ...
2023-07-14 15:01:22,139 - hbmep.utils.utils - INFO - func:build took: 0.00 sec


In [4]:
class RectifiedLogistic(Baseline):
    def __init__(self, config: Config):
        super(RectifiedLogistic, self).__init__(config=config)
        self.link = RECTIFIED_LOGISTIC

        self.mu_a = config.PRIORS[site.mu_a]
        self.sigma_a = config.PRIORS[site.sigma_a]

        self.sigma_b = config.PRIORS[site.sigma_b]

        self.sigma_L = config.PRIORS[site.sigma_L]
        self.sigma_H = config.PRIORS[site.sigma_H]
        self.sigma_v = config.PRIORS[site.sigma_v]

        self.g_1 = config.PRIORS[site.g_1]
        self.g_2 = config.PRIORS[site.g_2]

        self.p = config.PRIORS[site.p]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(self.mu_a[0], self.mu_a[1], low=0)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(self.sigma_a))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(self.sigma_b))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(self.sigma_L))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(self.sigma_H))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(self.sigma_v))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    g_1 = numpyro.sample(
                        site.g_1, dist.HalfCauchy(self.g_1)
                    )
                    g_2 = numpyro.sample(
                        site.g_2, dist.HalfCauchy(self.g_2)
                    )

                    p = numpyro.sample(site.p, dist.HalfNormal(self.p))

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[feature0, subject] + \
            jnp.maximum(
                0,
                -1 + \
                (H[feature0, subject] + 1) / \
                jnp.power(
                    1 + \
                    (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1) * \
                    jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
                    1 / v[feature0, subject]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[feature0, subject] + \
            g_2[feature0, subject] * jnp.power(1 / mu, p[feature0, subject])
        )

        """ Mixture """
        q = numpyro.sample("q", dist.Beta(1, 24))

        mixing_distribution = dist.Categorical(probs=jnp.array([q, 1 - q]))

        component_distributions = [
            dist.Gamma(mu * beta, beta).to_event(1),
            dist.Gamma(mu * 10000, 10000).to_event(1)
        ]

        Mixture = MixtureGeneral(
            mixing_distribution=mixing_distribution,
            component_distributions=component_distributions
        )

        # with numpyro.plate(site.data, n_data):
        #     return numpyro.sample(
        #         site.obs,
        #         dist.Gamma(mu * beta, beta).to_event(1),
        #         obs=response_obs
        #     )

        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                Mixture,
                obs=response_obs
            )



In [5]:
model = RectifiedLogistic(config=config)

2023-07-14 15:01:22,826 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-07-14 15:01:22,826 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-07-14 15:01:22,826 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
2023-07-14 15:01:22,826 - jax._src.xla_bridge - INFO - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.


In [6]:
# data.plot(df=df, encoder_dict=encoder_dict)

In [7]:
mcmc, posterior_samples = model.run_inference(df=df)


2023-07-14 15:01:23,248 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-14 15:02:04,653 - hbmep.utils.utils - INFO - func:run_inference took: 41.41 sec


In [11]:
mcmc.print_summary(prob=.95)



                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      1.56      0.12      1.50      1.46      1.82      3.94      1.74
  L[0,0,0]      0.01      0.00      0.01      0.01      0.02   3290.34      1.00
  a[0,0,0]     53.86      3.53     54.46     47.33     59.54      9.21      1.21
  b[0,0,0]      0.07      0.02      0.07      0.03      0.10      8.58      1.28
g_1[0,0,0]     10.92      5.43     10.76      0.00     20.26    203.70      1.02
g_2[0,0,0]      4.64      4.62      3.06      0.00     14.01    688.74      1.01
  p[0,0,0]      0.39      0.47      0.25      0.00      1.13    106.33      1.05
         q      0.45      0.06      0.45      0.33      0.57   1209.72      1.00
  v[0,0,0]      1.18      1.31      0.72      0.00      3.97    111.71      1.05
  µ_a[0,0]    126.74     20.67    126.75     87.46    166.25   2492.62      1.00
  σ_H[0,0]     16.72     20.04      8.77      0.53     59.01   5177.34      1.00
  σ_L[0,0]      0.03      0

In [9]:
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-14 15:02:05,098 - hbmep.model.baseline - INFO - Rendering recruitment curves ...
2023-07-14 15:02:08,041 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/reports/paper/run01/recruitment_curves.pdf
2023-07-14 15:02:08,041 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 2.94 sec


In [10]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)


2023-07-14 15:02:08,306 - hbmep.model.baseline - INFO - Rendering Posterior Predictive Check ...
2023-07-14 15:02:11,278 - hbmep.model.baseline - INFO - Saved to /home/vishu/repos/reports/paper/run01/posterior_predictive_check.pdf
2023-07-14 15:02:11,280 - hbmep.utils.utils - INFO - func:render_predictive_check took: 2.98 sec
