In [1]:
%reload_ext autoreload
%autoreload 2

import os
import logging
import multiprocessing
from pathlib import Path
from typing import Optional

import arviz as az
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns

import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import scipy.stats as stats
import numpyro
from numpyro.diagnostics import hpdi

from hbmep.config import Config
# from hbmep_paper.simulator import HierarchicalBayesianModel
from hbmep.model.utils import Site as site

PLATFORM = "cpu"
jax.config.update("jax_platforms", PLATFORM)
numpyro.set_platform(PLATFORM)

cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()
numpyro.enable_validation()

logger = logging.getLogger(__name__)

In [2]:
def _plot(
    self,
    df: pd.DataFrame,
    dfs: list[pd.DataFrame],
    destination_path: str,
    threshold_dfs: Optional[pd.DataFrame] = None
):
    """ Setup pdf layout """
    combinations = self._make_combinations(df=dfs[0], columns=self.combination_columns)
    n_combinations = len(combinations)

    n_columns_per_response = 1 + len(dfs)

    n_fig_rows = 10
    n_fig_columns = n_columns_per_response * self.n_response

    n_pdf_pages = n_combinations // n_fig_rows
    if n_combinations % n_fig_rows: n_pdf_pages += 1
    logger.info("Rendering ...")

    """ Iterate over pdf pages """
    pdf = PdfPages(destination_path)
    combination_counter = 0

    for page in range(n_pdf_pages):
        n_rows_current_page = min(
            n_fig_rows,
            n_combinations - page * n_fig_rows
        )

        fig, axes = plt.subplots(
            n_rows_current_page,
            n_fig_columns,
            figsize=(
                n_fig_columns * self.subplot_cell_width,
                n_rows_current_page * self.subplot_cell_height
            ),
            constrained_layout=True,
            squeeze=False
        )

        """ Iterate over combinations """
        for i in range(n_rows_current_page):
            curr_combination = combinations[combination_counter]

            """ Filter dataframe based on current combination """
            df_ind = df[self.combination_columns].apply(tuple, axis=1).isin([curr_combination])
            curr_df = df[df_ind].reset_index(drop=True).copy()

            # th_ind = threshold_df[self.combination_columns].apply(tuple, axis=1).isin([curr_combination])
            # curr_th_df = threshold_df[th_ind].reset_index(drop=True).copy()
            # curr_ths = curr_th_df[self.response].values[0]

            """ Iterate over responses """
            j = 0
            for r, response in enumerate(self.response):
                """ Plots: Scatter Plot """
                if curr_df.shape[0]:
                    ax = axes[i, j]
                    sns.scatterplot(data=curr_df, x=self.intensity, y=response, color=self.response_colors[r], ax=ax)

                    if j == 1:
                        ax.legend(loc="upper right")

                    ax.set_title("Real Data - " + response + f" - {curr_combination}")

                j += 1

                for sim_ind, curr_sim in enumerate(dfs):
                    curr_sim = curr_sim.copy()
                    curr_sim_ind = curr_sim[self.combination_columns].apply(tuple, axis=1).isin([curr_combination])
                    curr_sim = curr_sim[curr_sim_ind].reset_index(drop=True).copy()

                    ax = axes[i, j]
                    sns.scatterplot(data=curr_sim, x=self.intensity, y=response, color=self.response_colors[r], ax=ax)
                    # ax.axvline(curr_ths[r], color="black", label="True Threshold")
                    # ax.set_title(f"{N} Pulses, {m} Reps")

                    # if not sim_ind:
                    #     ax.legend(loc="upper right")

                    ax.set_title(f"Simulated: {curr_combination}")

                    j += 1

            combination_counter += 1

        pdf.savefig(fig)
        plt.close()

    pdf.close()
    plt.show()

    logger.info(f"Saved to {destination_path}")
    return




In [3]:
import numpyro.distributions as dist
from hbmep.model import Baseline
from hbmep_paper.utils.constants import HBM


