# CVAE diagnostics: calibration, dependence curves, and posterior predictive checks

This notebook demonstrates how to use the diagnostics utilities in
`multioutcome_cvae` for a fitted CVAE model:

- Global and per-outcome **calibration** plots (Bernoulli example)
- Scalar calibration summaries: **ECE** and **MCE**
- SHAP-style **dependence curves** for a single feature and outcome
- **Posterior predictive checks** (PPCs) for Gaussian and Poisson outcome families

We will:

1. Simulate data and fit a Bernoulli CVAE
2. Evaluate calibration and dependence
3. Simulate Gaussian and Poisson outcomes and run posterior predictive checks


In [0]:
#!pip uninstall -y multioutcome-cvae multioutcome_cvae

# #--- 2. Remove cached wheels / build artifacts ---
# import shutil, os, glob

# # # pip cache
# shutil.rmtree(os.path.expanduser("~/.cache/pip"), ignore_errors=True)

# # Python egg & wheel caches
# for path in [
#     "/databricks/driver/multioutcome_cvae.egg-info",
#     "/databricks/driver/multioutcome-cvae.egg-info",
# ]:
#     shutil.rmtree(path, ignore_errors=True)

# # Remove old installs in site-packages
# import site
# for sp in site.getsitepackages():
#     for pkg in ["multioutcome_cvae", "multioutcome-cvae"]:
#         target = os.path.join(sp, pkg)
#         shutil.rmtree(target, ignore_errors=True)

# %pip uninstall -y multioutcome-cvae multioutcome_cvae

#%pip install git+https://github.com/jarrod-dalton/multioutcome-cvae.git
#%restart_python

In [0]:
import numpy as np
import matplotlib.pyplot as plt

from multioutcome_cvae import (
    CVAETrainer,
    simulate_cvae_data,
    # diagnostics:
    plot_global_calibration,
    plot_per_outcome_calibration_grid,
    expected_calibration_error,
    maximum_calibration_error,
    plot_dependence_curve,
    posterior_predictive_check_gaussian,
    posterior_predictive_check_poisson,
    conditional_ppc_by_feature_decile,
)


