Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
from .free_form_flow import FreeFormFlow
from .inference_network import InferenceNetwork
from .mlp import MLP
from .lstnet import LSTNet
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/free_form_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .free_form_flow import FreeFormFlow
183 changes: 183 additions & 0 deletions bayesflow/networks/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import keras
from keras import ops
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp

from ..inference_network import InferenceNetwork


@serializable(package="networks.free_form_flow")
class FreeFormFlow(InferenceNetwork):
"""Implements a dimensionality-preserving Free-form Flow.
Incorporates ideas from [1-2].

[1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F
ree-form flows: Make Any Architecture a Normalizing Flow.
In International Conference on Artificial Intelligence and Statistics.

[2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., &
Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows.
In International Conference on Learning Representations.
"""

def __init__(
self,
beta: float = 50.0,
encoder_subnet: str | type = "mlp",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is not serializable out of the box, so we would need a from_config method here. But we can add this later.

decoder_subnet: str | type = "mlp",
base_distribution: str = "normal",
hutchinson_sampling: str = "qr",
**kwargs,
):
"""Creates an instance of a Free-form Flow.

Parameters:
-----------
beta : float, optional, default: 50.0
encoder_subnet : str or type, optional, default: "mlp"
A neural network type for the flow, will be instantiated using
encoder_subnet_kwargs. Will be equipped with a projector to ensure
the correct output dimension and a global skip connection.
decoder_subnet : str or type, optional, default: "mlp"
A neural network type for the flow, will be instantiated using
decoder_subnet_kwargs. Will be equipped with a projector to ensure
the correct output dimension and a global skip connection.
base_distribution : str, optional, default: "normal"
The latent distribution
hutchinson_sampling : str, optional, default: "qr
One of `["sphere", "qr"]`. Select the sampling scheme for the
vectors of the Hutchinson trace estimator.
**kwargs : dict, optional, default: {}
Additional keyword arguments
"""
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))
self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs", {}))
self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs", {}))
self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")

self.hutchinson_sampling = hutchinson_sampling
self.beta = beta

self.seed_generator = keras.random.SeedGenerator()

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
self.encoder_projector.units = xz_shape[-1]
self.decoder_projector.units = xz_shape[-1]

# construct input shape for subnet and subnet projector
input_shape = list(xz_shape)

if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]

input_shape = tuple(input_shape)

self.encoder_subnet.build(input_shape)
self.decoder_subnet.build(input_shape)

input_shape = self.encoder_subnet.compute_output_shape(input_shape)
self.encoder_projector.build(input_shape)

input_shape = self.decoder_subnet.compute_output_shape(input_shape)
self.decoder_projector.build(input_shape)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if density:
if conditions is None:
# None cannot be batched, so supply as keyword argument
z, log_det = log_jacobian_determinant(x, self.encode, conditions=None, training=training, **kwargs)
else:
# conditions should be batched, supply as positional argument
z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs)

log_density = self.base_distribution.log_prob(z) + log_det
return z, log_density

z = self.encode(x, conditions, training=training, **kwargs)
return z

def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if density:
if conditions is None:
# None cannot be batched, so supply as keyword argument
x, log_det = log_jacobian_determinant(z, self.decode, conditions=None, training=training, **kwargs)
else:
# conditions should be batched, supply as positional argument
x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs)
log_density = self.base_distribution.log_prob(z) - log_det
return x, log_density

x = self.decode(z, conditions, training=training, **kwargs)
return x

def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
if conditions is None:
inp = x
else:
inp = concatenate(x, conditions, axis=-1)
network_out = self.encoder_projector(
self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer this non-nested for better errors, but again, not a major issue.

)
return network_out + x

def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
if conditions is None:
inp = z
else:
inp = concatenate(z, conditions, axis=-1)
network_out = self.decoder_projector(
self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
)
return network_out + z

def _sample_v(self, x):
batch_size = ops.shape(x)[0]
total_dim = ops.shape(x)[-1]
match self.hutchinson_sampling:
case "qr":
# Use QR decomposition as described in [2]
v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator)
q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x))
v = q * ops.sqrt(total_dim)
case "sphere":
# Sample from sphere with radius sqrt(total_dim), as implemented in [1]
v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator)
v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True))
case _:
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.")
return v

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
# sample random vector
v = self._sample_v(x)

def encode(x):
return self.encode(x, conditions, training=stage == "training")

def decode(z):
return self.decode(z, conditions, training=stage == "training")

# VJP computation
z, vjp_fn = vjp(encode, x)
v1 = vjp_fn(v)[0]
# JVP computation
x_pred, v2 = jvp(decode, (z,), (v,))

# equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0]
surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1)
nll = -self.base_distribution.log_prob(z)
maximum_likelihood_loss = nll - surrogate
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss)

