Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/networks/transformers/fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
raise ValueError("Argument `template_dim` should be in ['lstm', 'gru']")

self.output_projector = keras.layers.Dense(summary_dim)
self.summary_dim = summary_dim

def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
"""Compresses the input sequence into a summary vector of size `summary_dim`.
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/transformers/set_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
)
self.pooling_by_attention = PoolingByMultiHeadAttention(**(global_attention_settings | pooling_settings))
self.output_projector = keras.layers.Dense(summary_dim)
self.summary_dim = summary_dim

def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
"""Compresses the input sequence into a summary vector of size `summary_dim`.
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/transformers/time_series_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
# Pooling will be applied as a final step to the abstract representations obtained from set attention
self.pooling = keras.layers.GlobalAvgPool1D()
self.output_projector = keras.layers.Dense(summary_dim)
self.summary_dim = summary_dim

self.time_axis = time_axis

Expand Down
25 changes: 3 additions & 22 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ def conditions_size(request):
return request.param


@pytest.fixture(scope="function")
def coupling_flow():
from bayesflow.networks import CouplingFlow

return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32)))
@pytest.fixture(params=[1, 4], scope="session")
def summary_dim(request):
return request.param


@pytest.fixture(params=["two_moons"], scope="session")
Expand All @@ -49,16 +47,6 @@ def feature_size(request):
return request.param


@pytest.fixture(params=["coupling_flow"], scope="function")
def inference_network(request):
return request.getfixturevalue(request.param)


@pytest.fixture(params=["inference_network", "summary_network"], scope="function")
def network(request):
return request.getfixturevalue(request.param)


@pytest.fixture(scope="session")
def random_conditions(batch_size, conditions_size):
if conditions_size is None:
Expand Down Expand Up @@ -94,13 +82,6 @@ def simulator(request):
return request.getfixturevalue(request.param)


@pytest.fixture(params=[None], scope="function")
def summary_network(request):
if request.param is None:
return None
return request.getfixturevalue(request.param)


@pytest.fixture(scope="session")
def training_dataset(simulator, batch_size):
from bayesflow.datasets import OfflineDataset
Expand Down
31 changes: 15 additions & 16 deletions tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import pytest


@pytest.fixture()
def deep_set():
from bayesflow.networks import DeepSet

return DeepSet()


# For the serialization tests, we want to test passing str and type.
# For all other tests, this is not necessary and would double test time.
# Therefore, below we specify two variants of each network, one without and
Expand Down Expand Up @@ -79,23 +72,29 @@ def inference_network_subnet(request):
return request.getfixturevalue(request.param)


@pytest.fixture()
def lst_net():
@pytest.fixture(scope="function")
def lst_net(summary_dim):
from bayesflow.networks import LSTNet

return LSTNet()
return LSTNet(summary_dim=summary_dim)


@pytest.fixture()
def set_transformer():
@pytest.fixture(scope="function")
def set_transformer(summary_dim):
from bayesflow.networks import SetTransformer

return SetTransformer()
return SetTransformer(summary_dim=summary_dim)


@pytest.fixture(scope="function")
def deep_set(summary_dim):
from bayesflow.networks import DeepSet

return DeepSet(summary_dim=summary_dim)


@pytest.fixture(params=[None, "deep_set", "lst_net", "set_transformer"])
def summary_network(request):
@pytest.fixture(params=[None, "lst_net", "set_transformer", "deep_set"], scope="function")
def summary_network(request, summary_dim):
if request.param is None:
return None

return request.getfixturevalue(request.param)
28 changes: 28 additions & 0 deletions tests/test_networks/test_summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,31 @@ def test_save_and_load(tmp_path, summary_network, random_set):
loaded = keras.saving.load_model(tmp_path / "model.keras")

assert_layers_equal(summary_network, loaded)


@pytest.mark.parametrize("stage", ["training", "validation"])
def test_compute_metrics(stage, summary_network, random_set):
if summary_network is None:
pytest.skip()

summary_network.build(keras.ops.shape(random_set))

metrics = summary_network.compute_metrics(random_set, stage=stage)

assert "outputs" in metrics

# check that the batch dimension is preserved
assert keras.ops.shape(metrics["outputs"])[0] == keras.ops.shape(random_set)[0]

# check summary dimension
summary_dim = summary_network.summary_dim
assert keras.ops.shape(metrics["outputs"])[-1] == summary_dim

if summary_network.base_distribution is not None:
assert "loss" in metrics
assert keras.ops.shape(metrics["loss"]) == ()

if stage != "training":
for metric in summary_network.metrics:
assert metric.name in metrics
assert keras.ops.shape(metrics[metric.name]) == ()
7 changes: 7 additions & 0 deletions tests/test_two_moons/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import pytest


@pytest.fixture()
def inference_network():
from bayesflow.networks import CouplingFlow

return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32)))


@pytest.fixture()
def approximator(adapter, inference_network):
from bayesflow import ContinuousApproximator
Expand Down