From 0548f4667ad69d3ba9e35df5d723b46e59823971 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 27 Apr 2025 13:27:32 +0000 Subject: [PATCH 1/5] document and expose serialization module * add functools.wraps call to allow_kwargs decorator, as before it was breaking the autodoc functionality --- bayesflow/utils/__init__.py | 3 +- bayesflow/utils/decorators.py | 1 + bayesflow/utils/serialization.py | 87 ++++++++++++++++++++++++++++---- 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 737c533ce..47ab771ff 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -6,6 +6,7 @@ keras_utils, logging, numpy_utils, + serialization, ) from .callbacks import detailed_loss_callback @@ -104,4 +105,4 @@ from ._docs import _add_imports_to_all -_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils"]) +_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"]) diff --git a/bayesflow/utils/decorators.py b/bayesflow/utils/decorators.py index 7fd32edc9..1283fe66a 100644 --- a/bayesflow/utils/decorators.py +++ b/bayesflow/utils/decorators.py @@ -17,6 +17,7 @@ def allow_args(fn: Decorator) -> Decorator: def wrapper(f: Fn) -> Fn: ... @overload def wrapper(*fargs: any, **fkwargs: any) -> Fn: ... + @wraps(fn) def wrapper(*fargs: any, **fkwargs: any) -> Fn: if len(fargs) == 1 and not fkwargs and callable(fargs[0]): # called without arguments diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index 500264f05..fb9320a7e 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -92,34 +92,83 @@ def deserialize_value_or_type(config, name): return updated_config -def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs): +def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): + """Deserialize an object serialized with :py:func:`serialize`. + + Wrapper function around `keras.saving.deserialize_keras_object` to enable deserialization of + classes. + + Parameters + ---------- + config : dict + Python dict describing the object. + custom_objects : dict, optional + Python dict containing a mapping between custom object names and the corresponding + classes or functions. Forwarded to `keras.saving.deserialize_keras_object`. + safe_mode : bool, optional + Boolean, whether to disallow unsafe lambda deserialization. When safe_mode=False, + loading an object has the potential to trigger arbitrary code execution. This argument + is only applicable to the Keras v3 model format. Defaults to True. + Forwarded to `keras.saving.deserialize_keras_object`. + + Returns + ------- + obj : + The object described by the config dictionary. + + Raises + ------ + ValueError + If a type in the config can not be deserialized. + + See Also + -------- + serialize + """ with monkey_patch(deserialize_keras_object, deserialize) as original_deserialize: - if isinstance(obj, str) and obj.startswith(_type_prefix): + if isinstance(config, str) and config.startswith(_type_prefix): # we marked this as a type during serialization - obj = obj[len(_type_prefix) :] + config = config[len(_type_prefix) :] tp = keras.saving.get_registered_object( # TODO: can we pass module objects without overwriting numpy's dict with builtins? - obj, + config, custom_objects=custom_objects, module_objects=np.__dict__ | builtins.__dict__, ) if tp is None: raise ValueError( - f"Could not deserialize type {obj!r}. Make sure it is registered with " + f"Could not deserialize type {config!r}. Make sure it is registered with " f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`." ) return tp - if inspect.isclass(obj): + if inspect.isclass(config): # add this base case since keras does not cover it - return obj + return config - obj = original_deserialize(obj, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs) + obj = original_deserialize(config, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs) return obj @allow_args -def serializable(cls, package=None, name=None): +def serializable(cls, package: str | None = None, name: str | None = None): + """Register class as Keras serialize. + + Wrapper function around `keras.saving.register_keras_serializable` to automatically + set the `package` and `name` arguments. + + Parameters + ---------- + cls : type + The class to register. + package : str, optional + `package` argument forwarded to `keras.saving.register_keras_serializable`. + If None is provided, the package is automatically inferred using the __name__ + attribute of the module the class resides in. + name : str, optional + `name` argument forwarded to `keras.saving.register_keras_serializable`. + If None is provided, the classe's __name__ attribute is used. + """ if package is None: frame = sys._getframe(1) g = frame.f_globals @@ -133,6 +182,26 @@ def serializable(cls, package=None, name=None): def serialize(obj): + """Serialize an object using Keras. + + Wrapper function around `keras.saving.serialize_keras_object`, which adds the + ability to serialize classes. + + Parameters + ---------- + object : Keras serializable object, or class + The object to serialize + + Returns + ------- + config : dict + A python dict that represents the object. The python dict can be deserialized via + :py:func:`deserialize`. + + See Also + -------- + deserialize + """ if isinstance(obj, (tuple, list, dict)): return keras.tree.map_structure(serialize, obj) elif inspect.isclass(obj): From 6a92a3183644f92c89c01db2aab48218d54782ae Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 27 Apr 2025 13:34:27 +0000 Subject: [PATCH 2/5] restructure and update developer docs * move content to separate pages * update section on serialization --- docsrc/source/development/index.md | 90 ++++------------------ docsrc/source/development/introduction.md | 12 +++ docsrc/source/development/pitfalls.md | 13 ++++ docsrc/source/development/serialization.md | 28 +++++++ docsrc/source/development/stages.md | 8 ++ 5 files changed, 74 insertions(+), 77 deletions(-) create mode 100644 docsrc/source/development/introduction.md create mode 100644 docsrc/source/development/pitfalls.md create mode 100644 docsrc/source/development/serialization.md create mode 100644 docsrc/source/development/stages.md diff --git a/docsrc/source/development/index.md b/docsrc/source/development/index.md index c62971532..adbadf21f 100644 --- a/docsrc/source/development/index.md +++ b/docsrc/source/development/index.md @@ -1,87 +1,23 @@ -# Patterns & Caveats +# Developer Documentation -**Note**: This document is part of BayesFlow's developer documentation, and +**Attention:** You are looking BayesFlow's developer documentation, which is aimed at people who want to extend or improve BayesFlow. For user documentation, -please refer to the examples and the public API documentation. +please refer to the {doc}`../examples` and the {doc}`../api/bayesflow`. -## Introduction - -From version 2 on, BayesFlow is built on [Keras](https://keras.io/) v3, which -allows writing machine learning pipelines that run in JAX, TensorFlow and PyTorch. -By using functionality provided by Keras, and extending it with backend-specific -code where necessary, we aim to build BayesFlow in a backend-agnostic fashion as -well. - -As Keras is built upon three different backend, each with different functionality -and design decisions, it has its own quirks and compromises. This documents -outlines some of them, along with the design decisions and programming patterns -we use to counter them. - -This document is work in progress, so if you read through the code base and +This section is work in progress, so if you read through the code base and encounter something that looks odd, but shows up in multiple places, please open an issue so that we can add it here. Also, if you introduce a new pattern that others will have to use in the future as well, please document it here, along with some background information on why it is necessary and how to use it in practice. -## Privileged `training` argument in the `call()` method cannot be passed via `kwargs` - -For layers that have different behavior at training and inference time (e.g., -dropout or batch normalization layers), a boolean `training` argument can be -exposed, see [this section of the Keras documentation](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#privileged-training-argument-in-the-call-method). -If we want to pass this manually, we have to do so explicitly and not as part -of a set of keyword arguments via `**kwargs`. - -@Lars: Maybe you can add more details on what is going on behind the scenes. - -## Serialization - -Serialization deals with the problem of storing objects to disk, and loading -them at a later point in time. This is straight-forward for data structures like -numpy arrays, but for classes with custom behavior, like approximators or neural -network layers, it is somewhat more complex. - -Please refer to the Keras guide [Save, serialize, and export models](https://keras.io/guides/serialization_and_saving/) -for an introduction, and [Customizing Saving and Serialization](https://keras.io/guides/customizing_saving_and_serialization/) -for advanced concepts. - -The basic idea is: by storing the arguments of the constructor of a class -(i.e., the arguments of the `__init__` function), we can later construct an -object identical to the one we have stored, except for the weights. -As the structure is identical, we can then map the stored weights to the newly -constructed object. The caveat is that all arguments have to be either basic -Python objects (like int, float, string, bool, ...) or themselves serializable. -If they are not, we have to manually specify how to serialize them, and how to -load them later on. - -### Registering classes as serializable - -TODO - -### Serialization of custom types - -In BayesFlow, we often encounter situations where we do not want to pass a -specific object (e.g., an MPL of a certain size), but we want to pass its type -(MLP) and the arguments to construct it. With the type and the arguments, we can -then construct multiple instances of the network in different places, for example -as the network inside a coupling block. - -Unfortunately, `type` is not Keras serializable, so we have to serialize those -arguments manually. To complicate matters further, we also allow passing a string -instead of a type, which is then used to select the correct type. - -To make it more concrete, we look at the `CouplingFlow` class, which takes the -argument `subnet` that provide the type of the subnet. It is either a -string (e.g., `"mlp"`) or a class (e.g., `bayesflow.networks.MLP`). In the first -case, we can just store the value and load it, in the latter case, we first have -to convert the type to a string that we can later convert back into a type. - -We provide two helper functions that can deal with both cases: -`bayesflow.utils.serialize_value_or_type(config, name, obj)` and -`bayesflow.utils.deserialize_value_or_type(config, name)`. -In `get_config`, we use the first to store the object, whereas we use the -latter in `from_config` to load it again. +```{toctree} +:maxdepth: 1 +:titlesonly: +:numbered: -As we need all arguments to `__init__` in `get_config`, it can make sense to -build a `config` dictionary in `__init__` already, which can then be stored when -`get_config` is called. Take a look at `CouplingFlow` for an example of that. +introduction +pitfalls +stages +serialization +``` diff --git a/docsrc/source/development/introduction.md b/docsrc/source/development/introduction.md new file mode 100644 index 000000000..a60830c2a --- /dev/null +++ b/docsrc/source/development/introduction.md @@ -0,0 +1,12 @@ +# Introduction + +From version 2 on, BayesFlow is built on [Keras3](https://keras.io/), which +allows writing machine learning pipelines that run in JAX, TensorFlow and PyTorch. +By using functionality provided by Keras, and extending it with backend-specific +code where necessary, we aim to build BayesFlow in a backend-agnostic fashion as +well. + +As Keras is built upon three different backends, each with different functionality +and design decisions, it comes with its own quirks and compromises. The following documents +outline some of them, along with the design decisions and programming patterns +we use to counter them. diff --git a/docsrc/source/development/pitfalls.md b/docsrc/source/development/pitfalls.md new file mode 100644 index 000000000..69d183ec1 --- /dev/null +++ b/docsrc/source/development/pitfalls.md @@ -0,0 +1,13 @@ +# Potential Pitfalls + +This document covers things we have learned during development that might cause problems or hard to find bugs. + +## Privileged `training` argument in the `call()` method cannot be passed via `kwargs` + +For layers that have different behavior at training and inference time (e.g., +dropout or batch normalization layers), a boolean `training` argument can be +exposed, see [this section of the Keras documentation](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#privileged-training-argument-in-the-call-method). +If we want to pass this manually, we have to do so explicitly and not as part +of a set of keyword arguments via `**kwargs`. + +@Lars: Maybe you can add more details on what is going on behind the scenes. diff --git a/docsrc/source/development/serialization.md b/docsrc/source/development/serialization.md new file mode 100644 index 000000000..ddcf1074d --- /dev/null +++ b/docsrc/source/development/serialization.md @@ -0,0 +1,28 @@ +# Serialization: Enable Model Saving & Loading + +Serialization deals with the problem of storing objects to disk, and loading them at a later point in time. +This is straight-forward for data structures like numpy arrays, but for classes with custom behavior it is somewhat more complex. + +Please refer to the Keras guide [Save, serialize, and export models](https://keras.io/guides/serialization_and_saving/) for an introduction, and [Customizing Saving and Serialization](https://keras.io/guides/customizing_saving_and_serialization/) for advanced concepts. + +The basic idea is: by storing the arguments of the constructor of a class (i.e., the arguments of the `__init__` function), we can later construct an object similar to the one we have stored, except for the weights and other stateful content. +As the structure is identical, we can then map the stored weights to the newly constructed object. +The caveat is that all arguments have to be either basic Python objects (like int, float, string, bool, ...) or themselves serializable. +If they are not, we have to manually specify how to serialize them, and how to load them later on. +One important example is that types are not serializable. +As we want/need to pass them in some places, we have to resort to some custom behavior, that is described below. + +## Serialization Utilities + +BayesFlows serialization utilities can be found in the {py:mod}`~bayesflow.utils.serialization` module. +We mainly provide three convenience functions: + +- The {py:func}`~bayesflow.utils.serialization.serializable` decorator wraps the `keras.saving.register_keras_serializable` function to provide automatic `package` and `name` arguments. +- The {py:func}`~bayesflow.utils.serialization.serialize` function, which adds support for serializing classes. +- Its counterpart {py:func}`~bayesflow.utils.serialization.deserialize`, adds support to deserialize classes. + +_Note: The `(de)serialize_value_or_type` functions are made obsolete by the functions given above and will probably be deprecated soon._ + +## Usage + +To use the adapted serialization functions, you have to use them in the `get_config` and `from_config` method. Please refer to existing classes in the library for usage examples. diff --git a/docsrc/source/development/stages.md b/docsrc/source/development/stages.md new file mode 100644 index 000000000..e9aa4ad8f --- /dev/null +++ b/docsrc/source/development/stages.md @@ -0,0 +1,8 @@ +# Stages + +To keep track of the phase each functionality is called in, we provide a `stage` parameter. +There are three stages: + +- `training`: The stage to train approximator (and related stateful objects, like the adapter) +- `validation`: Identical setting to `training`, but calls in this stage should _not_ change the approximator +- `inference`: Calls in this change should not change the approximator. In addition, the input structure might be different compared to the training phase. For example for sampling, we only provide `summary_conditions` and `inference_conditions`, but not the `inference_variables`, which we want to infer. From f1f6329afbd28a6915762798937a3f5eb583d830 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 27 Apr 2025 13:44:43 +0000 Subject: [PATCH 3/5] ci: update pip via python -m pip pip install -U pip setuptools wheel leads to an error: https://github.com/bayesflow-org/bayesflow/actions/runs/14692655483/job/41230057180?pr=449 --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index ab3d03078..eee02895c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -43,7 +43,7 @@ jobs: - name: Install Dependencies run: | - pip install -U pip setuptools wheel + python -m pip install -U pip setuptools wheel pip install .[test] - name: Install JAX From a1b4d19d53474a7d2dfa8d6c2875550b1e1bfea1 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 27 Apr 2025 14:23:30 +0000 Subject: [PATCH 4/5] serializable: increase depth in sys._getframe The functools.wrap decorator adds a frame object to the call stack --- bayesflow/utils/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index fb9320a7e..01b1cb3ec 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -170,7 +170,7 @@ def serializable(cls, package: str | None = None, name: str | None = None): If None is provided, the classe's __name__ attribute is used. """ if package is None: - frame = sys._getframe(1) + frame = sys._getframe(2) g = frame.f_globals package = g.get("__name__", "bayesflow") From 06e3352388226db1fdff043675834be3c06c3f78 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 27 Apr 2025 20:32:12 +0000 Subject: [PATCH 5/5] deprecate (de)serialize_value_or_type - add deprecation warning, remove functionality - replace all occurences with the corresponding new functions --- bayesflow/networks/point_inference_network.py | 14 ++-- bayesflow/scores/scoring_rule.py | 19 ++--- bayesflow/utils/serialization.py | 81 +++---------------- docsrc/source/development/serialization.md | 2 - 4 files changed, 24 insertions(+), 92 deletions(-) diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py index 3b1699e5a..63094a2a8 100644 --- a/bayesflow/networks/point_inference_network.py +++ b/bayesflow/networks/point_inference_network.py @@ -1,11 +1,7 @@ import keras -from keras.saving import ( - deserialize_keras_object as deserialize, - serialize_keras_object as serialize, - register_keras_serializable as serializable, -) -from bayesflow.utils import model_kwargs, find_network, serialize_value_or_type, deserialize_value_or_type +from bayesflow.utils import model_kwargs, find_network +from bayesflow.utils.serialization import deserialize, serializable, serialize from bayesflow.types import Shape, Tensor from bayesflow.scores import ScoringRule, ParametricDistributionScore from bayesflow.utils.decorators import allow_batch_size @@ -30,10 +26,10 @@ def __init__( self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {})) self.config = { + "subnet": serialize(subnet), + "scores": serialize(scores), **kwargs, } - self.config = serialize_value_or_type(self.config, "subnet", subnet) - self.config["scores"] = serialize(self.scores) def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: """Builds all network components based on shapes of conditions and targets. @@ -119,7 +115,7 @@ def get_config(self): def from_config(cls, config): config = config.copy() config["scores"] = deserialize(config["scores"]) - config = deserialize_value_or_type(config, "subnet") + config["subnet"] = deserialize(config["subnet"]) return cls(**config) def call( diff --git a/bayesflow/scores/scoring_rule.py b/bayesflow/scores/scoring_rule.py index a1a3f5717..0144de458 100644 --- a/bayesflow/scores/scoring_rule.py +++ b/bayesflow/scores/scoring_rule.py @@ -1,10 +1,10 @@ import math import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor -from bayesflow.utils import find_network, serialize_value_or_type, deserialize_value_or_type +from bayesflow.utils import find_network +from bayesflow.utils.serialization import deserialize, serializable, serialize @serializable(package="bayesflow.scores") @@ -51,23 +51,16 @@ def __init__( self.config = {"subnets_kwargs": self.subnets_kwargs} def get_config(self): - self.config["subnets"] = { - key: serialize_value_or_type({}, "subnet", subnet) for key, subnet in self.subnets.items() - } - self.config["links"] = {key: serialize_value_or_type({}, "link", link) for key, link in self.links.items()} + self.config["subnets"] = {key: serialize(subnet) for key, subnet in self.subnets.items()} + self.config["links"] = {key: serialize(link) for key, link in self.links.items()} return self.config @classmethod def from_config(cls, config): config = config.copy() - config["subnets"] = { - key: deserialize_value_or_type(subnet_dict, "subnet")["subnet"] - for key, subnet_dict in config["subnets"].items() - } - config["links"] = { - key: deserialize_value_or_type(link_dict, "link")["link"] for key, link_dict in config["links"].items() - } + config["subnets"] = {key: deserialize(subnet_dict) for key, subnet_dict in config["subnets"].items()} + config["links"] = {key: deserialize(link_dict) for key, link_dict in config["links"].items()} return cls(**config) diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index 01b1cb3ec..bb55aee41 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -5,6 +5,7 @@ import keras import numpy as np import sys +from warnings import warn # this import needs to be exactly like this to work with monkey patching from keras.saving import deserialize_keras_object @@ -19,77 +20,21 @@ def serialize_value_or_type(config, name, obj): - """Serialize an object that can be either a value or a type - and add it to a copy of the supplied dictionary. - - Parameters - ---------- - config : dict - Dictionary to add the serialized object to. This function does not - modify the dictionary in place, but returns a modified copy. - name : str - Name of the obj that should be stored. Required for later deserialization. - obj : object or type - The object to serialize. If `obj` is of type `type`, we use - `keras.saving.get_registered_name` to obtain the registered type name. - If it is not a type, we try to serialize it as a Keras object. - - Returns - ------- - updated_config : dict - Updated dictionary with a new key `"_bayesflow__type"` or - `"_bayesflow__val"`. The prefix is used to avoid name collisions, - the suffix indicates how the stored value has to be deserialized. - - Notes - ----- - We allow strings or `type` parameters at several places to instantiate objects - of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot - be serialized, we have to distinguish the two cases for serialization and - deserialization. This function is a helper function to standardize and - simplify this. - """ - updated_config = config.copy() - if isinstance(obj, type): - updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj) - else: - updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj) - return updated_config + """This function is deprecated.""" + warn( + "This method is deprecated. It was replaced by bayesflow.utils.serialization.serialize.", + DeprecationWarning, + stacklevel=2, + ) def deserialize_value_or_type(config, name): - """Deserialize an object that can be either a value or a type and add - it to the supplied dictionary. - - Parameters - ---------- - config : dict - Dictionary containing the object to deserialize. If a type was - serialized, it should contain the key `"_bayesflow__type"`. - If an object was serialized, it should contain the key - `"_bayesflow__val"`. In a copy of this dictionary, - the item will be replaced with the key `name`. - name : str - Name of the object to deserialize. - - Returns - ------- - updated_config : dict - Updated dictionary with a new key `name`, with a value that is either - a type or an object. - - See Also - -------- - serialize_value_or_type - """ - updated_config = config.copy() - if f"{PREFIX}{name}_type" in config: - updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"]) - del updated_config[f"{PREFIX}{name}_type"] - elif f"{PREFIX}{name}_val" in config: - updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"]) - del updated_config[f"{PREFIX}{name}_val"] - return updated_config + """This function is deprecated.""" + warn( + "This method is deprecated. It was replaced by bayesflow.utils.serialization.deserialize.", + DeprecationWarning, + stacklevel=2, + ) def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): diff --git a/docsrc/source/development/serialization.md b/docsrc/source/development/serialization.md index ddcf1074d..15c812686 100644 --- a/docsrc/source/development/serialization.md +++ b/docsrc/source/development/serialization.md @@ -21,8 +21,6 @@ We mainly provide three convenience functions: - The {py:func}`~bayesflow.utils.serialization.serialize` function, which adds support for serializing classes. - Its counterpart {py:func}`~bayesflow.utils.serialization.deserialize`, adds support to deserialize classes. -_Note: The `(de)serialize_value_or_type` functions are made obsolete by the functions given above and will probably be deprecated soon._ - ## Usage To use the adapted serialization functions, you have to use them in the `get_config` and `from_config` method. Please refer to existing classes in the library for usage examples.