From af707dc43e186989bfe445db133a58b1e02cc0bf Mon Sep 17 00:00:00 2001 From: LarsKue Date: Wed, 12 Mar 2025 14:45:16 +0100 Subject: [PATCH 01/10] update requirement --- environment.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yaml b/environment.yaml index 6afae599b..c194d8002 100644 --- a/environment.yaml +++ b/environment.yaml @@ -4,7 +4,7 @@ channels: dependencies: - jupyter - jupyterlab - - keras ~= 3.7.0 + - keras ~= 3.9.0 - numpy ~= 1.26 - matplotlib - pre-commit diff --git a/pyproject.toml b/pyproject.toml index c69de6a7d..d46df283d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ license = { file = "LICENSE" } requires-python = ">= 3.10, < 3.12" dependencies = [ - "keras ~= 3.7.0", + "keras ~= 3.9.0", "numpy ~= 1.26.4", "scipy ~= 1.14.1", "matplotlib", From a1354ddc1a1629aa7e959959f9c63c118f1a884b Mon Sep 17 00:00:00 2001 From: LarsKue Date: Wed, 12 Mar 2025 14:45:28 +0100 Subject: [PATCH 02/10] fix adapter serialization --- bayesflow/adapters/transforms/as_set.py | 2 ++ bayesflow/adapters/transforms/expand_dims.py | 2 ++ bayesflow/adapters/transforms/log.py | 2 ++ bayesflow/adapters/transforms/sqrt.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/bayesflow/adapters/transforms/as_set.py b/bayesflow/adapters/transforms/as_set.py index 2e0fe86e1..2eeeb2bd1 100644 --- a/bayesflow/adapters/transforms/as_set.py +++ b/bayesflow/adapters/transforms/as_set.py @@ -1,8 +1,10 @@ +from keras.saving import register_keras_serializable as serializable import numpy as np from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class AsSet(ElementwiseTransform): """The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets. diff --git a/bayesflow/adapters/transforms/expand_dims.py b/bayesflow/adapters/transforms/expand_dims.py index 6a9519d8e..12796b05a 100644 --- a/bayesflow/adapters/transforms/expand_dims.py +++ b/bayesflow/adapters/transforms/expand_dims.py @@ -1,3 +1,4 @@ +from keras.saving import register_keras_serializable as serializable import numpy as np from keras.saving import ( @@ -8,6 +9,7 @@ from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class ExpandDims(ElementwiseTransform): """ Expand the shape of an array. diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index cefe468b2..8b8b61c17 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -1,3 +1,4 @@ +from keras.saving import register_keras_serializable as serializable import numpy as np from keras.saving import ( @@ -8,6 +9,7 @@ from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class Log(ElementwiseTransform): """Log transforms a variable. diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index 88bb81a08..8f3b069b9 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -1,8 +1,10 @@ +from keras.saving import register_keras_serializable as serializable import numpy as np from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class Sqrt(ElementwiseTransform): """Square-root transform a variable. From 6b03bf1f9df8641f9b03e325bcbe1dddbae776c6 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 18 Mar 2025 12:10:05 +0100 Subject: [PATCH 03/10] run linter and clean up --- bayesflow/adapters/transforms/as_set.py | 1 - bayesflow/adapters/transforms/expand_dims.py | 1 - bayesflow/adapters/transforms/log.py | 1 - bayesflow/adapters/transforms/sqrt.py | 1 - 4 files changed, 4 deletions(-) diff --git a/bayesflow/adapters/transforms/as_set.py b/bayesflow/adapters/transforms/as_set.py index b67637149..2eeeb2bd1 100644 --- a/bayesflow/adapters/transforms/as_set.py +++ b/bayesflow/adapters/transforms/as_set.py @@ -1,6 +1,5 @@ from keras.saving import register_keras_serializable as serializable import numpy as np -from keras.saving import register_keras_serializable as serializable from .elementwise_transform import ElementwiseTransform diff --git a/bayesflow/adapters/transforms/expand_dims.py b/bayesflow/adapters/transforms/expand_dims.py index b18f34be6..eb4d712f4 100644 --- a/bayesflow/adapters/transforms/expand_dims.py +++ b/bayesflow/adapters/transforms/expand_dims.py @@ -1,4 +1,3 @@ -from keras.saving import register_keras_serializable as serializable import numpy as np from keras.saving import ( deserialize_keras_object as deserialize, diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index 81d4130c9..e264fccfa 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -1,4 +1,3 @@ -from keras.saving import register_keras_serializable as serializable import numpy as np from keras.saving import ( deserialize_keras_object as deserialize, diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index fb064d35b..8f3b069b9 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -1,6 +1,5 @@ from keras.saving import register_keras_serializable as serializable import numpy as np -from keras.saving import register_keras_serializable as serializable from .elementwise_transform import ElementwiseTransform From a9a760aa5b46e2e5c86d475e3b832c53f697a1a3 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 18 Mar 2025 12:12:48 +0100 Subject: [PATCH 04/10] make scores serializable --- bayesflow/scores/mean_score.py | 3 +++ bayesflow/scores/median_score.py | 3 +++ bayesflow/scores/multivariate_normal_score.py | 2 ++ bayesflow/scores/normed_difference_score.py | 2 ++ bayesflow/scores/parametric_distribution_score.py | 3 +++ bayesflow/scores/quantile_score.py | 2 ++ bayesflow/scores/scoring_rule.py | 2 ++ 7 files changed, 17 insertions(+) diff --git a/bayesflow/scores/mean_score.py b/bayesflow/scores/mean_score.py index b4f0d265c..553a7c3af 100644 --- a/bayesflow/scores/mean_score.py +++ b/bayesflow/scores/mean_score.py @@ -1,6 +1,9 @@ +from keras.saving import register_keras_serializable as serializable + from .normed_difference_score import NormedDifferenceScore +@serializable(package="bayesflow.scores") class MeanScore(NormedDifferenceScore): r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |^2` diff --git a/bayesflow/scores/median_score.py b/bayesflow/scores/median_score.py index a7ce56845..10c8809c3 100644 --- a/bayesflow/scores/median_score.py +++ b/bayesflow/scores/median_score.py @@ -1,6 +1,9 @@ +from keras.saving import register_keras_serializable as serializable + from .normed_difference_score import NormedDifferenceScore +@serializable(package="bayesflow.scores") class MedianScore(NormedDifferenceScore): r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |` diff --git a/bayesflow/scores/multivariate_normal_score.py b/bayesflow/scores/multivariate_normal_score.py index 5b8de9a88..96657986e 100644 --- a/bayesflow/scores/multivariate_normal_score.py +++ b/bayesflow/scores/multivariate_normal_score.py @@ -1,6 +1,7 @@ import math import keras +from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.links import PositiveSemiDefinite @@ -9,6 +10,7 @@ from .parametric_distribution_score import ParametricDistributionScore +@serializable(package="bayesflow.scores") class MultivariateNormalScore(ParametricDistributionScore): r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = \log( \mathcal N (\theta; \mu, \Sigma))` diff --git a/bayesflow/scores/normed_difference_score.py b/bayesflow/scores/normed_difference_score.py index 9c7116446..6b9dbe163 100644 --- a/bayesflow/scores/normed_difference_score.py +++ b/bayesflow/scores/normed_difference_score.py @@ -1,10 +1,12 @@ import keras +from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from .scoring_rule import ScoringRule +@serializable(package="bayesflow.scores") class NormedDifferenceScore(ScoringRule): r""":math:`S(\hat \theta, \theta; k) = | \hat \theta - \theta |^k` diff --git a/bayesflow/scores/parametric_distribution_score.py b/bayesflow/scores/parametric_distribution_score.py index e5aaacc65..51cef1776 100644 --- a/bayesflow/scores/parametric_distribution_score.py +++ b/bayesflow/scores/parametric_distribution_score.py @@ -1,8 +1,11 @@ +from keras.saving import register_keras_serializable as serializable + from bayesflow.types import Tensor from .scoring_rule import ScoringRule +@serializable(package="bayesflow.scores") class ParametricDistributionScore(ScoringRule): r""":math:`S(\hat p_\phi, \theta; k) = \log(\hat p_\phi(\theta))` diff --git a/bayesflow/scores/quantile_score.py b/bayesflow/scores/quantile_score.py index 811f80d9d..2e3ec54ef 100644 --- a/bayesflow/scores/quantile_score.py +++ b/bayesflow/scores/quantile_score.py @@ -1,6 +1,7 @@ from typing import Sequence import keras +from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.utils import logging @@ -9,6 +10,7 @@ from .scoring_rule import ScoringRule +@serializable(package="bayesflow.scores") class QuantileScore(ScoringRule): r""":math:`S(\hat \theta_i, \theta; \tau_i) = (\hat \theta_i - \theta)(\mathbf{1}_{\hat \theta - \theta > 0} - \tau_i)` diff --git a/bayesflow/scores/scoring_rule.py b/bayesflow/scores/scoring_rule.py index e7f7385a5..ef0645cc1 100644 --- a/bayesflow/scores/scoring_rule.py +++ b/bayesflow/scores/scoring_rule.py @@ -1,11 +1,13 @@ import math import keras +from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.utils import find_network, serialize_value_or_type, deserialize_value_or_type +@serializable(package="bayesflow.scores") class ScoringRule: """Base class for scoring rules. From 559b98d3dd7faddaf3910fcd4827e7ef3aded568 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 18 Mar 2025 14:22:03 +0100 Subject: [PATCH 05/10] fix unserializable link --- bayesflow/links/positive_semi_definite.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bayesflow/links/positive_semi_definite.py b/bayesflow/links/positive_semi_definite.py index b447de57b..a2215b6c6 100644 --- a/bayesflow/links/positive_semi_definite.py +++ b/bayesflow/links/positive_semi_definite.py @@ -1,8 +1,10 @@ import keras +from keras.saving import register_keras_serializable as serializable from bayesflow.utils import keras_kwargs +@serializable(package="bayesflow.links") class PositiveSemiDefinite(keras.Layer): """Activation function to link from any square matrix to a positive semidefinite matrix.""" From 5c53b2eb56a233d042a816b77a48487cc5317fb2 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 18 Mar 2025 14:22:11 +0100 Subject: [PATCH 06/10] use runtimeerror instead of assert --- bayesflow/links/ordered.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bayesflow/links/ordered.py b/bayesflow/links/ordered.py index c683be08c..420555833 100644 --- a/bayesflow/links/ordered.py +++ b/bayesflow/links/ordered.py @@ -12,6 +12,7 @@ def __init__(self, axis: int, anchor_index: int, **kwargs): super().__init__(**keras_kwargs(kwargs)) self.axis = axis self.anchor_index = anchor_index + self.group_indices = None self.config = {"axis": axis, "anchor_index": anchor_index, **kwargs} @@ -22,9 +23,9 @@ def get_config(self): def build(self, input_shape): super().build(input_shape) - assert self.anchor_index % input_shape[self.axis] != 0 and self.anchor_index != -1, ( - "anchor should not be first or last index." - ) + if self.anchor_index % input_shape[self.axis] != 0 and self.anchor_index != -1: + raise RuntimeError("Anchor should not be first or last index.") + self.group_indices = dict( below=list(range(0, self.anchor_index)), above=list(range(self.anchor_index + 1, input_shape[self.axis])), From 818f517b74aa996391bd2f6d63931205f15413a9 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 18 Mar 2025 14:23:18 +0100 Subject: [PATCH 07/10] use raise instead of assert --- bayesflow/utils/_docs/_populate_all.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bayesflow/utils/_docs/_populate_all.py b/bayesflow/utils/_docs/_populate_all.py index 28cce4265..50da8048b 100644 --- a/bayesflow/utils/_docs/_populate_all.py +++ b/bayesflow/utils/_docs/_populate_all.py @@ -3,7 +3,9 @@ def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list[str] | None = None): """Add all global variables to __all__""" - assert type(include_modules) in [bool, list] + if not isinstance(include_modules, (bool, list)): + raise ValueError("include_modules must be a boolean or a list of strings") + exclude = exclude or [] calling_module = inspect.stack()[1] local_stack = calling_module[0] From a807aee82d45c6b034f133c1d6e695cffb3ffc5f Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 18 Mar 2025 14:31:30 +0100 Subject: [PATCH 08/10] clean up asserts -> raise also make some code backend agnostic --- bayesflow/links/ordered_quantiles.py | 21 ++++++++++--------- .../consistency_models/consistency_model.py | 5 ++++- .../networks/embeddings/fourier_embedding.py | 5 ++++- bayesflow/scores/multivariate_normal_score.py | 11 +++++++--- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/bayesflow/links/ordered_quantiles.py b/bayesflow/links/ordered_quantiles.py index 944daf7bb..a487a6a01 100644 --- a/bayesflow/links/ordered_quantiles.py +++ b/bayesflow/links/ordered_quantiles.py @@ -44,16 +44,17 @@ def build(self, input_shape): else: # choose quantile level closest to median as anchor index self.anchor_index = keras.ops.argmin(keras.ops.abs(keras.ops.convert_to_tensor(self.q) - 0.5)) - msg = ( - "Length of `q` does not coincide with input shape: " - f"len(q)={len(self.q)}, position {self.axis} of shape={input_shape}" - ) - assert num_quantile_levels == len(self.q), msg - msg = ( - "The link function `OrderedQuantiles` expects at least 3 quantile levels," - f" but only {num_quantile_levels} were given." - ) - assert self.anchor_index not in (0, -1, num_quantile_levels - 1), msg + if len(self.q) != num_quantile_levels: + raise RuntimeError( + f"Length of `q` does not coincide with input shape: len(q)={len(self.q)}, " + f"position {self.axis} of shape={input_shape}" + ) + + if self.anchor_index not in [0, -1, num_quantile_levels - 1]: + raise RuntimeError( + f"The link function `OrderedQuantiles` expects at least 3 quantile levels, " + f"but only {num_quantile_levels} were given." + ) super().build(input_shape) diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index af3b8aa2e..9057ca07b 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -177,7 +177,10 @@ def build(self, xz_shape, conditions_shape=None): # First, we calculate all unique numbers of discretization steps n # in a loop, as self.total_steps might be large self.max_n = int(self._schedule_discretization(self.total_steps)) - assert self.max_n == self.s1 + 1 + + if self.max_n != self.s1 + 1: + raise ValueError("The maximum number of discretization steps must be equal to s1 + 1.") + unique_n = set() for step in range(int(self.total_steps)): unique_n.add(int(self._schedule_discretization(step))) diff --git a/bayesflow/networks/embeddings/fourier_embedding.py b/bayesflow/networks/embeddings/fourier_embedding.py index 5db7b0f6d..b43253995 100644 --- a/bayesflow/networks/embeddings/fourier_embedding.py +++ b/bayesflow/networks/embeddings/fourier_embedding.py @@ -39,7 +39,10 @@ def __init__( """ super().__init__(**kwargs) - assert embed_dim % 2 == 0, f"Embedding dimension must be even, but is {embed_dim}." + + if embed_dim % 2 != 0: + raise ValueError(f"Embedding dimension must be even, but is {embed_dim}.") + self.w = self.add_weight(initializer=initializer, shape=(embed_dim // 2,), trainable=trainable) self.scale = scale self.embed_dim = embed_dim diff --git a/bayesflow/scores/multivariate_normal_score.py b/bayesflow/scores/multivariate_normal_score.py index 96657986e..66153fd34 100644 --- a/bayesflow/scores/multivariate_normal_score.py +++ b/bayesflow/scores/multivariate_normal_score.py @@ -98,9 +98,14 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor A tensor of shape (batch_size, num_samples, D) containing the generated samples. """ batch_size, num_samples = batch_shape - dim = mean.shape[-1] - assert mean.shape == (batch_size, dim), "mean must have shape (batch_size, D)" - assert covariance.shape == (batch_size, dim, dim), "covariance must have shape (batch_size, D, D)" + dim = keras.ops.shape(mean)[-1] + if keras.ops.shape(mean) != (batch_size, dim): + raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}") + + if keras.ops.shape(covariance) != (batch_size, dim, dim): + raise ValueError( + f"covariance must have shape (batch_size, {dim}, {dim}), but got {keras.ops.shape(covariance)}" + ) # Use Cholesky decomposition to generate samples cholesky_factor = keras.ops.cholesky(covariance) From 3afd6cd776d32094b39e111a60ef10ff86bac1c9 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 18 Mar 2025 14:40:43 +0100 Subject: [PATCH 09/10] fix incorrect error check --- bayesflow/links/ordered_quantiles.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/links/ordered_quantiles.py b/bayesflow/links/ordered_quantiles.py index a487a6a01..ce86309cc 100644 --- a/bayesflow/links/ordered_quantiles.py +++ b/bayesflow/links/ordered_quantiles.py @@ -51,7 +51,7 @@ def build(self, input_shape): f"position {self.axis} of shape={input_shape}" ) - if self.anchor_index not in [0, -1, num_quantile_levels - 1]: + if self.anchor_index in [0, -1, num_quantile_levels - 1]: raise RuntimeError( f"The link function `OrderedQuantiles` expects at least 3 quantile levels, " f"but only {num_quantile_levels} were given." From a390ba84ea21c5791d5708fac40587ab04e6ee4a Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 18 Mar 2025 14:54:06 +0100 Subject: [PATCH 10/10] fix another incorrect error check --- bayesflow/links/ordered.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/links/ordered.py b/bayesflow/links/ordered.py index 420555833..0d7dafba4 100644 --- a/bayesflow/links/ordered.py +++ b/bayesflow/links/ordered.py @@ -23,7 +23,7 @@ def get_config(self): def build(self, input_shape): super().build(input_shape) - if self.anchor_index % input_shape[self.axis] != 0 and self.anchor_index != -1: + if self.anchor_index % input_shape[self.axis] == 0 or self.anchor_index == -1: raise RuntimeError("Anchor should not be first or last index.") self.group_indices = dict(