return base_metrics | {"loss": loss}
2 changes: 2 additions & 0 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
parse_bytes,
)
from .jacobian_trace import jacobian_trace
from .jacobian import compute_jacobian, log_jacobian_determinant
from .jvp import jvp
from .vjp import vjp
from .optimal_transport import optimal_transport
from .tensor_utils import (
expand_left,
Expand Down
129 changes: 129 additions & 0 deletions bayesflow/utils/jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from collections.abc import Callable
import keras
from keras import ops
from bayesflow.types import Tensor

from functools import partial, wraps


def compute_jacobian(
x_in: Tensor,
fn: Callable,
*func_args: any,
grad_type: str = "backward",
**func_kwargs: any,
) -> tuple[Tensor, Tensor]:
"""Computes the Jacobian of a function with respect to its input.

:param x_in: The input tensor to compute the jacobian at.
Shape: (batch_size, in_dim).
:param fn: The function to compute the jacobian of, which transforms
`x` to `fn(x)` of shape (batch_size, out_dim).
:param func_args: The positional arguments to pass to the function.
func_args are batched over the first dimension.
:param grad_type: The type of gradient to use. Either 'backward' or
'forward'.
:param func_kwargs: The keyword arguments to pass to the function.
func_kwargs are not batched.
:return: The output of the function `fn(x)` and the jacobian
of the function with respect to its input `x` of shape
(batch_size, out_dim, in_dim)."""

def batch_wrap(fn: Callable) -> Callable:
"""Add a batch dimension to each tensor argument.

:param fn:
:return: wrapped function"""

def deep_unsqueeze(arg):
if ops.is_tensor(arg):
return arg[None, ...]
elif isinstance(arg, dict):
return {key: deep_unsqueeze(value) for key, value in arg.items()}
elif isinstance(arg, (list, tuple)):
return [deep_unsqueeze(value) for value in arg]
raise ValueError(f"Argument cannot be batched: {arg}")

@wraps(fn)
def wrapper(*args, **kwargs):
args = deep_unsqueeze(args)
return fn(*args, **kwargs)[0]

return wrapper

def double_output(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
out = fn(*args, **kwargs)
return out, out

return wrapper

match keras.backend.backend():
case "torch":
import torch
from torch.func import jacrev, jacfwd, vmap

jacfn = jacrev if grad_type == "backward" else jacfwd
with torch.inference_mode(False):
with torch.no_grad():
fn_kwargs_prefilled = partial(fn, **func_kwargs)
fn_batch_expanded = batch_wrap(fn_kwargs_prefilled)
fn_return_val = double_output(fn_batch_expanded)
fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True))
jac, x_out = fn_jac_batched(x_in, *func_args)
case "jax":
from jax import jacrev, jacfwd, vmap

jacfn = jacrev if grad_type == "backward" else jacfwd
fn_kwargs_prefilled = partial(fn, **func_kwargs)
fn_batch_expanded = batch_wrap(fn_kwargs_prefilled)
fn_return_val = double_output(fn_batch_expanded)
fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True))
jac, x_out = fn_jac_batched(x_in, *func_args)
case "tensorflow":
if grad_type == "forward":
raise NotImplementedError("For TensorFlow, only backward mode Jacobian computation is available.")
import tensorflow as tf

with tf.GradientTape() as tape:
tape.watch(x_in)
x_out = fn(x_in, *func_args, **func_kwargs)
jac = tape.batch_jacobian(x_out, x_in)

case _:
raise NotImplementedError(f"compute_jacobian not implemented for {keras.backend.backend()}.")
return x_out, jac


def log_jacobian_determinant(
x_in: Tensor,
fn: Callable,
*func_args: any,
grad_type: str = "backward",
**func_kwargs: any,
) -> tuple[Tensor, Tensor]:
"""Computes the log Jacobian determinant of a function
with respect to its input.

:param x_in: The input tensor to compute the jacobian at.
Shape: (batch_size, in_dim).
:param fn: The function to compute the jacobian of, which transforms
`x` to `fn(x)` of shape (batch_size, out_dim).
:param func_args: The positional arguments to pass to the function.
func_args are batched over the first dimension.
:param grad_type: The type of gradient to use. Either 'backward' or
'forward'.
:param func_kwargs: The keyword arguments to pass to the function.
func_kwargs are not batched.
:return: The output of the function `fn(x)` and the log jacobian determinant
of the function with respect to its input `x` of shape
(batch_size, out_dim, in_dim)."""

x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs)
jac = ops.reshape(
jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer this in multiple lines

)
log_det = ops.slogdet(jac)[1]

return x_out, log_det
38 changes: 0 additions & 38 deletions bayesflow/utils/jacobian_trace/_vjp.py

This file was deleted.

Loading
Loading