-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Our integration suite currently lacks an adaptive method that allows specifying tolerances. Having control over the error is especially important when the results are used in comparisons (e.g., the log-prob in diffusion models and flow matching).
As they would mainly be used for evaluation, they would not necessarily need to be differentiable, which would allow us to use third-party integrators as well. I think SciPy's scipy.integrate.solve_ivp would be the most obvious choice for this. An implementation could look approximately like this (with more careful flattening/reshaping, so that it works in a general setting):
import numpy as np
import bayesflow as bf
from collections.abc import Callable
import scipy.integrate
def integrate_scipy(
fn: Callable,
state: dict,
start_time: np.ndarray,
stop_time: np.ndarray,
method: str = "RK45",
atol: float = 1e-6,
rtol: float = 1e-3,
**kwargs,
):
adapter = (
bf.Adapter()
.concatenate(list(state.keys()), into="x", axis=-1)
.convert_dtype(np.float32, np.float64)
)
initial_state = adapter.forward(state)["x"]
shape = initial_state.shape
def scipy_wrapper_fn(time, x):
state = adapter.inverse({"x": x.reshape(shape)})
state = keras.tree.map_structure(keras.ops.convert_to_tensor, state)
time = keras.ops.convert_to_tensor(time, dtype="float32")
deltas = fn(time, **bf.utils.filter_kwargs(state, fn))
deltas = keras.tree.map_structure(keras.ops.convert_to_numpy, deltas)
return adapter.forward(deltas)["x"].reshape(-1)
res = scipy.integrate.solve_ivp(
scipy_wrapper_fn,
(start_time, stop_time),
initial_state.reshape(-1),
method=method,
atol=atol,
rtol=rtol,
)
return adapter.inverse({"x": res.y[:,-1].reshape(shape)})Regarding the interface, we could allow passing the method "scipy" and allow passing kwargs to scipy.integrate.solve_ivp.
Tagging @LarsKue and @stefanradev93. Would you welcome such a change, and if so, what would be your preferred interface?
Edit: For a trained approximator using a diffusion model or flow matching, you can test it like this:
approximator.inference_network.integrate_kwargs['steps'] = 100
approximator.inference_network.integrate_kwargs['atol'] = 1e-5
approximator.inference_network.integrate_kwargs['rtol'] = 1e-5
approximator.inference_network.integrate_kwargs['method'] = "RK45"
bf.networks.diffusion_model.diffusion_model.integrate = integrate_scipy
bf.networks.flow_matching.flow_matching.integrate = integrate_scipy
log_prob = approximator.log_prob(data=dataset)