In [1]:
%reload_ext autoreload
%autoreload 2

import os
import logging
import multiprocessing
from pathlib import Path

import numpyro
from hbmep.config import MepConfig
from hbmep.dataset import MepDataset
from hbmep.models import Model

numpyro.set_platform("cpu")
cpu_count = multiprocessing.cpu_count() - 2
numpyro.set_host_device_count(cpu_count)
numpyro.enable_x64()

FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO)
logger = logging.getLogger(__name__)

#### Load config


In [2]:
# Path to toml file
root_path = Path(os.getcwd()).parent.absolute()
toml_path = os.path.join(root_path, "config.toml")
logger.info(f"Toml path - {toml_path}")

# Load config and validate
config = MepConfig(toml_path=toml_path)


2023-07-10 14:32:13,656 - __main__ - INFO - Toml path - /home/vishu/repos/hbmep/config.toml
2023-07-10 14:32:13,656 - hbmep.config - INFO - Verifying configuration ...
2023-07-10 14:32:13,657 - hbmep.config - INFO - Success!


#### Load data and preprocess

In [3]:
# Initialize dataset
data = MepDataset(config=config)

# Preprocess data
df, encoder_dict, _ = data.build()

2023-07-10 14:32:13,765 - hbmep.dataset.core - INFO - Initialized /home/vishu/repos/hbmep/reports/test_run_01 for storing artefacts
2023-07-10 14:32:13,766 - hbmep.dataset.core - INFO - Copied config to /home/vishu/repos/hbmep/reports/test_run_01
2023-07-10 14:32:13,766 - hbmep.dataset.core - INFO - Reading data from /home/vishu/data/mock.csv ...
2023-07-10 14:32:13,769 - hbmep.dataset.core - INFO - Processing data ...
2023-07-10 14:32:13,773 - hbmep.utils.utils - INFO - func:preprocess took: 0.00 sec
2023-07-10 14:32:13,773 - hbmep.utils.utils - INFO - func:build took: 0.01 sec


In [4]:
encoder_dict.keys()

dict_keys(['participant', 'compound_position'])

#### Visualize dataset

In [5]:
data.plot(df=df, encoder_dict=encoder_dict)


2023-07-10 14:32:14,839 - hbmep.dataset.core - INFO - Saved to /home/vishu/repos/hbmep/reports/test_run_01/dataset.pdf
2023-07-10 14:32:14,839 - hbmep.utils.utils - INFO - func:plot took: 0.84 sec


#### Initialize model

In [6]:
import numpy as np
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from hbmep.config import MepConfig
from hbmep.models import Baseline
from hbmep.models.utils import Site as site
from hbmep.utils.constants import RECTIFIED_LOGISTIC

