From 0939aacce93653dc000b5ffb63219cab6f8f8874 Mon Sep 17 00:00:00 2001 From: Pritom Gogoi Date: Wed, 9 Oct 2024 00:39:12 +0200 Subject: [PATCH 01/11] add fixtures and code for summary networks * add fixtures for some summary network types * add a test for compute_metrics method of summary networks --- tests/conftest.py | 24 ++++++++++++++++++-- tests/test_networks/test_summary_networks.py | 24 ++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b55c7a494..2dfb299a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import keras import pytest - BACKENDS = ["jax", "numpy", "tensorflow", "torch"] @@ -98,7 +97,28 @@ def simulator(request): return request.getfixturevalue(request.param) -@pytest.fixture(params=[None], scope="function") +@pytest.fixture(scope="function") +def lst_net(): + from bayesflow.networks import LSTNet + + return LSTNet() + + +@pytest.fixture(scope="function") +def set_transformer(): + from bayesflow.networks import SetTransformer + + return SetTransformer() + + +@pytest.fixture(scope="function") +def deep_set(): + from bayesflow.networks import DeepSet + + return DeepSet() + + +@pytest.fixture(params=[None, "lst_net", "set_transformer", "deep_set"], scope="function") def summary_network(request): if request.param is None: return None diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 8c371fa8c..ff4534315 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -62,3 +62,27 @@ def test_serialize_deserialize(tmp_path, summary_network, random_set): loaded = keras.saving.load_model(tmp_path / "model.keras") assert_layers_equal(summary_network, loaded) + + +@pytest.mark.parametrize("stage", ["training", "validation"]) +def test_compute_metrics(stage, summary_network, random_set): + if summary_network is None: + pytest.skip() + + summary_network.build(keras.ops.shape(random_set)) + + metrics = summary_network.compute_metrics(random_set, stage=stage) + + assert "outputs" in metrics + + # check that the batch dimension is preserved + assert metrics["outputs"].shape[0] == keras.ops.shape(random_set)[0] + + if summary_network.base_distribution is not None: + assert "loss" in metrics + assert metrics["loss"].shape == () + + if stage != "training": + for metric in summary_network.metrics: + assert metric.name in metrics + assert metrics[metric.name].shape == () \ No newline at end of file From 0a555106143768fe0828a5a63691495d51ea3f22 Mon Sep 17 00:00:00 2001 From: Pritom Gogoi Date: Wed, 9 Oct 2024 01:33:53 +0200 Subject: [PATCH 02/11] add fixtures for testing various summary_dim values --- tests/conftest.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2dfb299a2..fe67cc0d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,29 +97,34 @@ def simulator(request): return request.getfixturevalue(request.param) +@pytest.fixture(params=[1, 2, 16], scope="session") +def summary_dim(request): + return request.param + + @pytest.fixture(scope="function") -def lst_net(): +def lst_net(summary_dim): from bayesflow.networks import LSTNet - return LSTNet() + return LSTNet(summary_dim=summary_dim) @pytest.fixture(scope="function") -def set_transformer(): +def set_transformer(summary_dim): from bayesflow.networks import SetTransformer - return SetTransformer() + return SetTransformer(summary_dim=summary_dim) @pytest.fixture(scope="function") -def deep_set(): +def deep_set(summary_dim): from bayesflow.networks import DeepSet - return DeepSet() + return DeepSet(summary_dim=summary_dim) @pytest.fixture(params=[None, "lst_net", "set_transformer", "deep_set"], scope="function") -def summary_network(request): +def summary_network(request, summary_dim): if request.param is None: return None return request.getfixturevalue(request.param) From 80762c8996b01abbb068cc96cab4b86fb559b3ae Mon Sep 17 00:00:00 2001 From: Pritom Gogoi Date: Wed, 9 Oct 2024 01:50:32 +0200 Subject: [PATCH 03/11] verify the correct summary dim in output tensor --- bayesflow/networks/lstnet/lstnet.py | 1 + bayesflow/networks/transformers/set_transformer.py | 1 + tests/test_networks/test_summary_networks.py | 9 ++++++++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/bayesflow/networks/lstnet/lstnet.py b/bayesflow/networks/lstnet/lstnet.py index 82f5af6bb..42dcfbe07 100644 --- a/bayesflow/networks/lstnet/lstnet.py +++ b/bayesflow/networks/lstnet/lstnet.py @@ -38,6 +38,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.summary_dim = summary_dim # Convolutional backbone -> can be extended with inception-like structure if not isinstance(filters, (list, tuple)): diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 38a10a602..f4eabfd45 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -48,6 +48,7 @@ def __init__( """ super().__init__(**kwargs) + self.summary_dim = summary_dim # Construct a series of set-attention blocks self.attention_blocks = keras.Sequential() diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index ff4534315..de87406c6 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -78,6 +78,12 @@ def test_compute_metrics(stage, summary_network, random_set): # check that the batch dimension is preserved assert metrics["outputs"].shape[0] == keras.ops.shape(random_set)[0] + # check summary dimension + summary_dim = summary_network.summary_dim + assert metrics["outputs"].shape[-1] == summary_dim + + print(metrics["outputs"].shape) + if summary_network.base_distribution is not None: assert "loss" in metrics assert metrics["loss"].shape == () @@ -85,4 +91,5 @@ def test_compute_metrics(stage, summary_network, random_set): if stage != "training": for metric in summary_network.metrics: assert metric.name in metrics - assert metrics[metric.name].shape == () \ No newline at end of file + assert metrics[metric.name].shape == () + From 9b1523a62e4d18436014138976fae3d2d7c7a846 Mon Sep 17 00:00:00 2001 From: Pritom Gogoi Date: Thu, 10 Oct 2024 00:05:38 +0200 Subject: [PATCH 04/11] move summary network fixtures into `test_networks/conftest.py` --- tests/conftest.py | 28 +-------------------------- tests/test_networks/conftest.py | 34 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fe67cc0d1..aa50e6986 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,33 +97,7 @@ def simulator(request): return request.getfixturevalue(request.param) -@pytest.fixture(params=[1, 2, 16], scope="session") -def summary_dim(request): - return request.param - - -@pytest.fixture(scope="function") -def lst_net(summary_dim): - from bayesflow.networks import LSTNet - - return LSTNet(summary_dim=summary_dim) - - -@pytest.fixture(scope="function") -def set_transformer(summary_dim): - from bayesflow.networks import SetTransformer - - return SetTransformer(summary_dim=summary_dim) - - -@pytest.fixture(scope="function") -def deep_set(summary_dim): - from bayesflow.networks import DeepSet - - return DeepSet(summary_dim=summary_dim) - - -@pytest.fixture(params=[None, "lst_net", "set_transformer", "deep_set"], scope="function") +@pytest.fixture(params=[None], scope="function") def summary_network(request, summary_dim): if request.param is None: return None diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index e69de29bb..f7c29c33b 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -0,0 +1,34 @@ +import pytest + + +@pytest.fixture(params=[1, 2, 16], scope="session") +def summary_dim(request): + return request.param + + +@pytest.fixture(scope="function") +def lst_net(summary_dim): + from bayesflow.networks import LSTNet + + return LSTNet(summary_dim=summary_dim) + + +@pytest.fixture(scope="function") +def set_transformer(summary_dim): + from bayesflow.networks import SetTransformer + + return SetTransformer(summary_dim=summary_dim) + + +@pytest.fixture(scope="function") +def deep_set(summary_dim): + from bayesflow.networks import DeepSet + + return DeepSet(summary_dim=summary_dim) + + +@pytest.fixture(params=[None, "lst_net", "set_transformer", "deep_set"], scope="function") +def summary_network(request, summary_dim): + if request.param is None: + return None + return request.getfixturevalue(request.param) \ No newline at end of file From 817509c2b892e55c39afcf6eeb615bc99d07f738 Mon Sep 17 00:00:00 2001 From: Pritom Gogoi Date: Thu, 10 Oct 2024 00:09:51 +0200 Subject: [PATCH 05/11] replace use of `tensor.shape` with `keras.ops.shape` --- tests/test_networks/test_summary_networks.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index de87406c6..5652feb6b 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -76,20 +76,18 @@ def test_compute_metrics(stage, summary_network, random_set): assert "outputs" in metrics # check that the batch dimension is preserved - assert metrics["outputs"].shape[0] == keras.ops.shape(random_set)[0] + assert keras.ops.shape(metrics["outputs"])[0] == keras.ops.shape(random_set)[0] # check summary dimension summary_dim = summary_network.summary_dim - assert metrics["outputs"].shape[-1] == summary_dim - - print(metrics["outputs"].shape) + assert keras.ops.shape(metrics["outputs"])[-1] == summary_dim if summary_network.base_distribution is not None: assert "loss" in metrics - assert metrics["loss"].shape == () + assert keras.ops.shape(metrics["loss"]) == () if stage != "training": for metric in summary_network.metrics: assert metric.name in metrics - assert metrics[metric.name].shape == () + assert keras.ops.shape(metrics[metric.name]) == () From 83b2f2de31414b76f0d5f1bc34b12976993c9f2f Mon Sep 17 00:00:00 2001 From: Pritom Gogoi Date: Thu, 10 Oct 2024 01:00:27 +0200 Subject: [PATCH 06/11] add test for key_dim of set_transformer summary network --- .../networks/transformers/set_transformer.py | 2 ++ tests/test_networks/conftest.py | 12 ++++++++++++ tests/test_networks/test_summary_networks.py | 16 ++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index f4eabfd45..a58afa8c4 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -49,6 +49,8 @@ def __init__( super().__init__(**kwargs) self.summary_dim = summary_dim + self.key_dim = key_dim + self.num_attention_blocks = num_attention_blocks # Construct a series of set-attention blocks self.attention_blocks = keras.Sequential() diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index f7c29c33b..2d61a73b6 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -6,6 +6,11 @@ def summary_dim(request): return request.param +@pytest.fixture(params=[1, 2, 16], scope="session") +def key_dim(request): + return request.param + + @pytest.fixture(scope="function") def lst_net(summary_dim): from bayesflow.networks import LSTNet @@ -20,6 +25,13 @@ def set_transformer(summary_dim): return SetTransformer(summary_dim=summary_dim) +@pytest.fixture(scope="function") +def set_transformer_key_dim_variation(summary_dim, key_dim): + from bayesflow.networks import SetTransformer + + return SetTransformer(summary_dim=summary_dim, key_dim=key_dim) + + @pytest.fixture(scope="function") def deep_set(summary_dim): from bayesflow.networks import DeepSet diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 5652feb6b..57627ebf7 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -91,3 +91,19 @@ def test_compute_metrics(stage, summary_network, random_set): assert metric.name in metrics assert keras.ops.shape(metrics[metric.name]) == () + +def test_set_transformer_with_key_dim(set_transformer_key_dim_variation, random_set): + + set_transformer_key_dim_variation.build(keras.ops.shape(random_set)) + _ = set_transformer_key_dim_variation(random_set) + + att_layers = set_transformer_key_dim_variation.attention_blocks.layers + + # check that the number of attention blocks is as per the specified key_dim + assert len(att_layers) == set_transformer_key_dim_variation.num_attention_blocks + + # check that the key_dim is set correctly per attention block + for i, layer in enumerate(att_layers): + assert keras.ops.shape(layer.output)[-1] == set_transformer_key_dim_variation.key_dim + if i != 0: + assert keras.ops.shape(layer.input)[-1] == set_transformer_key_dim_variation.key_dim From 15ad65974a0182d989a2ceb68b64b474574f027f Mon Sep 17 00:00:00 2001 From: Pritom Gogoi Date: Tue, 22 Oct 2024 01:04:40 +0200 Subject: [PATCH 07/11] reformat code using ruff --- tests/test_networks/conftest.py | 2 +- tests/test_networks/test_summary_networks.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 2d61a73b6..f917d448c 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -43,4 +43,4 @@ def deep_set(summary_dim): def summary_network(request, summary_dim): if request.param is None: return None - return request.getfixturevalue(request.param) \ No newline at end of file + return request.getfixturevalue(request.param) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 57627ebf7..d36061cd4 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -93,7 +93,6 @@ def test_compute_metrics(stage, summary_network, random_set): def test_set_transformer_with_key_dim(set_transformer_key_dim_variation, random_set): - set_transformer_key_dim_variation.build(keras.ops.shape(random_set)) _ = set_transformer_key_dim_variation(random_set) From 117f8808c209576eb7a9b8cca036b8265a04956c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 27 Feb 2025 15:44:47 +0100 Subject: [PATCH 08/11] clean up tests of summary networks --- .../networks/transformers/set_transformer.py | 2 -- tests/test_networks/conftest.py | 13 +------------ tests/test_networks/test_summary_networks.py | 16 ---------------- 3 files changed, 1 insertion(+), 30 deletions(-) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 2efda980f..251bb69b3 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -81,8 +81,6 @@ def __init__( super().__init__(**kwargs) self.summary_dim = summary_dim - self.key_dim = key_dim - self.num_attention_blocks = num_attention_blocks check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 8467c1ba0..8f6a731fa 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -1,15 +1,11 @@ import pytest -@pytest.fixture(params=[1, 2, 16], scope="session") +@pytest.fixture(params=[1, 4], scope="session") def summary_dim(request): return request.param -@pytest.fixture(params=[1, 2, 16], scope="session") -def key_dim(request): - return request.param - # For the serialization tests, we want to test passing str and type. # For all other tests, this is not necessary and would double test time. # Therefore, below we specify two variants of each network, one without and @@ -94,13 +90,6 @@ def set_transformer(summary_dim): return SetTransformer(summary_dim=summary_dim) -@pytest.fixture(scope="function") -def set_transformer_key_dim_variation(summary_dim, key_dim): - from bayesflow.networks import SetTransformer - - return SetTransformer(summary_dim=summary_dim, key_dim=key_dim) - - @pytest.fixture(scope="function") def deep_set(summary_dim): from bayesflow.networks import DeepSet diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index ebf050c6d..44e3dfee1 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -107,19 +107,3 @@ def test_compute_metrics(stage, summary_network, random_set): for metric in summary_network.metrics: assert metric.name in metrics assert keras.ops.shape(metrics[metric.name]) == () - - -def test_set_transformer_with_key_dim(set_transformer_key_dim_variation, random_set): - set_transformer_key_dim_variation.build(keras.ops.shape(random_set)) - _ = set_transformer_key_dim_variation(random_set) - - att_layers = set_transformer_key_dim_variation.attention_blocks.layers - - # check that the number of attention blocks is as per the specified key_dim - assert len(att_layers) == set_transformer_key_dim_variation.num_attention_blocks - - # check that the key_dim is set correctly per attention block - for i, layer in enumerate(att_layers): - assert keras.ops.shape(layer.output)[-1] == set_transformer_key_dim_variation.key_dim - if i != 0: - assert keras.ops.shape(layer.input)[-1] == set_transformer_key_dim_variation.key_dim From 3d53e397d10783bd01e6b94e1d472b7288346246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 27 Feb 2025 15:46:38 +0100 Subject: [PATCH 09/11] run linter again --- tests/test_networks/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 8f6a731fa..ec98d7485 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -83,6 +83,7 @@ def lst_net(summary_dim): return LSTNet(summary_dim=summary_dim) + @pytest.fixture(scope="function") def set_transformer(summary_dim): from bayesflow.networks import SetTransformer From b8d870f5fdafe206b66204343338fe6673736951 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 27 Feb 2025 15:55:45 +0100 Subject: [PATCH 10/11] store summary_dim in summary networks init consistently after output_projector --- bayesflow/networks/lstnet/lstnet.py | 1 - bayesflow/networks/transformers/fusion_transformer.py | 1 + bayesflow/networks/transformers/set_transformer.py | 2 +- bayesflow/networks/transformers/time_series_transformer.py | 1 + 4 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/lstnet/lstnet.py b/bayesflow/networks/lstnet/lstnet.py index fb14f884f..695907ef8 100644 --- a/bayesflow/networks/lstnet/lstnet.py +++ b/bayesflow/networks/lstnet/lstnet.py @@ -37,7 +37,6 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.summary_dim = summary_dim # Convolutional backbone -> can be extended with inception-like structure if not isinstance(filters, (list, tuple)): diff --git a/bayesflow/networks/transformers/fusion_transformer.py b/bayesflow/networks/transformers/fusion_transformer.py index 2c5624c93..89f51100d 100644 --- a/bayesflow/networks/transformers/fusion_transformer.py +++ b/bayesflow/networks/transformers/fusion_transformer.py @@ -120,6 +120,7 @@ def __init__( raise ValueError("Argument `template_dim` should be in ['lstm', 'gru']") self.output_projector = keras.layers.Dense(summary_dim) + self.summary_dim = summary_dim def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 251bb69b3..c301f1537 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -80,7 +80,6 @@ def __init__( """ super().__init__(**kwargs) - self.summary_dim = summary_dim check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths) @@ -126,6 +125,7 @@ def __init__( ) self.pooling_by_attention = PoolingByMultiHeadAttention(**(global_attention_settings | pooling_settings)) self.output_projector = keras.layers.Dense(summary_dim) + self.summary_dim = summary_dim def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. diff --git a/bayesflow/networks/transformers/time_series_transformer.py b/bayesflow/networks/transformers/time_series_transformer.py index 9ec7cc9bd..6429743a6 100644 --- a/bayesflow/networks/transformers/time_series_transformer.py +++ b/bayesflow/networks/transformers/time_series_transformer.py @@ -103,6 +103,7 @@ def __init__( # Pooling will be applied as a final step to the abstract representations obtained from set attention self.pooling = keras.layers.GlobalAvgPool1D() self.output_projector = keras.layers.Dense(summary_dim) + self.summary_dim = summary_dim self.time_axis = time_axis From 17ece6d311be34d10c1dbf17461ab97871944541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 27 Feb 2025 16:30:39 +0100 Subject: [PATCH 11/11] clean up fixture definitions of inference and summary networks --- tests/conftest.py | 25 +++---------------------- tests/test_networks/conftest.py | 5 ----- tests/test_two_moons/conftest.py | 7 +++++++ 3 files changed, 10 insertions(+), 27 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e1aa61b1c..315efd531 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,11 +32,9 @@ def conditions_size(request): return request.param -@pytest.fixture(scope="function") -def coupling_flow(): - from bayesflow.networks import CouplingFlow - - return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32))) +@pytest.fixture(params=[1, 4], scope="session") +def summary_dim(request): + return request.param @pytest.fixture(params=["two_moons"], scope="session") @@ -49,16 +47,6 @@ def feature_size(request): return request.param -@pytest.fixture(params=["coupling_flow"], scope="function") -def inference_network(request): - return request.getfixturevalue(request.param) - - -@pytest.fixture(params=["inference_network", "summary_network"], scope="function") -def network(request): - return request.getfixturevalue(request.param) - - @pytest.fixture(scope="session") def random_conditions(batch_size, conditions_size): if conditions_size is None: @@ -94,13 +82,6 @@ def simulator(request): return request.getfixturevalue(request.param) -@pytest.fixture(params=[None], scope="function") -def summary_network(request, summary_dim): - if request.param is None: - return None - return request.getfixturevalue(request.param) - - @pytest.fixture(scope="session") def training_dataset(simulator, batch_size): from bayesflow.datasets import OfflineDataset diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index ec98d7485..2310fbe33 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -1,11 +1,6 @@ import pytest -@pytest.fixture(params=[1, 4], scope="session") -def summary_dim(request): - return request.param - - # For the serialization tests, we want to test passing str and type. # For all other tests, this is not necessary and would double test time. # Therefore, below we specify two variants of each network, one without and diff --git a/tests/test_two_moons/conftest.py b/tests/test_two_moons/conftest.py index 7a5bb870f..6e3db2674 100644 --- a/tests/test_two_moons/conftest.py +++ b/tests/test_two_moons/conftest.py @@ -1,6 +1,13 @@ import pytest +@pytest.fixture() +def inference_network(): + from bayesflow.networks import CouplingFlow + + return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32))) + + @pytest.fixture() def approximator(adapter, inference_network): from bayesflow import ContinuousApproximator