In [1]:
import os
import pickle
import logging

In [2]:
import numpy as np
import pandas as pd
from joblib import Parallel, delayed

In [3]:
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
from hbmep.config import Config
from hbmep.model.utils import Site as site

In [5]:
from hbmep_paper.utils import setup_logging, run_svi
from models import (
    # MixtureModel,
    RectifiedLogistic,
    Logistic5,
    Logistic4,
    ReLU
)

In [6]:
logger = logging.getLogger(__name__)
LEVEL = logging.INFO

In [7]:
TOML_PATH = "/home/vishu/repos/hbmep-paper/configs/rats/J_RCML_000.toml"
DATA_PATH = "/home/vishu/data/hbmep-processed/J_RCML_000/data.csv"
FEATURES = [["participant", "compound_position"]]
# FEATURES = ["participant", "compound_position"]
RESPONSE = ["LBiceps", "LECR"]
BUILD_DIR = "/home/vishu/repos/hbmep-paper/reports/tms/fn-comparison/testing"


In [8]:
# Run single model
# Model = ReLU
# Model = Logistic4
Model = Logistic5
Model = RectifiedLogistic

# # Run multiple models in parallel
# n_jobs = -1
# models = [RectifiedLogistic, Logistic5, Logistic4, ReLU]
# with Parallel(n_jobs=n_jobs) as parallel:
#     parallel(delayed(main)(Model) for Model in models)

# Build model
config = Config(toml_path=TOML_PATH)
config.FEATURES = FEATURES
config.RESPONSE = RESPONSE
config.BUILD_DIR = os.path.join(BUILD_DIR, Model.NAME)
config.MCMC_PARAMS["num_warmup"] = 5000
config.MCMC_PARAMS["num_samples"] = 1000
model = Model(config=config)

# Setup logging
model._make_dir(config.BUILD_DIR)
setup_logging(
    dir=model.build_dir,
    fname=os.path.basename("loo-debug"),
    level=LEVEL
)

# Run inference
# run_inference(model)
# return

2024-02-19 11:22:03,326 - hbmep_paper.utils.utils - INFO - Logging to /home/vishu/repos/hbmep-paper/reports/tms/fn-comparison/testing/rectified_logistic/loo-debug.log


In [10]:
# Load data
df = pd.read_csv(DATA_PATH)
df, encoder_dict = model.load(df=df)
ind = df[model.features[0]].isin([0, 1])
df = df[ind].reset_index(drop=True).copy()

# Run inference
mcmc, posterior_samples = model.run_inference(df=df)
# svi_results, posterior_samples = run_svi(df=df, model=model, **svi_kwargs)
# svi_result, posterior_samples = run_svi(df=df, model=model, steps=20000, lr=1e-2)
# if model.NAME == "rectified_logistic":
#     logger.info(f"ell: {posterior_samples[site.ell].mean(axis=0)}")

# losses = np.array(svi_result.losses)
# plt.plot(losses)
# dest = os.path.join(model.build_dir, "losses.png")
# plt.savefig(dest)
# logger.info(f"Saved to {dest}")

# Predict and render plots
prediction_df = model.make_prediction_dataset(df=df)
posterior_predictive = model.predict(df=prediction_df, posterior_samples=posterior_samples)
model.render_recruitment_curves(df=df, encoder_dict=encoder_dict, posterior_samples=posterior_samples, prediction_df=prediction_df, posterior_predictive=posterior_predictive)
model.render_predictive_check(df=df, encoder_dict=encoder_dict, prediction_df=prediction_df, posterior_predictive=posterior_predictive)



2024-02-19 11:22:08,789 - hbmep.dataset.core - INFO - Artefacts will be stored here - /home/vishu/repos/hbmep-paper/reports/tms/fn-comparison/testing/rectified_logistic
2024-02-19 11:22:08,800 - hbmep.dataset.core - INFO - Processing data ...
2024-02-19 11:22:08,802 - hbmep.utils.utils - INFO - func:load took: 0.01 sec
2024-02-19 11:22:08,803 - hbmep.model.baseline - INFO - Running inference with rectified_logistic ...


  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

2024-02-19 11:23:01,135 - hbmep.utils.utils - INFO - func:run_inference took: 52.33 sec
2024-02-19 11:23:01,139 - hbmep.utils.utils - INFO - func:make_prediction_dataset took: 0.00 sec
2024-02-19 11:23:02,185 - hbmep.utils.utils - INFO - func:predict took: 1.04 sec
2024-02-19 11:23:02,185 - hbmep.plotter.core - INFO - Rendering recruitment curves ...
2024-02-19 11:23:02,933 - hbmep.plotter.core - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/tms/fn-comparison/testing/rectified_logistic/recruitment_curves.pdf
2024-02-19 11:23:02,933 - hbmep.utils.utils - INFO - func:render_recruitment_curves took: 0.75 sec
2024-02-19 11:23:02,934 - hbmep.plotter.core - INFO - Rendering posterior predictive checks ...
2024-02-19 11:23:03,864 - hbmep.plotter.core - INFO - Saved to /home/vishu/repos/hbmep-paper/reports/tms/fn-comparison/testing/rectified_logistic/posterior_predictive_check.pdf
2024-02-19 11:23:03,864 - hbmep.utils.utils - INFO - func:predictive_checks_renderer took: 0.93 sec
2024-0

