Skip to content
2 changes: 2 additions & 0 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
distributions,
networks,
simulators,
workflows,
utils,
)

from .workflows import BasicWorkflow
from .approximators import ContinuousApproximator
from .adapters import Adapter
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
Expand Down
9 changes: 9 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .transforms import (
AsSet,
AsTimeSeries,
Broadcast,
Concatenate,
Constrain,
Expand Down Expand Up @@ -112,6 +113,14 @@ def as_set(self, keys: str | Sequence[str]):
self.transforms.append(transform)
return self

def as_time_series(self, keys: str | Sequence[str]):
if isinstance(keys, str):
keys = [keys]

transform = MapTransform({key: AsTimeSeries() for key in keys})
self.transforms.append(transform)
return self

def broadcast(
self, keys: str | Sequence[str], *, to: str, expand: str | int | tuple = "left", exclude: int | tuple = -1
):
Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .as_set import AsSet
from .as_time_series import AsTimeSeries
from .broadcast import Broadcast
from .concatenate import Concatenate
from .constrain import Constrain
Expand Down
6 changes: 6 additions & 0 deletions bayesflow/adapters/transforms/as_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ class AsSet(ElementwiseTransform):
This is useful, for example, in a linear regression context where we can index
the observations in arbitrary order and always get the same regression line.

Currently, all this transform does is to ensure that the variable
arrays are at least 3D. The 2rd dimension is treated as the
set dimension and the 3rd dimension as the data dimension.
In the future, the transform will have more advanced behavior
to better ensure the correct treatment of sets.

Useage:

adapter = (
Expand Down
32 changes: 32 additions & 0 deletions bayesflow/adapters/transforms/as_time_series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np

from .elementwise_transform import ElementwiseTransform


class AsTimeSeries(ElementwiseTransform):
"""
The `.as_time_series` transform can be used to indicate that
variables shall be treated as time series.

Currently, all this transformation does is to ensure that the variable
arrays are at least 3D. The 2rd dimension is treated as the
time series dimension and the 3rd dimension as the data dimension.
In the future, the transform will have more advanced behavior
to better ensure the correct treatment of time series data.

Useage:

adapter = (
bf.Adapter()
.as_time_series(["x", "y"])
)
"""

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.atleast_3d(data)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
if data.shape[2] == 1:
return np.squeeze(data, axis=2)

return data
25 changes: 14 additions & 11 deletions bayesflow/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from .plots import calibration_ecdf
from .plots import calibration_histogram
from .plots import loss
from .plots import mc_calibration
from .plots import mc_confusion_matrix
from .plots import mmd_hypothesis_test
from .plots import pairs_posterior
from .plots import pairs_prior
from .plots import pairs_samples
from .plots import recovery
from .plots import z_score_contraction
from .metrics import root_mean_squared_error, calibration_error, posterior_contraction

from .plots import (
calibration_ecdf,
calibration_histogram,
loss,
mc_calibration,
mc_confusion_matrix,
mmd_hypothesis_test,
pairs_posterior,
pairs_samples,
recovery,
z_score_contraction,
)
3 changes: 3 additions & 0 deletions bayesflow/diagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .calibration_error import calibration_error
from .posterior_contraction import posterior_contraction
from .root_mean_squared_error import root_mean_squared_error
82 changes: 82 additions & 0 deletions bayesflow/diagnostics/metrics/calibration_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Sequence, Any, Mapping, Callable

import numpy as np

from ...utils.dict_utils import dicts_to_arrays


def calibration_error(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
resolution: int = 20,
aggregation: Callable = np.median,
min_quantile: float = 0.005,
max_quantile: float = 0.995,
variable_names: Sequence[str] = None,
) -> Mapping[str, Any]:
"""Computes an aggregate score for the marginal calibration error over an ensemble of approximate
posteriors. The calibration error is given as the aggregate (e.g., median) of the absolute deviation
between an alpha-CI and the relative number of inliers from ``prior_samples`` over multiple alphas in
(0, 1).

Parameters
----------
targets : np.ndarray of shape (num_datasets, num_draws, num_variables)
The random draws from the approximate posteriors over ``num_datasets``
references : np.ndarray of shape (num_datasets, num_variables)
The corresponding ground-truth values sampled from the prior
resolution : int, optional, default: 20
The number of credibility intervals (CIs) to consider
aggregation : callable or None, optional, default: np.median
The function used to aggregate the marginal calibration errors.
If ``None`` provided, the per-alpha calibration errors will be returned.
min_quantile : float in (0, 1), optional, default: 0.005
The minimum posterior quantile to consider.
max_quantile : float in (0, 1), optional, default: 0.995
The maximum posterior quantile to consider.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.

Returns
-------
result : dict
Dictionary containing:
- "values" : float or np.ndarray
The aggregated calibration error per variable
- "metric_name" : str
The name of the metric ("Calibration Error").
- "variable_names" : str
The (inferred) variable names.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)

# Define alpha values and the corresponding quantile bounds
alphas = np.linspace(start=min_quantile, stop=max_quantile, num=resolution)
regions = 1 - alphas
lowers = regions / 2
uppers = 1 - lowers

# Compute quantiles for each alpha, for each dataset and parameter
quantiles = np.quantile(samples["targets"], [lowers, uppers], axis=1)

# Shape: (2, resolution, num_datasets, num_params)
lower_bounds, upper_bounds = quantiles[0], quantiles[1]

# Compute masks for inliers
lower_mask = lower_bounds <= samples["references"][None, ...]
upper_mask = upper_bounds >= samples["references"][None, ...]

# Logical AND to identify inliers for each alpha
inlier_id = np.logical_and(lower_mask, upper_mask)

# Compute the relative number of inliers for each alpha
alpha_pred = np.mean(inlier_id, axis=1)

# Calculate absolute error between predicted inliers and alpha
absolute_errors = np.abs(alpha_pred - alphas[:, None])

# Aggregate errors across alpha
error = aggregation(absolute_errors, axis=0)

return {"values": error, "metric_name": "Calibration Error", "variable_names": variable_names}
52 changes: 52 additions & 0 deletions bayesflow/diagnostics/metrics/posterior_contraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Sequence, Any, Mapping, Callable

import numpy as np

from ...utils.dict_utils import dicts_to_arrays


def posterior_contraction(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
aggregation: Callable = np.median,
variable_names: Sequence[str] = None,
) -> Mapping[str, Any]:
"""Computes the posterior contraction (PC) from prior to posterior for the given samples.

Parameters
----------
targets : np.ndarray of shape (num_datasets, num_draws_post, num_variables)
Posterior samples, comprising `num_draws_post` random draws from the posterior distribution
for each data set from `num_datasets`.
references : np.ndarray of shape (num_datasets, num_variables)
Prior samples, comprising `num_datasets` ground truths.
aggregation : callable, optional (default = np.median)
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.

Returns
-------
result : dict
Dictionary containing:
- "values" : float or np.ndarray
The aggregated posterior contraction per variable
- "metric_name" : str
The name of the metric ("Posterior Contraction").
- "variable_names" : str
The (inferred) variable names.

Notes
-----
Posterior contraction measures the reduction in uncertainty from the prior to the posterior.
Values close to 1 indicate strong contraction (high reduction in uncertainty), while values close to 0
indicate low contraction.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)

post_vars = samples["targets"].var(axis=1, ddof=1)
prior_vars = samples["references"].var(axis=0, keepdims=True, ddof=1)
contraction = 1 - (post_vars / prior_vars)
contraction = aggregation(contraction, axis=0)
return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": samples["variable_names"]}
59 changes: 59 additions & 0 deletions bayesflow/diagnostics/metrics/root_mean_squared_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Sequence, Any, Mapping, Callable

import numpy as np

from ...utils.dict_utils import dicts_to_arrays


def root_mean_squared_error(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
normalize: bool = True,
aggregation: Callable = np.median,
variable_names: Sequence[str] = None,
) -> Mapping[str, Any]:
"""Computes the (Normalized) Root Mean Squared Error (RMSE/NRMSE) for the given posterior and prior samples.

Parameters
----------
targets : np.ndarray of shape (num_datasets, num_draws_post, num_variables)
Posterior samples, comprising `num_draws_post` random draws from the posterior distribution
for each data set from `num_datasets`.
references : np.ndarray of shape (num_datasets, num_variables)
Prior samples, comprising `num_datasets` ground truths.
normalize : bool, optional (default = True)
Whether to normalize the RMSE using the range of the prior samples.
aggregation : callable, optional (default = np.median)
Function to aggregate the RMSE across draws. Typically `np.mean` or `np.median`.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.

Notes
-----
Aggregation is performed after computing the RMSE for each posterior draw, instead of first aggregating
the posterior draws and then computing the RMSE between aggregates and ground truths.

Returns
-------
result : dict
Dictionary containing:
- "values" : np.ndarray
The aggregated (N)RMSE for each variable.
- "metric_name" : str
The name of the metric ("RMSE" or "NRMSE").
- "variable_names" : str
The (inferred) variable names.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)

rmse = np.sqrt(np.mean((samples["targets"] - samples["references"][:, None, :]) ** 2, axis=0))

if normalize:
rmse /= (samples["references"].max(axis=0) - samples["references"].min(axis=0))[None, :]
metric_name = "NRMSE"
else:
metric_name = "RMSE"

rmse = aggregation(rmse, axis=0)
return {"values": rmse, "metric_name": metric_name, "variable_names": samples["variable_names"]}
1 change: 0 additions & 1 deletion bayesflow/diagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .mc_confusion_matrix import mc_confusion_matrix
from .mmd_hypothesis_test import mmd_hypothesis_test
from .pairs_posterior import pairs_posterior
from .pairs_prior import pairs_prior
from .pairs_samples import pairs_samples
from .recovery import recovery
from .z_score_contraction import z_score_contraction
Loading
Loading