-
Notifications
You must be signed in to change notification settings - Fork 78
Add free-form flows as inference networks #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cfd2e45
d4a9644
938a3b8
e5a6667
e75592d
95c1d2f
7bc52a8
6d47ede
cdb61a4
ddc24f4
a945341
ed1f0f9
6d6b0b5
6109d52
2d34927
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .free_form_flow import FreeFormFlow |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| import keras | ||
| from keras import ops | ||
| from keras.saving import register_keras_serializable as serializable | ||
|
|
||
| from bayesflow.types import Tensor | ||
| from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp | ||
|
|
||
| from ..inference_network import InferenceNetwork | ||
|
|
||
|
|
||
| @serializable(package="networks.free_form_flow") | ||
| class FreeFormFlow(InferenceNetwork): | ||
| """Implements a dimensionality-preserving Free-form Flow. | ||
| Incorporates ideas from [1-2]. | ||
|
|
||
| [1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F | ||
| ree-form flows: Make Any Architecture a Normalizing Flow. | ||
| In International Conference on Artificial Intelligence and Statistics. | ||
|
|
||
| [2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., & | ||
| Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows. | ||
| In International Conference on Learning Representations. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| beta: float = 50.0, | ||
| encoder_subnet: str | type = "mlp", | ||
| decoder_subnet: str | type = "mlp", | ||
| base_distribution: str = "normal", | ||
| hutchinson_sampling: str = "qr", | ||
| **kwargs, | ||
| ): | ||
| """Creates an instance of a Free-form Flow. | ||
|
|
||
| Parameters: | ||
| ----------- | ||
| beta : float, optional, default: 50.0 | ||
| encoder_subnet : str or type, optional, default: "mlp" | ||
| A neural network type for the flow, will be instantiated using | ||
| encoder_subnet_kwargs. Will be equipped with a projector to ensure | ||
| the correct output dimension and a global skip connection. | ||
| decoder_subnet : str or type, optional, default: "mlp" | ||
| A neural network type for the flow, will be instantiated using | ||
| decoder_subnet_kwargs. Will be equipped with a projector to ensure | ||
| the correct output dimension and a global skip connection. | ||
| base_distribution : str, optional, default: "normal" | ||
| The latent distribution | ||
| hutchinson_sampling : str, optional, default: "qr | ||
| One of `["sphere", "qr"]`. Select the sampling scheme for the | ||
| vectors of the Hutchinson trace estimator. | ||
| **kwargs : dict, optional, default: {} | ||
| Additional keyword arguments | ||
| """ | ||
| super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) | ||
| self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs", {})) | ||
| self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") | ||
| self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs", {})) | ||
| self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") | ||
|
|
||
| self.hutchinson_sampling = hutchinson_sampling | ||
| self.beta = beta | ||
|
|
||
| self.seed_generator = keras.random.SeedGenerator() | ||
|
|
||
| # noinspection PyMethodOverriding | ||
| def build(self, xz_shape, conditions_shape=None): | ||
| super().build(xz_shape) | ||
| self.encoder_projector.units = xz_shape[-1] | ||
| self.decoder_projector.units = xz_shape[-1] | ||
|
|
||
| # construct input shape for subnet and subnet projector | ||
| input_shape = list(xz_shape) | ||
|
|
||
| if conditions_shape is not None: | ||
| input_shape[-1] += conditions_shape[-1] | ||
|
|
||
| input_shape = tuple(input_shape) | ||
|
|
||
| self.encoder_subnet.build(input_shape) | ||
| self.decoder_subnet.build(input_shape) | ||
|
|
||
| input_shape = self.encoder_subnet.compute_output_shape(input_shape) | ||
| self.encoder_projector.build(input_shape) | ||
|
|
||
| input_shape = self.decoder_subnet.compute_output_shape(input_shape) | ||
| self.decoder_projector.build(input_shape) | ||
|
|
||
| def _forward( | ||
| self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs | ||
| ) -> Tensor | tuple[Tensor, Tensor]: | ||
| if density: | ||
| if conditions is None: | ||
| # None cannot be batched, so supply as keyword argument | ||
| z, log_det = log_jacobian_determinant(x, self.encode, conditions=None, training=training, **kwargs) | ||
| else: | ||
| # conditions should be batched, supply as positional argument | ||
| z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs) | ||
|
|
||
| log_density = self.base_distribution.log_prob(z) + log_det | ||
| return z, log_density | ||
|
|
||
| z = self.encode(x, conditions, training=training, **kwargs) | ||
| return z | ||
|
|
||
| def _inverse( | ||
| self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs | ||
| ) -> Tensor | tuple[Tensor, Tensor]: | ||
| if density: | ||
| if conditions is None: | ||
| # None cannot be batched, so supply as keyword argument | ||
| x, log_det = log_jacobian_determinant(z, self.decode, conditions=None, training=training, **kwargs) | ||
| else: | ||
| # conditions should be batched, supply as positional argument | ||
| x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs) | ||
| log_density = self.base_distribution.log_prob(z) - log_det | ||
| return x, log_density | ||
|
|
||
| x = self.decode(z, conditions, training=training, **kwargs) | ||
| return x | ||
|
|
||
| def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: | ||
| if conditions is None: | ||
| inp = x | ||
| else: | ||
| inp = concatenate(x, conditions, axis=-1) | ||
| network_out = self.encoder_projector( | ||
| self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer this non-nested for better errors, but again, not a major issue. |
||
| ) | ||
| return network_out + x | ||
|
|
||
| def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: | ||
| if conditions is None: | ||
| inp = z | ||
| else: | ||
| inp = concatenate(z, conditions, axis=-1) | ||
| network_out = self.decoder_projector( | ||
| self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs | ||
| ) | ||
| return network_out + z | ||
|
|
||
| def _sample_v(self, x): | ||
| batch_size = ops.shape(x)[0] | ||
| total_dim = ops.shape(x)[-1] | ||
| match self.hutchinson_sampling: | ||
| case "qr": | ||
| # Use QR decomposition as described in [2] | ||
| v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator) | ||
| q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x)) | ||
| v = q * ops.sqrt(total_dim) | ||
| case "sphere": | ||
| # Sample from sphere with radius sqrt(total_dim), as implemented in [1] | ||
| v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator) | ||
| v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True)) | ||
| case _: | ||
| raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.") | ||
| return v | ||
|
|
||
| def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: | ||
| base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) | ||
| # sample random vector | ||
| v = self._sample_v(x) | ||
|
|
||
| def encode(x): | ||
| return self.encode(x, conditions, training=stage == "training") | ||
|
|
||
| def decode(z): | ||
| return self.decode(z, conditions, training=stage == "training") | ||
|
|
||
| # VJP computation | ||
| z, vjp_fn = vjp(encode, x) | ||
| v1 = vjp_fn(v)[0] | ||
| # JVP computation | ||
| x_pred, v2 = jvp(decode, (z,), (v,)) | ||
|
|
||
| # equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0] | ||
| surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1) | ||
| nll = -self.base_distribution.log_prob(z) | ||
| maximum_likelihood_loss = nll - surrogate | ||
| reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) | ||
| loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss) | ||
|
|
||
| return base_metrics | {"loss": loss} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| from collections.abc import Callable | ||
| import keras | ||
| from keras import ops | ||
| from bayesflow.types import Tensor | ||
|
|
||
| from functools import partial, wraps | ||
|
|
||
|
|
||
| def compute_jacobian( | ||
| x_in: Tensor, | ||
| fn: Callable, | ||
| *func_args: any, | ||
| grad_type: str = "backward", | ||
| **func_kwargs: any, | ||
| ) -> tuple[Tensor, Tensor]: | ||
| """Computes the Jacobian of a function with respect to its input. | ||
|
|
||
| :param x_in: The input tensor to compute the jacobian at. | ||
| Shape: (batch_size, in_dim). | ||
| :param fn: The function to compute the jacobian of, which transforms | ||
| `x` to `fn(x)` of shape (batch_size, out_dim). | ||
| :param func_args: The positional arguments to pass to the function. | ||
| func_args are batched over the first dimension. | ||
| :param grad_type: The type of gradient to use. Either 'backward' or | ||
| 'forward'. | ||
| :param func_kwargs: The keyword arguments to pass to the function. | ||
| func_kwargs are not batched. | ||
| :return: The output of the function `fn(x)` and the jacobian | ||
| of the function with respect to its input `x` of shape | ||
| (batch_size, out_dim, in_dim).""" | ||
|
|
||
| def batch_wrap(fn: Callable) -> Callable: | ||
| """Add a batch dimension to each tensor argument. | ||
|
|
||
| :param fn: | ||
| :return: wrapped function""" | ||
|
|
||
| def deep_unsqueeze(arg): | ||
| if ops.is_tensor(arg): | ||
| return arg[None, ...] | ||
| elif isinstance(arg, dict): | ||
| return {key: deep_unsqueeze(value) for key, value in arg.items()} | ||
| elif isinstance(arg, (list, tuple)): | ||
| return [deep_unsqueeze(value) for value in arg] | ||
| raise ValueError(f"Argument cannot be batched: {arg}") | ||
|
|
||
| @wraps(fn) | ||
| def wrapper(*args, **kwargs): | ||
| args = deep_unsqueeze(args) | ||
| return fn(*args, **kwargs)[0] | ||
|
|
||
| return wrapper | ||
|
|
||
| def double_output(fn): | ||
| @wraps(fn) | ||
| def wrapper(*args, **kwargs): | ||
| out = fn(*args, **kwargs) | ||
| return out, out | ||
|
|
||
| return wrapper | ||
|
|
||
| match keras.backend.backend(): | ||
| case "torch": | ||
| import torch | ||
| from torch.func import jacrev, jacfwd, vmap | ||
|
|
||
| jacfn = jacrev if grad_type == "backward" else jacfwd | ||
| with torch.inference_mode(False): | ||
| with torch.no_grad(): | ||
| fn_kwargs_prefilled = partial(fn, **func_kwargs) | ||
| fn_batch_expanded = batch_wrap(fn_kwargs_prefilled) | ||
| fn_return_val = double_output(fn_batch_expanded) | ||
| fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True)) | ||
| jac, x_out = fn_jac_batched(x_in, *func_args) | ||
| case "jax": | ||
| from jax import jacrev, jacfwd, vmap | ||
|
|
||
| jacfn = jacrev if grad_type == "backward" else jacfwd | ||
| fn_kwargs_prefilled = partial(fn, **func_kwargs) | ||
| fn_batch_expanded = batch_wrap(fn_kwargs_prefilled) | ||
| fn_return_val = double_output(fn_batch_expanded) | ||
| fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True)) | ||
| jac, x_out = fn_jac_batched(x_in, *func_args) | ||
| case "tensorflow": | ||
| if grad_type == "forward": | ||
| raise NotImplementedError("For TensorFlow, only backward mode Jacobian computation is available.") | ||
| import tensorflow as tf | ||
|
|
||
| with tf.GradientTape() as tape: | ||
| tape.watch(x_in) | ||
| x_out = fn(x_in, *func_args, **func_kwargs) | ||
| jac = tape.batch_jacobian(x_out, x_in) | ||
|
|
||
| case _: | ||
| raise NotImplementedError(f"compute_jacobian not implemented for {keras.backend.backend()}.") | ||
| return x_out, jac | ||
|
|
||
|
|
||
| def log_jacobian_determinant( | ||
| x_in: Tensor, | ||
| fn: Callable, | ||
| *func_args: any, | ||
| grad_type: str = "backward", | ||
| **func_kwargs: any, | ||
| ) -> tuple[Tensor, Tensor]: | ||
| """Computes the log Jacobian determinant of a function | ||
| with respect to its input. | ||
|
|
||
| :param x_in: The input tensor to compute the jacobian at. | ||
| Shape: (batch_size, in_dim). | ||
| :param fn: The function to compute the jacobian of, which transforms | ||
| `x` to `fn(x)` of shape (batch_size, out_dim). | ||
| :param func_args: The positional arguments to pass to the function. | ||
| func_args are batched over the first dimension. | ||
| :param grad_type: The type of gradient to use. Either 'backward' or | ||
| 'forward'. | ||
| :param func_kwargs: The keyword arguments to pass to the function. | ||
| func_kwargs are not batched. | ||
| :return: The output of the function `fn(x)` and the log jacobian determinant | ||
| of the function with respect to its input `x` of shape | ||
| (batch_size, out_dim, in_dim).""" | ||
|
|
||
| x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs) | ||
| jac = ops.reshape( | ||
| jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:]))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would prefer this in multiple lines |
||
| ) | ||
| log_det = ops.slogdet(jac)[1] | ||
|
|
||
| return x_out, log_det | ||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typeis not serializable out of the box, so we would need afrom_configmethod here. But we can add this later.