From cfd2e4513bbbc2a2380350f4382b962f62104bff Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Wed, 20 Nov 2024 09:31:57 +0000 Subject: [PATCH 01/13] feat: add free-form flows as inference networks * implements the fff loss * still missing: calculation of the log probability --- bayesflow/networks/__init__.py | 1 + bayesflow/networks/free_form_flow/__init__.py | 1 + .../networks/free_form_flow/free_form_flow.py | 197 ++++++++++++++++++ 3 files changed, 199 insertions(+) create mode 100644 bayesflow/networks/free_form_flow/__init__.py create mode 100644 bayesflow/networks/free_form_flow/free_form_flow.py diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index fce9b27fa..9a915572b 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -3,6 +3,7 @@ from .coupling_flow import CouplingFlow from .deep_set import DeepSet from .flow_matching import FlowMatching +from .free_form_flow import FreeFormFlow from .inference_network import InferenceNetwork from .mlp import MLP from .lstnet import LSTNet diff --git a/bayesflow/networks/free_form_flow/__init__.py b/bayesflow/networks/free_form_flow/__init__.py new file mode 100644 index 000000000..803280523 --- /dev/null +++ b/bayesflow/networks/free_form_flow/__init__.py @@ -0,0 +1 @@ +from .free_form_flow import FreeFormFlow diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py new file mode 100644 index 000000000..ca965ba88 --- /dev/null +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -0,0 +1,197 @@ +import keras +from keras import ops +from keras.saving import register_keras_serializable as serializable + +from bayesflow.types import Tensor +from bayesflow.utils import find_network, keras_kwargs, concatenate + +from ..inference_network import InferenceNetwork + + +@serializable(package="networks.free_form_flow") +class FreeFormFlow(InferenceNetwork): + """Implements a dimensionality-preserving Free-form Flow. + Incorporates ideas from [1-2]. + + [1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F + ree-form flows: Make Any Architecture a Normalizing Flow. + In International Conference on Artificial Intelligence and Statistics. + + [2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., & + Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows. + In International Conference on Learning Representations. + """ + + def __init__( + self, + beta: float = 50.0, + encoder_subnet: str | type = "mlp", + decoder_subnet: str | type = "mlp", + base_distribution: str = "normal", + hutchinson_sampling: str = "qr", + **kwargs, + ): + """Creates an instance of a Free-form Flow. + + Parameters: + ----------- + beta : float, optional, default: 50.0 + encoder_subnet : str or type, optional, default: "mlp" + A neural network type for the flow, will be instantiated using + encoder_subnet_kwargs. Will be equipped with a projector to ensure + the correct output dimension and a global skip connection. + decoder_subnet : str or type, optional, default: "mlp" + A neural network type for the flow, will be instantiated using + decoder_subnet_kwargs. Will be equipped with a projector to ensure + the correct output dimension and a global skip connection. + base_distribution : str, optional, default: "normal" + The latent distribution + hutchinson_sampling : str, optional, default: "qr + One of `["sphere", "qr"]`. Select the sampling scheme for the + vectors of the Hutchinson trace estimator. + **kwargs : dict, optional, default: {} + Additional keyword arguments + """ + super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) + self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs")) + self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") + self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs")) + self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") + + self.hutchinson_sampling = hutchinson_sampling + self.beta = beta + + self.seed_generator = keras.random.SeedGenerator() + + # noinspection PyMethodOverriding + def build(self, xz_shape, conditions_shape=None): + super().build(xz_shape) + self.encoder_projector.units = xz_shape[-1] + self.decoder_projector.units = xz_shape[-1] + + # construct input shape for subnet and subnet projector + input_shape = list(xz_shape) + + if conditions_shape is not None: + input_shape[-1] += conditions_shape[-1] + + input_shape = tuple(input_shape) + + self.encoder_subnet.build(input_shape) + self.decoder_subnet.build(input_shape) + + input_shape = self.encoder_subnet.compute_output_shape(input_shape) + self.encoder_projector.build(input_shape) + + input_shape = self.decoder_subnet.compute_output_shape(input_shape) + self.decoder_projector.build(input_shape) + + def _forward( + self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + z = self.encode(x, conditions, training=training, **kwargs) + + if density: + raise NotImplementedError("density computation not implemented yet") + log_density = None + return z, log_density + + return z + + def _inverse( + self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + x = self.decode(z, conditions, training=training, **kwargs) + + if density: + raise NotImplementedError("density computation not implemented yet") + log_density = None + return x, log_density + + return x + + def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: + if conditions is None: + inp = x + else: + inp = concatenate(x, conditions, axis=-1) + network_out = self.encoder_projector( + self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs + ) + return network_out + x + + def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: + if conditions is None: + inp = z + else: + inp = concatenate(z, conditions, axis=-1) + network_out = self.decoder_projector( + self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs + ) + return network_out + z + + def _sample_v(self, x): + batch_size = ops.shape(x)[0] + total_dim = ops.shape(x)[-1] + match self.hutchinson_sampling: + case "qr": + # Use QR decomposition as described in [2] + v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator) + q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x)) + v = q * ops.sqrt(total_dim) + case "sphere": + # Sample from sphere with radius sqrt(total_dim), as implemented in [1] + v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator) + v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True)) + case _: + raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.") + return v + + def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: + base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) + # sample random vector + v = self._sample_v(x) + + def encode(x): + return self.encode(x, conditions, training=stage == "training") + + def decode(z): + return self.decode(z, conditions, training=stage == "training") + + # calculate VJP and JVP (backend-specific) + match keras.backend.backend(): + case "torch": + import torch + + z, vjp_fn = torch.func.vjp(encode, x) + (v1,) = vjp_fn(v) + x_pred, v2 = torch.func.jvp(decode, (z,), (v,)) + case "jax": + import jax + + z, vjp_fn = jax.vjp(encode, x) + (v1,) = vjp_fn(v) + x_pred, v2 = jax.jvp(decode, (z,), (v,)) + case "tensorflow": + import tensorflow as tf + + # VJP computation + with tf.GradientTape() as tape: + tape.watch(x) + z = encode(x) + v1 = tape.gradient(z, x, v) + # JVP computation + with tf.autodiff.ForwardAccumulator(primals=(z,), tangents=(v,)) as acc: + x_pred = decode(z) + v2 = acc.jvp(x_pred) + case _: + raise NotImplementedError(f"Loss function not implemented for backend {keras.backend.backend()}") + + # equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0] + surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1) + nll = -self.base_distribution.log_prob(z) + maximum_likelihood_loss = nll - surrogate + reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) + loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss) + + return base_metrics | {"loss": loss} From d4a96448f524b55a1011bf7ce81e1a2cc5357623 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 21 Nov 2024 06:38:32 +0000 Subject: [PATCH 02/13] fff: add log jacobian determinant computation --- .../networks/free_form_flow/free_form_flow.py | 16 +-- bayesflow/utils/__init__.py | 1 + bayesflow/utils/jacobian.py | 130 ++++++++++++++++++ 3 files changed, 138 insertions(+), 9 deletions(-) create mode 100644 bayesflow/utils/jacobian.py diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index ca965ba88..6c82a38fe 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -3,7 +3,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs, concatenate +from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant from ..inference_network import InferenceNetwork @@ -89,25 +89,23 @@ def build(self, xz_shape, conditions_shape=None): def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: - z = self.encode(x, conditions, training=training, **kwargs) - if density: - raise NotImplementedError("density computation not implemented yet") - log_density = None + z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs) + log_density = self.base_distribution.log_prob(z) + log_det return z, log_density + z = self.encode(x, conditions, training=training, **kwargs) return z def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: - x = self.decode(z, conditions, training=training, **kwargs) - if density: - raise NotImplementedError("density computation not implemented yet") - log_density = None + x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs) + log_density = self.base_distribution.log_prob(z) - log_det return x, log_density + x = self.decode(z, conditions, training=training, **kwargs) return x def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 669ad33de..e1ad8fa70 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -25,6 +25,7 @@ parse_bytes, ) from .jacobian_trace import jacobian_trace +from .jacobian import compute_jacobian, log_jacobian_determinant from .jvp import jvp from .optimal_transport import optimal_transport from .tensor_utils import ( diff --git a/bayesflow/utils/jacobian.py b/bayesflow/utils/jacobian.py new file mode 100644 index 000000000..cf7e11c0c --- /dev/null +++ b/bayesflow/utils/jacobian.py @@ -0,0 +1,130 @@ +# adapted from https://github.com/vislearn/FFF +import keras +from keras import ops +from bayesflow.types import Tensor + +from functools import partial, wraps + + +def batch_wrap(fn: callable) -> callable: + """Add a batch dimension to each tensor argument. + + :param fn: + :return: wrapped function""" + + def deep_unsqueeze(arg): + if isinstance(arg, dict): + return {key: deep_unsqueeze(value) for key, value in arg.items()} + elif isinstance(arg, (list, tuple)): + return [deep_unsqueeze(value) for value in arg] + else: + return arg[None, ...] + + @wraps(fn) + def wrapper(*args, **kwargs): + args = deep_unsqueeze(args) + return fn(*args, **kwargs)[0] + + return wrapper + + +def double_output(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + out = fn(*args, **kwargs) + return out, out + + return wrapper + + +def compute_jacobian( + x_in: Tensor, + fn: callable, + *func_args: any, + grad_type: str = "backward", + **func_kwargs: any, +) -> tuple[Tensor, Tensor]: + """Computes the Jacobian of a function with respect to its input. + + :param x_in: The input tensor to compute the jacobian at. + Shape: (batch_size, in_dim). + :param fn: The function to compute the jacobian of, which transforms + `x` to `fn(x)` of shape (batch_size, out_dim). + :param func_args: The positional arguments to pass to the function. + func_args are batched over the first dimension. + :param grad_type: The type of gradient to use. Either 'backward' or + 'forward'. + :param func_kwargs: The keyword arguments to pass to the function. + func_kwargs are not batched. + :return: The output of the function `fn(x)` and the jacobian + of the function with respect to its input `x` of shape + (batch_size, out_dim, in_dim).""" + + match keras.backend.backend(): + case "torch": + import torch + from torch.func import jacrev, jacfwd, vmap + + jacfn = jacrev if grad_type == "backward" else jacfwd + with torch.inference_mode(False): + with torch.no_grad(): + fn_kwargs_prefilled = partial(fn, **func_kwargs) + fn_batch_expanded = batch_wrap(fn_kwargs_prefilled) + fn_return_val = double_output(fn_batch_expanded) + fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True)) + jac, x_out = fn_jac_batched(x_in, *func_args) + case "jax": + from jax import jacrev, jacfwd, vmap + + jacfn = jacrev if grad_type == "backward" else jacfwd + fn_kwargs_prefilled = partial(fn, **func_kwargs) + fn_batch_expanded = batch_wrap(fn_kwargs_prefilled) + fn_return_val = double_output(fn_batch_expanded) + fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True)) + jac, x_out = fn_jac_batched(x_in, *func_args) + case "tensorflow": + if grad_type == "forward": + raise NotImplementedError("For TensorFlow, only backward mode Jacobian computation is available.") + import tensorflow as tf + + with tf.GradientTape() as tape: + tape.watch(x_in) + x_out = fn(x_in, *func_args, **func_kwargs) + jac = tape.batch_jacobian(x_out, x_in) + + case _: + raise NotImplementedError(f"compute_jacobian not implemented for {keras.backend.backend()}.") + return x_out, jac + + +def log_jacobian_determinant( + x_in: Tensor, + fn: callable, + *func_args: any, + grad_type: str = "backward", + **func_kwargs: any, +) -> tuple[Tensor, Tensor]: + """Computes the log Jacobian determinant of a function + with respect to its input. + + :param x_in: The input tensor to compute the jacobian at. + Shape: (batch_size, in_dim). + :param fn: The function to compute the jacobian of, which transforms + `x` to `fn(x)` of shape (batch_size, out_dim). + :param func_args: The positional arguments to pass to the function. + func_args are batched over the first dimension. + :param grad_type: The type of gradient to use. Either 'backward' or + 'forward'. + :param func_kwargs: The keyword arguments to pass to the function. + func_kwargs are not batched. + :return: The output of the function `fn(x)` and the log jacobian determinant + of the function with respect to its input `x` of shape + (batch_size, out_dim, in_dim).""" + + x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs) + jac = ops.reshape( + jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:]))) + ) + log_det = ops.slogdet(jac)[1] + + return x_out, log_det From 938a3b82fc81482ae8770b6b6e0dea89c0756644 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 21 Nov 2024 07:44:09 +0000 Subject: [PATCH 03/13] util: make vjp globally accessible Change `torch.autograd.functional.vjp` to `torch.func.vjp` as the former implementation broke gradient flow. It then also uses the same API as Jax, making the code easier to parse. --- bayesflow/utils/__init__.py | 1 + .../jacobian_trace/compute_jacobian_trace.py | 10 ++++------ .../jacobian_trace/estimate_jacobian_trace.py | 8 ++++---- .../utils/{jacobian_trace/_vjp.py => vjp.py} | 20 ++++++++----------- 4 files changed, 17 insertions(+), 22 deletions(-) rename bayesflow/utils/{jacobian_trace/_vjp.py => vjp.py} (72%) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index e1ad8fa70..78f692c0a 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -27,6 +27,7 @@ from .jacobian_trace import jacobian_trace from .jacobian import compute_jacobian, log_jacobian_determinant from .jvp import jvp +from .vjp import vjp from .optimal_transport import optimal_transport from .tensor_utils import ( expand_left, diff --git a/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py b/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py index de03baa0a..7123f7a31 100644 --- a/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py +++ b/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py @@ -3,9 +3,7 @@ import numpy as np from bayesflow.types import Tensor - - -from ._vjp import _make_vjp_fn +from ..vjp import vjp def compute_jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor) -> (Tensor, Tensor): @@ -24,15 +22,15 @@ def compute_jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor) -> (Tensor, shape = keras.ops.shape(x) trace = keras.ops.zeros(shape[:-1]) - fx, vjp_fn = _make_vjp_fn(f, x) + fx, vjp_fn = vjp(f, x) for dim in range(shape[-1]): projector = np.zeros(shape, dtype="float32") projector[..., dim] = 1.0 projector = keras.ops.convert_to_tensor(projector) - vjp = vjp_fn(projector) + vjp_val = vjp_fn(projector) - trace += vjp[..., dim] + trace += vjp_val[..., dim] return fx, trace diff --git a/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py b/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py index c0a867d19..252577858 100644 --- a/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py +++ b/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py @@ -2,7 +2,7 @@ from bayesflow.types import Tensor -from ._vjp import _make_vjp_fn +from ..vjp import vjp def estimate_jacobian_trace(f: callable, x: Tensor, steps: int = 1) -> (Tensor, Tensor): @@ -25,13 +25,13 @@ def estimate_jacobian_trace(f: callable, x: Tensor, steps: int = 1) -> (Tensor, shape = keras.ops.shape(x) trace = keras.ops.zeros(shape[:-1]) - fx, vjp_fn = _make_vjp_fn(f, x) + fx, vjp_fn = vjp(f, x) for _ in range(steps): projector = keras.random.normal(shape) - vjp = vjp_fn(projector) + vjp_val = vjp_fn(projector) - trace += keras.ops.sum(vjp * projector, axis=-1) + trace += keras.ops.sum(vjp_val * projector, axis=-1) return fx, trace diff --git a/bayesflow/utils/jacobian_trace/_vjp.py b/bayesflow/utils/vjp.py similarity index 72% rename from bayesflow/utils/jacobian_trace/_vjp.py rename to bayesflow/utils/vjp.py index b7e71e494..279113e80 100644 --- a/bayesflow/utils/jacobian_trace/_vjp.py +++ b/bayesflow/utils/vjp.py @@ -3,13 +3,20 @@ from bayesflow.types import Tensor -def _make_vjp_fn(f: callable, x: Tensor) -> (Tensor, callable): +def vjp(f: callable, x: Tensor) -> (Tensor, callable): match keras.backend.backend(): case "jax": import jax fx, _vjp_fn = jax.vjp(f, x) + def vjp_fn(projector): + return _vjp_fn(projector)[0] + case "torch": + import torch + + fx, _vjp_fn = torch.func.vjp(f, x) + def vjp_fn(projector): return _vjp_fn(projector)[0] case "tensorflow": @@ -21,17 +28,6 @@ def vjp_fn(projector): def vjp_fn(projector): return tape.gradient(fx, x, projector) - case "torch": - import torch - - x = keras.ops.copy(x) - x.requires_grad_(True) - - with torch.enable_grad(): - fx = f(x) - - def vjp_fn(projector): - return torch.autograd.grad(fx, x, projector, retain_graph=True)[0] case other: raise NotImplementedError(f"Cannot build a vjp function for backend '{other}'.") From e5a6667b427b9cf98bc51616265c8c838d3df7e9 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 21 Nov 2024 07:49:10 +0000 Subject: [PATCH 04/13] utils: change autograd backend for torch jvp Change from `torch.autograd.functional.jvp` to `torch.func.jvp`, as recommended in the documentation. https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html Using autograd.functional seems to break the gradient flow, while `func` does not produce problems in this regard. --- bayesflow/utils/jvp.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/bayesflow/utils/jvp.py b/bayesflow/utils/jvp.py index d086cfdc1..084929937 100644 --- a/bayesflow/utils/jvp.py +++ b/bayesflow/utils/jvp.py @@ -8,24 +8,20 @@ def jvp(fn: callable, primals: tuple[Tensor] | Tensor, tangents: tuple[Tensor] | the input (primals) and vectors in tangents.""" match keras.backend.backend(): + case "jax": + import jax + + fn_output, _jvp = jax.jvp(fn, primals, tangents) case "torch": import torch - fn_output, _jvp = torch.autograd.functional.jvp(fn, primals, tangents) + fn_output, _jvp = torch.func.jvp(fn, primals, tangents) case "tensorflow": import tensorflow as tf with tf.autodiff.ForwardAccumulator(primals=primals, tangents=tangents) as acc: fn_output = fn(*primals) _jvp = acc.jvp(fn_output) - case "jax": - import jax - - fn_output, _jvp = jax.jvp( - fn, - primals, - tangents, - ) case _: raise NotImplementedError(f"JVP not implemented for backend {keras.backend.backend()}") return fn_output, _jvp From e75592d6efc6465203dbda176cc8d22522cc23d1 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 21 Nov 2024 07:51:32 +0000 Subject: [PATCH 05/13] fff: use vjp and jvp from utils --- .../networks/free_form_flow/free_form_flow.py | 35 ++++--------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index 6c82a38fe..1ea321bb3 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -3,7 +3,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant +from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp from ..inference_network import InferenceNetwork @@ -156,34 +156,11 @@ def encode(x): def decode(z): return self.decode(z, conditions, training=stage == "training") - # calculate VJP and JVP (backend-specific) - match keras.backend.backend(): - case "torch": - import torch - - z, vjp_fn = torch.func.vjp(encode, x) - (v1,) = vjp_fn(v) - x_pred, v2 = torch.func.jvp(decode, (z,), (v,)) - case "jax": - import jax - - z, vjp_fn = jax.vjp(encode, x) - (v1,) = vjp_fn(v) - x_pred, v2 = jax.jvp(decode, (z,), (v,)) - case "tensorflow": - import tensorflow as tf - - # VJP computation - with tf.GradientTape() as tape: - tape.watch(x) - z = encode(x) - v1 = tape.gradient(z, x, v) - # JVP computation - with tf.autodiff.ForwardAccumulator(primals=(z,), tangents=(v,)) as acc: - x_pred = decode(z) - v2 = acc.jvp(x_pred) - case _: - raise NotImplementedError(f"Loss function not implemented for backend {keras.backend.backend()}") + # VJP computation + z, vjp_fn = vjp(encode, x) + v1 = vjp_fn(v) + # JVP computation + x_pred, v2 = jvp(decode, (z,), (v,)) # equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0] surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1) From 7bc52a89ced4878bd9bc84577558b54a3bc60645 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 21 Nov 2024 12:36:19 +0100 Subject: [PATCH 06/13] improve docs and type hints --- .../jacobian_trace/compute_jacobian_trace.py | 10 ++--- .../jacobian_trace/estimate_jacobian_trace.py | 3 +- bayesflow/utils/jvp.py | 33 ++++++++++----- bayesflow/utils/vjp.py | 40 +++++++++++-------- 4 files changed, 55 insertions(+), 31 deletions(-) diff --git a/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py b/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py index 7123f7a31..98f3d9697 100644 --- a/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py +++ b/bayesflow/utils/jacobian_trace/compute_jacobian_trace.py @@ -6,10 +6,10 @@ from ..vjp import vjp -def compute_jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor) -> (Tensor, Tensor): +def compute_jacobian_trace(fn: Callable[[Tensor], Tensor], x: Tensor) -> (Tensor, Tensor): """Compute the exact trace of the Jacobian matrix of f by projection on each axis. - :param f: The function to be differentiated. + :param fn: The function to be differentiated. :param x: Tensor of shape (n, ..., d) The input tensor to f. @@ -22,15 +22,15 @@ def compute_jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor) -> (Tensor, shape = keras.ops.shape(x) trace = keras.ops.zeros(shape[:-1]) - fx, vjp_fn = vjp(f, x) + fx, vjp_fn = vjp(fn, x) for dim in range(shape[-1]): projector = np.zeros(shape, dtype="float32") projector[..., dim] = 1.0 projector = keras.ops.convert_to_tensor(projector) - vjp_val = vjp_fn(projector) + vjp_value = vjp_fn(projector)[0] - trace += vjp_val[..., dim] + trace += vjp_value[..., dim] return fx, trace diff --git a/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py b/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py index 252577858..0f612f1e9 100644 --- a/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py +++ b/bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py @@ -1,3 +1,4 @@ +from collections.abc import Callable import keras from bayesflow.types import Tensor @@ -5,7 +6,7 @@ from ..vjp import vjp -def estimate_jacobian_trace(f: callable, x: Tensor, steps: int = 1) -> (Tensor, Tensor): +def estimate_jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor, steps: int = 1) -> (Tensor, Tensor): """Estimate the trace of the Jacobian matrix of f using Hutchinson's algorithm. :param f: The function to be differentiated. diff --git a/bayesflow/utils/jvp.py b/bayesflow/utils/jvp.py index 084929937..1f25fb3a2 100644 --- a/bayesflow/utils/jvp.py +++ b/bayesflow/utils/jvp.py @@ -1,27 +1,42 @@ +from collections.abc import Callable import keras from bayesflow.types import Tensor -def jvp(fn: callable, primals: tuple[Tensor] | Tensor, tangents: tuple[Tensor] | Tensor): - """Compute the dot product between the Jacobian of the given function at the point given by - the input (primals) and vectors in tangents.""" +def jvp(fn: Callable, primals: Tensor | tuple[Tensor, ...], tangents: Tensor | tuple[Tensor, ...]) -> (any, Tensor): + """ + Backend-agnostic version of the Jacobian-vector product (jvp). + Compute the Jacobian-vector product of the given function at the point given by the input (primals). + + :param fn: The function to differentiate. + Signature and return value must be compatible with the vjp method of the backend in use. + + :param primals: Input tensors to `fn`. + + :param tangents: Tangent vectors to differentiate `fn` with respect to. + + :return: The output of `fn(*primals)` and the Jacobian-vector product of `fn` evaluated at `primals` with respect to + `tangents`. + """ match keras.backend.backend(): case "jax": import jax - fn_output, _jvp = jax.jvp(fn, primals, tangents) + fx, _jvp = jax.jvp(fn, primals, tangents) case "torch": import torch - fn_output, _jvp = torch.func.jvp(fn, primals, tangents) + fx, _jvp = torch.func.jvp(fn, primals, tangents) case "tensorflow": import tensorflow as tf - with tf.autodiff.ForwardAccumulator(primals=primals, tangents=tangents) as acc: - fn_output = fn(*primals) - _jvp = acc.jvp(fn_output) + with tf.autodiff.ForwardAccumulator(primals, tangents) as acc: + fx = fn(*primals) + + _jvp = acc.jvp(fx) case _: raise NotImplementedError(f"JVP not implemented for backend {keras.backend.backend()}") - return fn_output, _jvp + + return fx, _jvp diff --git a/bayesflow/utils/vjp.py b/bayesflow/utils/vjp.py index 279113e80..435c46334 100644 --- a/bayesflow/utils/vjp.py +++ b/bayesflow/utils/vjp.py @@ -1,34 +1,42 @@ +from collections.abc import Callable import keras +from functools import partial from bayesflow.types import Tensor -def vjp(f: callable, x: Tensor) -> (Tensor, callable): +def vjp(fn: Callable, *primals: Tensor) -> (any, Callable[[Tensor], tuple[Tensor, ...]]): + """ + Backend-agnostic version of the vector-Jacobian product (vjp). + Computes the vector-Jacobian product of the given function at the point given by the input (primals). + + :param fn: The function to differentiate. + Signature and return value must be compatible with the vjp method of the backend in use. + + :param primals: Input tensors to `fn`. + + :return: The output of `fn(*primals)` and a vjp function. + The vjp function takes a single tensor argument, and returns the vector-Jacobian product of this argument with + `fn` as evaluated at `primals`. + """ match keras.backend.backend(): case "jax": import jax - fx, _vjp_fn = jax.vjp(f, x) - - def vjp_fn(projector): - return _vjp_fn(projector)[0] + fx, vjp_fn = jax.vjp(fn, *primals) case "torch": import torch - fx, _vjp_fn = torch.func.vjp(f, x) - - def vjp_fn(projector): - return _vjp_fn(projector)[0] + fx, vjp_fn = torch.func.vjp(fn, *primals) case "tensorflow": import tensorflow as tf with tf.GradientTape(persistent=True) as tape: - tape.watch(x) - fx = f(x) - - def vjp_fn(projector): - return tape.gradient(fx, x, projector) - case other: - raise NotImplementedError(f"Cannot build a vjp function for backend '{other}'.") + for p in primals: + tape.watch(p) + fx = fn(*primals) + vjp_fn = partial(tape.gradient, fx, primals) + case _: + raise NotImplementedError(f"VJP not implemented for backend {keras.backend.backend()}") return fx, vjp_fn From 6d47ede9b39ea9281596d3f1a8c3bc4d698ccc96 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 21 Nov 2024 12:40:35 +0100 Subject: [PATCH 07/13] fix vjp call in fff --- bayesflow/networks/free_form_flow/free_form_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index 1ea321bb3..9ea41c99d 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -158,7 +158,7 @@ def decode(z): # VJP computation z, vjp_fn = vjp(encode, x) - v1 = vjp_fn(v) + v1 = vjp_fn(v)[0] # JVP computation x_pred, v2 = jvp(decode, (z,), (v,)) From cdb61a45a1e5526ea466d36a84c10135e4c139ac Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 21 Nov 2024 12:40:49 +0100 Subject: [PATCH 08/13] add fff to tests, remove flow matching from global tests --- tests/conftest.py | 9 +-------- tests/test_networks/conftest.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 738fbd539..d7933545e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,14 +46,7 @@ def feature_size(request): return request.param -@pytest.fixture(scope="function") -def flow_matching(): - from bayesflow.networks import FlowMatching - - return FlowMatching(subnet="mlp", subnet_kwargs=dict(widths=(32, 32))) - - -@pytest.fixture(params=["coupling_flow", "flow_matching"], scope="function") +@pytest.fixture(params=["coupling_flow"], scope="function") def inference_network(request): return request.getfixturevalue(request.param) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 6eff398f5..62796f11b 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -8,6 +8,32 @@ def deep_set(): return DeepSet() +@pytest.fixture() +def coupling_flow(): + from bayesflow.networks import CouplingFlow + + return CouplingFlow() + + +@pytest.fixture() +def flow_matching(): + from bayesflow.networks import FlowMatching + + return FlowMatching() + + +@pytest.fixture() +def free_form_flow(): + from bayesflow.networks import FreeFormFlow + + return FreeFormFlow() + + +@pytest.fixture(params=["coupling_flow", "flow_matching", "free_form_flow"], scope="function") +def inference_network(request): + return request.getfixturevalue(request.param) + + @pytest.fixture() def lst_net(): from bayesflow.networks import LSTNet From ddc24f4c034dfffc27bdee148b806141280a3510 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 21 Nov 2024 12:41:01 +0100 Subject: [PATCH 09/13] fix default kwargs for fff subnets --- bayesflow/networks/free_form_flow/free_form_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index 9ea41c99d..0c02ff003 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -53,9 +53,9 @@ def __init__( Additional keyword arguments """ super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) - self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs")) + self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs", {})) self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") - self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs")) + self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs", {})) self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") self.hutchinson_sampling = hutchinson_sampling From a945341875e1b75bc63bd1a82a834b446bd45bac Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 21 Nov 2024 12:42:58 +0100 Subject: [PATCH 10/13] remove double source attribution --- bayesflow/utils/jacobian.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bayesflow/utils/jacobian.py b/bayesflow/utils/jacobian.py index cf7e11c0c..3dc87aae2 100644 --- a/bayesflow/utils/jacobian.py +++ b/bayesflow/utils/jacobian.py @@ -1,4 +1,3 @@ -# adapted from https://github.com/vislearn/FFF import keras from keras import ops from bayesflow.types import Tensor From ed1f0f9668e70b12780cb63385d1552adf1de160 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 21 Nov 2024 12:43:44 +0100 Subject: [PATCH 11/13] improve type hints --- bayesflow/utils/jacobian.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/jacobian.py b/bayesflow/utils/jacobian.py index 3dc87aae2..838ee79f6 100644 --- a/bayesflow/utils/jacobian.py +++ b/bayesflow/utils/jacobian.py @@ -1,3 +1,4 @@ +from collections.abc import Callable import keras from keras import ops from bayesflow.types import Tensor @@ -5,7 +6,7 @@ from functools import partial, wraps -def batch_wrap(fn: callable) -> callable: +def batch_wrap(fn: Callable) -> Callable: """Add a batch dimension to each tensor argument. :param fn: @@ -38,7 +39,7 @@ def wrapper(*args, **kwargs): def compute_jacobian( x_in: Tensor, - fn: callable, + fn: Callable, *func_args: any, grad_type: str = "backward", **func_kwargs: any, @@ -98,7 +99,7 @@ def compute_jacobian( def log_jacobian_determinant( x_in: Tensor, - fn: callable, + fn: Callable, *func_args: any, grad_type: str = "backward", **func_kwargs: any, From 6d6b0b58d77f92155f4af8bd2f2d54fb55122730 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 21 Nov 2024 12:16:33 +0000 Subject: [PATCH 12/13] adjust batch_wrap to handle non-iterable arguments --- bayesflow/utils/jacobian.py | 61 ++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/bayesflow/utils/jacobian.py b/bayesflow/utils/jacobian.py index 838ee79f6..b48da8602 100644 --- a/bayesflow/utils/jacobian.py +++ b/bayesflow/utils/jacobian.py @@ -6,37 +6,6 @@ from functools import partial, wraps -def batch_wrap(fn: Callable) -> Callable: - """Add a batch dimension to each tensor argument. - - :param fn: - :return: wrapped function""" - - def deep_unsqueeze(arg): - if isinstance(arg, dict): - return {key: deep_unsqueeze(value) for key, value in arg.items()} - elif isinstance(arg, (list, tuple)): - return [deep_unsqueeze(value) for value in arg] - else: - return arg[None, ...] - - @wraps(fn) - def wrapper(*args, **kwargs): - args = deep_unsqueeze(args) - return fn(*args, **kwargs)[0] - - return wrapper - - -def double_output(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - out = fn(*args, **kwargs) - return out, out - - return wrapper - - def compute_jacobian( x_in: Tensor, fn: Callable, @@ -60,6 +29,36 @@ def compute_jacobian( of the function with respect to its input `x` of shape (batch_size, out_dim, in_dim).""" + def batch_wrap(fn: Callable) -> Callable: + """Add a batch dimension to each tensor argument. + + :param fn: + :return: wrapped function""" + + def deep_unsqueeze(arg): + if ops.is_tensor(arg): + return arg[None, ...] + elif isinstance(arg, dict): + return {key: deep_unsqueeze(value) for key, value in arg.items()} + elif isinstance(arg, (list, tuple)): + return [deep_unsqueeze(value) for value in arg] + return arg + + @wraps(fn) + def wrapper(*args, **kwargs): + args = deep_unsqueeze(args) + return fn(*args, **kwargs)[0] + + return wrapper + + def double_output(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + out = fn(*args, **kwargs) + return out, out + + return wrapper + match keras.backend.backend(): case "torch": import torch From 6109d5262d43ec4efd6beee0edeea3301721be86 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 21 Nov 2024 13:46:56 +0000 Subject: [PATCH 13/13] fff: handle conditions=None --- .../networks/free_form_flow/free_form_flow.py | 15 +++++++++++++-- bayesflow/utils/jacobian.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index 0c02ff003..c893d7df8 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -90,7 +90,13 @@ def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: if density: - z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs) + if conditions is None: + # None cannot be batched, so supply as keyword argument + z, log_det = log_jacobian_determinant(x, self.encode, conditions=None, training=training, **kwargs) + else: + # conditions should be batched, supply as positional argument + z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs) + log_density = self.base_distribution.log_prob(z) + log_det return z, log_density @@ -101,7 +107,12 @@ def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: if density: - x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs) + if conditions is None: + # None cannot be batched, so supply as keyword argument + x, log_det = log_jacobian_determinant(z, self.decode, conditions=None, training=training, **kwargs) + else: + # conditions should be batched, supply as positional argument + x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs) log_density = self.base_distribution.log_prob(z) - log_det return x, log_density diff --git a/bayesflow/utils/jacobian.py b/bayesflow/utils/jacobian.py index b48da8602..830ef6e01 100644 --- a/bayesflow/utils/jacobian.py +++ b/bayesflow/utils/jacobian.py @@ -42,7 +42,7 @@ def deep_unsqueeze(arg): return {key: deep_unsqueeze(value) for key, value in arg.items()} elif isinstance(arg, (list, tuple)): return [deep_unsqueeze(value) for value in arg] - return arg + raise ValueError(f"Argument cannot be batched: {arg}") @wraps(fn) def wrapper(*args, **kwargs):