Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions bayesflow/experimental/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
layer_kwargs,
weighted_mean,
integrate,
integrate_stochastic,
)


Expand Down Expand Up @@ -373,7 +374,7 @@ class DiffusionModel(InferenceNetwork):
}

INTEGRATE_DEFAULT_CONFIG = {
"method": "euler",
"method": "euler", # or euler_maruyama
"steps": 100,
}

Expand Down Expand Up @@ -529,6 +530,7 @@ def velocity(
time: float | Tensor,
conditions: Tensor = None,
training: bool = False,
stochastic_solver: bool = False,
clip_x: bool = False,
) -> Tensor:
# calculate the current noise level and transform into correct shape
Expand All @@ -548,13 +550,30 @@ def velocity(
# convert x to score
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)

# compute velocity for the ODE depending on the noise schedule
# compute velocity f, g of the SDE or ODE
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz)
out = f - 0.5 * g_squared * score

# todo: for the SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
if stochastic_solver:
# for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW
out = f - g_squared * score
else:
# for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt
out = f - 0.5 * g_squared * score

return out

def compute_diffusion_term(
self,
xz: Tensor,
time: float | Tensor,
training: bool = False,
) -> Tensor:
# calculate the current noise level and transform into correct shape
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
log_snr_t = keras.ops.broadcast_to(log_snr_t, keras.ops.shape(xz)[:-1] + (1,))
g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t)
return ops.sqrt(g_squared)

def _velocity_trace(
self,
xz: Tensor,
Expand Down Expand Up @@ -586,6 +605,9 @@ def _forward(
| self.integrate_kwargs
| kwargs
)
if integrate_kwargs["method"] == "euler_maruyama":
raise ValueError("Stoachastic methods are not supported for forward integration.")

if density:

def deltas(time, xz):
Expand Down Expand Up @@ -636,6 +658,8 @@ def _inverse(
| kwargs
)
if density:
if integrate_kwargs["method"] == "euler_maruyama":
raise ValueError("Stoachastic methods are not supported for density computation.")

def deltas(time, xz):
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
Expand All @@ -656,11 +680,23 @@ def deltas(time, xz):
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}

state = {"xz": z}
state = integrate(
deltas,
state,
**integrate_kwargs,
)
if integrate_kwargs["method"] == "euler_maruyama":

def diffusion(time, xz):
return {"xz": self.compute_diffusion_term(xz, time=time, training=training)}

state = integrate_stochastic(
deltas,
diffusion,
state,
**integrate_kwargs,
)
else:
state = integrate(
deltas,
state,
**integrate_kwargs,
)

x = state["xz"]
return x
Expand Down
4 changes: 1 addition & 3 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
repo_url,
)
from .hparam_utils import find_batch_size, find_memory_budget
from .integrate import (
integrate,
)
from .integrate import integrate, integrate_stochastic
from .io import (
pickle_load,
format_bytes,
Expand Down
158 changes: 157 additions & 1 deletion bayesflow/utils/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import keras

import numpy as np
from typing import Literal
from typing import Literal, Union, List

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
Expand Down Expand Up @@ -293,3 +293,159 @@ def integrate(
return integrate_scheduled(fn, state, steps, method, **kwargs)
else:
raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})")


def euler_maruyama_step(
drift_fn: Callable,
diffusion_fn: Callable,
state: dict[str, ArrayLike],
time: ArrayLike,
step_size: ArrayLike,
noise: dict[str, ArrayLike] = None,
tolerance: ArrayLike = 1e-6,
min_step_size: ArrayLike = -float("inf"),
max_step_size: ArrayLike = float("inf"),
use_adaptive_step_size: bool = False,
) -> (dict[str, ArrayLike], ArrayLike, ArrayLike):
"""
Performs a single Euler-Maruyama step for stochastic differential equations.

Args:
drift_fn: Function that computes the drift term.
diffusion_fn: Function that computes the diffusion term.
state: Dictionary containing the current state.
time: Current time.
step_size: Size of the integration step.
noise: Dictionary of noise terms for each state variable.
tolerance: Error tolerance for adaptive step size.
min_step_size: Minimum allowed step size.
max_step_size: Maximum allowed step size.
use_adaptive_step_size: Whether to use adaptive step sizing.

Returns:
Tuple of (new_state, new_time, new_step_size).
"""
# Compute drift term
drift = drift_fn(time, **filter_kwargs(state, drift_fn))

