Skip to content

Improve Summary Networks #243

@LarsKue

Description

@LarsKue

In addition to #242, we could do with some general clean-up of the summary networks. This is mostly in relation to their constructor arguments, most of which are not meaningful to non-devs. I think it would also help to simplify the implementation, e.g. by dropping the multiple poolings used in DeepSet in favor of a single pooling.

This might make the networks slightly less configurable, but could greatly improve their overall usability in terms of lowering the entry barrier to configuration. For most users, I think this would be a benefit, and power users could still just implement their own version to circumvent the potential reduction of configuration options.

This is just an idea meant for discussion. I would be glad to hear your thoughts!

Here is an example implementation for a reduced DeepSet:

@serializable(package="bayesflow.networks")
class DeepSet(SummaryNetwork):
    def __init__(
            self,
            *,
            summary_dim: int = 16,
            widths: (Sequence[int], Sequence[int]) = ((128, 128), (128, 128)),
            pooling: Literal["sum", "mean"] = "mean",
            activation: str = "gelu",
            dropout: float | None = 0.05,
            **kwargs,
    ):
        super().__init__(**kwargs)

        self.summary_dim = summary_dim
        self.pooling = pooling

        self.equivariant_mlp = MLP(
            widths=widths[0],
            activation=activation,
            dropout=dropout,
        )

        self.pooling_layer = PoolingLayer(pooling)

        self.invariant_mlp = MLP(
            widths=widths[1],
            activation=activation,
            dropout=dropout,
        )

        self.output_projector = keras.layers.Dense(summary_dim, activation=None)

    @property
    def pooling(self):
        return self.pooling_layer.method

    @pooling.setter
    def pooling(self, pooling: Literal["sum", "mean"]):
        self.pooling_layer = PoolingLayer(pooling)

    def build(self, input_shape):
        super().build(input_shape)
        self.call(keras.ops.zeros(input_shape))

    def call(self, x: Tensor, **kwargs) -> Tensor:
        x = self.equivariant_mlp(x)
        x = self.pooling_layer(x)
        x = self.invariant_mlp(x)
        x = self.output_projector(x)

        return x

which uses this PoolingLayer:

@serializable(package="bayesflow.networks")
class PoolingLayer(keras.Layer):
    def __init__(self, method: Literal["mean", "sum"] = "mean", axis: int = 1, keepdims: bool = False, **kwargs):
        super().__init__(**kwargs)

        self.method = method
        self.axis = axis
        self.keepdims = keepdims

        # stateless layers do not need to be built
        self.built = True

    def call(self, x: Tensor) -> Tensor:
        match self.method:
            case "mean":
                x = keras.ops.mean(x, axis=self.axis, keepdims=self.keepdims)
            case "sum":
                x = keras.ops.sum(x, axis=self.axis, keepdims=self.keepdims)
            case other:
                raise ValueError(f"Unknown pooling method: {other!r}")

        return x

Metadata

Metadata

Assignees

Labels

discussionDiscuss a topic or question not necessarily with a clear output in mind.

Type

No type

Projects

Status

Done

Relationships

None yet

Development

No branches or pull requests

Issue actions