diff --git a/bayesflow/networks/transformers/fusion_transformer.py b/bayesflow/networks/transformers/fusion_transformer.py index 2c5624c93..89f51100d 100644 --- a/bayesflow/networks/transformers/fusion_transformer.py +++ b/bayesflow/networks/transformers/fusion_transformer.py @@ -120,6 +120,7 @@ def __init__( raise ValueError("Argument `template_dim` should be in ['lstm', 'gru']") self.output_projector = keras.layers.Dense(summary_dim) + self.summary_dim = summary_dim def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index bd3333a20..c301f1537 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -125,6 +125,7 @@ def __init__( ) self.pooling_by_attention = PoolingByMultiHeadAttention(**(global_attention_settings | pooling_settings)) self.output_projector = keras.layers.Dense(summary_dim) + self.summary_dim = summary_dim def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. diff --git a/bayesflow/networks/transformers/time_series_transformer.py b/bayesflow/networks/transformers/time_series_transformer.py index 9ec7cc9bd..6429743a6 100644 --- a/bayesflow/networks/transformers/time_series_transformer.py +++ b/bayesflow/networks/transformers/time_series_transformer.py @@ -103,6 +103,7 @@ def __init__( # Pooling will be applied as a final step to the abstract representations obtained from set attention self.pooling = keras.layers.GlobalAvgPool1D() self.output_projector = keras.layers.Dense(summary_dim) + self.summary_dim = summary_dim self.time_axis = time_axis diff --git a/tests/conftest.py b/tests/conftest.py index 5e6598ba8..315efd531 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,11 +32,9 @@ def conditions_size(request): return request.param -@pytest.fixture(scope="function") -def coupling_flow(): - from bayesflow.networks import CouplingFlow - - return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32))) +@pytest.fixture(params=[1, 4], scope="session") +def summary_dim(request): + return request.param @pytest.fixture(params=["two_moons"], scope="session") @@ -49,16 +47,6 @@ def feature_size(request): return request.param -@pytest.fixture(params=["coupling_flow"], scope="function") -def inference_network(request): - return request.getfixturevalue(request.param) - - -@pytest.fixture(params=["inference_network", "summary_network"], scope="function") -def network(request): - return request.getfixturevalue(request.param) - - @pytest.fixture(scope="session") def random_conditions(batch_size, conditions_size): if conditions_size is None: @@ -94,13 +82,6 @@ def simulator(request): return request.getfixturevalue(request.param) -@pytest.fixture(params=[None], scope="function") -def summary_network(request): - if request.param is None: - return None - return request.getfixturevalue(request.param) - - @pytest.fixture(scope="session") def training_dataset(simulator, batch_size): from bayesflow.datasets import OfflineDataset diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index a42289f34..2310fbe33 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -1,13 +1,6 @@ import pytest -@pytest.fixture() -def deep_set(): - from bayesflow.networks import DeepSet - - return DeepSet() - - # For the serialization tests, we want to test passing str and type. # For all other tests, this is not necessary and would double test time. # Therefore, below we specify two variants of each network, one without and @@ -79,23 +72,29 @@ def inference_network_subnet(request): return request.getfixturevalue(request.param) -@pytest.fixture() -def lst_net(): +@pytest.fixture(scope="function") +def lst_net(summary_dim): from bayesflow.networks import LSTNet - return LSTNet() + return LSTNet(summary_dim=summary_dim) -@pytest.fixture() -def set_transformer(): +@pytest.fixture(scope="function") +def set_transformer(summary_dim): from bayesflow.networks import SetTransformer - return SetTransformer() + return SetTransformer(summary_dim=summary_dim) + + +@pytest.fixture(scope="function") +def deep_set(summary_dim): + from bayesflow.networks import DeepSet + + return DeepSet(summary_dim=summary_dim) -@pytest.fixture(params=[None, "deep_set", "lst_net", "set_transformer"]) -def summary_network(request): +@pytest.fixture(params=[None, "lst_net", "set_transformer", "deep_set"], scope="function") +def summary_network(request, summary_dim): if request.param is None: return None - return request.getfixturevalue(request.param) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index c2abcb0d6..44e3dfee1 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -79,3 +79,31 @@ def test_save_and_load(tmp_path, summary_network, random_set): loaded = keras.saving.load_model(tmp_path / "model.keras") assert_layers_equal(summary_network, loaded) + + +@pytest.mark.parametrize("stage", ["training", "validation"]) +def test_compute_metrics(stage, summary_network, random_set): + if summary_network is None: + pytest.skip() + + summary_network.build(keras.ops.shape(random_set)) + + metrics = summary_network.compute_metrics(random_set, stage=stage) + + assert "outputs" in metrics + + # check that the batch dimension is preserved + assert keras.ops.shape(metrics["outputs"])[0] == keras.ops.shape(random_set)[0] + + # check summary dimension + summary_dim = summary_network.summary_dim + assert keras.ops.shape(metrics["outputs"])[-1] == summary_dim + + if summary_network.base_distribution is not None: + assert "loss" in metrics + assert keras.ops.shape(metrics["loss"]) == () + + if stage != "training": + for metric in summary_network.metrics: + assert metric.name in metrics + assert keras.ops.shape(metrics[metric.name]) == () diff --git a/tests/test_two_moons/conftest.py b/tests/test_two_moons/conftest.py index 7a5bb870f..6e3db2674 100644 --- a/tests/test_two_moons/conftest.py +++ b/tests/test_two_moons/conftest.py @@ -1,6 +1,13 @@ import pytest +@pytest.fixture() +def inference_network(): + from bayesflow.networks import CouplingFlow + + return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32))) + + @pytest.fixture() def approximator(adapter, inference_network): from bayesflow import ContinuousApproximator