Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:

- name: Install Dependencies
run: |
pip install -U pip setuptools wheel
python -m pip install -U pip setuptools wheel
pip install .[test]

- name: Install JAX
Expand Down
14 changes: 5 additions & 9 deletions bayesflow/networks/point_inference_network.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import keras
from keras.saving import (
deserialize_keras_object as deserialize,
serialize_keras_object as serialize,
register_keras_serializable as serializable,
)

from bayesflow.utils import model_kwargs, find_network, serialize_value_or_type, deserialize_value_or_type
from bayesflow.utils import model_kwargs, find_network
from bayesflow.utils.serialization import deserialize, serializable, serialize
from bayesflow.types import Shape, Tensor
from bayesflow.scores import ScoringRule, ParametricDistributionScore
from bayesflow.utils.decorators import allow_batch_size
Expand All @@ -30,10 +26,10 @@ def __init__(
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))

self.config = {
"subnet": serialize(subnet),
"scores": serialize(scores),
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)
self.config["scores"] = serialize(self.scores)

def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
"""Builds all network components based on shapes of conditions and targets.
Expand Down Expand Up @@ -119,7 +115,7 @@ def get_config(self):
def from_config(cls, config):
config = config.copy()
config["scores"] = deserialize(config["scores"])
config = deserialize_value_or_type(config, "subnet")
config["subnet"] = deserialize(config["subnet"])
return cls(**config)

def call(
Expand Down
19 changes: 6 additions & 13 deletions bayesflow/scores/scoring_rule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import math

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import find_network, serialize_value_or_type, deserialize_value_or_type
from bayesflow.utils import find_network
from bayesflow.utils.serialization import deserialize, serializable, serialize


@serializable(package="bayesflow.scores")
Expand Down Expand Up @@ -51,23 +51,16 @@ def __init__(
self.config = {"subnets_kwargs": self.subnets_kwargs}

def get_config(self):
self.config["subnets"] = {
key: serialize_value_or_type({}, "subnet", subnet) for key, subnet in self.subnets.items()
}
self.config["links"] = {key: serialize_value_or_type({}, "link", link) for key, link in self.links.items()}
self.config["subnets"] = {key: serialize(subnet) for key, subnet in self.subnets.items()}
self.config["links"] = {key: serialize(link) for key, link in self.links.items()}

return self.config

@classmethod
def from_config(cls, config):
config = config.copy()
config["subnets"] = {
key: deserialize_value_or_type(subnet_dict, "subnet")["subnet"]
for key, subnet_dict in config["subnets"].items()
}
config["links"] = {
key: deserialize_value_or_type(link_dict, "link")["link"] for key, link_dict in config["links"].items()
}
config["subnets"] = {key: deserialize(subnet_dict) for key, subnet_dict in config["subnets"].items()}
config["links"] = {key: deserialize(link_dict) for key, link_dict in config["links"].items()}

return cls(**config)

Expand Down
3 changes: 2 additions & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
keras_utils,
logging,
numpy_utils,
serialization,
)

from .callbacks import detailed_loss_callback
Expand Down Expand Up @@ -104,4 +105,4 @@

from ._docs import _add_imports_to_all

_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils"])
_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"])
1 change: 1 addition & 0 deletions bayesflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def allow_args(fn: Decorator) -> Decorator:
def wrapper(f: Fn) -> Fn: ...
@overload
def wrapper(*fargs: any, **fkwargs: any) -> Fn: ...
@wraps(fn)
def wrapper(*fargs: any, **fkwargs: any) -> Fn:
if len(fargs) == 1 and not fkwargs and callable(fargs[0]):
# called without arguments
Expand Down
152 changes: 83 additions & 69 deletions bayesflow/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import keras
import numpy as np
import sys
from warnings import warn

# this import needs to be exactly like this to work with monkey patching
from keras.saving import deserialize_keras_object
Expand All @@ -19,109 +20,102 @@


def serialize_value_or_type(config, name, obj):
"""Serialize an object that can be either a value or a type
and add it to a copy of the supplied dictionary.
"""This function is deprecated."""
warn(

Check warning on line 24 in bayesflow/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/serialization.py#L24

Added line #L24 was not covered by tests
"This method is deprecated. It was replaced by bayesflow.utils.serialization.serialize.",
DeprecationWarning,
stacklevel=2,
)

Parameters
----------
config : dict
Dictionary to add the serialized object to. This function does not
modify the dictionary in place, but returns a modified copy.
name : str
Name of the obj that should be stored. Required for later deserialization.
obj : object or type
The object to serialize. If `obj` is of type `type`, we use
`keras.saving.get_registered_name` to obtain the registered type name.
If it is not a type, we try to serialize it as a Keras object.

Returns
-------
updated_config : dict
Updated dictionary with a new key `"_bayesflow_<name>_type"` or
`"_bayesflow_<name>_val"`. The prefix is used to avoid name collisions,
the suffix indicates how the stored value has to be deserialized.

