Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,17 @@ def constrain(
lower: int | float | np.ndarray = None,
upper: int | float | np.ndarray = None,
method: str = "default",
inclusive: str = "both",
epsilon: float = 1e-15,
):
if isinstance(keys, str):
keys = [keys]

transform = MapTransform(
transform_map={key: Constrain(lower=lower, upper=upper, method=method) for key in keys}
transform_map={
key: Constrain(lower=lower, upper=upper, method=method, inclusive=inclusive, epsilon=epsilon)
for key in keys
}
)
self.transforms.append(transform)
return self
Expand Down
51 changes: 42 additions & 9 deletions bayesflow/adapters/transforms/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@serializable(package="bayesflow.adapters")
class Constrain(ElementwiseTransform):
"""
Constrains neural network predictions of a data variable to specificied bounds.
Constrains neural network predictions of a data variable to specified bounds.

Parameters:
String containing the name of the data variable to be transformed e.g. "sigma". See examples below.
Expand All @@ -28,14 +28,21 @@ class Constrain(ElementwiseTransform):
- Double bounded methods: sigmoid, expit, (default = sigmoid)
- Lower bound only methods: softplus, exp, (default = softplus)
- Upper bound only methods: softplus, exp, (default = softplus)

inclusive: Indicates which bounds are inclusive (or exclusive).
- "both" (default): Both lower and upper bounds are inclusive.
- "lower": Lower bound is inclusive, upper bound is exclusive.
- "upper": Lower bound is exclusive, upper bound is inclusive.
- "none": Both lower and upper bounds are exclusive.
epsilon: Small value to ensure inclusive bounds are not violated.
Current default is 1e-15 as this ensures finite outcomes
with the default transformations applied to data exactly at the boundaries.


Examples:
1) Let sigma be the standard deviation of a normal distribution,
then sigma should always be greater than zero.

Useage:
Usage:
adapter = (
bf.Adapter()
.constrain("sigma", lower=0)
Expand All @@ -45,14 +52,19 @@ class Constrain(ElementwiseTransform):
[0,1] then we would constrain the neural network to estimate p in the following way.

Usage:
adapter = (
bf.Adapter()
.constrain("p", lower=0, upper=1, method = "sigmoid")
)
>>> import bayesflow as bf
>>> adapter = bf.Adapter()
>>> adapter.constrain("p", lower=0, upper=1, method="sigmoid", inclusive="both")
"""

def __init__(
self, *, lower: int | float | np.ndarray = None, upper: int | float | np.ndarray = None, method: str = "default"
self,
*,
lower: int | float | np.ndarray = None,
upper: int | float | np.ndarray = None,
method: str = "default",
inclusive: str = "both",
epsilon: float = 1e-15,
):
super().__init__()

Expand Down Expand Up @@ -121,12 +133,31 @@ def unconstrain(x):

self.lower = lower
self.upper = upper

self.method = method
self.inclusive = inclusive
self.epsilon = epsilon

self.constrain = constrain
self.unconstrain = unconstrain

# do this last to avoid serialization issues
match inclusive:
case "lower":
if lower is not None:
lower = lower - epsilon
case "upper":
if upper is not None:
upper = upper + epsilon
case True | "both":
if lower is not None:
lower = lower - epsilon
if upper is not None:
upper = upper + epsilon
case False | None | "none":
pass
case other:
raise ValueError(f"Unsupported value for 'inclusive': {other!r}.")

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Constrain":
return cls(**config)
Expand All @@ -136,6 +167,8 @@ def get_config(self) -> dict:
"lower": self.lower,
"upper": self.upper,
"method": self.method,
"inclusive": self.inclusive,
"epsilon": self.epsilon,
}

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
Expand Down
56 changes: 56 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,59 @@ def test_serialize_deserialize(adapter, custom_objects, random_data):
deserialized_processed = deserialized(random_data)
for key, value in processed.items():
assert np.allclose(value, deserialized_processed[key])


def test_constrain():
# check if constraint-implied transforms are applied correctly
import numpy as np
import warnings
from bayesflow.adapters import Adapter

data = {
"x_lower_cont": np.random.exponential(1, size=(32, 1)),
"x_upper_cont": -np.random.exponential(1, size=(32, 1)),
"x_both_cont": np.random.beta(0.5, 0.5, size=(32, 1)),
"x_lower_disc1": np.zeros(shape=(32, 1)),
"x_lower_disc2": np.zeros(shape=(32, 1)),
"x_upper_disc1": np.ones(shape=(32, 1)),
"x_upper_disc2": np.ones(shape=(32, 1)),
"x_both_disc1": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))),
"x_both_disc2": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))),
}

adapter = (
Adapter()
.constrain("x_lower_cont", lower=0)
.constrain("x_upper_cont", upper=0)
.constrain("x_both_cont", lower=0, upper=1)
.constrain("x_lower_disc1", lower=0, inclusive="lower")
.constrain("x_lower_disc2", lower=0, inclusive="none")
.constrain("x_upper_disc1", upper=1, inclusive="upper")
.constrain("x_upper_disc2", upper=1, inclusive="none")
.constrain("x_both_disc1", lower=0, upper=1, inclusive="both")
.constrain("x_both_disc2", lower=0, upper=1, inclusive="none")
)

with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
result = adapter(data)

# continuous variables should not have boundary issues
assert result["x_lower_cont"].min() < 0.0
assert result["x_upper_cont"].max() > 0.0
assert result["x_both_cont"].min() < 0.0
assert result["x_both_cont"].max() > 1.0

# discrete variables at the boundaries should not have issues
# if inclusive is set properly
assert np.isfinite(result["x_lower_disc1"].min())
assert np.isfinite(result["x_upper_disc1"].max())
assert np.isfinite(result["x_both_disc1"].min())
assert np.isfinite(result["x_both_disc1"].max())

# discrete variables at the boundaries should have issues
# if inclusive is not set properly
assert np.isneginf(result["x_lower_disc2"][0])
assert np.isinf(result["x_upper_disc2"][0])
assert np.isneginf(result["x_both_disc2"][0])
assert np.isinf(result["x_both_disc2"][-1])