class HierarchicalBayesianModel(Baseline):
    LINK = HBM

    def __init__(self, config: Config):
        super(HierarchicalBayesianModel, self).__init__(config=config)
        self.combination_columns = self.features + [self.subject]
        self.priors = {
            site.a, site.b, site.L, site.H, site.v, site.g_1, site.g_2
        }

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate("n_feature0", n_feature0, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(50, 50, low=0)``
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(50))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(0.1))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.5))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(5))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(10))

                with numpyro.plate(site.n_subject, n_subject, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0, high=100)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    g_1 = numpyro.sample(site.g_1, dist.Exponential(0.01))
                    g_2 = numpyro.sample(site.g_2, dist.Exponential(0.01))

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[subject, feature0]
            + jnp.maximum(
                0,
                -1
                + (H[subject, feature0] + 1)
                / jnp.power(
                    1
                    + (jnp.power(1 + H[subject, feature0], v[subject, feature0]) - 1)
                    * jnp.exp(-b[subject, feature0] * (intensity - a[subject, feature0])),
                    1 / v[subject, feature0]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[subject, feature0] + g_2[subject, feature0] * (1 / mu) ** 2
        )

        """ Observation """
        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(concentration=mu * beta, rate=beta).to_event(1),
                obs=response_obs
            )

In [4]:
import pickle


root_path = Path(os.getcwd()).parent.parent.parent.absolute()
toml_path = os.path.join(root_path, "configs/human/tms/hbm-chains.toml")

CONFIG = Config(toml_path=toml_path)
MODEL = HierarchicalBayesianModel(config=CONFIG)

src = "/home/vishu/data/hbmep-processed/human/tms/data.csv"
DF = pd.read_csv(src)

DF, ENCODER_DICT = MODEL.load(df=DF)

dest = os.path.join(MODEL.build_dir, "inference.pkl")
with open(dest, "rb") as g:
    _, MCMC, POSTERIOR_SAMPLES = pickle.load(g)

2023-09-26 10:45:01,052 - hbmep.config - INFO - Verifying configuration ...
2023-09-26 10:45:01,052 - hbmep.config - INFO - Success!
2023-09-26 10:45:01,066 - hbmep.model.baseline - INFO - Initialized model with hierarchical_bayesian link
2023-09-26 10:45:01,071 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/human/tms/hbm-chains
2023-09-26 10:45:01,071 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep-paper/reports/human/tms/hbm-chains
2023-09-26 10:45:01,072 - hbmep.dataset.core - INFO - Processing data ...
2023-09-26 10:45:01,073 - hbmep.utils.utils - INFO - func:load took: 0.00 sec


In [5]:
# prediction_df = \
#     pd.DataFrame(np.arange(0, 10, 1), columns=[MODEL.subject]) \
#     .merge(
#         pd.DataFrame(np.arange(0, 1, 1), columns=MODEL.features),
#         how="cross"
#     ) \
#     .merge(
#         pd.DataFrame([0, 100], columns=[MODEL.intensity]),
#         how="cross"
#     )

# prediction_df = MODEL.make_prediction_dataset(df=prediction_df)

# prediction_df = MODEL.make_prediction_dataset(df=prediction_df)

# posterior_predictive = MODEL.predict(df=prediction_df, posterior_samples=POSTERIOR_SAMPLES)

# obs = posterior_predictive[site.obs]

# prediction_df[MODEL.response] = obs[0, ...]


In [5]:
prediction_df = DF.copy()

prediction_df[MODEL.subject] = 5 + prediction_df[MODEL.subject]

prediction_df = pd.concat([DF, prediction_df], ignore_index=True)

posterior_predictive = MODEL.predict(
    df=prediction_df,
    posterior_samples={u: v for u, v in POSTERIOR_SAMPLES.items() if u not in MODEL.priors}
)
obs = posterior_predictive[site.obs]

prediction_df[MODEL.response] = obs[0, ...]




2023-09-26 10:45:57,191 - hbmep.utils.utils - INFO - func:predict took: 9.51 sec


In [6]:
dest = os.path.join(MODEL.build_dir, "power-analysis", "simulated_data.pdf")
_plot(self=MODEL, df=DF, dfs=[prediction_df], destination_path=dest)

2023-09-26 10:45:57,262 - __main__ - INFO - Rendering ...


2023-09-26 10:46:03,100 - __main__ - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/human/tms/hbm-chains/power-analysis/simulated_data.pdf