Notes
-----
We allow strings or `type` parameters at several places to instantiate objects
of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot
be serialized, we have to distinguish the two cases for serialization and
deserialization. This function is a helper function to standardize and
simplify this.
"""
updated_config = config.copy()
if isinstance(obj, type):
updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj)
else:
updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj)
return updated_config
def deserialize_value_or_type(config, name):
"""This function is deprecated."""
warn(

Check warning on line 33 in bayesflow/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/serialization.py#L33

Added line #L33 was not covered by tests
"This method is deprecated. It was replaced by bayesflow.utils.serialization.deserialize.",
DeprecationWarning,
stacklevel=2,
)


def deserialize_value_or_type(config, name):
"""Deserialize an object that can be either a value or a type and add
it to the supplied dictionary.
def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs):
"""Deserialize an object serialized with :py:func:`serialize`.

Wrapper function around `keras.saving.deserialize_keras_object` to enable deserialization of
classes.

Parameters
----------
config : dict
Dictionary containing the object to deserialize. If a type was
serialized, it should contain the key `"_bayesflow_<name>_type"`.
If an object was serialized, it should contain the key
`"_bayesflow_<name>_val"`. In a copy of this dictionary,
the item will be replaced with the key `name`.
name : str
Name of the object to deserialize.
config : dict
Python dict describing the object.
custom_objects : dict, optional
Python dict containing a mapping between custom object names and the corresponding
classes or functions. Forwarded to `keras.saving.deserialize_keras_object`.
safe_mode : bool, optional
Boolean, whether to disallow unsafe lambda deserialization. When safe_mode=False,
loading an object has the potential to trigger arbitrary code execution. This argument
is only applicable to the Keras v3 model format. Defaults to True.
Forwarded to `keras.saving.deserialize_keras_object`.

Returns
-------
updated_config : dict
Updated dictionary with a new key `name`, with a value that is either
a type or an object.
obj :
The object described by the config dictionary.

Raises
------
ValueError
If a type in the config can not be deserialized.

See Also
--------
serialize_value_or_type
serialize
"""
updated_config = config.copy()
if f"{PREFIX}{name}_type" in config:
updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"])
del updated_config[f"{PREFIX}{name}_type"]
elif f"{PREFIX}{name}_val" in config:
updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"])
del updated_config[f"{PREFIX}{name}_val"]
return updated_config


def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
with monkey_patch(deserialize_keras_object, deserialize) as original_deserialize:
if isinstance(obj, str) and obj.startswith(_type_prefix):
if isinstance(config, str) and config.startswith(_type_prefix):
# we marked this as a type during serialization
obj = obj[len(_type_prefix) :]
config = config[len(_type_prefix) :]
tp = keras.saving.get_registered_object(
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
obj,
config,
custom_objects=custom_objects,
module_objects=np.__dict__ | builtins.__dict__,
)
if tp is None:
raise ValueError(
f"Could not deserialize type {obj!r}. Make sure it is registered with "
f"Could not deserialize type {config!r}. Make sure it is registered with "
f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`."
)
return tp
if inspect.isclass(obj):
if inspect.isclass(config):
# add this base case since keras does not cover it
return obj
return config

Check warning on line 91 in bayesflow/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/serialization.py#L91

Added line #L91 was not covered by tests

obj = original_deserialize(obj, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)
obj = original_deserialize(config, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)

return obj


@allow_args
def serializable(cls, package=None, name=None):
def serializable(cls, package: str | None = None, name: str | None = None):
"""Register class as Keras serialize.

Wrapper function around `keras.saving.register_keras_serializable` to automatically
set the `package` and `name` arguments.

Parameters
----------
cls : type
The class to register.
package : str, optional
`package` argument forwarded to `keras.saving.register_keras_serializable`.
If None is provided, the package is automatically inferred using the __name__
attribute of the module the class resides in.
name : str, optional
`name` argument forwarded to `keras.saving.register_keras_serializable`.
If None is provided, the classe's __name__ attribute is used.
"""
if package is None:
frame = sys._getframe(1)
frame = sys._getframe(2)
g = frame.f_globals
package = g.get("__name__", "bayesflow")

Expand All @@ -133,6 +127,26 @@


def serialize(obj):
"""Serialize an object using Keras.

Wrapper function around `keras.saving.serialize_keras_object`, which adds the
ability to serialize classes.

Parameters
----------
object : Keras serializable object, or class
The object to serialize

Returns
-------
config : dict
A python dict that represents the object. The python dict can be deserialized via
:py:func:`deserialize`.

See Also
--------
deserialize
"""
if isinstance(obj, (tuple, list, dict)):
return keras.tree.map_structure(serialize, obj)
elif inspect.isclass(obj):
Expand Down
Loading