Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the SDE variant of DPM-Solver and DPM-Solver++ #3344

Merged
merged 4 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
103 changes: 87 additions & 16 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


Expand Down Expand Up @@ -70,6 +71,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
second-order `sde-dpmsolver++`.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
Expand Down Expand Up @@ -103,10 +108,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++`.
algorithm_type (`str`, default `dpmsolver++`):
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
sampling (e.g. stable-diffusion).
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
`sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
solver_type (`str`, default `midpoint`):
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
Expand Down Expand Up @@ -180,7 +185,7 @@ def __init__(
self.init_noise_sigma = 1.0

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
Expand Down Expand Up @@ -212,7 +217,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
"""
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
Expand Down Expand Up @@ -338,10 +343,10 @@ def convert_model_output(
"""

# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
Expand All @@ -360,33 +365,42 @@ def convert_model_output(
x0_pred = self._threshold_sample(x0_pred)

return x0_pred

# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output
if self.config.variance_type in ["learned", "learned_range"]:
epsilon = model_output[:, :3]
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler."
)

if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t

return epsilon

def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
timestep: int,
prev_timestep: int,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
One step for the first-order DPM-Solver (equivalent to DDIM).
Expand All @@ -411,6 +425,20 @@ def dpm_solver_first_order_update(
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t

def multistep_dpm_solver_second_order_update(
Expand All @@ -419,6 +447,7 @@ def multistep_dpm_solver_second_order_update(
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
One step for the second-order multistep DPM-Solver.
Expand Down Expand Up @@ -470,6 +499,38 @@ def multistep_dpm_solver_second_order_update(
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * (torch.exp(h) - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t

def multistep_dpm_solver_third_order_update(
Expand Down Expand Up @@ -532,6 +593,7 @@ def step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Expand Down Expand Up @@ -574,12 +636,21 @@ def step(
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output

if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
else:
noise = None

if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
prev_sample = self.dpm_solver_first_order_update(
model_output, timestep, prev_timestep, sample, noise=noise
)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
)
else:
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
Expand Down
18 changes: 11 additions & 7 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,20 @@ def test_prediction_type(self):
self.check_over_configs(prediction_type=prediction_type)

def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
if algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
if order == 3:
continue
else:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
Expand Down