From b422a7047e479f6383a463d97e44c7932ff6af1f Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 25 Feb 2024 20:14:49 -0800 Subject: [PATCH 1/5] DPMMultistep rescale_betas_zero_snr --- .../scheduling_dpmsolver_multistep.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 6a090c477290..1aa925ef3cc7 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -70,6 +70,42 @@ def alpha_bar_fn(t): betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ @@ -144,6 +180,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -173,6 +213,7 @@ def __init__( variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, ): if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" @@ -191,8 +232,17 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) From 32350380f6d80d42d6e0d64e6220f86e06e9e288 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 25 Feb 2024 20:51:08 -0800 Subject: [PATCH 2/5] DPM upcast samples in step() --- .../schedulers/scheduling_dpmsolver_multistep.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 1aa925ef3cc7..f507efdfe8bb 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -930,6 +930,11 @@ def step( if self.step_index is None: self._init_step_index(timestep) + # store old dtype because model_output isn't always the same it seems + return_dtype = sample.dtype + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final @@ -947,7 +952,7 @@ def step( if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 ) else: noise = None @@ -962,6 +967,9 @@ def step( if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # Cast sample back to expected dtype + prev_sample = prev_sample.to(return_dtype) + # upon completion increase step index by one self._step_index += 1 From 70c1211f40510c4a5be816e006f844a92356c5c2 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 25 Feb 2024 21:09:53 -0800 Subject: [PATCH 3/5] DPM rescale_betas_zero_snr UT --- tests/schedulers/test_scheduler_dpm_multi.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 515cea5bc4ba..fcf8881bc820 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -213,6 +213,10 @@ def test_inference_steps(self): for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_full_loop_no_noise(self): sample = self.full_loop() result_mean = torch.mean(torch.abs(sample)) From b81e0102e884b7992b6418816d3304c0f929130a Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 26 Feb 2024 14:51:03 -0800 Subject: [PATCH 4/5] DPMSolverMulti move sample upcast after model convert Avoids having to re-use the dtype. --- .../schedulers/scheduling_dpmsolver_multistep.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index f507efdfe8bb..d8f969551593 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -930,11 +930,6 @@ def step( if self.step_index is None: self._init_step_index(timestep) - # store old dtype because model_output isn't always the same it seems - return_dtype = sample.dtype - # Upcast to avoid precision issues when computing prev_sample - sample = sample.to(torch.float32) - # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final @@ -950,6 +945,9 @@ def step( self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 @@ -968,7 +966,7 @@ def step( self.lower_order_nums += 1 # Cast sample back to expected dtype - prev_sample = prev_sample.to(return_dtype) + prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 From 555831e78f41245e857caf1a74ed42ddc49b0879 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 27 Feb 2024 00:35:53 -0800 Subject: [PATCH 5/5] Add a newline for Ruff --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d8f969551593..5c6d03fa229c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -70,6 +70,7 @@ def alpha_bar_fn(t): betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) + # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """