# Calibration Loss for ANCOVA NPE Training

This notebook compares two training approaches for the ANCOVA 2-arm continuous outcome model:

1. **Baseline** — Standard BayesFlow training with negative log-likelihood (NLL) loss only
2. **Calibrated** — NLL + differentiable calibration loss from [Falkner et al. (NeurIPS 2023)](https://arxiv.org/abs/2310.13402)

Both models use the **same architecture and training settings**. The only difference is the
calibration loss term, which penalizes under-coverage during training.

The calibration loss package lives at https://github.com/matthiaskloft/bfcalloss and is installed via:
```bash
pip install -e ".[calibration]"
```

In [None]:
import os

if not os.environ.get("KERAS_BACKEND"):
    os.environ["KERAS_BACKEND"] = "torch"

import numpy as np
import matplotlib.pyplot as plt
from itertools import product

import keras
import bayesflow as bf

# ANCOVA model
from rctbp_bf_training.models.ancova.model import (
    ANCOVAConfig,
    create_ancova_adapter,
    create_simulator,
    create_ancova_workflow_components,
    create_validation_grid,
    make_simulate_fn,
)
from rctbp_bf_training.core.infrastructure import (
    build_summary_network,
    build_inference_network,
)
from rctbp_bf_training.core.utils import MovingAverageEarlyStopping
from rctbp_bf_training.core.validation import (
    run_validation_pipeline,
    make_bayesflow_infer_fn,
)

# Calibration loss (from bfcalloss)
from bayesflow_calibration import (
    CalibratedContinuousApproximator,
    CalibrationMonitorCallback,
    GammaSchedule,
)

np.set_printoptions(suppress=True)
RNG = np.random.default_rng(2025)

print(f"BayesFlow {bf.__version__}")
print(f"Keras {keras.__version__} (backend: {keras.backend.backend()})")

## 1. Shared Configuration

Both models share the same ANCOVA config, simulator, adapter, and network architecture.
We use a smaller architecture for faster iteration.

In [None]:
config = ANCOVAConfig()

# Use the config defaults for network architecture
print("Summary network:", config.workflow.summary_network)
print("Inference network:", config.workflow.inference_network)
print("Training:", config.workflow.training)

In [None]:
# Create simulator and adapter (shared between both models)
simulator = create_simulator(config, RNG)
adapter = create_ancova_adapter()

# Quick test
sim_test = simulator.sample(10)
processed = adapter(sim_test)
print("summary_variables:", processed["summary_variables"].shape)
print("inference_variables:", processed["inference_variables"].shape)
print("inference_conditions:", processed["inference_conditions"].shape)

## 2. Shared Training Settings

In [None]:
# Training hyperparameters (same for both models)
EPOCHS = 100
BATCH_SIZE = 256
BATCHES_PER_EPOCH = 50
VALIDATION_SIMS = 500

train_config = config.workflow.training


def make_optimizer():
    """Create a fresh optimizer (must be separate per model)."""
    steps_per_epoch = BATCH_SIZE * BATCHES_PER_EPOCH
    lr_schedule = keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=train_config.initial_lr,
        decay_steps=steps_per_epoch,
        decay_rate=train_config.decay_rate,
        staircase=True,
    )
    return keras.optimizers.Adam(learning_rate=lr_schedule)


def make_early_stopping():
    """Create a fresh early stopping callback."""
    return MovingAverageEarlyStopping(
        window=train_config.early_stopping_window,
        patience=train_config.early_stopping_patience,
    )

## 3. Train Baseline Model (NLL only)

Standard BayesFlow training — the normalizing flow is trained to maximize the log-density
of the true parameters under the learned posterior.

In [None]:
# Build fresh networks for the baseline
summary_net_base, inference_net_base, _ = create_ancova_workflow_components(config)

workflow_base = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    inference_network=inference_net_base,
    summary_network=summary_net_base,
    optimizer=make_optimizer(),
    inference_conditions=["N", "p_alloc", "prior_df", "prior_scale"],
)
workflow_base.approximator.compile(optimizer=make_optimizer())

print("Baseline workflow created")

In [None]:
history_base = workflow_base.fit_online(
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    num_batches_per_epoch=BATCHES_PER_EPOCH,
    validation_data=VALIDATION_SIMS,
    callbacks=[make_early_stopping()],
)
print("Baseline training complete")

## 4. Train Calibrated Model (NLL + Calibration Loss)

Uses `CalibratedContinuousApproximator` from the `bayesflow-calibration` package.
This subclasses BayesFlow's `ContinuousApproximator` and injects a calibration loss
term that penalizes under-coverage during training.

Key settings:
- **`gamma_schedule`**: linear warmup — start with pure NLL, ramp up calibration pressure over 20 epochs
- **`calibration_mode=0.0`**: conservativeness mode (penalize under-coverage only)
- **`n_rank_samples=100`**: prior samples for rank computation
- **`subsample_size=80`**: subsample batch for calibration loss (reduces overhead)

In [None]:
# Build fresh networks for the calibrated model
summary_net_cal, inference_net_cal, _ = create_ancova_workflow_components(config)


def ancova_prior_fn(n_samples):
    """Sample b_group from the ANCOVA marginal prior.

    Integrates over meta-parameters (prior_df, prior_scale) to get
    the marginal prior distribution of b_group.
    Must return np.ndarray of shape (n_samples, param_dim).
    """
    rng = np.random.default_rng()
    samples = np.zeros((n_samples, 1))
    for j in range(n_samples):
        # Sample meta-parameters from their priors
        prior_df = int(round(
            rng.integers(0, config.meta.prior_df_max + 1)
        ))
        prior_scale = rng.gamma(
            shape=config.meta.prior_scale_gamma_shape,
            scale=config.meta.prior_scale_gamma_scale,
        )
        # Sample b_group from the conditional prior
        from rctbp_bf_training.core.utils import sample_t_or_normal
        samples[j, 0] = sample_t_or_normal(
            df=prior_df, scale=prior_scale, rng=rng
        )
    return samples.astype(np.float32)


# Quick check: sample from the marginal prior
test_prior = ancova_prior_fn(1000)
print(f"Marginal prior: mean={test_prior.mean():.3f}, std={test_prior.std():.3f}, shape={test_prior.shape}")

In [None]:
# Create the calibrated approximator
gamma_schedule = GammaSchedule(
    schedule_type="linear_warmup",
    gamma_max=100.0,
    warmup_epochs=20,
    gamma_min=0.0,
)

approximator_cal = CalibratedContinuousApproximator(
    # BayesFlow ContinuousApproximator args
    inference_network=inference_net_cal,
    summary_network=summary_net_cal,
    # Calibration-specific args
    prior_fn=ancova_prior_fn,
    gamma_schedule=gamma_schedule,
    calibration_mode=0.0,       # conservativeness: penalize under-coverage only
    n_rank_samples=100,         # prior samples for rank computation
    subsample_size=80,          # subsample batch for efficiency
)

approximator_cal.compile(optimizer=make_optimizer())
print("Calibrated approximator created")
print(f"  gamma schedule: {gamma_schedule.schedule_type}, max={gamma_schedule.gamma_max}, warmup={gamma_schedule.warmup_epochs}")
print(f"  calibration_mode: {approximator_cal.calibration_mode}")
print(f"  n_rank_samples: {approximator_cal.n_rank_samples}")
print(f"  subsample_size: {approximator_cal.subsample_size}")

In [None]:
# We can't use BasicWorkflow for the calibrated model since it uses
# ContinuousApproximator internally. Instead, we train the approximator
# directly using BayesFlow's online training loop.

# Create a BasicWorkflow just for the data pipeline, but swap the approximator
workflow_cal = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    inference_network=inference_net_cal,
    summary_network=summary_net_cal,
    optimizer=make_optimizer(),
    inference_conditions=["N", "p_alloc", "prior_df", "prior_scale"],
)
# Replace the approximator with our calibrated one
workflow_cal.approximator = approximator_cal

# Train with CalibrationMonitorCallback (REQUIRED for gamma scheduling)
history_cal = workflow_cal.fit_online(
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    num_batches_per_epoch=BATCHES_PER_EPOCH,
    validation_data=VALIDATION_SIMS,
    callbacks=[make_early_stopping(), CalibrationMonitorCallback()],
)
print("Calibrated training complete")

## 5. Training Loss Comparison

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training loss
ax = axes[0]
ax.plot(history_base.history["loss"], label="Baseline (NLL)", alpha=0.8)
ax.plot(history_cal.history["loss"], label="Calibrated (NLL + cal)", alpha=0.8)
ax.set_xlabel("Epoch")
ax.set_ylabel("Training Loss")
ax.set_title("Training Loss")
ax.legend()
ax.grid(True, alpha=0.3)

# Validation loss
ax = axes[1]
ax.plot(history_base.history.get("val_loss", []), label="Baseline", alpha=0.8)
ax.plot(history_cal.history.get("val_loss", []), label="Calibrated", alpha=0.8)
ax.set_xlabel("Epoch")
ax.set_ylabel("Validation Loss")
ax.set_title("Validation Loss")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. BayesFlow Built-in Diagnostics

Quick comparison using BayesFlow's default diagnostic metrics.

In [None]:
# Generate shared validation data
val_sims = simulator.sample(1000)

# Compute default diagnostics for both models
metrics_base = workflow_base.compute_default_diagnostics(test_data=val_sims)
metrics_cal = workflow_cal.compute_default_diagnostics(test_data=val_sims)

import pandas as pd

comparison = pd.DataFrame({
    "Baseline": metrics_base["b_group"],
    "Calibrated": metrics_cal["b_group"],
})
print("BayesFlow default diagnostics (b_group):")
display(comparison)

## 7. Validation on Conditions Grid

Systematic comparison across a grid of ANCOVA conditions (N, prior_df, prior_scale, b_group).

In [None]:
# Use the reduced validation grid (faster)
conditions = create_validation_grid(extended=False)
print(f"Validation grid: {len(conditions)} conditions")
print(f"Example: {conditions[0]}")

In [None]:
# Run validation for baseline model
simulate_fn = make_simulate_fn(rng=RNG)

infer_fn_base = make_bayesflow_infer_fn(
    workflow_base.approximator,
    param_key="b_group",
    data_keys=["outcome", "covariate", "group"],
    context_keys={"N": int, "p_alloc": float, "prior_df": float, "prior_scale": float},
)

print("=== Baseline Model ===")
results_base = run_validation_pipeline(
    conditions_list=conditions,
    n_sims=500,
    n_post_draws=500,
    simulate_fn=simulate_fn,
    infer_fn=infer_fn_base,
    true_param_key="b_arm_treat",
    verbose=True,
)

In [None]:
# Run validation for calibrated model
infer_fn_cal = make_bayesflow_infer_fn(
    workflow_cal.approximator,
    param_key="b_group",
    data_keys=["outcome", "covariate", "group"],
    context_keys={"N": int, "p_alloc": float, "prior_df": float, "prior_scale": float},
)

print("=== Calibrated Model ===")
results_cal = run_validation_pipeline(
    conditions_list=conditions,
    n_sims=500,
    n_post_draws=500,
    simulate_fn=simulate_fn,
    infer_fn=infer_fn_cal,
    true_param_key="b_arm_treat",
    verbose=True,
)

## 8. Summary Comparison

In [None]:
s_base = results_base["metrics"]["summary"]
s_cal = results_cal["metrics"]["summary"]

summary_keys = [
    "recovery_corr", "recovery_r2", "overall_nrmse", "overall_bias",
    "mean_contraction", "mean_cal_error",
    "coverage_50", "coverage_80", "coverage_90", "coverage_95",
    "sbc_ks_pvalue", "sbc_c2st_accuracy",
]

comparison_df = pd.DataFrame({
    "Metric": summary_keys,
    "Baseline": [s_base.get(k, float("nan")) for k in summary_keys],
    "Calibrated": [s_cal.get(k, float("nan")) for k in summary_keys],
}).set_index("Metric")

# Format nicely
comparison_df["Difference"] = comparison_df["Calibrated"] - comparison_df["Baseline"]

print("=" * 65)
print("         Baseline  vs  Calibrated  (NLL + calibration loss)")
print("=" * 65)
display(comparison_df.round(4))

print("\nKey:")
print("  mean_cal_error: lower is better (0 = perfect calibration)")
print("  coverage_XX: closer to XX/100 is better")
print("  sbc_ks_pvalue: > 0.05 suggests calibration")
print("  sbc_c2st_accuracy: closer to 0.5 is better")

## 9. Coverage Profile Comparison

The coverage profile shows empirical vs. nominal coverage at every level from 1% to 99%.
A well-calibrated model follows the diagonal.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, (label, results) in zip(axes, [("Baseline", results_base), ("Calibrated", results_cal)]):
    profile = results["metrics"]["summary"]["coverage_profile"]
    nominal = sorted(profile.keys())
    empirical = [profile[n] for n in nominal]

    ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Perfect calibration")
    ax.plot(nominal, empirical, "b-", lw=2, label="Empirical")
    ax.fill_between(nominal, empirical, nominal, alpha=0.2, color="red")
    ax.set_xlabel("Nominal Coverage")
    ax.set_ylabel("Empirical Coverage")
    ax.set_title(f"{label} — Coverage Profile")
    ax.legend(loc="upper left")
    ax.set_aspect("equal")
    ax.grid(True, alpha=0.3)

    cal_err = results["metrics"]["summary"]["mean_cal_error"]
    ax.text(0.95, 0.05, f"Cal. Error: {cal_err:.4f}",
            transform=ax.transAxes, ha="right", va="bottom",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8))

plt.tight_layout()
plt.show()

## 10. Per-Condition Calibration Error Comparison

In [None]:
cond_base = results_base["metrics"]["condition_metrics"]
cond_cal = results_cal["metrics"]["condition_metrics"]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Mean calibration error per condition
ax = axes[0]
x = np.arange(len(cond_base))
w = 0.35
ax.bar(x - w/2, cond_base["mean_cal_error"], w, label="Baseline", alpha=0.8)
ax.bar(x + w/2, cond_cal["mean_cal_error"], w, label="Calibrated", alpha=0.8)
ax.set_xlabel("Condition")
ax.set_ylabel("Mean Calibration Error")
ax.set_title("Calibration Error by Condition")
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

# 95% coverage per condition
ax = axes[1]
ax.bar(x - w/2, cond_base["coverage_95"], w, label="Baseline", alpha=0.8)
ax.bar(x + w/2, cond_cal["coverage_95"], w, label="Calibrated", alpha=0.8)
ax.axhline(0.95, color="red", ls="--", alpha=0.7, label="Nominal (95%)")
ax.set_xlabel("Condition")
ax.set_ylabel("95% Coverage")
ax.set_title("95% Coverage by Condition")
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

# NRMSE per condition (should not degrade much)
ax = axes[2]
ax.bar(x - w/2, cond_base["nrmse"], w, label="Baseline", alpha=0.8)
ax.bar(x + w/2, cond_cal["nrmse"], w, label="Calibrated", alpha=0.8)
ax.set_xlabel("Condition")
ax.set_ylabel("NRMSE")
ax.set_title("NRMSE by Condition (lower is better)")
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

plt.tight_layout()
plt.show()

## 11. Condition-Level Summary Tables

In [None]:
print("Baseline — condition summary:")
display(results_base["metrics"]["condition_summary"].round(4))

print("\nCalibrated — condition summary:")
display(results_cal["metrics"]["condition_summary"].round(4))

## 12. Conclusion

**Expected observations:**

- The calibrated model should have **lower calibration error** and empirical coverage
  closer to nominal levels (especially at 90% and 95%).
- In conservativeness mode (`calibration_mode=0.0`), the calibrated model may produce
  slightly **wider** credible intervals (higher posterior SD) — this is by design.
- The NRMSE and recovery correlation should remain similar, showing the calibration loss
  does not significantly harm point estimation accuracy.
- The SBC KS p-value should be higher for the calibrated model (closer to uniform rank
  distribution), and the C2ST accuracy closer to 0.5.

**Notes:**
- The calibration loss adds training overhead (roughly 2-6x depending on `n_rank_samples`
  and `subsample_size`). Adjust these for your compute budget.
- `gamma_schedule` with linear warmup is recommended — starting with pure NLL lets the
  network learn a reasonable posterior before calibration pressure is applied.