Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions hackable_diffusion/lib/sampling/gaussian_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,19 @@ def finalize(
class VelocityStep(SamplerStep):
"""DDIM sampler from https://arxiv.org/abs/2010.02502.

epsilon controls the interpolation between DDIM and DDPM:
epsilon = 0.0 gives (deterministic) DDIM and epsilon = 1.0 gives DDPM.
stoch_coeff controls the interpolation between DDIM and DDPM:
stoch_coeff = 0.0 gives the discretisation of an ODE (as in Flow Matching) and
stoch_coeff = 1.0 gives the discretisation of an SDE.

Attributes:
corruption_process: The corruption process to use.
stoch_coeff: The interpolation parameter between DDIM and DDPM.
stochastic_last_step: Whether the last step is stochastic.
"""

corruption_process: GaussianProcess
epsilon: float
stoch_coeff: float
stochastic_last_step: bool = False

@kt.typechecked
def initialize(
Expand Down Expand Up @@ -402,9 +409,9 @@ def update(
score = prediction_dict["score"]
z = jax.random.normal(key=next_step_info.rng, shape=xt.shape)

delta = -velocity + 0.5 * self.epsilon**2 * g**2 * score
delta = -velocity + 0.5 * self.stoch_coeff**2 * g**2 * score
new_mean = xt + delta * dt
volatility = jnp.sqrt(dt) * g * self.epsilon
volatility = jnp.sqrt(dt) * g * self.stoch_coeff

if stochastic:
new_xt = new_mean + volatility * z
Expand All @@ -428,7 +435,7 @@ def finalize(
prediction,
current_step,
last_step_info,
stochastic=False,
stochastic=self.stochastic_last_step,
)


Expand Down
248 changes: 247 additions & 1 deletion hackable_diffusion/lib/sampling/gaussian_step_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from absl.testing import absltest
from absl.testing import parameterized


################################################################################
# MARK: Type Aliases
################################################################################
Expand Down Expand Up @@ -110,6 +109,30 @@ def _ddim_update(
return new_mean, volatility


def _velocity_update(
xt: jnp.ndarray,
prediction: TargetInfo,
time: jnp.ndarray,
next_time: jnp.ndarray,
stoch_coeff: float,
process: GaussianProcess,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Helper function to compute the Velocity update."""
g = process.schedule.g(time)
prediction_dict = process.convert_predictions(
prediction=prediction,
xt=xt,
time=time,
)
velocity = prediction_dict['velocity']
score = prediction_dict['score']
dt = time - next_time
delta = -velocity + 0.5 * jnp.square(stoch_coeff) * jnp.square(g) * score
mean = xt + delta * dt
volatility = jnp.sqrt(dt) * g * stoch_coeff
return mean, volatility


################################################################################
# MARK: Test constants
################################################################################
Expand Down Expand Up @@ -699,6 +722,229 @@ def test_update_specific_parameters(self):
)


class VelocityStepTest(parameterized.TestCase):

def setUp(self):
super().setUp()

self.schedule = schedules.RFSchedule()
self.process = gaussian.GaussianProcess(schedule=self.schedule)
self.initial_noise = jnp.expand_dims(jnp.eye(4), axis=0)

@parameterized.parameters(
itertools.product(_STOCHASTICITY_LEVELS, _USE_STOCHASTIC_LAST_STEP)
)
def test_initialize(self, stoch_coeff, use_stochastic_last_step):

velocity_step = gaussian_step_sampler.VelocityStep(
corruption_process=self.process,
stoch_coeff=stoch_coeff,
stochastic_last_step=use_stochastic_last_step,
)

initial_step = velocity_step.initialize(
initial_noise=self.initial_noise,
initial_step_info=StepInfo(
step=0,
time=jnp.array([0.0]),
rng=jax.random.PRNGKey(0),
),
)

chex.assert_trees_all_equal(
initial_step,
DiffusionStep(
xt=jnp.expand_dims(jnp.eye(4), axis=0),
step_info=StepInfo(
step=0,
time=jnp.array([0.0]),
rng=jax.random.PRNGKey(0),
),
aux=dict(),
),
)

@parameterized.parameters(
itertools.product(_STOCHASTICITY_LEVELS, _USE_STOCHASTIC_LAST_STEP)
)
def test_update(self, stoch_coeff, use_stochastic_last_step):

velocity_step = gaussian_step_sampler.VelocityStep(
corruption_process=self.process,
stoch_coeff=stoch_coeff,
stochastic_last_step=use_stochastic_last_step,
)

initial_step = velocity_step.initialize(
initial_noise=self.initial_noise,
initial_step_info=StepInfo(
step=0,
time=jnp.array([0.2]),
rng=jax.random.PRNGKey(0),
),
)

prediction = dummy_inference_fn(
xt=initial_step.xt,
conditioning={},
time=initial_step.step_info.time,
)

next_step = velocity_step.update(
prediction=prediction,
current_step=initial_step,
next_step_info=StepInfo(
step=1, time=jnp.array([0.1]), rng=jax.random.PRNGKey(1)
),
)

z = jax.random.normal(
key=next_step.step_info.rng, shape=initial_step.xt.shape
)

mean, volatility = _velocity_update(
xt=initial_step.xt,
prediction=prediction,
time=initial_step.step_info.time,
next_time=next_step.step_info.time,
stoch_coeff=stoch_coeff,
process=self.process,
)
expected_xt = mean + volatility * z

chex.assert_trees_all_close(
next_step,
DiffusionStep(
xt=expected_xt,
step_info=StepInfo(
step=1,
time=jnp.array([0.1]),
rng=jax.random.PRNGKey(1),
),
aux={},
),
atol=1e-6,
)

@parameterized.parameters(
itertools.product(_STOCHASTICITY_LEVELS, _USE_STOCHASTIC_LAST_STEP)
)
def test_finalize(self, stoch_coeff, use_stochastic_last_step):
velocity_step = gaussian_step_sampler.VelocityStep(
corruption_process=self.process,
stoch_coeff=stoch_coeff,
stochastic_last_step=use_stochastic_last_step,
)

initial_step = velocity_step.initialize(
initial_noise=self.initial_noise,
initial_step_info=StepInfo(
step=0,
time=jnp.array([0.2]),
rng=jax.random.PRNGKey(0),
),
)

prediction = dummy_inference_fn(
xt=initial_step.xt,
conditioning={},
time=initial_step.step_info.time,
)

final_step = velocity_step.finalize(
prediction=prediction,
current_step=initial_step,
last_step_info=StepInfo(
step=1, time=jnp.array([0.1]), rng=jax.random.PRNGKey(1)
),
)

z = jax.random.normal(
key=final_step.step_info.rng, shape=initial_step.xt.shape
)

mean, volatility = _velocity_update(
xt=initial_step.xt,
prediction=prediction,
time=initial_step.step_info.time,
next_time=final_step.step_info.time,
stoch_coeff=stoch_coeff,
process=self.process,
)

if use_stochastic_last_step:
expected_xt = mean + volatility * z
else:
expected_xt = mean

chex.assert_trees_all_close(
final_step,
DiffusionStep(
xt=expected_xt,
step_info=StepInfo(
step=1,
time=jnp.array([0.1]),
rng=jax.random.PRNGKey(1),
),
aux={},
),
atol=1e-6,
)

def test_update_specific_parameters(self):

velocity_step = gaussian_step_sampler.VelocityStep(
corruption_process=self.process,
stoch_coeff=0.25,
stochastic_last_step=False,
)

initial_step = velocity_step.initialize(
initial_noise=self.initial_noise,
initial_step_info=StepInfo(
step=0,
time=jnp.array([0.2]),
rng=jax.random.PRNGKey(0),
),
)

prediction = dummy_inference_fn(
xt=initial_step.xt,
conditioning={},
time=initial_step.step_info.time,
)

next_step = velocity_step.update(
prediction=prediction,
current_step=initial_step,
next_step_info=StepInfo(
step=1, time=jnp.array([0.1]), rng=jax.random.PRNGKey(1)
),
)

chex.assert_trees_all_close(
next_step,
DiffusionStep(
xt=jnp.array(
[[
[0.983554, 0.004735, -0.007602, -0.008667],
[0.070809, 1.000478, 0.119717, 0.056051],
[-0.01623, 0.020032, 0.952613, -0.013727],
[0.049506, 0.043945, 0.049693, 1.022896],
]],
dtype=jnp.float32,
),
step_info=StepInfo(
step=1,
time=jnp.array([0.1]),
rng=jax.random.PRNGKey(1),
),
aux={},
),
atol=1e-6,
)


class HeunStepTest(absltest.TestCase):

def setUp(self):
Expand Down
Loading