From 7393927212f6b0f7c9e2cb8453469f0736a3e10a Mon Sep 17 00:00:00 2001 From: Valentin De Bortoli Date: Tue, 5 May 2026 01:45:24 -0700 Subject: [PATCH] Add tests for VelocityStep in gaussian_step_sampler_test.py PiperOrigin-RevId: 910521018 --- .../lib/sampling/gaussian_step_sampler.py | 19 +- .../sampling/gaussian_step_sampler_test.py | 248 +++++++++++++++++- 2 files changed, 260 insertions(+), 7 deletions(-) diff --git a/hackable_diffusion/lib/sampling/gaussian_step_sampler.py b/hackable_diffusion/lib/sampling/gaussian_step_sampler.py index 2ac1a42..416fb9b 100644 --- a/hackable_diffusion/lib/sampling/gaussian_step_sampler.py +++ b/hackable_diffusion/lib/sampling/gaussian_step_sampler.py @@ -354,12 +354,19 @@ def finalize( class VelocityStep(SamplerStep): """DDIM sampler from https://arxiv.org/abs/2010.02502. - epsilon controls the interpolation between DDIM and DDPM: - epsilon = 0.0 gives (deterministic) DDIM and epsilon = 1.0 gives DDPM. + stoch_coeff controls the interpolation between DDIM and DDPM: + stoch_coeff = 0.0 gives the discretisation of an ODE (as in Flow Matching) and + stoch_coeff = 1.0 gives the discretisation of an SDE. + + Attributes: + corruption_process: The corruption process to use. + stoch_coeff: The interpolation parameter between DDIM and DDPM. + stochastic_last_step: Whether the last step is stochastic. """ corruption_process: GaussianProcess - epsilon: float + stoch_coeff: float + stochastic_last_step: bool = False @kt.typechecked def initialize( @@ -402,9 +409,9 @@ def update( score = prediction_dict["score"] z = jax.random.normal(key=next_step_info.rng, shape=xt.shape) - delta = -velocity + 0.5 * self.epsilon**2 * g**2 * score + delta = -velocity + 0.5 * self.stoch_coeff**2 * g**2 * score new_mean = xt + delta * dt - volatility = jnp.sqrt(dt) * g * self.epsilon + volatility = jnp.sqrt(dt) * g * self.stoch_coeff if stochastic: new_xt = new_mean + volatility * z @@ -428,7 +435,7 @@ def finalize( prediction, current_step, last_step_info, - stochastic=False, + stochastic=self.stochastic_last_step, ) diff --git a/hackable_diffusion/lib/sampling/gaussian_step_sampler_test.py b/hackable_diffusion/lib/sampling/gaussian_step_sampler_test.py index 6d1bf0b..0d783dc 100644 --- a/hackable_diffusion/lib/sampling/gaussian_step_sampler_test.py +++ b/hackable_diffusion/lib/sampling/gaussian_step_sampler_test.py @@ -29,7 +29,6 @@ from absl.testing import absltest from absl.testing import parameterized - ################################################################################ # MARK: Type Aliases ################################################################################ @@ -110,6 +109,30 @@ def _ddim_update( return new_mean, volatility +def _velocity_update( + xt: jnp.ndarray, + prediction: TargetInfo, + time: jnp.ndarray, + next_time: jnp.ndarray, + stoch_coeff: float, + process: GaussianProcess, +) -> tuple[jnp.ndarray, jnp.ndarray]: + """Helper function to compute the Velocity update.""" + g = process.schedule.g(time) + prediction_dict = process.convert_predictions( + prediction=prediction, + xt=xt, + time=time, + ) + velocity = prediction_dict['velocity'] + score = prediction_dict['score'] + dt = time - next_time + delta = -velocity + 0.5 * jnp.square(stoch_coeff) * jnp.square(g) * score + mean = xt + delta * dt + volatility = jnp.sqrt(dt) * g * stoch_coeff + return mean, volatility + + ################################################################################ # MARK: Test constants ################################################################################ @@ -699,6 +722,229 @@ def test_update_specific_parameters(self): ) +class VelocityStepTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + + self.schedule = schedules.RFSchedule() + self.process = gaussian.GaussianProcess(schedule=self.schedule) + self.initial_noise = jnp.expand_dims(jnp.eye(4), axis=0) + + @parameterized.parameters( + itertools.product(_STOCHASTICITY_LEVELS, _USE_STOCHASTIC_LAST_STEP) + ) + def test_initialize(self, stoch_coeff, use_stochastic_last_step): + + velocity_step = gaussian_step_sampler.VelocityStep( + corruption_process=self.process, + stoch_coeff=stoch_coeff, + stochastic_last_step=use_stochastic_last_step, + ) + + initial_step = velocity_step.initialize( + initial_noise=self.initial_noise, + initial_step_info=StepInfo( + step=0, + time=jnp.array([0.0]), + rng=jax.random.PRNGKey(0), + ), + ) + + chex.assert_trees_all_equal( + initial_step, + DiffusionStep( + xt=jnp.expand_dims(jnp.eye(4), axis=0), + step_info=StepInfo( + step=0, + time=jnp.array([0.0]), + rng=jax.random.PRNGKey(0), + ), + aux=dict(), + ), + ) + + @parameterized.parameters( + itertools.product(_STOCHASTICITY_LEVELS, _USE_STOCHASTIC_LAST_STEP) + ) + def test_update(self, stoch_coeff, use_stochastic_last_step): + + velocity_step = gaussian_step_sampler.VelocityStep( + corruption_process=self.process, + stoch_coeff=stoch_coeff, + stochastic_last_step=use_stochastic_last_step, + ) + + initial_step = velocity_step.initialize( + initial_noise=self.initial_noise, + initial_step_info=StepInfo( + step=0, + time=jnp.array([0.2]), + rng=jax.random.PRNGKey(0), + ), + ) + + prediction = dummy_inference_fn( + xt=initial_step.xt, + conditioning={}, + time=initial_step.step_info.time, + ) + + next_step = velocity_step.update( + prediction=prediction, + current_step=initial_step, + next_step_info=StepInfo( + step=1, time=jnp.array([0.1]), rng=jax.random.PRNGKey(1) + ), + ) + + z = jax.random.normal( + key=next_step.step_info.rng, shape=initial_step.xt.shape + ) + + mean, volatility = _velocity_update( + xt=initial_step.xt, + prediction=prediction, + time=initial_step.step_info.time, + next_time=next_step.step_info.time, + stoch_coeff=stoch_coeff, + process=self.process, + ) + expected_xt = mean + volatility * z + + chex.assert_trees_all_close( + next_step, + DiffusionStep( + xt=expected_xt, + step_info=StepInfo( + step=1, + time=jnp.array([0.1]), + rng=jax.random.PRNGKey(1), + ), + aux={}, + ), + atol=1e-6, + ) + + @parameterized.parameters( + itertools.product(_STOCHASTICITY_LEVELS, _USE_STOCHASTIC_LAST_STEP) + ) + def test_finalize(self, stoch_coeff, use_stochastic_last_step): + velocity_step = gaussian_step_sampler.VelocityStep( + corruption_process=self.process, + stoch_coeff=stoch_coeff, + stochastic_last_step=use_stochastic_last_step, + ) + + initial_step = velocity_step.initialize( + initial_noise=self.initial_noise, + initial_step_info=StepInfo( + step=0, + time=jnp.array([0.2]), + rng=jax.random.PRNGKey(0), + ), + ) + + prediction = dummy_inference_fn( + xt=initial_step.xt, + conditioning={}, + time=initial_step.step_info.time, + ) + + final_step = velocity_step.finalize( + prediction=prediction, + current_step=initial_step, + last_step_info=StepInfo( + step=1, time=jnp.array([0.1]), rng=jax.random.PRNGKey(1) + ), + ) + + z = jax.random.normal( + key=final_step.step_info.rng, shape=initial_step.xt.shape + ) + + mean, volatility = _velocity_update( + xt=initial_step.xt, + prediction=prediction, + time=initial_step.step_info.time, + next_time=final_step.step_info.time, + stoch_coeff=stoch_coeff, + process=self.process, + ) + + if use_stochastic_last_step: + expected_xt = mean + volatility * z + else: + expected_xt = mean + + chex.assert_trees_all_close( + final_step, + DiffusionStep( + xt=expected_xt, + step_info=StepInfo( + step=1, + time=jnp.array([0.1]), + rng=jax.random.PRNGKey(1), + ), + aux={}, + ), + atol=1e-6, + ) + + def test_update_specific_parameters(self): + + velocity_step = gaussian_step_sampler.VelocityStep( + corruption_process=self.process, + stoch_coeff=0.25, + stochastic_last_step=False, + ) + + initial_step = velocity_step.initialize( + initial_noise=self.initial_noise, + initial_step_info=StepInfo( + step=0, + time=jnp.array([0.2]), + rng=jax.random.PRNGKey(0), + ), + ) + + prediction = dummy_inference_fn( + xt=initial_step.xt, + conditioning={}, + time=initial_step.step_info.time, + ) + + next_step = velocity_step.update( + prediction=prediction, + current_step=initial_step, + next_step_info=StepInfo( + step=1, time=jnp.array([0.1]), rng=jax.random.PRNGKey(1) + ), + ) + + chex.assert_trees_all_close( + next_step, + DiffusionStep( + xt=jnp.array( + [[ + [0.983554, 0.004735, -0.007602, -0.008667], + [0.070809, 1.000478, 0.119717, 0.056051], + [-0.01623, 0.020032, 0.952613, -0.013727], + [0.049506, 0.043945, 0.049693, 1.022896], + ]], + dtype=jnp.float32, + ), + step_info=StepInfo( + step=1, + time=jnp.array([0.1]), + rng=jax.random.PRNGKey(1), + ), + aux={}, + ), + atol=1e-6, + ) + + class HeunStepTest(absltest.TestCase): def setUp(self):