diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index d39f03fa8..c964033c7 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -233,12 +233,17 @@ def constrain( lower: int | float | np.ndarray = None, upper: int | float | np.ndarray = None, method: str = "default", + inclusive: str = "both", + epsilon: float = 1e-15, ): if isinstance(keys, str): keys = [keys] transform = MapTransform( - transform_map={key: Constrain(lower=lower, upper=upper, method=method) for key in keys} + transform_map={ + key: Constrain(lower=lower, upper=upper, method=method, inclusive=inclusive, epsilon=epsilon) + for key in keys + } ) self.transforms.append(transform) return self diff --git a/bayesflow/adapters/transforms/constrain.py b/bayesflow/adapters/transforms/constrain.py index 7de854c6d..1451f5618 100644 --- a/bayesflow/adapters/transforms/constrain.py +++ b/bayesflow/adapters/transforms/constrain.py @@ -16,7 +16,7 @@ @serializable(package="bayesflow.adapters") class Constrain(ElementwiseTransform): """ - Constrains neural network predictions of a data variable to specificied bounds. + Constrains neural network predictions of a data variable to specified bounds. Parameters: String containing the name of the data variable to be transformed e.g. "sigma". See examples below. @@ -28,14 +28,21 @@ class Constrain(ElementwiseTransform): - Double bounded methods: sigmoid, expit, (default = sigmoid) - Lower bound only methods: softplus, exp, (default = softplus) - Upper bound only methods: softplus, exp, (default = softplus) - + inclusive: Indicates which bounds are inclusive (or exclusive). + - "both" (default): Both lower and upper bounds are inclusive. + - "lower": Lower bound is inclusive, upper bound is exclusive. + - "upper": Lower bound is exclusive, upper bound is inclusive. + - "none": Both lower and upper bounds are exclusive. + epsilon: Small value to ensure inclusive bounds are not violated. + Current default is 1e-15 as this ensures finite outcomes + with the default transformations applied to data exactly at the boundaries. Examples: 1) Let sigma be the standard deviation of a normal distribution, then sigma should always be greater than zero. - Useage: + Usage: adapter = ( bf.Adapter() .constrain("sigma", lower=0) @@ -45,14 +52,19 @@ class Constrain(ElementwiseTransform): [0,1] then we would constrain the neural network to estimate p in the following way. Usage: - adapter = ( - bf.Adapter() - .constrain("p", lower=0, upper=1, method = "sigmoid") - ) + >>> import bayesflow as bf + >>> adapter = bf.Adapter() + >>> adapter.constrain("p", lower=0, upper=1, method="sigmoid", inclusive="both") """ def __init__( - self, *, lower: int | float | np.ndarray = None, upper: int | float | np.ndarray = None, method: str = "default" + self, + *, + lower: int | float | np.ndarray = None, + upper: int | float | np.ndarray = None, + method: str = "default", + inclusive: str = "both", + epsilon: float = 1e-15, ): super().__init__() @@ -121,12 +133,31 @@ def unconstrain(x): self.lower = lower self.upper = upper - self.method = method + self.inclusive = inclusive + self.epsilon = epsilon self.constrain = constrain self.unconstrain = unconstrain + # do this last to avoid serialization issues + match inclusive: + case "lower": + if lower is not None: + lower = lower - epsilon + case "upper": + if upper is not None: + upper = upper + epsilon + case True | "both": + if lower is not None: + lower = lower - epsilon + if upper is not None: + upper = upper + epsilon + case False | None | "none": + pass + case other: + raise ValueError(f"Unsupported value for 'inclusive': {other!r}.") + @classmethod def from_config(cls, config: dict, custom_objects=None) -> "Constrain": return cls(**config) @@ -136,6 +167,8 @@ def get_config(self) -> dict: "lower": self.lower, "upper": self.upper, "method": self.method, + "inclusive": self.inclusive, + "epsilon": self.epsilon, } def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 41e8c2bb3..7247869d7 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -31,3 +31,59 @@ def test_serialize_deserialize(adapter, custom_objects, random_data): deserialized_processed = deserialized(random_data) for key, value in processed.items(): assert np.allclose(value, deserialized_processed[key]) + + +def test_constrain(): + # check if constraint-implied transforms are applied correctly + import numpy as np + import warnings + from bayesflow.adapters import Adapter + + data = { + "x_lower_cont": np.random.exponential(1, size=(32, 1)), + "x_upper_cont": -np.random.exponential(1, size=(32, 1)), + "x_both_cont": np.random.beta(0.5, 0.5, size=(32, 1)), + "x_lower_disc1": np.zeros(shape=(32, 1)), + "x_lower_disc2": np.zeros(shape=(32, 1)), + "x_upper_disc1": np.ones(shape=(32, 1)), + "x_upper_disc2": np.ones(shape=(32, 1)), + "x_both_disc1": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))), + "x_both_disc2": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))), + } + + adapter = ( + Adapter() + .constrain("x_lower_cont", lower=0) + .constrain("x_upper_cont", upper=0) + .constrain("x_both_cont", lower=0, upper=1) + .constrain("x_lower_disc1", lower=0, inclusive="lower") + .constrain("x_lower_disc2", lower=0, inclusive="none") + .constrain("x_upper_disc1", upper=1, inclusive="upper") + .constrain("x_upper_disc2", upper=1, inclusive="none") + .constrain("x_both_disc1", lower=0, upper=1, inclusive="both") + .constrain("x_both_disc2", lower=0, upper=1, inclusive="none") + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + result = adapter(data) + + # continuous variables should not have boundary issues + assert result["x_lower_cont"].min() < 0.0 + assert result["x_upper_cont"].max() > 0.0 + assert result["x_both_cont"].min() < 0.0 + assert result["x_both_cont"].max() > 1.0 + + # discrete variables at the boundaries should not have issues + # if inclusive is set properly + assert np.isfinite(result["x_lower_disc1"].min()) + assert np.isfinite(result["x_upper_disc1"].max()) + assert np.isfinite(result["x_both_disc1"].min()) + assert np.isfinite(result["x_both_disc1"].max()) + + # discrete variables at the boundaries should have issues + # if inclusive is not set properly + assert np.isneginf(result["x_lower_disc2"][0]) + assert np.isinf(result["x_upper_disc2"][0]) + assert np.isneginf(result["x_both_disc2"][0]) + assert np.isinf(result["x_both_disc2"][-1])