class RectifiedLogistic(Baseline):
    def __init__(self, config: MepConfig):
        super(RectifiedLogistic, self).__init__(config=config)
        self.name = RECTIFIED_LOGISTIC

        # self.mu_a = config.PRIORS[site.mu_a]
        # self.sigma_a = config.PRIORS[site.sigma_a]

        # self.sigma_b = config.PRIORS[site.sigma_b]

        # self.sigma_L = config.PRIORS[site.sigma_L]
        # self.sigma_H = config.PRIORS[site.sigma_H]
        # self.sigma_v = config.PRIORS[site.sigma_v]

        # self.g_1 = config.PRIORS[site.g_1]
        # self.g_2 = config.PRIORS[site.g_2]

    def _model(self, subject, features, intensity, response_obs=None):
        intensity = intensity.reshape(-1, 1)
        intensity = np.tile(intensity, (1, self.n_response))

        feature0 = features[0].reshape(-1,)

        n_data = intensity.shape[0]
        n_subject = np.unique(subject).shape[0]
        n_feature0 = np.unique(feature0).shape[0]

        with numpyro.plate(site.n_response, self.n_response, dim=-1):
            with numpyro.plate(site.n_subject, n_subject, dim=-2):
                """ Hyper-priors """
                mu_a = numpyro.sample(
                    site.mu_a,
                    dist.TruncatedNormal(150, 20, low=0)
                )
                sigma_a = numpyro.sample(site.sigma_a, dist.HalfNormal(20))

                sigma_b = numpyro.sample(site.sigma_b, dist.HalfNormal(.5))

                sigma_L = numpyro.sample(site.sigma_L, dist.HalfNormal(0.05))
                sigma_H = numpyro.sample(site.sigma_H, dist.HalfNormal(50))
                sigma_v = numpyro.sample(site.sigma_v, dist.HalfNormal(5))

                with numpyro.plate("n_feature0", n_feature0, dim=-3):
                    """ Priors """
                    a = numpyro.sample(
                        site.a,
                        dist.TruncatedNormal(mu_a, sigma_a, low=0)
                    )
                    b = numpyro.sample(site.b, dist.HalfNormal(sigma_b))

                    L = numpyro.sample(site.L, dist.HalfNormal(sigma_L))
                    H = numpyro.sample(site.H, dist.HalfNormal(sigma_H))
                    v = numpyro.sample(site.v, dist.HalfNormal(sigma_v))

                    g_1 = numpyro.sample(
                        site.g_1, dist.HalfCauchy(20)
                    )
                    g_2 = numpyro.sample(
                        site.g_2, dist.HalfCauchy(20)
                    )
                    p = numpyro.sample("p", dist.HalfNormal(10))

        """ Model """
        mu = numpyro.deterministic(
            site.mu,
            L[feature0, subject] + \
            jnp.maximum(
                0,
                -1 + \
                (H[feature0, subject] + 1) / \
                jnp.power(
                    1 + \
                    (jnp.power(1 + H[feature0, subject], v[feature0, subject]) - 1) * \
                    jnp.exp(-b[feature0, subject] * (intensity - a[feature0, subject])),
                    1 / v[feature0, subject]
                )
            )
        )
        beta = numpyro.deterministic(
            site.beta,
            g_1[feature0, subject] + \
            g_2[feature0, subject] * jnp.power(1 / mu, p[feature0, subject])
        )

        with numpyro.plate(site.data, n_data):
            return numpyro.sample(
                site.obs,
                dist.Gamma(mu * beta, beta).to_event(1),
                obs=response_obs
            )


In [7]:
# model = Model(config=config)
model = RectifiedLogistic(config=config)


2023-07-10 14:32:15,170 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-07-10 14:32:15,170 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-07-10 14:32:15,171 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
2023-07-10 14:32:15,171 - jax._src.xla_bridge - INFO - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.


Prior predticve check: We can draw from the model to see if it correctly specifies our prior knowledge

In [9]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict)

2023-07-10 14:32:43,422 - hbmep.models.baseline - INFO - Saved to /home/vishu/repos/hbmep/reports/test_run_01/prior_predictive_check.pdf
2023-07-10 14:32:43,422 - hbmep.utils.utils - INFO - func:render_predictive_check took: 3.87 sec


#### Run MCMC inference

In [10]:
mcmc, posterior_samples = model.run_inference(df=df)

2023-07-10 14:32:43,596 - hbmep.models.baseline - INFO - Running inference with baseline ...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-07-10 14:37:17,270 - hbmep.utils.utils - INFO - func:run_inference took: 4 min and 33.68 sec


#### Diagnostics

In [11]:
mcmc.print_summary(prob=.95)


                mean       std    median      2.5%     97.5%     n_eff     r_hat
  H[0,0,0]      6.01      0.10      6.01      5.82      6.21  27135.74      1.00
  H[0,0,1]      3.78      0.06      3.78      3.66      3.89  35528.31      1.00
  H[0,1,0]     15.38     24.24      6.37      0.00     61.71  15822.76      1.00
  H[0,1,1]      2.51      7.24      0.85      0.00      9.28  13582.87      1.00
  H[1,0,0]      4.28      0.19      4.27      3.93      4.68  16206.16      1.00
  H[1,0,1]      1.49      0.04      1.49      1.42      1.56  27128.11      1.00
  H[1,1,0]     15.59     24.64      6.36      0.00     63.18  15866.57      1.00
  H[1,1,1]      2.44      6.42      0.85      0.00      9.22  12944.19      1.00
  H[2,0,0]     11.04     15.44      6.23      0.00     37.90  17663.81      1.00
  H[2,0,1]      7.37     11.79      3.72      0.00     26.94  16387.25      1.00
  H[2,1,0]      0.35      0.02      0.35      0.31      0.39  18217.75      1.00
  H[2,1,1]      0.56      0

