Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
- name: Install JAX
if: ${{ matrix.backend == 'jax' }}
run: |
pip install -U "jax[cpu]"
pip install -U jax
- name: Install NumPy
if: ${{ matrix.backend == 'numpy' }}
run: |
Expand All @@ -90,6 +90,7 @@ jobs:
conda config --show-sources
conda config --show
printenv | sort
conda env export

- name: Run Tests
run: |
Expand Down
46 changes: 18 additions & 28 deletions bayesflow/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import keras
from keras import layers
from keras.saving import register_keras_serializable as serializable, serialize_keras_object as serialize
from collections.abc import Sequence

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from .invariant_module import InvariantModule
from .equivariant_module import EquivariantModule

from .invariant_module import InvariantModule
from ..summary_network import SummaryNetwork


Expand All @@ -27,12 +26,12 @@ def __init__(
self,
summary_dim: int = 16,
depth: int = 2,
inner_pooling: str | keras.Layer = "mean",
output_pooling: str | keras.Layer = "mean",
mlp_widths_equivariant: tuple = (128, 128),
mlp_widths_invariant_inner: tuple = (128, 128),
mlp_widths_invariant_outer: tuple = (128, 128),
mlp_widths_invariant_last: tuple = (128, 128),
inner_pooling: str = "mean",
output_pooling: str = "mean",
mlp_widths_equivariant: Sequence[int] = (128, 128),
mlp_widths_invariant_inner: Sequence[int] = (128, 128),
mlp_widths_invariant_outer: Sequence[int] = (128, 128),
mlp_widths_invariant_last: Sequence[int] = (128, 128),
activation: str = "gelu",
kernel_initializer: str = "he_normal",
dropout: int | float | None = 0.05,
Expand All @@ -46,7 +45,7 @@ def __init__(
super().__init__(**kwargs)

# Stack of equivariant modules for a many-to-many learnable transformation
self.equivariant_modules = keras.Sequential()
self.equivariant_modules = []
for _ in range(depth):
equivariant_module = EquivariantModule(
mlp_widths_equivariant=mlp_widths_equivariant,
Expand All @@ -59,7 +58,7 @@ def __init__(
pooling=inner_pooling,
**kwargs,
)
self.equivariant_modules.add(equivariant_module)
self.equivariant_modules.append(equivariant_module)

# Invariant module for a many-to-one transformation
self.invariant_module = InvariantModule(
Expand All @@ -74,32 +73,23 @@ def __init__(
)

# Output linear layer to project set representation down to "summary_dim" learned summary statistics
self.output_projector = layers.Dense(summary_dim, activation="linear")
self.output_projector = keras.layers.Dense(summary_dim, activation="linear")
self.summary_dim = summary_dim

def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))

def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
"""Performs the forward pass of a learnable deep invariant transformation consisting of
a sequence of equivariant transforms followed by an invariant transform.

#TODO
"""
x = self.equivariant_modules(input_set, training=training)
x = self.invariant_module(x, training=training)

return self.output_projector(x)

def get_config(self):
base_config = super().get_config()
for em in self.equivariant_modules:
x = em(x, training=training)

config = {
"invariant_module": serialize(self.equivariant_modules),
"equivariant_fc": serialize(self.invariant_module),
"output_projector": serialize(self.output_projector),
"summary_dim": serialize(self.summary_dim),
}
x = self.invariant_module(x, training=training)

return base_config | config
return self.output_projector(x)
10 changes: 6 additions & 4 deletions bayesflow/networks/deep_set/equivariant_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Sequence

import keras
from keras import ops, layers
from keras.saving import register_keras_serializable as serializable
Expand All @@ -19,10 +21,10 @@ class EquivariantModule(keras.Layer):

def __init__(
self,
mlp_widths_equivariant: tuple = (128, 128),
mlp_widths_invariant_inner: tuple = (128, 128),
mlp_widths_invariant_outer: tuple = (128, 128),
pooling: str | keras.Layer = "mean",
mlp_widths_equivariant: Sequence[int] = (128, 128),
mlp_widths_invariant_inner: Sequence[int] = (128, 128),
mlp_widths_invariant_outer: Sequence[int] = (128, 128),
pooling: str = "mean",
activation: str = "gelu",
kernel_initializer: str = "he_normal",
dropout: int | float | None = 0.05,
Expand Down
11 changes: 7 additions & 4 deletions bayesflow/networks/deep_set/invariant_module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import keras
from keras import layers
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs
from bayesflow.utils import find_pooling
from bayesflow.utils import keras_kwargs


@serializable(package="bayesflow.networks")
Expand All @@ -19,12 +21,12 @@ class InvariantModule(keras.Layer):

def __init__(
self,
mlp_widths_inner: tuple = (128, 128),
mlp_widths_outer: tuple = (128, 128),
mlp_widths_inner: Sequence[int] = (128, 128),
mlp_widths_outer: Sequence[int] = (128, 128),
activation: str = "gelu",
kernel_initializer: str = "he_normal",
dropout: int | float | None = 0.05,
pooling: str | keras.Layer = "mean",
pooling: str = "mean",
spectral_normalization: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -54,6 +56,7 @@ def __init__(
self.inner_fc.add(layer)

# Outer fully connected net for sum decomposition: inner( pooling( inner(set) ) )
# TODO: why does using Sequential work here, but not in DeepSet?
self.outer_fc = keras.Sequential()
for width in mlp_widths_outer:
if dropout is not None and dropout > 0:
Expand Down
41 changes: 15 additions & 26 deletions bayesflow/networks/lstnet/lstnet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import keras
from keras import layers, Sequential
from keras.saving import register_keras_serializable as serializable, serialize_keras_object as serialize
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor

from .skip_recurrent import SkipRecurrentNet
from ..summary_network import SummaryNetwork

Expand All @@ -30,7 +28,7 @@ def __init__(
activation: str = "mish",
kernel_initializer: str = "glorot_uniform",
groups: int = 8,
recurrent_type: str | keras.Layer = "gru",
recurrent_type: str = "gru",
recurrent_dim: int = 128,
bidirectional: bool = True,
dropout: float = 0.05,
Expand All @@ -46,18 +44,19 @@ def __init__(
kernel_sizes = (kernel_sizes,)
if not isinstance(strides, (list, tuple)):
strides = (strides,)
self.conv_blocks = Sequential()
self.conv_blocks = []
for f, k, s in zip(filters, kernel_sizes, strides):
self.conv_blocks.add(
layers.Conv1D(
self.conv_blocks.append(
keras.layers.Conv1D(
filters=f,
kernel_size=k,
strides=s,
activation=activation,
kernel_initializer=kernel_initializer,
padding="same",
)
)
self.conv_blocks.add(layers.GroupNormalization(groups=groups))
self.conv_blocks.append(keras.layers.GroupNormalization(groups=groups))

# Recurrent and feedforward backbones
self.recurrent = SkipRecurrentNet(
Expand All @@ -68,27 +67,17 @@ def __init__(
skip_steps=skip_steps,
dropout=dropout,
)
self.output_projector = layers.Dense(summary_dim)
self.output_projector = keras.layers.Dense(summary_dim)
self.summary_dim = summary_dim

def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor:
summary = self.conv_blocks(time_series, training=training)
summary = self.recurrent(summary, training=training)
summary = self.output_projector(summary)
return summary
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
for c in self.conv_blocks:
x = c(x, training=training)

x = self.recurrent(x, training=training)
x = self.output_projector(x)
return x

def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))

def get_config(self):
base_config = super().get_config()

config = {
"conv_blocks": serialize(self.conv_blocks),
"recurrent": serialize(self.recurrent),
"output_projector": serialize(self.output_projector),
"summary_dim": serialize(self.summary_dim),
}

return base_config | config
7 changes: 5 additions & 2 deletions bayesflow/networks/lstnet/skip_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SkipRecurrentNet(keras.Model):
def __init__(
self,
hidden_dim: int = 256,
recurrent_type: str | keras.Layer = "gru",
recurrent_type: str = "gru",
bidirectional: bool = True,
input_channels: int = 64,
skip_steps: int = 4,
Expand All @@ -32,7 +32,10 @@ def __init__(
super().__init__(**keras_kwargs(kwargs))

self.skip_conv = keras.layers.Conv1D(
filters=input_channels * skip_steps, kernel_size=skip_steps, strides=skip_steps
filters=input_channels * skip_steps,
kernel_size=skip_steps,
strides=skip_steps,
padding="same",
)

recurrent_constructor = find_recurrent_net(recurrent_type)
Expand Down
23 changes: 19 additions & 4 deletions bayesflow/networks/mlp/mlp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence
from typing import Literal

import keras
from keras import layers
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs

from .hidden_block import ConfigurableHiddenBlock


Expand All @@ -19,11 +21,14 @@ class MLP(keras.Layer):

def __init__(
self,
widths: tuple = (512, 512),
*,
depth: int = None,
width: int = None,
widths: Sequence[int] = None,
activation: str = "mish",
kernel_initializer: str = "he_normal",
residual: bool = True,
dropout: float = 0.05,
dropout: Literal[0, None] | float = 0.05,
spectral_normalization: bool = False,
**kwargs,
):
Expand All @@ -45,9 +50,19 @@ def __init__(
dropout : float, optional, default: 0.05
Dropout rate for the hidden layers in the internal layers.
"""

super().__init__(**keras_kwargs(kwargs))

if widths is not None:
if depth is not None or width is not None:
raise ValueError("Either specify 'widths' or 'depth' and 'width', not both.")
else:
if depth is None or width is None:
# use the default
depth = 2
width = 512

widths = [width] * depth

self.res_blocks = []
projector = layers.Dense(
units=widths[0],
Expand Down
30 changes: 30 additions & 0 deletions tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest


@pytest.fixture()
def deep_set():
from bayesflow.networks import DeepSet

return DeepSet()


@pytest.fixture()
def lst_net():
from bayesflow.networks import LSTNet

return LSTNet()


@pytest.fixture()
def set_transformer():
from bayesflow.networks import SetTransformer

return SetTransformer()


@pytest.fixture(params=[None, "deep_set", "lst_net", "set_transformer"])
def summary_network(request):
if request.param is None:
return None

return request.getfixturevalue(request.param)
17 changes: 16 additions & 1 deletion tests/test_networks/test_inference_networks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import keras
import numpy as np
import pytest
from keras.saving import (
deserialize_keras_object as deserialize,
serialize_keras_object as serialize,
)

from tests.utils import allclose, assert_layers_equal

Expand Down Expand Up @@ -121,7 +125,18 @@ def f(x):
assert allclose(inverse_log_density, numerical_inverse_log_density, rtol=1e-4, atol=1e-5)


def test_serialize_deserialize(tmp_path, inference_network, random_samples, random_conditions):
def test_serialize_deserialize(inference_network, random_samples, random_conditions):
# to save, the model must be built
inference_network(random_samples, conditions=random_conditions)

serialized = serialize(inference_network)
deserialized = deserialize(serialized)
reserialized = serialize(deserialized)

assert serialized == reserialized


def test_save_and_load(tmp_path, inference_network, random_samples, random_conditions):
# to save, the model must be built
inference_network(random_samples, conditions=random_conditions)

Expand Down
Loading
Loading