From 2ba93d43f7baa3a68cd78cba6a55c866f1442ad3 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 10 Apr 2025 07:24:27 +0000 Subject: [PATCH 1/2] Add class for custom transforms to adapter. This commit reintroduces the features that were present in `LambdaTransform`, but only allowing registered functions. While being stricter, that allows for closer scaffolding and raising errors early on, so that users cannot provide functions that will not be (de)serializable later on. As there are a few failure modes, the focus is on providing detailed error messages to enable users to solve problems without external help. --- bayesflow/adapters/adapter.py | 85 +++++++- bayesflow/adapters/transforms/__init__.py | 1 + .../serializable_custom_transform.py | 184 ++++++++++++++++++ tests/test_adapters/conftest.py | 8 + tests/test_adapters/test_adapters.py | 69 +++++++ 5 files changed, 346 insertions(+), 1 deletion(-) create mode 100644 bayesflow/adapters/transforms/serializable_custom_transform.py diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index d0711062c..f3eb7386e 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -1,4 +1,4 @@ -from collections.abc import MutableSequence, Sequence, Mapping +from collections.abc import Callable, MutableSequence, Sequence, Mapping import numpy as np @@ -24,6 +24,7 @@ NumpyTransform, OneHot, Rename, + SerializableCustomTransform, Sqrt, Standardize, ToArray, @@ -283,6 +284,88 @@ def apply( self.transforms.append(transform) return self + def apply_serializable( + self, + include: str | Sequence[str] = None, + *, + serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], + serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], + predicate: Predicate = None, + exclude: str | Sequence[str] = None, + **kwargs, + ): + """Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter. + + Parameters + ---------- + serializable_forward_fn : function, no lambda + Registered serializable function to transform the data in the forward pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + serializable_inverse_fn : function, no lambda + Registered serializable function to transform the data in the inverse pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + predicate : Predicate, optional + Function that indicates which variables should be transformed. + include : str or Sequence of str, optional + Names of variables to include in the transform. + exclude : str or Sequence of str, optional + Names of variables to exclude from the transform. + **kwargs : dict + Additional keyword arguments passed to the transform. + + Raises + ------ + ValueError + When the provided functions are not registered serializable functions. + + Notes + ----- + Important: The forward and inverse functions have to be registered with Keras. + To do so, use the `@keras.saving.register_keras_serializable` decorator. + They must also be registered (and identical) when loading the adapter + at a later point in time. + + Examples + -------- + + The example below shows how to use the + `keras.saving.register_keras_serializable` decorator to + register functions with Keras. Note that for this simple + example, one usually would use the simpler :py:meth:`apply` + method. + + >>> import keras + >>> + >>> @keras.saving.register_keras_serializable("custom") + >>> def forward_fn(x): + >>> return x**2 + >>> + >>> @keras.saving.register_keras_serializable("custom") + >>> def inverse_fn(x): + >>> return x**0.5 + >>> + >>> adapter = bf.Adapter().apply_serializable( + >>> "x", + >>> serializable_forward_fn=forward_fn, + >>> serializable_inverse_fn=inverse_fn, + >>> ) + """ + transform = FilterTransform( + transform_constructor=SerializableCustomTransform, + predicate=predicate, + include=include, + exclude=exclude, + serializable_forward_fn=serializable_forward_fn, + serializable_inverse_fn=serializable_inverse_fn, + **kwargs, + ) + self.transforms.append(transform) + return self + def as_set(self, keys: str | Sequence[str]): """Append an :py:class:`~transforms.AsSet` transform to the adapter. diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index b3e95a494..81e9f665f 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -15,6 +15,7 @@ from .one_hot import OneHot from .rename import Rename from .scale import Scale +from .serializable_custom_transform import SerializableCustomTransform from .shift import Shift from .sqrt import Sqrt from .standardize import Standardize diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py new file mode 100644 index 000000000..6337d1f12 --- /dev/null +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -0,0 +1,184 @@ +from collections.abc import Callable +import numpy as np +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, + get_registered_name, + get_registered_object, +) +from .elementwise_transform import ElementwiseTransform +from ...utils import filter_kwargs +import inspect + + +@serializable(package="bayesflow.adapters") +class SerializableCustomTransform(ElementwiseTransform): + """ + Transforms a parameter using a pair of registered serializable forward and inverse functions. + + Parameters + ---------- + serializable_forward_fn : function, no lambda + Registered serializable function to transform the data in the forward pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + serializable_inverse_fn : function, no lambda + Function to transform the data in the inverse pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + + Raises + ------ + ValueError + When the provided functions are not registered serializable functions. + + Notes + ----- + Important: The forward and inverse functions have to be registered with Keras. + To do so, use the `@keras.saving.register_keras_serializable` decorator. + They must also be registered (and identical) when loading the adapter + at a later point in time. + + """ + + def __init__( + self, + *, + serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], + serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], + ): + super().__init__() + + self._check_serializable(serializable_forward_fn, label="serializable_forward_fn") + self._check_serializable(serializable_inverse_fn, label="serializable_inverse_fn") + self._forward = serializable_forward_fn + self._inverse = serializable_inverse_fn + + @classmethod + def _check_serializable(cls, function, label=""): + GENERAL_EXAMPLE_CODE = f"""The example code below shows the structure of a correctly decorated function: + +``` +import keras + +@keras.saving.register_keras_serializable('custom') +def my_{label}(...): + [your code goes here...] +``` +""" + if function is None: + raise TypeError( + f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}" + ) + registered_name = get_registered_name(function) + # check if function is a lambda function + if registered_name == "": + raise ValueError( + f"The provided function for '{label}' is a lambda function, " + "which cannot be serialized. " + "Please provide a registered serializable function by using the " + "@keras.saving.register_keras_serializable decorator." + f"\n{GENERAL_EXAMPLE_CODE}" + ) + if inspect.ismethod(function): + raise ValueError( + f"The provided value for '{label}' is a method, not a function. " + "Methods cannot be serialized separately from their classes. " + "Please provide a registered serializable function instead by " + "moving the functionality to a function (i.e., outside of the class) and " + "using the @keras.saving.register_keras_serializable decorator." + f"\n{GENERAL_EXAMPLE_CODE}" + ) + registered_object_for_name = get_registered_object(registered_name) + if registered_object_for_name is None: + try: + source_max_lines = 5 + function_source_code = inspect.getsource(function).split("\n") + if len(function_source_code) > source_max_lines: + function_source_code = function_source_code[:source_max_lines] + [" [...]"] + + example_code = "For your provided function, this would look like this:\n\n" + example_code += "\n".join( + ["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"] + + function_source_code + + ["```"] + ) + except OSError: + example_code = GENERAL_EXAMPLE_CODE + raise ValueError( + f"The provided function for '{label}' is not registered with Keras.\n" + "Please register the function using the " + "@keras.saving.register_keras_serializable decorator.\n" + f"{example_code}" + ) + if registered_object_for_name is not function: + raise ValueError( + f"The provided function for '{label}' does not match the function " + f"registered under its name '{registered_name}'. " + f"(registered function: {registered_object_for_name}, provided function: {function}). " + ) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform": + if get_registered_object(config["forward"]["config"], custom_objects) is None: + provided_function_msg = "" + if config["_forward_source_code"]: + provided_function_msg = ( + f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```" + ) + raise TypeError( + "\n\nPLEASE READ HERE:\n" + "-----------------\n" + "The forward function that was provided as `serializable_forward_fn` " + "is not registered with Keras, making deserialization impossible. " + f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original " + "function before loading your model." + f"{provided_function_msg}" + ) + if get_registered_object(config["inverse"]["config"], custom_objects) is None: + provided_function_msg = "" + if config["_inverse_source_code"]: + provided_function_msg = ( + f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```" + ) + raise TypeError( + "\n\nPLEASE READ HERE:\n" + "-----------------\n" + "The inverse function that was provided as `serializable_inverse_fn` " + "is not registered with Keras, making deserialization impossible. " + f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original " + "function before loading your model." + f"{provided_function_msg}" + ) + forward = deserialize(config["forward"], custom_objects) + inverse = deserialize(config["inverse"], custom_objects) + return cls( + serializable_forward_fn=forward, + serializable_inverse_fn=inverse, + ) + + def get_config(self) -> dict: + forward_source_code = inverse_source_code = None + try: + forward_source_code = inspect.getsource(self._forward) + inverse_source_code = inspect.getsource(self._inverse) + except OSError: + pass + return { + "forward": serialize(self._forward), + "inverse": serialize(self._inverse), + "_forward_source_code": forward_source_code, + "_inverse_source_code": inverse_source_code, + } + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + # filter kwargs so that other transform args like batch_size, strict, ... are not passed through + kwargs = filter_kwargs(kwargs, self._forward) + return self._forward(data, **kwargs) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + kwargs = filter_kwargs(kwargs, self._inverse) + return self._inverse(data, **kwargs) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index abd5797bd..b8365e94a 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -5,6 +5,11 @@ @pytest.fixture() def adapter(): from bayesflow.adapters import Adapter + import keras + + @keras.saving.register_keras_serializable("custom") + def serializable_fn(x): + return x d = ( Adapter() @@ -20,6 +25,9 @@ def adapter(): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") + .apply_serializable( + include="x", serializable_forward_fn=serializable_fn, serializable_inverse_fn=serializable_fn + ) .scale("x", by=[-1, 2]) .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 3dea0baf4..69edb6e34 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -3,6 +3,7 @@ serialize_keras_object as serialize, ) import numpy as np +import pytest def test_cycle_consistency(adapter, random_data): @@ -110,3 +111,71 @@ def test_simple_transforms(random_data): assert np.allclose(inverse["t1"], random_data["t1"]) assert np.allclose(inverse["p1"], random_data["p1"]) + + +def test_custom_transform(): + # test that transform raises errors in all relevant cases + import keras + from bayesflow.adapters.transforms import SerializableCustomTransform + from copy import deepcopy + + class A: + @classmethod + def fn(cls, x): + return x + + def not_registered_fn(x): + return x + + @keras.saving.register_keras_serializable("custom") + def registered_fn(x): + return x + + @keras.saving.register_keras_serializable("custom") + def registered_but_changed(x): + return x + + def registered_but_changed(x): # noqa: F811 + return 2 * x + + # method instead of function provided + with pytest.raises(ValueError): + SerializableCustomTransform(serializable_forward_fn=A.fn, serializable_inverse_fn=registered_fn) + SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=A.fn) + + # lambda function provided + with pytest.raises(ValueError): + SerializableCustomTransform(serializable_forward_fn=lambda x: x, serializable_inverse_fn=registered_fn) + SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=lambda x: x) + + # unregistered function provided + with pytest.raises(ValueError): + SerializableCustomTransform(serializable_forward_fn=not_registered_fn, serializable_inverse_fn=registered_fn) + SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=not_registered_fn) + + # function does not match registered function + with pytest.raises(ValueError): + SerializableCustomTransform( + serializable_forward_fn=registered_but_changed, serializable_inverse_fn=registered_fn + ) + SerializableCustomTransform( + serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_but_changed + ) + + transform = SerializableCustomTransform( + serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_fn + ) + serialized_transform = keras.saving.serialize_keras_object(transform) + keras.saving.deserialize_keras_object(serialized_transform) + + # modify name of the forward function so that it cannot be found + corrupt_serialized_transform = deepcopy(serialized_transform) + corrupt_serialized_transform["config"]["forward"]["config"] = "nonexistent" + with pytest.raises(TypeError): + keras.saving.deserialize_keras_object(corrupt_serialized_transform) + + # modify name of the inverse transform so that it cannot be found + corrupt_serialized_transform = deepcopy(serialized_transform) + corrupt_serialized_transform["config"]["inverse"]["config"] = "nonexistent" + with pytest.raises(TypeError): + keras.saving.deserialize_keras_object(corrupt_serialized_transform) From 8fdfde9bc435995939e422c0b067e091c65dfec4 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 11 Apr 2025 16:17:59 +0000 Subject: [PATCH 2/2] custom transform: less verbose naming, fix tests --- bayesflow/adapters/adapter.py | 16 +++---- .../serializable_custom_transform.py | 43 +++++++++---------- tests/test_adapters/conftest.py | 4 +- tests/test_adapters/test_adapters.py | 30 ++++++------- 4 files changed, 44 insertions(+), 49 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index f3eb7386e..6476ffcb4 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -288,8 +288,8 @@ def apply_serializable( self, include: str | Sequence[str] = None, *, - serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], - serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], + forward: Callable[[np.ndarray, ...], np.ndarray], + inverse: Callable[[np.ndarray, ...], np.ndarray], predicate: Predicate = None, exclude: str | Sequence[str] = None, **kwargs, @@ -298,12 +298,12 @@ def apply_serializable( Parameters ---------- - serializable_forward_fn : function, no lambda + forward : function, no lambda Registered serializable function to transform the data in the forward pass. For the adapter to be serializable, this function has to be serializable as well (see Notes). Therefore, only proper functions and no lambda functions can be used here. - serializable_inverse_fn : function, no lambda + inverse : function, no lambda Registered serializable function to transform the data in the inverse pass. For the adapter to be serializable, this function has to be serializable as well (see Notes). Therefore, only proper functions and no lambda @@ -350,8 +350,8 @@ def apply_serializable( >>> >>> adapter = bf.Adapter().apply_serializable( >>> "x", - >>> serializable_forward_fn=forward_fn, - >>> serializable_inverse_fn=inverse_fn, + >>> forward=forward_fn, + >>> inverse=inverse_fn, >>> ) """ transform = FilterTransform( @@ -359,8 +359,8 @@ def apply_serializable( predicate=predicate, include=include, exclude=exclude, - serializable_forward_fn=serializable_forward_fn, - serializable_inverse_fn=serializable_inverse_fn, + forward=forward, + inverse=inverse, **kwargs, ) self.transforms.append(transform) diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py index 6337d1f12..75d588afd 100644 --- a/bayesflow/adapters/transforms/serializable_custom_transform.py +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -19,12 +19,12 @@ class SerializableCustomTransform(ElementwiseTransform): Parameters ---------- - serializable_forward_fn : function, no lambda + forward : function, no lambda Registered serializable function to transform the data in the forward pass. For the adapter to be serializable, this function has to be serializable as well (see Notes). Therefore, only proper functions and no lambda functions can be used here. - serializable_inverse_fn : function, no lambda + inverse : function, no lambda Function to transform the data in the inverse pass. For the adapter to be serializable, this function has to be serializable as well (see Notes). Therefore, only proper functions and no lambda @@ -47,28 +47,27 @@ class SerializableCustomTransform(ElementwiseTransform): def __init__( self, *, - serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], - serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], + forward: Callable[[np.ndarray, ...], np.ndarray], + inverse: Callable[[np.ndarray, ...], np.ndarray], ): super().__init__() - self._check_serializable(serializable_forward_fn, label="serializable_forward_fn") - self._check_serializable(serializable_inverse_fn, label="serializable_inverse_fn") - self._forward = serializable_forward_fn - self._inverse = serializable_inverse_fn + self._check_serializable(forward, label="forward") + self._check_serializable(inverse, label="inverse") + self._forward = forward + self._inverse = inverse @classmethod def _check_serializable(cls, function, label=""): - GENERAL_EXAMPLE_CODE = f"""The example code below shows the structure of a correctly decorated function: - -``` -import keras - -@keras.saving.register_keras_serializable('custom') -def my_{label}(...): - [your code goes here...] -``` -""" + GENERAL_EXAMPLE_CODE = ( + "The example code below shows the structure of a correctly decorated function:\n\n" + "```\n" + "import keras\n\n" + "@keras.saving.register_keras_serializable('custom')\n" + f"def my_{label}(...):\n" + " [your code goes here...]\n" + "```\n" + ) if function is None: raise TypeError( f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}" @@ -132,7 +131,7 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr raise TypeError( "\n\nPLEASE READ HERE:\n" "-----------------\n" - "The forward function that was provided as `serializable_forward_fn` " + "The forward function that was provided as `forward` " "is not registered with Keras, making deserialization impossible. " f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original " "function before loading your model." @@ -147,7 +146,7 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr raise TypeError( "\n\nPLEASE READ HERE:\n" "-----------------\n" - "The inverse function that was provided as `serializable_inverse_fn` " + "The inverse function that was provided as `inverse` " "is not registered with Keras, making deserialization impossible. " f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original " "function before loading your model." @@ -156,8 +155,8 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr forward = deserialize(config["forward"], custom_objects) inverse = deserialize(config["inverse"], custom_objects) return cls( - serializable_forward_fn=forward, - serializable_inverse_fn=inverse, + forward=forward, + inverse=inverse, ) def get_config(self) -> dict: diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index b8365e94a..d379d7cc4 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -25,9 +25,7 @@ def serializable_fn(x): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") - .apply_serializable( - include="x", serializable_forward_fn=serializable_fn, serializable_inverse_fn=serializable_fn - ) + .apply_serializable(include="x", forward=serializable_fn, inverse=serializable_fn) .scale("x", by=[-1, 2]) .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 69edb6e34..0d17c419f 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -140,31 +140,29 @@ def registered_but_changed(x): # noqa: F811 # method instead of function provided with pytest.raises(ValueError): - SerializableCustomTransform(serializable_forward_fn=A.fn, serializable_inverse_fn=registered_fn) - SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=A.fn) + SerializableCustomTransform(forward=A.fn, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=A.fn) # lambda function provided with pytest.raises(ValueError): - SerializableCustomTransform(serializable_forward_fn=lambda x: x, serializable_inverse_fn=registered_fn) - SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=lambda x: x) + SerializableCustomTransform(forward=lambda x: x, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=lambda x: x) # unregistered function provided with pytest.raises(ValueError): - SerializableCustomTransform(serializable_forward_fn=not_registered_fn, serializable_inverse_fn=registered_fn) - SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=not_registered_fn) + SerializableCustomTransform(forward=not_registered_fn, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=not_registered_fn) # function does not match registered function with pytest.raises(ValueError): - SerializableCustomTransform( - serializable_forward_fn=registered_but_changed, serializable_inverse_fn=registered_fn - ) - SerializableCustomTransform( - serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_but_changed - ) - - transform = SerializableCustomTransform( - serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_fn - ) + SerializableCustomTransform(forward=registered_but_changed, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=registered_but_changed) + + transform = SerializableCustomTransform(forward=registered_fn, inverse=registered_fn) serialized_transform = keras.saving.serialize_keras_object(transform) keras.saving.deserialize_keras_object(serialized_transform)