#### Plot recruitment curves

In [33]:
import os
import logging
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import scipy.stats as stats
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import jax
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive

from hbmep.config import MepConfig
from hbmep.dataset import MepDataset
from hbmep.models.utils import Site as site
from hbmep.utils import (
    timing,
    floor,
    ceil,
    evaluate_posterior_mean,
    evaluate_hpdi_interval
)
from hbmep.utils.constants import (
    BASELINE,
    RECRUITMENT_CURVES,
    PRIOR_PREDICTIVE,
    POSTERIOR_PREDICTIVE
)

logger = logging.getLogger(__name__)


def render_recruitment_curves(
    self,
    df: pd.DataFrame,
    encoder_dict: dict,
    posterior_samples: dict,
    mat: Optional[np.ndarray] = None,
    time: Optional[np.ndarray] = None,
    auc_window: Optional[list[float]] = None
):
    if mat is not None:
        assert time is not None
        assert auc_window is not None

    """ Setup pdf layout """
    combinations = self._make_combinations(df=df, columns=self.columns)
    n_combinations = len(combinations)

    n_columns_per_response = 3
    if mat is not None: n_columns_per_response += 1

    n_fig_rows = 10
    n_fig_columns = n_columns_per_response * self.n_response

    n_pdf_pages = n_combinations // n_fig_rows
    if n_combinations % n_fig_rows: n_pdf_pages += 1

    """ Iterate over pdf pages """
    pdf = PdfPages(self.recruitment_curves_path)
    combination_counter = 0

    for page in range(n_pdf_pages):
        n_rows_current_page = min(
            n_fig_rows,
            n_combinations - page * n_fig_rows
        )

        fig, axes = plt.subplots(
            n_rows_current_page,
            n_fig_columns,
            figsize=(
                n_fig_columns * self.subplot_cell_width,
                n_rows_current_page * self.subplot_cell_height
            ),
            constrained_layout=True,
            squeeze=False
        )

        """ Iterate over combinations """
        for i in range(n_rows_current_page):
            combination = combinations[combination_counter]

            """ Filter dataframe """
            ind = df[self.columns].apply(tuple, axis=1).isin([combination])
            temp_df = df[ind].reset_index(drop=True).copy()

            """ Tickmarks """
            min_intensity = temp_df[self.intensity].min()
            min_intensity = floor(min_intensity, base=self.base)
            max_intensity = temp_df[self.intensity].max()
            max_intensity = ceil(max_intensity, base=self.base)

            n_points = min(2000, ceil((max_intensity - min_intensity) / 5, base=100))
            x_space = np.linspace(min_intensity, max_intensity, n_points)
            x_ticks = np.arange(min_intensity, max_intensity, self.base)

            """ Predictions """
            predictions = self._predict(
                intensity=x_space,
                combination=combination,
                posterior_samples=posterior_samples
            )
            mu = predictions[site.mu]
            mu_posterior_mean = evaluate_posterior_mean(mu)

            """ Threshold estimate """
            threshold, threshold_posterior, hpdi_interval = \
                self._estimate_threshold(combination, posterior_samples)

            """ Iterate over responses """
            for (r, response) in enumerate(self.response):
                j = n_columns_per_response * r

                """ EEG Data """
                if mat is not None:
                    ax = axes[i, j]
                    temp_mat = mat[ind, :, r]

                    for k in range(temp_mat.shape[0]):
                        x = temp_mat[k, :]/60 + temp_df[self.intensity].values[k]
                        ax.plot(x, time, color="green", alpha=.4)

                    ax.axhline(
                        y=auc_window[0],
                        color="red",
                        linestyle='--',
                        alpha=.4,
                        label=f"AUC Window {auc_window}"
                    )
                    ax.axhline(
                        y=auc_window[1],
                        color="red",
                        linestyle='--',
                        alpha=.4
                    )

                    ax.set_xticks(ticks=x_ticks)
                    ax.tick_params(axis="x", rotation=90)
                    ax.set_xlim(left=min_intensity, right=max_intensity)
                    ax.set_ylim(bottom=-0.001, top=auc_window[1] + .005)

                    ax.set_xlabel(f"{self.intensity}")
                    ax.set_ylabel(f"Time")
                    ax.legend(loc="upper right")
                    ax.set_title(f"Motor Evoked Potential")

                    j += 1

                """ Plots """
                sns.scatterplot(
                    data=temp_df,
                    x=self.intensity,
                    y=response,
                    ax=axes[i, j]
                )
                sns.scatterplot(
                    data=temp_df,
                    x=self.intensity,
                    y=response,
                    alpha=.4,
                    ax=axes[i, j + 1]
                )

                """ Threshold KDE """
                sns.kdeplot(
                    x=threshold_posterior[:, r],
                    color="b",
                    ax=axes[i, j + 1],
                    alpha=.4
                )
                sns.kdeplot(
                    x=threshold_posterior[:, r],
                    color="b",
                    ax=axes[i, j + 2]
                )

                """ Plots: Recruitment curve """
                sns.lineplot(
                    x=x_space,
                    y=mu_posterior_mean[:, r],
                    label="Mean Recruitment Curve",
                    color="r",
                    alpha=0.4,
                    ax=axes[i, j + 1]
                )

                """ Plots: Threshold estimate """
                axes[i, j + 2].axvline(
                    threshold[r],
                    linestyle="--",
                    color="r",
                    label=f"Mean Posterior"
                )
                axes[i, j + 2].axvline(
                    hpdi_interval[:, r][0],
                    linestyle="--",
                    color="g",
                    label="95% HPDI"
                )
                axes[i, j + 2].axvline(
                    hpdi_interval[:, r][1],
                    linestyle="--",
                    color="g"
                )

                """ Labels """
                title = f"{response} - {tuple(self.columns)}\nencoded: {combination}"
                combination_inverse = self._invert_combination(
                    combination=combination,
                    columns=self.columns,
                    encoder_dict=encoder_dict
                )
                title += f"\ndecoded: {tuple(combination_inverse)}"
                axes[i, j].set_title(title)
                axes[i, j + 1].set_title("Model Fit")

                skew = stats.skew(a=threshold_posterior[:, r])
                kurt = stats.kurtosis(a=threshold_posterior[:, r])

                title = f"TH: {threshold[r]:.2f}"
                title += f", CI: ({hpdi_interval[:, r][0]:.1f}, {hpdi_interval[:, r][1]:.1f})"
                title += f", LEN: {hpdi_interval[:, r][1] - hpdi_interval[:, r][0]:.1f}"
                title += r', $\overline{\mu_3}$'
                title += f": {skew:.1f}"
                title += f", K: {kurt:.1f}"
                axes[i, j + 2].set_title(title)

                """ Ticks """
                for k in [j, j + 1]:
                    ax = axes[i, k]
                    ax.set_xticks(ticks=x_ticks)
                    ax.tick_params(axis="x", rotation=90)
                    ax.set_xlim(left=min_intensity, right=max_intensity)

                """ Legends """
                for k in [j + 1, j + 2]:
                    ax = axes[i, k]
                    ax.legend(loc="upper left")

            combination_counter += 1

        pdf.savefig(fig)
        plt.close()

    pdf.close()
    plt.show()

    logger.info(f"Saved to {self.recruitment_curves_path}")
    return


In [34]:
render_recruitment_curves(self=model, df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)

2023-07-10 14:51:00,164 - __main__ - INFO - Saved to /home/vishu/repos/hbmep/reports/test_run_01/recruitment_curves.pdf


#### Posterior Predictive Check

We can now supply the posterior samples to `render_predictive_check` method to inspect how well our model is able to explain the data

In [31]:
model.render_predictive_check(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples)

2023-07-10 14:49:06,459 - hbmep.models.baseline - INFO - Saved to /home/vishu/repos/hbmep/reports/test_run_01/posterior_predictive_check.pdf
2023-07-10 14:49:06,462 - hbmep.utils.utils - INFO - func:render_predictive_check took: 12.61 sec
