Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bayesflow/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from bayesflow.utils.decorators import sanitize_input_shape

from .equivariant_module import EquivariantModule
from .invariant_module import InvariantModule
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self.output_projector = keras.layers.Dense(summary_dim, activation="linear")
self.summary_dim = summary_dim

@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/networks/deep_set/equivariant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils.decorators import sanitize_input_shape
from .invariant_module import InvariantModule


Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(

self.layer_norm = layers.LayerNormalization() if layer_norm else None

@sanitize_input_shape
def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))

Expand Down
2 changes: 2 additions & 0 deletions bayesflow/networks/deep_set/invariant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from bayesflow.types import Tensor
from bayesflow.utils import find_pooling
from bayesflow.utils.decorators import sanitize_input_shape


@serializable(package="bayesflow.networks")
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(

self.pooling_layer = find_pooling(pooling, **pooling_kwargs)

@sanitize_input_shape
def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))

Expand Down
2 changes: 2 additions & 0 deletions bayesflow/networks/lstnet/lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils.decorators import sanitize_input_shape
from .skip_recurrent import SkipRecurrentNet
from ..summary_network import SummaryNetwork

Expand Down Expand Up @@ -78,6 +79,7 @@ def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
x = self.output_projector(x)
return x

@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
2 changes: 2 additions & 0 deletions bayesflow/networks/lstnet/skip_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs, find_recurrent_net
from bayesflow.utils.decorators import sanitize_input_shape


@serializable(package="bayesflow.networks")
Expand Down Expand Up @@ -58,5 +59,6 @@ def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor:
skip_summary = self.skip_recurrent(self.skip_conv(time_series), training=training)
return keras.ops.concatenate((direct_summary, skip_summary), axis=-1)

@sanitize_input_shape
def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))
2 changes: 2 additions & 0 deletions bayesflow/networks/summary_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from bayesflow.metrics.functional import maximum_mean_discrepancy
from bayesflow.types import Tensor
from bayesflow.utils import find_distribution, keras_kwargs
from bayesflow.utils.decorators import sanitize_input_shape


class SummaryNetwork(keras.Layer):
def __init__(self, base_distribution: str = None, **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.base_distribution = find_distribution(base_distribution)

@sanitize_input_shape
def build(self, input_shape):
if self.base_distribution is not None:
output_shape = keras.ops.shape(self.call(keras.ops.zeros(input_shape)))
Expand Down
18 changes: 18 additions & 0 deletions bayesflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
import inspect
from typing import overload, TypeVar
from bayesflow.types import Shape

Fn = TypeVar("Fn", bound=Callable[..., any])

Expand Down Expand Up @@ -110,3 +111,20 @@ def callback(x):
fn = alias("batch_shape", "batch_size")(fn)

return fn


def sanitize_input_shape(fn: Callable):
"""Decorator to replace the first dimension in input_shape with a dummy batch size if it is None"""

# The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which
# causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used
# in build. To alleviate those problems, this decorator replaces None with an arbitrary batch size.
def callback(input_shape: Shape) -> Shape:
if input_shape[0] is None:
input_shape = list(input_shape)
input_shape[0] = 32
return tuple(input_shape)
return input_shape

fn = argument_callback("input_shape", callback)(fn)
return fn