[0;31m---------------------------------------------------------------------------[0m
[0;31mModuleNotFoundError[0m                       Traceback (most recent call last)
File [0;32m<command-6971452809746442>, line 4[0m
[1;32m      1[0m [38;5;28;01mimport[39;00m [38;5;21;01mnumpy[39;00m [38;5;28;01mas[39;00m [38;5;21;01mnp[39;00m
[1;32m      2[0m [38;5;28;01mimport[39;00m [38;5;21;01mmatplotlib[39;00m[38;5;21;01m.[39;00m[38;5;21;01mpyplot[39;00m [38;5;28;01mas[39;00m [38;5;21;01mplt[39;00m
[0;32m----> 4[0m [38;5;28;01mfrom[39;00m [38;5;21;01mmultioutcome_cvae[39;00m [38;5;28;01mimport[39;00m (
[1;32m      5[0m     CVAETrainer,
[1;32m      6[0m     simulate_cvae_data,
[1;32m      7[0m     [38;5;66;03m# diagnostics:[39;00m
[1;32m      8[0m     plot_global_calibration,
[1;32m      9[0m     plot_per_outcome_calibration_grid,
[1;32m     10[0m     expected_calibration_error,
[1;32m     11[0m     maximum_calibration_error,
[1;32m     12[0

## Simulate Data and Fit Bernoulli CVAE

In [0]:
# 1. Simulate Bernoulli outcomes
X_b, Y_b, _ = simulate_cvae_data(
    n_samples=5000,
    n_features=5,
    n_outcomes=8,
    latent_dim=2,
    outcome_type="bernoulli",
    seed=1234,
)

# 2. Fit a Bernoulli CVAE
trainer_b = CVAETrainer(
    x_dim=X_b.shape[1],
    y_dim=Y_b.shape[1],
    latent_dim=4,
    outcome_type="bernoulli",
    hidden_dim=32,
    n_hidden_layers=2,
)

history_b = trainer_b.fit(
    X_train=X_b,
    Y_train=Y_b,
    num_epochs=20,
    verbose=False,
    seed=123,
)

# 3. Predicted probabilities (E[Y | X])
p_b = trainer_b.predict_mean(X_b, n_mc=30)


[0;31m---------------------------------------------------------------------------[0m
[0;31mNameError[0m                                 Traceback (most recent call last)
File [0;32m<command-6971452809746445>, line 2[0m
[1;32m      1[0m [38;5;66;03m# 1. Simulate Bernoulli outcomes[39;00m
[0;32m----> 2[0m X_b, Y_b, _ [38;5;241m=[39m simulate_cvae_data(
[1;32m      3[0m     n_samples[38;5;241m=[39m[38;5;241m5000[39m,
[1;32m      4[0m     n_features[38;5;241m=[39m[38;5;241m5[39m,
[1;32m      5[0m     n_outcomes[38;5;241m=[39m[38;5;241m8[39m,
[1;32m      6[0m     latent_dim[38;5;241m=[39m[38;5;241m2[39m,
[1;32m      7[0m     outcome_type[38;5;241m=[39m[38;5;124m"[39m[38;5;124mbernoulli[39m[38;5;124m"[39m,
[1;32m      8[0m     seed[38;5;241m=[39m[38;5;241m1234[39m,
[1;32m      9[0m )
[1;32m     11[0m [38;5;66;03m# 2. Fit a Bernoulli CVAE[39;00m
[1;32m     12[0m trainer_b [38;5;241m=[39m CVAETrainer(
[1;32m     13[0m     x_dim

## Bernoulli calibration

We use the diagnostics helpers to examine:

- Global calibration (all outcomes flattened into a single vector)
- Per-outcome calibration (one subplot per `Y` dimension)
- Scalar metrics:
  - ECE (Expected Calibration Error)
  - MCE (Maximum Calibration Error)

For Bernoulli outcomes, `Y` is 0/1 and the model outputs probabilities
via `predict_mean(..., outcome_type="bernoulli")`.


In [0]:
# Flatten all outcomes for global calibration
y_flat = Y_b.ravel()
p_flat = p_b.ravel()

# Global calibration plot
plt.figure(figsize=(6, 6))
plot_global_calibration(
    y_true=y_flat,
    y_pred=p_flat,
    outcome_type="bernoulli",
    n_bins=10,
    alpha=0.05,
)
plt.show()

# Per-outcome calibration grid
fig = plot_per_outcome_calibration_grid(
    Y_true=Y_b,
    Y_pred=p_b,
    outcome_type="bernoulli",
    n_bins=10,
    alpha=0.05,
    max_cols=3,
)
plt.show()

# Scalar summaries: ECE & MCE
ece_b = expected_calibration_error(
    y_true=y_flat,
    y_pred=p_flat,
    n_bins=10,
    outcome_type="bernoulli",
)
mce_b = maximum_calibration_error(
    y_true=y_flat,
    y_pred=p_flat,
    n_bins=10,
    outcome_type="bernoulli",
)

print(f"Bernoulli ECE: {ece_b:.4f}")
print(f"Bernoulli MCE: {mce_b:.4f}")


[0;31m---------------------------------------------------------------------------[0m
[0;31mNameError[0m                                 Traceback (most recent call last)
File [0;32m<command-6971452809746447>, line 3[0m
[1;32m      1[0m [38;5;66;03m# Flatten all outcomes for global calibration[39;00m
[1;32m      2[0m y_flat [38;5;241m=[39m Y_b[38;5;241m.[39mravel()
[0;32m----> 3[0m p_flat [38;5;241m=[39m p_b[38;5;241m.[39mravel()
[1;32m      5[0m [38;5;66;03m# Global calibration plot[39;00m
[1;32m      6[0m plt[38;5;241m.[39mfigure(figsize[38;5;241m=[39m([38;5;241m6[39m, [38;5;241m6[39m))

[0;31mNameError[0m: name 'p_b' is not defined

## SHAP-style dependence curves

To visualize how a single feature in `X` influences one component of the
multivariate outcome `Y`, we use a simple partial-dependence style curve:

- Choose a feature index `k` in `X`
- Choose an outcome index `j` in `Y`
- Sweep `X[:, k]` over a grid of values (within a quantile range)
- For each grid value, replace that column in `X`, compute `E[Y | X]`,
  and average the chosen outcome dimension across observations

This produces a 1D curve showing how `E[Y_j | X]` changes as `X_k` varies.


In [0]:
# Pick a feature and outcome index
feature_index = 0   # which column of X_b to vary
outcome_index = 0   # which column of Y_b to track

ax = plot_dependence_curve(
    trainer=trainer_b,
    X=X_b,
    feature_index=feature_index,
    outcome_index=outcome_index,
    n_grid=50,
    n_mc=30,
    quantile_range=(0.05, 0.95),
    feature_name=f"X[{feature_index}]",
    outcome_name=f"Y[{outcome_index}]",
)
plt.show()


com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$5(SequenceExecutionState.scala:132)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:132)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:190)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.can

## Posterior predictive checks (Gaussian)

We now demonstrate posterior predictive checks (PPCs) for Gaussian outcomes.

1. Simulate data via `simulate_cvae_data(...)`
2. Fit a CVAE with the appropriate outcome type
3. Use:
   - `posterior_predictive_check_gaussian(...)` 

to simulate replicated datasets and compare summary statistics (mean and
variance) to those of the observed data. Histograms show where the observed
summary falls within the posterior predictive distribution.


In [0]:
# 1. Simulate Gaussian data
X_g, Y_g, _ = simulate_cvae_data(
    n_samples=20000,
    n_features=5,
    n_outcomes=6,
    latent_dim=2,
    outcome_type="gaussian",
    seed=5432,
    noise_sd=1.0,
)

# 2. Fit Gaussian CVAE
trainer_g = CVAETrainer(
    x_dim=X_g.shape[1],
    y_dim=Y_g.shape[1],
    latent_dim=6,
    outcome_type="gaussian",
    hidden_dim=32,
    n_hidden_layers=2,
)

history_g = trainer_g.fit(
    X_train=X_g,
    Y_train=Y_g,
    num_epochs=20,
    verbose=False,
    seed=234,
)

# 3. Run Gaussian PPC (will plot histograms)
ppc_g = posterior_predictive_check_gaussian(
    trainer=trainer_g,
    X=X_g,
    Y=Y_g,
    n_rep=100,
    n_mc_params=20,
    plot=True,
)


com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:434)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:465)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:750)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperation$1(UsageLogging.scala:510)
	at com.databricks.logging.UsageLogging.executeThunkAndCaptureResultTags$1(UsageLogging.scala:616)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperationWithResultTags$4(UsageLogging.scala:643)
	at com.databricks.logging.AttributionContextTracing.$anonfun$withAttributionContext$1(AttributionContextTracing.scala:80)
	at com.databricks.logging.AttributionContext$.$anonfun$withValue$1(AttributionContext.scala:348)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:59)
	at com.databricks.logging.AttributionContext$.withValue(Attr

### Count outcomes and the Poisson CVAE

Many multivariate outcomes in health-services and population research are **counts**:

- number of hospitalizations,
- number of ED visits,
- number of prescriptions filled,
- number of complications, etc.

A standard starting point is the **Poisson model**. For a single count outcome \(Y\) with covariates \(X\),

- we model the **conditional mean** via a log link  
  log( E[Y | X] ) = η(X),
- and the Poisson assumption implies  
  Var(Y | X) = E[Y | X].

In practice, two things often hold simultaneously:

1. We care a lot about **E[Y | X]** (e.g., for prediction or counterfactuals).
2. The Poisson assumption for **Var(Y | X)** is usually **too optimistic** (overdispersion is common).

In this vignette, the **primary goal of the CVAE** is to learn a *flexible conditional mean* E[Y | X] and a realistic **joint dependence structure** across multiple count outcomes. The Poisson outcome family in `multioutcome_cvae` does exactly this:

- It models the *log-rate* λ_ij(X_i, Z_i) for each outcome j with a shared latent variable Z that couples the outcomes.
- It uses Monte Carlo over Z to produce **predictive means** and **predictive variances**:
  - E[Y_ij | X_i] ≈ E_Z[ λ_ij(X_i, Z) ]
  - Var(Y_ij | X_i] ≈ E_Z[ λ_ij ] + Var_Z[ λ_ij ].

Below, we:

1. Simulate multivariate Poisson data with a known latent structure.
2. Fit a Poisson CVAE to this data.
3. Use **posterior predictive checks (PPCs)** to compare:
   - the distribution of **global means** and **variances** under the fitted model,
   - versus the same summaries computed from the observed data.

The goal is not to get a “perfect” model (this is a small, synthetic example), but to demonstrate that:

- the Poisson CVAE recovers a reasonable conditional mean surface E[Y | X],
- and that posterior predictive checks give a transparent way to assess goodness-of-fit for counts.


In [0]:
# Poisson example: simulate data, fit CVAE, run posterior predictive checks

import numpy as np
import matplotlib.pyplot as plt

from multioutcome_cvae import (
    CVAETrainer,
    simulate_cvae_data,
    posterior_predictive_check_poisson,
)

# ------------------------------------------------------------
# 1. Simulate multivariate Poisson data
# ------------------------------------------------------------
# We use a latent structure so that outcomes are correlated, and
# a base_rate / rate_scale that lead to moderate counts (not all 0s).

X_p, Y_p, params_p = simulate_cvae_data(
    n_samples=20000,
    n_features=5,
    n_outcomes=6,
    latent_dim=2,
    outcome_type="poisson",
    seed=7203,
    base_rate=1.0,
    rate_scale=0.6,
)

mean_obs = Y_p.mean()
var_obs = Y_p.var()
print(f"Observed Poisson data: mean={mean_obs:.3f}, variance={var_obs:.3f}, var/mean={var_obs/mean_obs:.2f}")

# Per-outcome means/variances (helpful for diagnostics)
mean_obs_per_outcome = Y_p.mean(axis=0)
var_obs_per_outcome = Y_p.var(axis=0)
print("Per-outcome observed means:   ", np.round(mean_obs_per_outcome, 3))
print("Per-outcome observed variances:", np.round(var_obs_per_outcome, 3))

# ------------------------------------------------------------
# 2. Fit Poisson CVAE
# ------------------------------------------------------------
trainer_p = CVAETrainer(
    x_dim=X_p.shape[1],
    y_dim=Y_p.shape[1],
    latent_dim=6,
    outcome_type="poisson",
    hidden_dim=64,
    n_hidden_layers=2,
)

history_p = trainer_p.fit(
    X_train=X_p,
    Y_train=Y_p,
    num_epochs=30,      # modest training for vignette speed
    lr=1e-3,
    verbose=False,
    seed=345,
)

print(f"Final training loss (Poisson CVAE): {history_p['train_loss'][-1]:.3f}")

# ------------------------------------------------------------
# 3. Predictive means and simple scatter diagnostics
# ------------------------------------------------------------
params_pred = trainer_p.predict_params(X_p, n_mc=30)
lambda_pred = params_pred["rate"]      # shape (n_samples, n_outcomes)
var_y_pred = params_pred["var_y"]      # approximated Var(Y | X) under the mixture

mean_pred = lambda_pred.mean()
var_pred = var_y_pred.mean()
print(f"Average predicted mean:   {mean_pred:.3f}")
print(f"Average predicted Var(Y): {var_pred:.3f}")

# Per-outcome predicted means/variances
mean_pred_per_outcome = lambda_pred.mean(axis=0)
var_pred_per_outcome = var_y_pred.mean(axis=0)
print("Per-outcome predicted means:   ", np.round(mean_pred_per_outcome, 3))
print("Per-outcome predicted Var(Y):  ", np.round(var_pred_per_outcome, 3))

# Scatter plot: observed vs predicted per-outcome means
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

ax0 = axes[0]
ax0.scatter(mean_obs_per_outcome, mean_pred_per_outcome)
lim_min = min(mean_obs_per_outcome.min(), mean_pred_per_outcome.min())
lim_max = max(mean_obs_per_outcome.max(), mean_pred_per_outcome.max())
ax0.plot([lim_min, lim_max], [lim_min, lim_max], linestyle="--")
ax0.set_xlabel("Observed mean per outcome")
ax0.set_ylabel("Predicted mean per outcome")
ax0.set_title("Poisson CVAE: E[Y | X]\nobserved vs predicted")

# Scatter plot: observed vs predicted per-outcome variances
ax1 = axes[1]
ax1.scatter(var_obs_per_outcome, var_pred_per_outcome)
lim_min_v = min(var_obs_per_outcome.min(), var_pred_per_outcome.min())
lim_max_v = max(var_obs_per_outcome.max(), var_pred_per_outcome.max())
ax1.plot([lim_min_v, lim_max_v], [lim_min_v, lim_max_v], linestyle="--")
ax1.set_xlabel("Observed Var(Y) per outcome")
ax1.set_ylabel("Predicted Var(Y) per outcome")
ax1.set_title("Poisson CVAE: Var[Y | X]\nobserved vs predicted")

fig.tight_layout()

# ------------------------------------------------------------
# 4. Posterior predictive checks (global mean & variance)
# ------------------------------------------------------------
# We reuse the helper from utils_diagnostics. It:
#   - repeatedly draws 'replicated' datasets from the fitted model,
#   - computes summary statistics (mean, variance) for each replicate,
#   - and returns them (and optionally plots histograms).

ppc_p = posterior_predictive_check_poisson(
    trainer=trainer_p,
    X=X_p,
    Y=Y_p,
    n_rep=100,
    n_mc_params=20,
    plot=True,   # this will create histograms for mean and variance with the observed value overlaid
)

print("PPC (Poisson) summaries:")
print("  mean_obs:   ", ppc_p["mean_obs"])
print("  mean_rep μ: ", np.mean(ppc_p["mean_rep"]))
print("  mean_rep σ: ", np.std(ppc_p["mean_rep"]))
print("  var_obs:    ", ppc_p["var_obs"])
print("  var_rep μ:  ", np.mean(ppc_p["var_rep"]))
print("  var_rep σ:  ", np.std(ppc_p["var_rep"]))


com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:434)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:465)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:750)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperation$1(UsageLogging.scala:510)
	at com.databricks.logging.UsageLogging.executeThunkAndCaptureResultTags$1(UsageLogging.scala:616)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperationWithResultTags$4(UsageLogging.scala:643)
	at com.databricks.logging.AttributionContextTracing.$anonfun$withAttributionContext$1(AttributionContextTracing.scala:80)
	at com.databricks.logging.AttributionContext$.$anonfun$withValue$1(AttributionContext.scala:348)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:59)
	at com.databricks.logging.AttributionContext$.withValue(Attr

### A note on Negative Binomial counts and overdispersion

In many real datasets, counts are **overdispersed** relative to the Poisson assumption:

- The Poisson model implies Var(Y | X) = E[Y | X].
- Empirically, we often see Var(Y | X) considerably larger than E[Y | X], especially in utilization data (hospitalizations, ED visits, etc.).

A common fix in generalized linear models is the **Negative Binomial (NB)** family, which adds a dispersion parameter. One convenient parameterization is:

- E[Y | X] = μ(X)
- Var(Y | X) = μ(X) + μ(X)^2 / r

where **r > 0** is an overdispersion (or “size”) parameter. When r is large, Var(Y | X) ≈ μ(X), and the model behaves like a Poisson; when r is small, the variance can be much larger than the mean.

In a simple GLM, there are two big advantages:

1. The model is relatively low-dimensional: you estimate regression coefficients for μ(X), plus a small number of dispersion parameters.
2. The log-likelihood is reasonably well-behaved, and standard optimizers can usually find a good solution.

#### Why NB is challenging in a CVAE

In a CVAE, the situation is very different:

- The **mean μ(X, Z)** is represented by a deep neural network with latent variables Z.
- If we also let the **dispersion vary with X and Z**, the model effectively tries to learn *two* flexible surfaces: one for μ and one for the overdispersion.
- The KL term and the reconstruction term interact in a complicated way, especially when counts are low or highly variable.

In practice, when we tried to:

- give the decoder separate NB parameters for mean and dispersion, and
- train with a standard VAE objective,

we saw **poor and unstable behavior**:

- The model sometimes matched the **conditional mean** E[Y | X] reasonably well.
- But it tended to produce **biased and noisy estimates of the conditional variance**, even when the data were generated from a known NB process.
- Posterior predictive checks for variance (global or conditional on covariates) frequently showed clear lack of fit.

We experimented with several variants:

- different parameterizations of the dispersion (per-outcome, per-observation, global),
- different regularization schemes (penalties on raw dispersion parameters),
- and different training schedules (learning-rate tweaks, beta-KL schedules),

and still found that NB in the CVAE setting was **fragile**, especially compared to the much more stable Bernoulli, Gaussian, and Poisson families.

#### Current design choice

Given those observations, the current package takes a conservative stance:

- We **support Poisson** for count outcomes, focusing on getting E[Y | X] and cross-outcome dependence right.
- We **do not expose a Negative Binomial outcome family** in the public API at this time.
- For applications where detailed modeling of overdispersion is crucial, we recommend:
  - using a **Negative Binomial GLM** (or similar) as a baseline,
  - and treating the CVAE primarily as a flexible model for the **conditional mean** and joint structure, rather than a fully calibrated generative model for all aspects of the count distribution.

Future work could revisit NB or other overdispersed count models in the CVAE, possibly with:

- more structured parameterizations of dispersion,
- alternative objectives (e.g., score-based losses or proper scoring rules),
- or hierarchical priors to stabilize the overdispersion parameters.

For now, this vignette focuses on Poisson counts, where the CVAE behaves more predictably and the diagnostics are easier to interpret.
