From 8fa1cd0e24fec16524001150fdd4e5c4aa1f4fe6 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:29:31 +0100 Subject: [PATCH 01/19] add summary networks to tests --- tests/test_networks/conftest.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index e69de29bb..6eff398f5 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -0,0 +1,30 @@ +import pytest + + +@pytest.fixture() +def deep_set(): + from bayesflow.networks import DeepSet + + return DeepSet() + + +@pytest.fixture() +def lst_net(): + from bayesflow.networks import LSTNet + + return LSTNet() + + +@pytest.fixture() +def set_transformer(): + from bayesflow.networks import SetTransformer + + return SetTransformer() + + +@pytest.fixture(params=[None, "deep_set", "lst_net", "set_transformer"]) +def summary_network(request): + if request.param is None: + return None + + return request.getfixturevalue(request.param) From 3f0c6e251674a135809eafb85c2d1cb725f3aaa4 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:30:39 +0100 Subject: [PATCH 02/19] improve serialization tests by splitting config and full model loading --- .../test_networks/test_inference_networks.py | 24 ++++++++++++++++--- tests/test_networks/test_summary_networks.py | 19 ++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 67b91c859..85180ca35 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -1,6 +1,10 @@ import keras import numpy as np import pytest +from keras.saving import ( + deserialize_keras_object as deserialize, + serialize_keras_object as serialize, +) from tests.utils import allclose, assert_layers_equal @@ -121,11 +125,25 @@ def f(x): assert allclose(inverse_log_density, numerical_inverse_log_density, rtol=1e-4, atol=1e-5) -def test_serialize_deserialize(tmp_path, inference_network, random_samples, random_conditions): +def test_serialize_deserialize(inference_network, random_samples, random_conditions): # to save, the model must be built inference_network(random_samples, conditions=random_conditions) - keras.saving.save_model(inference_network, tmp_path / "model.keras") + serialized = serialize(inference_network) + deserialized = deserialize(serialized) + reserialized = serialize(deserialized) + + assert serialized == reserialized + assert_layers_equal(inference_network, deserialized) + + +def test_save_and_load(tmp_path, summary_network, random_set): + if summary_network is None: + pytest.skip() + + summary_network.build(keras.ops.shape(random_set)) + + keras.saving.save_model(summary_network, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") - assert_layers_equal(inference_network, loaded) + assert_layers_equal(summary_network, loaded) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 8c371fa8c..c2abcb0d6 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -1,6 +1,10 @@ import keras import numpy as np import pytest +from keras.saving import ( + deserialize_keras_object as deserialize, + serialize_keras_object as serialize, +) from tests.utils import assert_layers_equal @@ -52,7 +56,20 @@ def test_variable_set_size(summary_network, random_set): summary_network(new_input) -def test_serialize_deserialize(tmp_path, summary_network, random_set): +def test_serialize_deserialize(summary_network, random_set): + if summary_network is None: + pytest.skip() + + summary_network.build(keras.ops.shape(random_set)) + + serialized = serialize(summary_network) + deserialized = deserialize(serialized) + reserialized = serialize(deserialized) + + assert serialized == reserialized + + +def test_save_and_load(tmp_path, summary_network, random_set): if summary_network is None: pytest.skip() From 13174a01262043b6490289f00c95e97ce52b156a Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:31:22 +0100 Subject: [PATCH 03/19] remove non-serializable type hints --- bayesflow/networks/deep_set/deep_set.py | 12 ++++++------ bayesflow/networks/lstnet/lstnet.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index aeb42fa74..cdf9a8b89 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -27,12 +27,12 @@ def __init__( self, summary_dim: int = 16, depth: int = 2, - inner_pooling: str | keras.Layer = "mean", - output_pooling: str | keras.Layer = "mean", - mlp_widths_equivariant: tuple = (128, 128), - mlp_widths_invariant_inner: tuple = (128, 128), - mlp_widths_invariant_outer: tuple = (128, 128), - mlp_widths_invariant_last: tuple = (128, 128), + inner_pooling: str = "mean", + output_pooling: str = "mean", + mlp_widths_equivariant: Sequence[int] = (128, 128), + mlp_widths_invariant_inner: Sequence[int] = (128, 128), + mlp_widths_invariant_outer: Sequence[int] = (128, 128), + mlp_widths_invariant_last: Sequence[int] = (128, 128), activation: str = "gelu", kernel_initializer: str = "he_normal", dropout: int | float | None = 0.05, diff --git a/bayesflow/networks/lstnet/lstnet.py b/bayesflow/networks/lstnet/lstnet.py index 8680faec4..7eb525cac 100644 --- a/bayesflow/networks/lstnet/lstnet.py +++ b/bayesflow/networks/lstnet/lstnet.py @@ -30,7 +30,7 @@ def __init__( activation: str = "mish", kernel_initializer: str = "glorot_uniform", groups: int = 8, - recurrent_type: str | keras.Layer = "gru", + recurrent_type: str = "gru", recurrent_dim: int = 128, bidirectional: bool = True, dropout: float = 0.05, From cfed356a77c18a94131c1851463de2b180d9a832 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:32:26 +0100 Subject: [PATCH 04/19] fix deep set by removing inner model --- bayesflow/networks/deep_set/deep_set.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index cdf9a8b89..50285a38d 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -46,7 +46,7 @@ def __init__( super().__init__(**kwargs) # Stack of equivariant modules for a many-to-many learnable transformation - self.equivariant_modules = keras.Sequential() + self.equivariant_modules = [] for _ in range(depth): equivariant_module = EquivariantModule( mlp_widths_equivariant=mlp_widths_equivariant, @@ -59,7 +59,7 @@ def __init__( pooling=inner_pooling, **kwargs, ) - self.equivariant_modules.add(equivariant_module) + self.equivariant_modules.append(equivariant_module) # Invariant module for a many-to-one transformation self.invariant_module = InvariantModule( @@ -81,13 +81,16 @@ def build(self, input_shape): super().build(input_shape) self.call(keras.ops.zeros(input_shape)) - def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: + def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: """Performs the forward pass of a learnable deep invariant transformation consisting of a sequence of equivariant transforms followed by an invariant transform. #TODO """ - x = self.equivariant_modules(input_set, training=training) + + for em in self.equivariant_modules: + x = em(x, training=training) + x = self.invariant_module(x, training=training) return self.output_projector(x) From a6527dc0415df3e7bb1de8f925ec68c23d6f094e Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:32:47 +0100 Subject: [PATCH 05/19] remove incorrect get_config method on deep set --- bayesflow/networks/deep_set/deep_set.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index 50285a38d..07071be8c 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -94,15 +94,3 @@ def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: x = self.invariant_module(x, training=training) return self.output_projector(x) - - def get_config(self): - base_config = super().get_config() - - config = { - "invariant_module": serialize(self.equivariant_modules), - "equivariant_fc": serialize(self.invariant_module), - "output_projector": serialize(self.output_projector), - "summary_dim": serialize(self.summary_dim), - } - - return base_config | config From 00ce8f1cb0a0492295171b0a9614578fef9db46e Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:33:05 +0100 Subject: [PATCH 06/19] general clean-up --- bayesflow/networks/deep_set/deep_set.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index 07071be8c..adc8c6d33 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -1,12 +1,11 @@ -import keras -from keras import layers -from keras.saving import register_keras_serializable as serializable, serialize_keras_object as serialize +from collections.abc import Sequence +import keras +from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from .invariant_module import InvariantModule from .equivariant_module import EquivariantModule - +from .invariant_module import InvariantModule from ..summary_network import SummaryNetwork @@ -74,7 +73,7 @@ def __init__( ) # Output linear layer to project set representation down to "summary_dim" learned summary statistics - self.output_projector = layers.Dense(summary_dim, activation="linear") + self.output_projector = keras.layers.Dense(summary_dim, activation="linear") self.summary_dim = summary_dim def build(self, input_shape): From 62bd4ba4cc8da908435d00d033388aab349fd6d5 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:33:31 +0100 Subject: [PATCH 07/19] improve type hints --- bayesflow/networks/deep_set/equivariant_module.py | 10 ++++++---- bayesflow/networks/deep_set/invariant_module.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/bayesflow/networks/deep_set/equivariant_module.py b/bayesflow/networks/deep_set/equivariant_module.py index 1b9205cb2..be6af31ab 100644 --- a/bayesflow/networks/deep_set/equivariant_module.py +++ b/bayesflow/networks/deep_set/equivariant_module.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + import keras from keras import ops, layers from keras.saving import register_keras_serializable as serializable @@ -19,10 +21,10 @@ class EquivariantModule(keras.Layer): def __init__( self, - mlp_widths_equivariant: tuple = (128, 128), - mlp_widths_invariant_inner: tuple = (128, 128), - mlp_widths_invariant_outer: tuple = (128, 128), - pooling: str | keras.Layer = "mean", + mlp_widths_equivariant: Sequence[int] = (128, 128), + mlp_widths_invariant_inner: Sequence[int] = (128, 128), + mlp_widths_invariant_outer: Sequence[int] = (128, 128), + pooling: str = "mean", activation: str = "gelu", kernel_initializer: str = "he_normal", dropout: int | float | None = 0.05, diff --git a/bayesflow/networks/deep_set/invariant_module.py b/bayesflow/networks/deep_set/invariant_module.py index 28444372a..3f806fb0e 100644 --- a/bayesflow/networks/deep_set/invariant_module.py +++ b/bayesflow/networks/deep_set/invariant_module.py @@ -1,10 +1,12 @@ +from collections.abc import Sequence + import keras from keras import layers from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import keras_kwargs from bayesflow.utils import find_pooling +from bayesflow.utils import keras_kwargs @serializable(package="bayesflow.networks") @@ -19,12 +21,12 @@ class InvariantModule(keras.Layer): def __init__( self, - mlp_widths_inner: tuple = (128, 128), - mlp_widths_outer: tuple = (128, 128), + mlp_widths_inner: Sequence[int] = (128, 128), + mlp_widths_outer: Sequence[int] = (128, 128), activation: str = "gelu", kernel_initializer: str = "he_normal", dropout: int | float | None = 0.05, - pooling: str | keras.Layer = "mean", + pooling: str = "mean", spectral_normalization: bool = False, **kwargs, ): From 3f2c87854ff68e4b263e4593807e9653faa31234 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:34:12 +0100 Subject: [PATCH 08/19] fix LSTNet padding --- bayesflow/networks/lstnet/lstnet.py | 1 + bayesflow/networks/lstnet/skip_recurrent.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/bayesflow/networks/lstnet/lstnet.py b/bayesflow/networks/lstnet/lstnet.py index 7eb525cac..e69c59ccd 100644 --- a/bayesflow/networks/lstnet/lstnet.py +++ b/bayesflow/networks/lstnet/lstnet.py @@ -55,6 +55,7 @@ def __init__( strides=s, activation=activation, kernel_initializer=kernel_initializer, + padding="same", ) ) self.conv_blocks.add(layers.GroupNormalization(groups=groups)) diff --git a/bayesflow/networks/lstnet/skip_recurrent.py b/bayesflow/networks/lstnet/skip_recurrent.py index 6cf95ed7b..e6f78376e 100644 --- a/bayesflow/networks/lstnet/skip_recurrent.py +++ b/bayesflow/networks/lstnet/skip_recurrent.py @@ -32,7 +32,10 @@ def __init__( super().__init__(**keras_kwargs(kwargs)) self.skip_conv = keras.layers.Conv1D( - filters=input_channels * skip_steps, kernel_size=skip_steps, strides=skip_steps + filters=input_channels * skip_steps, + kernel_size=skip_steps, + strides=skip_steps, + padding="same", ) recurrent_constructor = find_recurrent_net(recurrent_type) From fd34ecaa58cd976c221a059ea4e38895b7e6de58 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:34:50 +0100 Subject: [PATCH 09/19] fix LSTNet by removing inner model --- bayesflow/networks/lstnet/lstnet.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/lstnet/lstnet.py b/bayesflow/networks/lstnet/lstnet.py index e69c59ccd..a1d393bac 100644 --- a/bayesflow/networks/lstnet/lstnet.py +++ b/bayesflow/networks/lstnet/lstnet.py @@ -46,10 +46,10 @@ def __init__( kernel_sizes = (kernel_sizes,) if not isinstance(strides, (list, tuple)): strides = (strides,) - self.conv_blocks = Sequential() + self.conv_blocks = [] for f, k, s in zip(filters, kernel_sizes, strides): - self.conv_blocks.add( - layers.Conv1D( + self.conv_blocks.append( + keras.layers.Conv1D( filters=f, kernel_size=k, strides=s, @@ -58,7 +58,7 @@ def __init__( padding="same", ) ) - self.conv_blocks.add(layers.GroupNormalization(groups=groups)) + self.conv_blocks.append(keras.layers.GroupNormalization(groups=groups)) # Recurrent and feedforward backbones self.recurrent = SkipRecurrentNet( @@ -72,11 +72,13 @@ def __init__( self.output_projector = layers.Dense(summary_dim) self.summary_dim = summary_dim - def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor: - summary = self.conv_blocks(time_series, training=training) - summary = self.recurrent(summary, training=training) - summary = self.output_projector(summary) - return summary + def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: + for c in self.conv_blocks: + x = c(x, training=training) + + x = self.recurrent(x, training=training) + x = self.output_projector(x) + return x def build(self, input_shape): super().build(input_shape) From a9e2d8170f738f21ca250c81e71e5513c542e120 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:35:07 +0100 Subject: [PATCH 10/19] general cleanup --- bayesflow/networks/lstnet/lstnet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bayesflow/networks/lstnet/lstnet.py b/bayesflow/networks/lstnet/lstnet.py index a1d393bac..7ca0a4403 100644 --- a/bayesflow/networks/lstnet/lstnet.py +++ b/bayesflow/networks/lstnet/lstnet.py @@ -1,9 +1,7 @@ import keras -from keras import layers, Sequential from keras.saving import register_keras_serializable as serializable, serialize_keras_object as serialize from bayesflow.types import Tensor - from .skip_recurrent import SkipRecurrentNet from ..summary_network import SummaryNetwork @@ -69,7 +67,7 @@ def __init__( skip_steps=skip_steps, dropout=dropout, ) - self.output_projector = layers.Dense(summary_dim) + self.output_projector = keras.layers.Dense(summary_dim) self.summary_dim = summary_dim def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: From c95662f7068c4ef4f67ed86205bf996426ef6f1f Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:35:22 +0100 Subject: [PATCH 11/19] remove non-serializable type hints --- bayesflow/networks/lstnet/skip_recurrent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/lstnet/skip_recurrent.py b/bayesflow/networks/lstnet/skip_recurrent.py index e6f78376e..f25c8ba78 100644 --- a/bayesflow/networks/lstnet/skip_recurrent.py +++ b/bayesflow/networks/lstnet/skip_recurrent.py @@ -22,7 +22,7 @@ class SkipRecurrentNet(keras.Model): def __init__( self, hidden_dim: int = 256, - recurrent_type: str | keras.Layer = "gru", + recurrent_type: str = "gru", bidirectional: bool = True, input_channels: int = 64, skip_steps: int = 4, From 63f2f831f7cfcf9e2a037d4bc7a1a7d12fa241b7 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:38:27 +0100 Subject: [PATCH 12/19] confusion --- bayesflow/networks/deep_set/invariant_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bayesflow/networks/deep_set/invariant_module.py b/bayesflow/networks/deep_set/invariant_module.py index 3f806fb0e..f4c6bcd8c 100644 --- a/bayesflow/networks/deep_set/invariant_module.py +++ b/bayesflow/networks/deep_set/invariant_module.py @@ -56,6 +56,7 @@ def __init__( self.inner_fc.add(layer) # Outer fully connected net for sum decomposition: inner( pooling( inner(set) ) ) + # TODO: why does using Sequential work here, but not in DeepSet? self.outer_fc = keras.Sequential() for width in mlp_widths_outer: if dropout is not None and dropout > 0: From d76ad79a595bebaf6230cc38356b960232851113 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:39:05 +0100 Subject: [PATCH 13/19] allow passing either depth and width or list of widths to MLP --- bayesflow/networks/mlp/mlp.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/bayesflow/networks/mlp/mlp.py b/bayesflow/networks/mlp/mlp.py index d0fc73015..0418b5aaf 100644 --- a/bayesflow/networks/mlp/mlp.py +++ b/bayesflow/networks/mlp/mlp.py @@ -1,10 +1,12 @@ +from collections.abc import Sequence +from typing import Literal + import keras from keras import layers from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor from bayesflow.utils import keras_kwargs - from .hidden_block import ConfigurableHiddenBlock @@ -19,11 +21,14 @@ class MLP(keras.Layer): def __init__( self, - widths: tuple = (512, 512), + *, + depth: int = None, + width: int = None, + widths: Sequence[int] = None, activation: str = "mish", kernel_initializer: str = "he_normal", residual: bool = True, - dropout: float = 0.05, + dropout: Literal[0, None] | float = 0.05, spectral_normalization: bool = False, **kwargs, ): @@ -45,9 +50,19 @@ def __init__( dropout : float, optional, default: 0.05 Dropout rate for the hidden layers in the internal layers. """ - super().__init__(**keras_kwargs(kwargs)) + if widths is not None: + if depth is not None or width is not None: + raise ValueError("Either specify 'widths' or 'depth' and 'width', not both.") + else: + if depth is None or width is None: + # use the default + depth = 2 + width = 512 + + widths = [width] * depth + self.res_blocks = [] projector = layers.Dense( units=widths[0], From 18eb56f2d5553fb39c675eed8338532c21c52a24 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:43:26 +0100 Subject: [PATCH 14/19] remove incorrect get_config method on LSTNet --- bayesflow/networks/lstnet/lstnet.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/bayesflow/networks/lstnet/lstnet.py b/bayesflow/networks/lstnet/lstnet.py index 7ca0a4403..bb2acb9d8 100644 --- a/bayesflow/networks/lstnet/lstnet.py +++ b/bayesflow/networks/lstnet/lstnet.py @@ -1,5 +1,5 @@ import keras -from keras.saving import register_keras_serializable as serializable, serialize_keras_object as serialize +from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor from .skip_recurrent import SkipRecurrentNet @@ -81,15 +81,3 @@ def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: def build(self, input_shape): super().build(input_shape) self.call(keras.ops.zeros(input_shape)) - - def get_config(self): - base_config = super().get_config() - - config = { - "conv_blocks": serialize(self.conv_blocks), - "recurrent": serialize(self.recurrent), - "output_projector": serialize(self.output_projector), - "summary_dim": serialize(self.summary_dim), - } - - return base_config | config From 9d91c0d26970c028125bc2c1db1144ff46a38199 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:45:13 +0100 Subject: [PATCH 15/19] fix config serialization test for inference networks --- tests/test_networks/test_inference_networks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 85180ca35..bd45c47e4 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -134,7 +134,6 @@ def test_serialize_deserialize(inference_network, random_samples, random_conditi reserialized = serialize(deserialized) assert serialized == reserialized - assert_layers_equal(inference_network, deserialized) def test_save_and_load(tmp_path, summary_network, random_set): From cf8188d67dd04ed15ac0340ae177dff57d612873 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 6 Nov 2024 13:47:31 +0100 Subject: [PATCH 16/19] fix save and load test for inference networks --- tests/test_networks/test_inference_networks.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index bd45c47e4..cdc4e9e09 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -136,13 +136,11 @@ def test_serialize_deserialize(inference_network, random_samples, random_conditi assert serialized == reserialized -def test_save_and_load(tmp_path, summary_network, random_set): - if summary_network is None: - pytest.skip() - - summary_network.build(keras.ops.shape(random_set)) +def test_save_and_load(tmp_path, inference_network, random_samples, random_conditions): + # to save, the model must be built + inference_network(random_samples, conditions=random_conditions) - keras.saving.save_model(summary_network, tmp_path / "model.keras") + keras.saving.save_model(inference_network, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") - assert_layers_equal(summary_network, loaded) + assert_layers_equal(inference_network, loaded) From ca5e889dd33f29006729bdfececa48c9d64fc848 Mon Sep 17 00:00:00 2001 From: larskue Date: Mon, 11 Nov 2024 14:29:19 +0100 Subject: [PATCH 17/19] try to fix jax install workflow --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 04e1533b0..217d425b0 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -69,7 +69,7 @@ jobs: - name: Install JAX if: ${{ matrix.backend == 'jax' }} run: | - pip install -U "jax[cpu]" + pip install -U jax - name: Install NumPy if: ${{ matrix.backend == 'numpy' }} run: | From 87f61efe5edb7736665fdae9518fefd5a53519d4 Mon Sep 17 00:00:00 2001 From: larskue Date: Mon, 11 Nov 2024 15:04:25 +0100 Subject: [PATCH 18/19] make tox config consistent with github actions install --- tox.ini | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 1171a04ab..960ea8d32 100644 --- a/tox.ini +++ b/tox.ini @@ -13,13 +13,15 @@ deps = pytest-cov pytest-rerunfailures jax: - jax[cpu] + jax numpy: numpy tensorflow: tensorflow torch: torch + torchvision + torchaudio set_env = jax: From 95fc33825619f98738cd76cd7ec2095ade394ed8 Mon Sep 17 00:00:00 2001 From: larskue Date: Mon, 11 Nov 2024 15:23:33 +0100 Subject: [PATCH 19/19] Add conda env export to tests workflow for better diagnosis --- .github/workflows/tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 217d425b0..68f7e2268 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -90,6 +90,7 @@ jobs: conda config --show-sources conda config --show printenv | sort + conda env export - name: Run Tests run: |