# Compute diffusion term
diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn))

# Generate noise if not provided
if noise is None:
noise = {}
for key in diffusion.keys():
shape = keras.ops.shape(diffusion[key])
noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size))

# Check if diffusion and noise have the same keys
if set(diffusion.keys()) != set(noise.keys()):
raise ValueError("Keys of diffusion terms and noise do not match.")

if use_adaptive_step_size:
# Perform a half-step to estimate error
intermediate_state = state.copy()
for key in drift.keys():
intermediate_state[key] = state[key] + (step_size * drift[key]) + (diffusion[key] * noise[key])

# Compute drift and diffusion at intermediate state
intermediate_drift = drift_fn(time + step_size, **filter_kwargs(intermediate_state, drift_fn))

# Compute error estimate
error_terms = []
for key in drift.keys():
error = keras.ops.norm(intermediate_drift[key] - drift[key], ord=2, axis=-1)
error_terms.append(error)

intermediate_error = keras.ops.stack(error_terms)
new_step_size = step_size * tolerance / (intermediate_error + 1e-9)

# Apply constraints to step size
new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size)

# Consolidate step size
new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size)))
else:
new_step_size = step_size

# Apply updates using Euler-Maruyama formula: dx = f(x)dt + g(x)dW
new_state = state.copy()
for key in drift.keys():
if key in diffusion:
new_state[key] = state[key] + (step_size * drift[key]) + (diffusion[key] * noise[key])
else:
# If no diffusion term for this variable, apply deterministic update
new_state[key] = state[key] + step_size * drift[key]

new_time = time + step_size

return new_state, new_time, new_step_size


def integrate_stochastic(
drift_fn: Callable,
diffusion_fn: Callable,
state: dict[str, ArrayLike],
start_time: ArrayLike,
stop_time: ArrayLike,
steps: int,
method: str = "euler_maruyama",
seed: int = None,
**kwargs,
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, List[ArrayLike]]]]:
"""
Integrates a stochastic differential equation from start_time to stop_time.

Args:
drift_fn: Function that computes the drift term.
diffusion_fn: Function that computes the diffusion term.
state: Dictionary containing the initial state.
start_time: Starting time for integration.
stop_time: Ending time for integration.
steps: Number of integration steps.
method: Integration method to use ('euler_maruyama').
seed: Random seed for noise generation.
**kwargs: Additional arguments to pass to the step function.

Returns:
If return_noise is False, returns the final state dictionary.
If return_noise is True, returns a tuple of (final_state, noise_history).
"""
if steps <= 0:
raise ValueError("Number of steps must be positive.")

# Set random seed if provided
if seed is not None:
keras.random.set_seed(seed)

# Select step function based on method
match method:
case "euler_maruyama":
step_fn = euler_maruyama_step
case str() as name:
raise ValueError(f"Unknown integration method name: {name!r}")
case other:
raise TypeError(f"Invalid integration method: {other!r}")

# Prepare step function with partial application
step_fn = partial(step_fn, drift_fn, diffusion_fn, **kwargs)
step_size = (stop_time - start_time) / steps

time = start_time

def body(_loop_var, _loop_state):
_state, _time = _loop_state

# Generate noise for this step
_noise = {}
for key in _state.keys():
shape = keras.ops.shape(_state[key])
_noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size))

# Perform integration step
_state, _time, _ = step_fn(_state, _time, step_size, noise=_noise)

return _state, _time

state, time = keras.ops.fori_loop(0, steps, body, (state, time))
return state