In [11]:
from numpyro.infer.util import log_density
intensity, features = model._get_regressors(df)
response, = model._get_response(df)
intensity = intensity[:20, ...]
features = features[:20, ...]
response = response[:20, ...]
logger.info(f"intensity: {intensity.shape}")
logger.info(f"features: {features.shape}")
logger.info(f"response: {response.shape}")


2024-02-19 09:50:20,694 - __main__ - INFO - intensity: (20, 1)
2024-02-19 09:50:20,695 - __main__ - INFO - features: (20, 1)
2024-02-19 09:50:20,695 - __main__ - INFO - response: (20, 2)


In [64]:
from numpyro import handlers
from jax import vmap
from jax import random
from jax.scipy.special import logsumexp
import jax.numpy as jnp


def log_likelihood(rng_key, params, model, *args, **kwargs):
    model = handlers.condition(model, params)
    # model = handlers.substitute(model, params)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    obs_node = model_trace["obs"]
    return obs_node["fn"].log_prob(obs_node["value"])


def log_pred_density(rng_key, params, model, *args, **kwargs):
    n = list(params.values())[0].shape[0]
    log_lk_fn = vmap(
        lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs)
    )
    log_lk_vals = log_lk_fn(random.split(rng_key, n), params)
    return (logsumexp(log_lk_vals, 0) - jnp.log(n)).sum()

_posterior_samples = posterior_samples.copy()
_posterior_samples[site.alpha] = 0 * _posterior_samples[site.alpha]
_posterior_samples[site.mu] = 0 * _posterior_samples[site.mu]
_posterior_samples[site.beta] = 0 * _posterior_samples[site.beta]
log_pred_density(
    random.PRNGKey(2),
    _posterior_samples,
    model._model,
    *model._get_regressors(df),
    *model._get_response(df),
)


Array(299.8246208, dtype=float64)

In [35]:
m = handlers.condition(model._model, {u: v[0, ...] for u, v in posterior_samples.items()})


In [38]:
trace = handlers.trace(m).get_trace(*model._get_regressors(df), *model._get_response(df))
type(trace)

collections.OrderedDict

In [51]:
trace[site.a]["value"] == posterior_samples[site.b][0, ...]

array([[False, False],
       [False, False]])

In [19]:
# ind = df[model.features[0]].isin([1])

log_pred_density(
    random.PRNGKey(2),
    posterior_samples,
    model._model,
    *model._get_regressors(df),
    *model._get_response(df)
)



Array(299.8246208, dtype=float64)

In [15]:
ind = df[model.intensity] < 62
df = df[ind].reset_index(drop=True).copy()

In [17]:
df

Unnamed: 0,pulse_amplitude,pulse_train_frequency,pulse_period,pulse_duration,pulse_count,train_delay,channel1_1,channel1_2,channel1_3,channel1_4,...,channel1_laterality,channel1_segment,channel2_laterality,channel2_segment,compound_position,compound_charge_params,participant,subdir_pattern,charge_param_error,participant___compound_position
0,50,0.5,1,0.4,1,2,0,0,0,0,...,,,M,C5,-C5M,50-0-50-100,amap01,*J_RCML_000*,,1
1,50,0.5,1,0.4,1,2,0,0,0,0,...,,,L,C5,-C5L,50-0-50-100,amap01,*J_RCML_000*,,0
2,56,0.5,1,0.4,1,2,0,0,0,0,...,,,L,C5,-C5L,50-0-50-100,amap01,*J_RCML_000*,,0
3,56,0.5,1,0.4,1,2,0,0,0,0,...,,,M,C5,-C5M,50-0-50-100,amap01,*J_RCML_000*,,1


In [None]:
from sklearn.model_selection import KFold


In [65]:
k = 5

(
    df
    .groupby(by=model.features)
    .apply(lambda x: x.sample(frac=1, random_state=0))
    .reset_index(drop=True)
    .groupby(by=model.features)
    .apply(lambda x: (x.shape[0] // k))
)


participant___compound_position
0    10
1    10
dtype: int64

In [16]:
log_pred_density(
    random.PRNGKey(2),
    posterior_samples,
    model._model,
    *model._get_regressors(df),
    *model._get_response(df),
)


Array(299.8246208, dtype=float64)

In [13]:
score = az.loo(mcmc, pointwise=True)
score




Computed from 4000 posterior samples and 204 observations log-likelihood matrix.

         Estimate       SE
elpd_loo   277.21    23.02
p_loo       22.62        -

------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.5]   (good)      197   96.6%
 (0.5, 0.7]   (ok)          3    1.5%
   (0.7, 1]   (bad)         4    2.0%
   (1, Inf)   (very bad)    0    0.0%

In [20]:
score = az.waic(mcmc, pointwise=True)
score


See http://arxiv.org/abs/1507.04544 for details


Computed from 4000 posterior samples and 204 observations log-likelihood matrix.

          Estimate       SE
elpd_waic   278.12    22.95
p_waic       21.71        -


In [15]:
numpyro_data = az.from_numpyro(mcmc)
type(numpyro_data)


arviz.data.inference_data.InferenceData

In [None]:
# Model evaluation
numpyro_data = az.from_numpyro(mcmc)
logger.info("Evaluating model ...")
score = az.loo(numpyro_data)
logger.info(f"ELPD LOO (Log): {score.elpd_loo:.2f}")
score = az.waic(numpyro_data)
logger.info(f"ELPD WAIC (Log): {score.elpd_waic:.2f}")

# # Save posterior
# dest = os.path.join(model.build_dir, "inference.pkl")
# with open(dest, "wb") as f:
#     pickle.dump((model, mcmc, posterior_samples), f)
# logger.info(dest)
