From 32b5a0342d0af33acf29e53feb72276d1ddaa4fd Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 20 Apr 2025 07:56:14 +0000 Subject: [PATCH 1/3] add failing tests for TimeSeriesTransformer The deserialization of the `TimeSeriesTransformer` is broken, but it went undetected as this network is not included in the tests yet. This commit adds the network, but does not resolve the deserialization problems yet. Regards #423 --- tests/test_networks/conftest.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index b4ad8df99..a3eb10f7a 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -119,6 +119,13 @@ def time_series_network(summary_dim): return TimeSeriesNetwork(summary_dim=summary_dim) +@pytest.fixture(scope="function") +def time_series_transformer(summary_dim): + from bayesflow.networks import TimeSeriesTransformer + + return TimeSeriesTransformer(summary_dim=summary_dim) + + @pytest.fixture(scope="function") def set_transformer(summary_dim): from bayesflow.networks import SetTransformer @@ -133,7 +140,9 @@ def deep_set(summary_dim): return DeepSet(summary_dim=summary_dim) -@pytest.fixture(params=[None, "time_series_network", "set_transformer", "deep_set"], scope="function") +@pytest.fixture( + params=[None, "time_series_network", "time_series_transformer", "set_transformer", "deep_set"], scope="function" +) def summary_network(request, summary_dim): if request.param is None: return None From 5c2f3904072cf3a7c9f68d6fa000d0675990b449 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 20 Apr 2025 08:34:34 +0000 Subject: [PATCH 2/3] Add (de)serialization code for TimeSeriesTransformer --- .../transformers/time_series_transformer.py | 46 +++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/transformers/time_series_transformer.py b/bayesflow/networks/transformers/time_series_transformer.py index dac0e52b7..4e5599ae6 100644 --- a/bayesflow/networks/transformers/time_series_transformer.py +++ b/bayesflow/networks/transformers/time_series_transformer.py @@ -1,8 +1,8 @@ import keras from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same -from bayesflow.utils.serialization import serializable +from bayesflow.utils import check_lengths_same, model_kwargs +from bayesflow.utils.serialization import deserialize, serializable, serialize from ..embeddings import Time2Vec, RecurrentEmbedding from ..summary_network import SummaryNetwork @@ -103,9 +103,22 @@ def __init__( # Pooling will be applied as a final step to the abstract representations obtained from set attention self.pooling = keras.layers.GlobalAvgPool1D() self.output_projector = keras.layers.Dense(summary_dim) - self.summary_dim = summary_dim + # store variables for serialization + self.summary_dim = summary_dim + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_depths = mlp_depths + self.mlp_widths = mlp_widths + self.dropout = dropout + self.mlp_activation = mlp_activation + self.kernel_initializer = kernel_initializer + self.use_bias = use_bias + self.layer_norm = layer_norm + self._time_embedding_arg = time_embedding + self.time_embed_dim = time_embed_dim self.time_axis = time_axis + self._kwargs = kwargs def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. @@ -147,3 +160,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens summary = self.pooling(inp) summary = self.output_projector(summary) return summary + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) + + def get_config(self): + base_config = super().get_config() + base_config = model_kwargs(base_config) + + config = { + "summary_dim": self.summary_dim, + "embed_dims": self.embed_dims, + "num_heads": self.num_heads, + "mlp_depths": self.mlp_depths, + "mlp_widths": self.mlp_widths, + "dropout": self.dropout, + "mlp_activation": self.mlp_activation, + "kernel_initializer": self.kernel_initializer, + "use_bias": self.use_bias, + "layer_norm": self.layer_norm, + "time_embedding": self._time_embedding_arg, + "time_embed_dim": self.time_embed_dim, + "time_axis": self.time_axis, + **self._kwargs, + } + + return base_config | serialize(config) From a814aa53722f7e032f575c0a9011cce2f81c9487 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 20 Apr 2025 08:52:42 +0000 Subject: [PATCH 3/3] Add serialization and tests for FusionTransformer --- .../transformers/fusion_transformer.py | 44 ++++++++++++++++++- tests/test_networks/conftest.py | 17 ++++++- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/transformers/fusion_transformer.py b/bayesflow/networks/transformers/fusion_transformer.py index 1821c25d2..cc5f14e65 100644 --- a/bayesflow/networks/transformers/fusion_transformer.py +++ b/bayesflow/networks/transformers/fusion_transformer.py @@ -2,8 +2,8 @@ from keras import layers from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same -from bayesflow.utils.serialization import serializable +from bayesflow.utils import check_lengths_same, model_kwargs +from bayesflow.utils.serialization import deserialize, serializable, serialize from ..summary_network import SummaryNetwork @@ -121,6 +121,19 @@ def __init__( self.output_projector = keras.layers.Dense(summary_dim) self.summary_dim = summary_dim + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_depths = mlp_depths + self.mlp_widths = mlp_widths + self.dropout = dropout + self.mlp_activation = mlp_activation + self.kernel_initializer = kernel_initializer + self.use_bias = use_bias + self.layer_norm = layer_norm + self.template_type = template_type + self.bidirectional = bidirectional + self.template_dim = template_dim + self._kwargs = kwargs def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. @@ -151,3 +164,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs) summary = self.output_projector(keras.ops.squeeze(summary, axis=1)) return summary + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) + + def get_config(self): + base_config = super().get_config() + base_config = model_kwargs(base_config) + + config = { + "summary_dim": self.summary_dim, + "embed_dims": self.embed_dims, + "num_heads": self.num_heads, + "mlp_depths": self.mlp_depths, + "mlp_widths": self.mlp_widths, + "dropout": self.dropout, + "mlp_activation": self.mlp_activation, + "kernel_initializer": self.kernel_initializer, + "use_bias": self.use_bias, + "layer_norm": self.layer_norm, + "template_type": self.template_type, + "bidirectional": self.bidirectional, + "template_dim": self.template_dim, + **self._kwargs, + } + + return base_config | serialize(config) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index a3eb10f7a..84c011812 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -126,6 +126,13 @@ def time_series_transformer(summary_dim): return TimeSeriesTransformer(summary_dim=summary_dim) +@pytest.fixture(scope="function") +def fusion_transformer(summary_dim): + from bayesflow.networks import FusionTransformer + + return FusionTransformer(summary_dim=summary_dim) + + @pytest.fixture(scope="function") def set_transformer(summary_dim): from bayesflow.networks import SetTransformer @@ -141,7 +148,15 @@ def deep_set(summary_dim): @pytest.fixture( - params=[None, "time_series_network", "time_series_transformer", "set_transformer", "deep_set"], scope="function" + params=[ + None, + "time_series_network", + "time_series_transformer", + "fusion_transformer", + "set_transformer", + "deep_set", + ], + scope="function", ) def summary_network(request, summary_dim): if request.param is None: