From fdac5001def635c6bea01c2e2936467747b6f190 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 11 Feb 2025 13:35:28 +0000 Subject: [PATCH 1/3] fix: FM, base_distribution.sample has no argument 'seed' --- bayesflow/networks/flow_matching/flow_matching.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 19bd214d0..db710c005 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -191,7 +191,10 @@ def compute_metrics( else: # not pre-configured, resample x1 = x - x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator) + if not self.base_distribution.built: + # ensure that base distribution is built + self.base_distribution.build(keras.ops.shape(x1)) + x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1]) if self.use_optimal_transport: x1, x0, conditions = optimal_transport( From 1081bbd0f7ecae0d9dac23d689bab8cc1868f187 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 11 Feb 2025 14:19:56 +0000 Subject: [PATCH 2/3] FlowMatching: build self instead of base_distribution --- bayesflow/networks/flow_matching/flow_matching.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index db710c005..07731edbe 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -191,9 +191,10 @@ def compute_metrics( else: # not pre-configured, resample x1 = x - if not self.base_distribution.built: - # ensure that base distribution is built - self.base_distribution.build(keras.ops.shape(x1)) + if not self.built: + xz_shape = keras.ops.shape(x1) + conditions_shape = None if conditions is None else keras.ops.shape(conditions) + self.build(xz_shape, conditions_shape) x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1]) if self.use_optimal_transport: From 0a4921b3ab5d5ba08bdee2453bd1976d5cfaec45 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 11 Feb 2025 14:26:06 +0000 Subject: [PATCH 3/3] fix: make transforms AsSet and AsTimeSeries serializable closes #302 --- bayesflow/adapters/transforms/as_set.py | 7 +++++++ bayesflow/adapters/transforms/as_time_series.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/bayesflow/adapters/transforms/as_set.py b/bayesflow/adapters/transforms/as_set.py index 49f1b47d3..e61828952 100644 --- a/bayesflow/adapters/transforms/as_set.py +++ b/bayesflow/adapters/transforms/as_set.py @@ -33,3 +33,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return np.squeeze(data, axis=2) return data + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "AsSet": + return cls() + + def get_config(self) -> dict: + return {} diff --git a/bayesflow/adapters/transforms/as_time_series.py b/bayesflow/adapters/transforms/as_time_series.py index 15ce04ddf..f52111146 100644 --- a/bayesflow/adapters/transforms/as_time_series.py +++ b/bayesflow/adapters/transforms/as_time_series.py @@ -30,3 +30,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return np.squeeze(data, axis=2) return data + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "AsTimeSeries": + return cls() + + def get_config(self) -> dict: + return {}