From 8f78025482e9f1fdd4b6678ee9f6b0cb50d70420 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 21 Jul 2023 03:52:09 +0000 Subject: [PATCH 01/37] add index_counter --- .../scheduling_dpmsolver_multistep.py | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d7516fa601e1..7dbbe35f3d55 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math +from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -274,11 +275,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc self.sigmas = torch.from_numpy(sigmas) - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] - self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -288,6 +284,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ] * self.config.solver_order self.lower_order_nums = 0 + # add an index counter for schedulers that allow duplicated timesteps + self._index_counter = defaultdict(int) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -660,11 +659,25 @@ def step( if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: + indices = (self.timesteps == timestep).nonzero() + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + + if len(indices) == 0: step_index = len(self.timesteps) - 1 else: - step_index = step_index.item() + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 + else: + pos = self._index_counter[timestep_int] + step_index = indices[pos].item() + + # advance index counter by 1 + self._index_counter[timestep_int] += 1 + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] lower_order_final = ( (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 From 3b886af21bb070d48e984199778097221036aee5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 21 Jul 2023 04:29:16 +0000 Subject: [PATCH 02/37] update test --- tests/schedulers/test_scheduler_dpm_multi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index c9935780b983..86b24af24095 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -264,10 +264,10 @@ def test_fp16_support(self): assert sample.dtype == torch.float16 - def test_unique_timesteps(self, **config): + def test_duplicated_timesteps(self, **config): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(scheduler.config.num_train_timesteps) - assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps + assert len(scheduler.timesteps) == scheduler.num_inference_steps From a05a13a9ab0a796d970c03e7b3e592948cf7b48a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 24 Jul 2023 19:11:10 +0000 Subject: [PATCH 03/37] update add print lines add print lines and change --- src/diffusers/models/unet_2d_condition.py | 1 + .../pipeline_stable_diffusion.py | 9 +- .../pipeline_stable_diffusion_k_diffusion.py | 118 ++++++++++- .../scheduling_dpmsolver_multistep.py | 183 ++++++++++++------ .../schedulers/scheduling_euler_discrete.py | 8 + 5 files changed, 257 insertions(+), 62 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d7756ab5edb3..1bc493df638b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -986,6 +986,7 @@ def forward( sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) + print(f" - unet out (sample): {sample.shape},{sample[0,0,:3,:3]}") if not return_dict: return (sample,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 54927049571c..276ddae727b4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -515,7 +515,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler + print(f" inside prepare_latents:") + print(f" - latents: {latents.shape},{latents[0,0,:3,:3]}") latents = latents * self.scheduler.init_noise_sigma + print(f" - latents * init_noise_sigma: {latents.shape},{latents[0,0,:3,:3]}") return latents @torch.no_grad() @@ -679,7 +682,9 @@ def __call__( for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + print(f" - latent_model_input: {latent_model_input.shape},{latent_model_input[0,0,:3,:3]}") latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + print(f" - latent_model_input(scaled): {latent_model_input.shape},{latent_model_input[0,0,:3,:3]}") # predict the noise residual noise_pred = self.unet( @@ -689,12 +694,12 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] - + print(f" - noise_pred: {noise_pred.shape},{noise_pred[0,0,:3,:3]}") # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - + print(f" - noise_pred (cfg): {noise_pred.shape},{noise_pred[0,0,:3,:3]}") if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 29a57470a341..4291072b6e94 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -18,7 +18,8 @@ from typing import Callable, List, Optional, Union import torch -from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser +from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser, DiscreteSchedule +from k_diffusion import utils from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras from ...image_processor import VaeImageProcessor @@ -32,6 +33,97 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# yiyi testing +from tqdm.auto import trange +@torch.no_grad() +def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM-Solver++(2M).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable): + print(f" - i :{i}, sigma: {sigmas[i]}") + denoised = model(x, sigmas[i] * s_in, **extra_args) + print(f" - denoised: {denoised.shape}, {denoised[0,0,:3,:3]}") + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + print(f" - sigma_t: {sigmas[i+1]}, sigma_s: {sigmas[i]}") + print(f" - t, t_next: {t},{t_next}") + h = t_next - t + print(f" - h: {h}") + if old_denoised is None or sigmas[i + 1] == 0: + print(f" first order") + print(f" - x/sample/latents: {x.shape},{x[0,0,:3,:3]}") + print(f" - sigma_fns(t_next): {sigma_fn(t_next)}, sigma_fn(t): {sigma_fn(t)}") + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised + print(f" -> x: {x[0,0,:3,:3]}") + else: + print(" second order") + print(f" yiyi testing") + print(f" - sigmas: {sigmas[i]}, {sigmas[i+1]}") + print(f" - sigma_fns: {sigma_fn(t)}, {sigma_fn(t_next)}") + h_last = t - t_fn(sigmas[i - 1]) + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d + print(f" -> x: {x[0,0,:3,:3]}") + old_denoised = denoised + return x + + +class DiscreteEpsDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output eps (the predicted + noise).""" + + def __init__(self, model, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.inner_model = model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_out, c_in + + def get_eps(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def loss(self, input, noise, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + return (eps - noise).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + print(f" arriving CompVisDenoiser.foward") + print(f" - input: {input.shape}, {input[0,0,:3,:3]}") + print(f" - c_in: {c_in.shape}, {c_in}") + print(f" - c_out:{c_out.shape}, {c_out}") + print(f" - sigma: {sigma}") + print(f" - t: {self.sigma_to_t(sigma)}") + print(f" - input * c_in : {(input * c_in).shape}, {(input * c_in)[0,0,:3,:3]}") + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + print(f" - eps: {eps.shape}, {eps[0,0,:3,:3]}") + print(f" - eps * c_out : {(eps * c_out).shape}, {(eps * c_out)[0,0,:3,:3]}") + print(f" - input + eps * c_out: {(input + eps * c_out).shape}, {(input + eps * c_out)[0,0,:3,:3]}") + print(f" leaving CompVisDenoiser.foward") + return input + eps * c_out + + +class CompVisDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for CompVis diffusion models.""" + + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + class ModelWrapper: def __init__(self, model, alphas_cumprod): self.model = model @@ -123,9 +215,12 @@ def __init__( self.k_diffusion_model = CompVisDenoiser(model) def set_scheduler(self, scheduler_type: str): - library = importlib.import_module("k_diffusion") - sampling = getattr(library, "sampling") - self.sampler = getattr(sampling, scheduler_type) + #library = importlib.import_module("k_diffusion") + #sampling = getattr(library, "sampling") + #self.sampler = getattr(sampling, scheduler_type) + if scheduler_type == "sample_dpmpp_2m": + self.sampler = sample_dpmpp_2m + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): @@ -530,9 +625,12 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) # 5. Prepare sigmas + if use_karras_sigmas: + print(f" - k_diffusion_model.sigmas :{self.k_diffusion_model.sigmas}") sigma_min: float = self.k_diffusion_model.sigmas[0].item() sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + print(f" -sigma_max: {sigma_max}, sigma_min: {sigma_min}") sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) sigmas = sigmas.to(device) else: @@ -551,19 +649,28 @@ def __call__( generator, latents, ) + print(f" - prepare_latents -> {latents.shape},{latents[0,0,:3,:3]}") latents = latents * sigmas[0] + print(f" - latents * initial noise sigma: {latents.shape},{latents[0,0,:3,:3]}") self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) # 7. Define model function def model_fn(x, t): + print(" ") + print(f" arriving model_fn") latent_model_input = torch.cat([x] * 2) + print(f" - latent_model_input: {latent_model_input.shape}, {latent_model_input[0,0,:3,:3]}") t = torch.cat([t] * 2) noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) + print(f" - noise_pred: {noise_pred.shape}, {noise_pred[0,0,:3,:3]}") noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + print(f" -> cfg -> {noise_pred.shape},{noise_pred[0,0,:3,:3]}") + print(" leaving model_fn") + print(" ") return noise_pred # 8. Run k-diffusion solver @@ -573,7 +680,8 @@ def model_fn(x, t): min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) sampler_kwargs["noise_sampler"] = noise_sampler - + + print(f" sigmas: {sigmas}") latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) if not output_type == "latent": diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 7dbbe35f3d55..d39417859da8 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -206,9 +206,6 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": @@ -225,9 +222,27 @@ def __init__( # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self._step_index = None + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + @property + def step_index(self): + """ + TODO: Nice docstring + """ + return self._step_index + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ @@ -243,23 +258,25 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + print(f" - last_timestep: {last_timestep}") # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + np.linspace(0, last_timestep - 1, num_inference_steps, dtype=float).round()[::-1].copy() ) + print(f" - timesteps: {len(timesteps)},{timesteps[0]}:{timesteps[-1]}") elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // (num_inference_steps + 1) + step_ratio = last_timestep // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(float) timesteps -= 1 else: raise ValueError( @@ -267,13 +284,21 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + print(f" sigmas: {sigmas[0]}: {sigmas[-1]}") + if self.config.use_karras_sigmas: - log_sigmas = np.log(sigmas) + print(f" sigmas (k): {sigmas[0]}, {sigmas[-1]}") sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) - self.sigmas = torch.from_numpy(sigmas) + print(" set_timesteps") + print(f" - sigmas: {sigmas[0]}: {sigmas[-1]}") + print(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device) @@ -283,9 +308,12 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc None, ] * self.config.solver_order self.lower_order_nums = 0 + print(f" - timesteps: {timesteps}") + print(f" - model_outputs: {self.model_outputs}") + print(f" - lower_order_nums: {self.lower_order_nums}") # add an index counter for schedulers that allow duplicated timesteps - self._index_counter = defaultdict(int) + self._step_index = None # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: @@ -384,13 +412,17 @@ def convert_model_output( """ # DPM-Solver++ needs to solve an integral of the data prediction model. + print(f" -algo_type: {self.config.algorithm_type}, pred_type: {self.config.prediction_type}, variance_type: {self.config.variance_type}") if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t + # yiyi update/testing here: + sigma = self.sigmas[self.step_index] + x0_pred = sample - sigma * model_output + print(f" - sigma: {sigma}") + print(f" - x0_pred: {x0_pred.shape},{x0_pred[0,0,:3,:3]}") elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": @@ -403,6 +435,7 @@ def convert_model_output( ) if self.config.thresholding: + print(f" threhold? ") x0_pred = self._threshold_sample(x0_pred) return x0_pred @@ -458,12 +491,18 @@ def dpm_solver_first_order_update( Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] - h = lambda_t - lambda_s + def t_fn(_sigma): + return -torch.log(_sigma) + + print(f" - sample: {sample[0,0,:3,:3]}") + + sigma_t, sigma_s = self.sigmas[self.step_index +1], self.sigmas[self.step_index] + print(f" - sigma_t: {sigma_t}, sigma_s: {sigma_s}") + h = t_fn(sigma_t) - t_fn(sigma_s) + print(f" - h: {h}") if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output + print(f" -> prev_sample: {x_t[0,0,:3,:3]}") elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": @@ -504,22 +543,32 @@ def multistep_dpm_solver_second_order_update( Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + + def t_fn(_sigma): + return -torch.log(_sigma) + + sigma_t, sigma_s0, sigma_s1 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1] m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + + h, h_0 = t_fn(sigma_t) - t_fn(sigma_s0), t_fn(sigma_s0) - t_fn(sigma_s1) r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) + print(f" sigma_t (current), sigma_s0 (prev), sigma_s1 (next): {sigma_t}, {sigma_s0}, {sigma_s1}") + # print(f" sigma_t, sigma_s0, {sigma_t},{sigma_s0}") + # print(f" sigma_s1: {self.sigma_t[s1]}") + # print(f" alpha_bar = {self.alphas_cumprod[s1]}") + print(f" sample: {sample[0,0,:3,:3]}") + # print(" -----") if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": + print(f" dpmsolver++ /midpoint ") x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + - (torch.exp(-h) - 1.0) * D0 + - 0.5 * (torch.exp(-h) - 1.0) * D1 ) + print(f" -> prev_sample: {x_t.shape}, {x_t[0,0,:3,:3]}") elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample @@ -629,6 +678,23 @@ def multistep_dpm_solver_third_order_update( ) return x_t + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + def step( self, model_output: torch.FloatTensor, @@ -657,37 +723,26 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - indices = (self.timesteps == timestep).nonzero() - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + if self.step_index is None: + self._init_step_index(timestep) + print(" ") + print(" ** inside step ***** ") - if len(indices) == 0: - step_index = len(self.timesteps) - 1 - else: - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(self._index_counter) == 0: - pos = 1 if len(indices) > 1 else 0 - else: - pos = self._index_counter[timestep_int] - step_index = indices[pos].item() - - # advance index counter by 1 - self._index_counter[timestep_int] += 1 + print(f" - timestep, step_index: {timestep}, {self.step_index}") + print(f" - sample/latents: {sample.shape},{sample[0,0,:3,:3]}") + print(f" - model_output (noise): {model_output.shape}, {model_output[0,0,:3,:3]}") - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 - ) + prev_timestep = 0 if self.step_index == len(self.timesteps) - 1 else self.timesteps[self.step_index + 1] + lower_order_final = (self.step_index == len(self.timesteps) - 1) lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) + print(f" - prev_timestep: {prev_timestep}") model_output = self.convert_model_output(model_output, timestep, sample) + print(f" - model_output (x0): {model_output.shape}, {model_output[0,0,:3,:3]}") for i in range(self.config.solver_order - 1): + print(f" move outputs {i+1} -> {i}") self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -699,15 +754,18 @@ def step( noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + print(f" 1st order update: {self.config.solver_order}, {self.lower_order_nums}, {lower_order_final}") prev_sample = self.dpm_solver_first_order_update( model_output, timestep, prev_timestep, sample, noise=noise ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] + print(f" 2nd order update") + timestep_list = [self.timesteps[self.step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: + print(f" 3rd order update") timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample @@ -715,23 +773,38 @@ def step( if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + print(f" lower_order_nums +1 -> {self.lower_order_nums}") + + # upon completion increase step index by one + self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + sample = sample / ((sigma**2 + 1) ** 0.5) + + self.is_scale_input_called = True return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index cb126d4b953c..9dbe26e124f6 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -243,9 +243,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + print(f" -sigmas: {sigmas[0]}, {sigmas[-1]}") log_sigmas = np.log(sigmas) if self.config.interpolation_type == "linear": + print(f" - timesteps: {len(timesteps)}") + print(timesteps) + print(f" - sigmas: {len(sigmas)}") sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) elif self.config.interpolation_type == "log_linear": sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() @@ -256,8 +260,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) if self.use_karras_sigmas: + print(" use k_sigmas") + print(f" -sigmas: {sigmas[0]}, {sigmas[-1]}") sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + print(f" -sigmas: {sigmas}") + print(f" -timesteps: {timesteps}") sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) From 670c782cb269ada256d0de4e5c5afb07858efb1b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 21 Aug 2023 01:51:09 +0000 Subject: [PATCH 04/37] fix --- src/diffusers/models/unet_2d_condition.py | 1 - .../pipeline_stable_diffusion.py | 7 - .../pipeline_stable_diffusion_k_diffusion.py | 120 +----------------- .../scheduling_dpmsolver_multistep.py | 43 +------ .../schedulers/scheduling_euler_discrete.py | 8 -- 5 files changed, 7 insertions(+), 172 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 1bc493df638b..d7756ab5edb3 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -986,7 +986,6 @@ def forward( sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) - print(f" - unet out (sample): {sample.shape},{sample[0,0,:3,:3]}") if not return_dict: return (sample,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 276ddae727b4..5ea7fd0d500a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -515,10 +515,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler - print(f" inside prepare_latents:") - print(f" - latents: {latents.shape},{latents[0,0,:3,:3]}") latents = latents * self.scheduler.init_noise_sigma - print(f" - latents * init_noise_sigma: {latents.shape},{latents[0,0,:3,:3]}") return latents @torch.no_grad() @@ -682,9 +679,7 @@ def __call__( for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - print(f" - latent_model_input: {latent_model_input.shape},{latent_model_input[0,0,:3,:3]}") latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - print(f" - latent_model_input(scaled): {latent_model_input.shape},{latent_model_input[0,0,:3,:3]}") # predict the noise residual noise_pred = self.unet( @@ -694,12 +689,10 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] - print(f" - noise_pred: {noise_pred.shape},{noise_pred[0,0,:3,:3]}") # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - print(f" - noise_pred (cfg): {noise_pred.shape},{noise_pred[0,0,:3,:3]}") if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 4291072b6e94..e272e9e3a505 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -18,8 +18,7 @@ from typing import Callable, List, Optional, Union import torch -from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser, DiscreteSchedule -from k_diffusion import utils +from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras from ...image_processor import VaeImageProcessor @@ -33,97 +32,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# yiyi testing -from tqdm.auto import trange -@torch.no_grad() -def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): - """DPM-Solver++(2M).""" - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - sigma_fn = lambda t: t.neg().exp() - t_fn = lambda sigma: sigma.log().neg() - old_denoised = None - - for i in trange(len(sigmas) - 1, disable=disable): - print(f" - i :{i}, sigma: {sigmas[i]}") - denoised = model(x, sigmas[i] * s_in, **extra_args) - print(f" - denoised: {denoised.shape}, {denoised[0,0,:3,:3]}") - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) - print(f" - sigma_t: {sigmas[i+1]}, sigma_s: {sigmas[i]}") - print(f" - t, t_next: {t},{t_next}") - h = t_next - t - print(f" - h: {h}") - if old_denoised is None or sigmas[i + 1] == 0: - print(f" first order") - print(f" - x/sample/latents: {x.shape},{x[0,0,:3,:3]}") - print(f" - sigma_fns(t_next): {sigma_fn(t_next)}, sigma_fn(t): {sigma_fn(t)}") - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised - print(f" -> x: {x[0,0,:3,:3]}") - else: - print(" second order") - print(f" yiyi testing") - print(f" - sigmas: {sigmas[i]}, {sigmas[i+1]}") - print(f" - sigma_fns: {sigma_fn(t)}, {sigma_fn(t_next)}") - h_last = t - t_fn(sigmas[i - 1]) - r = h_last / h - denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d - print(f" -> x: {x[0,0,:3,:3]}") - old_denoised = denoised - return x - - -class DiscreteEpsDDPMDenoiser(DiscreteSchedule): - """A wrapper for discrete schedule DDPM models that output eps (the predicted - noise).""" - - def __init__(self, model, alphas_cumprod, quantize): - super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) - self.inner_model = model - self.sigma_data = 1. - - def get_scalings(self, sigma): - c_out = -sigma - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - return c_out, c_in - - def get_eps(self, *args, **kwargs): - return self.inner_model(*args, **kwargs) - - def loss(self, input, noise, sigma, **kwargs): - c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) - return (eps - noise).pow(2).flatten(1).mean(1) - - def forward(self, input, sigma, **kwargs): - c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - print(f" arriving CompVisDenoiser.foward") - print(f" - input: {input.shape}, {input[0,0,:3,:3]}") - print(f" - c_in: {c_in.shape}, {c_in}") - print(f" - c_out:{c_out.shape}, {c_out}") - print(f" - sigma: {sigma}") - print(f" - t: {self.sigma_to_t(sigma)}") - print(f" - input * c_in : {(input * c_in).shape}, {(input * c_in)[0,0,:3,:3]}") - eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) - print(f" - eps: {eps.shape}, {eps[0,0,:3,:3]}") - print(f" - eps * c_out : {(eps * c_out).shape}, {(eps * c_out)[0,0,:3,:3]}") - print(f" - input + eps * c_out: {(input + eps * c_out).shape}, {(input + eps * c_out)[0,0,:3,:3]}") - print(f" leaving CompVisDenoiser.foward") - return input + eps * c_out - - -class CompVisDenoiser(DiscreteEpsDDPMDenoiser): - """A wrapper for CompVis diffusion models.""" - - def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod, quantize=quantize) - - def get_eps(self, *args, **kwargs): - return self.inner_model.apply_model(*args, **kwargs) - class ModelWrapper: def __init__(self, model, alphas_cumprod): self.model = model @@ -215,13 +123,10 @@ def __init__( self.k_diffusion_model = CompVisDenoiser(model) def set_scheduler(self, scheduler_type: str): - #library = importlib.import_module("k_diffusion") - #sampling = getattr(library, "sampling") - #self.sampler = getattr(sampling, scheduler_type) - if scheduler_type == "sample_dpmpp_2m": - self.sampler = sample_dpmpp_2m - - + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + self.sampler = getattr(sampling, scheduler_type) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" @@ -625,12 +530,9 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) # 5. Prepare sigmas - if use_karras_sigmas: - print(f" - k_diffusion_model.sigmas :{self.k_diffusion_model.sigmas}") sigma_min: float = self.k_diffusion_model.sigmas[0].item() sigma_max: float = self.k_diffusion_model.sigmas[-1].item() - print(f" -sigma_max: {sigma_max}, sigma_min: {sigma_min}") sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) sigmas = sigmas.to(device) else: @@ -649,28 +551,19 @@ def __call__( generator, latents, ) - print(f" - prepare_latents -> {latents.shape},{latents[0,0,:3,:3]}") latents = latents * sigmas[0] - print(f" - latents * initial noise sigma: {latents.shape},{latents[0,0,:3,:3]}") self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) # 7. Define model function def model_fn(x, t): - print(" ") - print(f" arriving model_fn") latent_model_input = torch.cat([x] * 2) - print(f" - latent_model_input: {latent_model_input.shape}, {latent_model_input[0,0,:3,:3]}") t = torch.cat([t] * 2) noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) - print(f" - noise_pred: {noise_pred.shape}, {noise_pred[0,0,:3,:3]}") noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - print(f" -> cfg -> {noise_pred.shape},{noise_pred[0,0,:3,:3]}") - print(" leaving model_fn") - print(" ") return noise_pred # 8. Run k-diffusion solver @@ -680,8 +573,7 @@ def model_fn(x, t): min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) sampler_kwargs["noise_sampler"] = noise_sampler - - print(f" sigmas: {sigmas}") + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) if not output_type == "latent": diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d39417859da8..d1e66f8404af 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -258,14 +258,12 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() - print(f" - last_timestep: {last_timestep}") # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, last_timestep - 1, num_inference_steps, dtype=float).round()[::-1].copy() ) - print(f" - timesteps: {len(timesteps)},{timesteps[0]}:{timesteps[-1]}") elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // self.num_inference_steps # creates integer timesteps by multiplying by ratio @@ -286,20 +284,14 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - print(f" sigmas: {sigmas[0]}: {sigmas[-1]}") if self.config.use_karras_sigmas: - print(f" sigmas (k): {sigmas[0]}, {sigmas[-1]}") sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) - print(" set_timesteps") - print(f" - sigmas: {sigmas[0]}: {sigmas[-1]}") - print(sigmas) - self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -308,9 +300,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc None, ] * self.config.solver_order self.lower_order_nums = 0 - print(f" - timesteps: {timesteps}") - print(f" - model_outputs: {self.model_outputs}") - print(f" - lower_order_nums: {self.lower_order_nums}") # add an index counter for schedulers that allow duplicated timesteps self._step_index = None @@ -412,17 +401,13 @@ def convert_model_output( """ # DPM-Solver++ needs to solve an integral of the data prediction model. - print(f" -algo_type: {self.config.algorithm_type}, pred_type: {self.config.prediction_type}, variance_type: {self.config.variance_type}") if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] - # yiyi update/testing here: sigma = self.sigmas[self.step_index] x0_pred = sample - sigma * model_output - print(f" - sigma: {sigma}") - print(f" - x0_pred: {x0_pred.shape},{x0_pred[0,0,:3,:3]}") elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": @@ -435,7 +420,6 @@ def convert_model_output( ) if self.config.thresholding: - print(f" threhold? ") x0_pred = self._threshold_sample(x0_pred) return x0_pred @@ -494,15 +478,10 @@ def dpm_solver_first_order_update( def t_fn(_sigma): return -torch.log(_sigma) - print(f" - sample: {sample[0,0,:3,:3]}") - sigma_t, sigma_s = self.sigmas[self.step_index +1], self.sigmas[self.step_index] - print(f" - sigma_t: {sigma_t}, sigma_s: {sigma_s}") h = t_fn(sigma_t) - t_fn(sigma_s) - print(f" - h: {h}") if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output - print(f" -> prev_sample: {x_t[0,0,:3,:3]}") elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": @@ -553,22 +532,15 @@ def t_fn(_sigma): h, h_0 = t_fn(sigma_t) - t_fn(sigma_s0), t_fn(sigma_s0) - t_fn(sigma_s1) r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) - print(f" sigma_t (current), sigma_s0 (prev), sigma_s1 (next): {sigma_t}, {sigma_s0}, {sigma_s1}") - # print(f" sigma_t, sigma_s0, {sigma_t},{sigma_s0}") - # print(f" sigma_s1: {self.sigma_t[s1]}") - # print(f" alpha_bar = {self.alphas_cumprod[s1]}") - print(f" sample: {sample[0,0,:3,:3]}") - # print(" -----") + if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": - print(f" dpmsolver++ /midpoint ") x_t = ( (sigma_t / sigma_s0) * sample - (torch.exp(-h) - 1.0) * D0 - 0.5 * (torch.exp(-h) - 1.0) * D1 ) - print(f" -> prev_sample: {x_t.shape}, {x_t[0,0,:3,:3]}") elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample @@ -725,24 +697,15 @@ def step( if self.step_index is None: self._init_step_index(timestep) - print(" ") - print(" ** inside step ***** ") - - print(f" - timestep, step_index: {timestep}, {self.step_index}") - print(f" - sample/latents: {sample.shape},{sample[0,0,:3,:3]}") - print(f" - model_output (noise): {model_output.shape}, {model_output[0,0,:3,:3]}") prev_timestep = 0 if self.step_index == len(self.timesteps) - 1 else self.timesteps[self.step_index + 1] lower_order_final = (self.step_index == len(self.timesteps) - 1) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - print(f" - prev_timestep: {prev_timestep}") model_output = self.convert_model_output(model_output, timestep, sample) - print(f" - model_output (x0): {model_output.shape}, {model_output[0,0,:3,:3]}") for i in range(self.config.solver_order - 1): - print(f" move outputs {i+1} -> {i}") self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -754,18 +717,15 @@ def step( noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - print(f" 1st order update: {self.config.solver_order}, {self.lower_order_nums}, {lower_order_final}") prev_sample = self.dpm_solver_first_order_update( model_output, timestep, prev_timestep, sample, noise=noise ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - print(f" 2nd order update") timestep_list = [self.timesteps[self.step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: - print(f" 3rd order update") timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample @@ -773,7 +733,6 @@ def step( if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 - print(f" lower_order_nums +1 -> {self.lower_order_nums}") # upon completion increase step index by one self._step_index += 1 diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 9dbe26e124f6..cb126d4b953c 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -243,13 +243,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - print(f" -sigmas: {sigmas[0]}, {sigmas[-1]}") log_sigmas = np.log(sigmas) if self.config.interpolation_type == "linear": - print(f" - timesteps: {len(timesteps)}") - print(timesteps) - print(f" - sigmas: {len(sigmas)}") sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) elif self.config.interpolation_type == "log_linear": sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() @@ -260,12 +256,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) if self.use_karras_sigmas: - print(" use k_sigmas") - print(f" -sigmas: {sigmas[0]}, {sigmas[-1]}") sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - print(f" -sigmas: {sigmas}") - print(f" -timesteps: {timesteps}") sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) From c95b545113b9b7db2f26a51f0f9ca8213f157b24 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 21 Aug 2023 03:01:34 +0000 Subject: [PATCH 05/37] style --- .../pipeline_stable_diffusion_k_diffusion.py | 2 +- .../scheduling_dpmsolver_multistep.py | 51 ++++++++++--------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index e272e9e3a505..29a57470a341 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -126,7 +126,7 @@ def set_scheduler(self, scheduler_type: str): library = importlib.import_module("k_diffusion") sampling = getattr(library, "sampling") self.sampler = getattr(sampling, scheduler_type) - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d1e66f8404af..82b77353223a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -15,7 +15,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math -from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -222,12 +221,12 @@ def __init__( # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - + self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None - + @property def init_noise_sigma(self): # standard deviation of the initial noise distribution @@ -235,7 +234,7 @@ def init_noise_sigma(self): return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 - + @property def step_index(self): """ @@ -243,7 +242,6 @@ def step_index(self): """ return self._step_index - def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -261,20 +259,18 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps, dtype=float).round()[::-1].copy() - ) + timesteps = np.linspace(0, last_timestep - 1, num_inference_steps).round()[::-1].copy().astype(np.float32) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(float) + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.float32) timesteps -= 1 else: raise ValueError( @@ -288,10 +284,11 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) + timesteps = timesteps.astype(np.int32) self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -475,10 +472,14 @@ def dpm_solver_first_order_update( Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ + def t_fn(_sigma): return -torch.log(_sigma) - sigma_t, sigma_s = self.sigmas[self.step_index +1], self.sigmas[self.step_index] + # YiYi notes: keep these for now so don't get an error, don't need once fully refactored + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] h = t_fn(sigma_t) - t_fn(sigma_s) if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output @@ -525,8 +526,16 @@ def multistep_dpm_solver_second_order_update( def t_fn(_sigma): return -torch.log(_sigma) - - sigma_t, sigma_s0, sigma_s1 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + + # YiYi notes: keep these for now so don't get an error, not needed once fully refactored + t, s0 = prev_timestep, timestep_list[-1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) m0, m1 = model_output_list[-1], model_output_list[-2] h, h_0 = t_fn(sigma_t) - t_fn(sigma_s0), t_fn(sigma_s0) - t_fn(sigma_s1) @@ -536,11 +545,7 @@ def t_fn(_sigma): if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0) * sample - - (torch.exp(-h) - 1.0) * D0 - - 0.5 * (torch.exp(-h) - 1.0) * D1 - ) + x_t = (sigma_t / sigma_s0) * sample - (torch.exp(-h) - 1.0) * D0 - 0.5 * (torch.exp(-h) - 1.0) * D1 elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample @@ -699,7 +704,7 @@ def step( self._init_step_index(timestep) prev_timestep = 0 if self.step_index == len(self.timesteps) - 1 else self.timesteps[self.step_index + 1] - lower_order_final = (self.step_index == len(self.timesteps) - 1) + lower_order_final = self.step_index == len(self.timesteps) - 1 lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) @@ -726,14 +731,14 @@ def step( self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + timestep_list = [self.timesteps[self.step_index - 2], self.timesteps[self.step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 - + # upon completion increase step index by one self._step_index += 1 @@ -741,7 +746,7 @@ def step( return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) - + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] From 515c1050409c5a383174a5c58ee691b36d9b20c5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 22 Aug 2023 00:29:38 +0000 Subject: [PATCH 06/37] sde-dpmsolver+++ --- .../pipeline_stable_diffusion.py | 2 ++ .../scheduling_dpmsolver_multistep.py | 21 +++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index fa2a6715dea2..9bc2ad57fdcc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -681,10 +681,12 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] + # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 1c42dfd8c381..cd244956f68c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -264,7 +264,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) - timesteps = timesteps.astype(np.int32) self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -460,7 +459,7 @@ def t_fn(_sigma): return -torch.log(_sigma) # YiYi notes: keep these for now so don't get an error, don't need once fully refactored - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + #alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] h = t_fn(sigma_t) - t_fn(sigma_s) @@ -472,7 +471,7 @@ def t_fn(_sigma): assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + (1 - torch.exp(-2.0 * h)) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": @@ -514,8 +513,8 @@ def t_fn(_sigma): return -torch.log(_sigma) # YiYi notes: keep these for now so don't get an error, not needed once fully refactored - t, s0 = prev_timestep, timestep_list[-1] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + #t, s0 = prev_timestep, timestep_list[-1] + #alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], @@ -535,8 +534,8 @@ def t_fn(_sigma): elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (torch.exp(-h) - 1.0) * D0 + + ((torch.exp(-h) - 1.0) / h + 1.0) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations @@ -557,15 +556,15 @@ def t_fn(_sigma): if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + (1 - torch.exp(-2.0 * h)) * D0 + + 0.5 * (1 - torch.exp(-2.0 * h)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + (1 - torch.exp(-2.0 * h)) * D0 + + ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": From a85a18c0ce5c08e56ae6f1f7ab642ccab8c78676 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 22 Aug 2023 02:14:50 +0000 Subject: [PATCH 07/37] v_prediction --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index cd244956f68c..008331fab621 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -388,8 +388,8 @@ def convert_model_output( elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = alpha_t * sample - sigma_t * model_output + sigma = self.sigmas[self.step_index] + x0_pred = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" From f238e0d86280a0db5c4ae29ac6dfd49f5a81ef83 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 11 Sep 2023 01:36:54 +0000 Subject: [PATCH 08/37] remove round --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 008331fab621..5e6869cb0f0c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -235,7 +235,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, last_timestep - 1, num_inference_steps).round()[::-1].copy().astype(np.float32) + timesteps = np.linspace(0, last_timestep - 1, num_inference_steps)[::-1].copy().astype(np.float32) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // self.num_inference_steps # creates integer timesteps by multiplying by ratio From 7198998f03ce7613f1c2d93fa8333907784a6354 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 11 Sep 2023 22:10:20 +0000 Subject: [PATCH 09/37] try it differently --- .../scheduling_dpmsolver_multistep.py | 149 ++++++++++-------- 1 file changed, 82 insertions(+), 67 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 5e6869cb0f0c..4450e11504e6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -181,6 +181,9 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": @@ -203,14 +206,6 @@ def __init__( self.lower_order_nums = 0 self._step_index = None - @property - def init_noise_sigma(self): - # standard deviation of the initial noise distribution - if self.config.timestep_spacing in ["linspace", "trailing"]: - return self.sigmas.max() - - return (self.sigmas.max() ** 2 + 1) ** 0.5 - @property def step_index(self): """ @@ -235,12 +230,14 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, last_timestep - 1, num_inference_steps)[::-1].copy().astype(np.float32) + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int32) + ) elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // self.num_inference_steps + step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int32) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps @@ -334,6 +331,13 @@ def _sigma_to_t(self, sigma, log_sigmas): t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t + + def _sigma_to_alpha_sigma_t(self, sigma): + + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: @@ -384,12 +388,14 @@ def convert_model_output( if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] - x0_pred = sample - sigma * model_output + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] - x0_pred = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" @@ -410,10 +416,12 @@ def convert_model_output( else: epsilon = model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( @@ -422,7 +430,8 @@ def convert_model_output( ) if self.config.thresholding: - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t @@ -454,24 +463,23 @@ def dpm_solver_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - - def t_fn(_sigma): - return -torch.log(_sigma) - - # YiYi notes: keep these for now so don't get an error, don't need once fully refactored - #alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] - h = t_fn(sigma_t) - t_fn(sigma_s) + sigma_t, alpha_t = self._sigma_to_alpha_sigma_t(sigma_t) + sigma_s, alpha_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample - + (1 - torch.exp(-2.0 * h)) * model_output + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": @@ -509,33 +517,38 @@ def multistep_dpm_solver_second_order_update( The sample tensor at the previous timestep. """ - def t_fn(_sigma): - return -torch.log(_sigma) - - # YiYi notes: keep these for now so don't get an error, not needed once fully refactored - #t, s0 = prev_timestep, timestep_list[-1] - #alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + m0, m1 = model_output_list[-1], model_output_list[-2] - h, h_0 = t_fn(sigma_t) - t_fn(sigma_s0), t_fn(sigma_s0) - t_fn(sigma_s1) + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": - x_t = (sigma_t / sigma_s0) * sample - (torch.exp(-h) - 1.0) * D0 - 0.5 * (torch.exp(-h) - 1.0) * D1 + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - - (torch.exp(-h) - 1.0) * D0 - + ((torch.exp(-h) - 1.0) / h + 1.0) * D1 + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations @@ -556,15 +569,15 @@ def t_fn(_sigma): if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (1 - torch.exp(-2.0 * h)) * D0 - + 0.5 * (1 - torch.exp(-2.0 * h)) * D1 + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (1 - torch.exp(-2.0 * h)) * D0 - + ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1 + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": @@ -609,16 +622,26 @@ def multistep_dpm_solver_third_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index +1], + self.sigmas[self.step_index], + self.sigmas[self.step_index -1], + self.sigmas[self.step_index -2] ) - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 @@ -699,7 +722,9 @@ def step( self._init_step_index(timestep) prev_timestep = 0 if self.step_index == len(self.timesteps) - 1 else self.timesteps[self.step_index + 1] - lower_order_final = self.step_index == len(self.timesteps) - 1 + lower_order_final = ( + (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) @@ -742,29 +767,19 @@ def step( return SchedulerOutput(prev_sample=prev_sample) - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input - def scale_model_input( - self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] - ) -> torch.FloatTensor: + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ - Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Args: - sample (`torch.FloatTensor`): input sample - timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + sample (`torch.FloatTensor`): + The input sample. Returns: `torch.FloatTensor`: A scaled input sample. """ - if self.step_index is None: - self._init_step_index(timestep) - - sigma = self.sigmas[self.step_index] - - sample = sample / ((sigma**2 + 1) ** 0.5) - - self.is_scale_input_called = True return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise From 67ef0e32f3db352c21e71320e12b78a8eac9b353 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 11 Sep 2023 23:54:31 +0000 Subject: [PATCH 10/37] add print lines --- .../scheduling_dpmsolver_multistep.py | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 4450e11504e6..af8abe742253 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -243,7 +243,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.float32) + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int32) timesteps -= 1 else: raise ValueError( @@ -251,15 +251,26 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + print(f" testing here") + print(f" - timesteps: {timesteps}") + print(f" expected sigma: {[sigmas[t] for t in timesteps]}") + log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + print(f" - sigmas: {sigmas}") if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) @@ -389,6 +400,8 @@ def convert_model_output( model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + print(f" - original: timestep: {timestep}, alpha_t, {self.alpha_t[timestep]}, sigma_t: {self.sigma_t[timestep]}") + print(f" - current: sigma: {sigma}, alpha_t: {alpha_t}, sigma_t: {sigma_t}") x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output @@ -463,16 +476,24 @@ def dpm_solver_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - + print(f" ") + print(f" - 1st order update") + print(f" - timestep: {timestep}") + print(f" - step_index: {self.step_index}") sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + print(f" - sigma_t, sigma_s: {sigma_t}, {sigma_s}") sigma_t, alpha_t = self._sigma_to_alpha_sigma_t(sigma_t) sigma_s, alpha_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + print(f" - sigma_t, sigma_s: {sigma_t}, {sigma_s}") + print(f" - alpha_t, alpha_s: {alpha_t}, {alpha_s}") + print(f" - lambda_t, lambda_s: {lambda_t}, {lambda_s}") h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + print(f" x_t: {x_t[0,0,:3,:3]}") elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": @@ -530,6 +551,11 @@ def multistep_dpm_solver_second_order_update( lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + print(f" ") + print(f" 2nd order update") + print(f" - sigma_t, sigma_s0, sigma_s1: {sigma_t}, {sigma_s0}, {sigma_s1}") + print(f" - alpha_t, alpha_s0, alpha_s1: {alpha_t},{alpha_s0},{alpha_s1}") + print(f" - lambda_t, lambda_s0, lambda_s1: {lambda_t},{lambda_s0},{lambda_s1}") m0, m1 = model_output_list[-1], model_output_list[-2] @@ -544,6 +570,7 @@ def multistep_dpm_solver_second_order_update( - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) + print(f" - x_t: {x_t[0,0,:3,:3]}") elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample From 86601bc18264a4eaf3ff7a98375a226acd9e113d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 01:34:32 +0000 Subject: [PATCH 11/37] add --- .../schedulers/scheduling_dpmsolver_multistep.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index af8abe742253..0aaa0b27a156 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -263,8 +263,10 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) - - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + else: + sigma_last = (1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0] ** 0.5 + sigmas = np.concatenate([sigmas,[sigma_last]]) self.sigmas = torch.from_numpy(sigmas).to(device=device) # when num_inference_steps == num_train_timesteps, we can end up with @@ -482,8 +484,8 @@ def dpm_solver_first_order_update( print(f" - step_index: {self.step_index}") sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] print(f" - sigma_t, sigma_s: {sigma_t}, {sigma_s}") - sigma_t, alpha_t = self._sigma_to_alpha_sigma_t(sigma_t) - sigma_s, alpha_s = self._sigma_to_alpha_sigma_t(sigma_s) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) print(f" - sigma_t, sigma_s: {sigma_t}, {sigma_s}") From 0445412c2af0dbb164b82fa6bfe79ca2c79539b1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 02:14:53 +0000 Subject: [PATCH 12/37] add --- .../scheduling_dpmsolver_multistep.py | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 0aaa0b27a156..21c8c0d8dc7c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -251,22 +251,20 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - print(f" testing here") - print(f" - timesteps: {timesteps}") - print(f" expected sigma: {[sigmas[t] for t in timesteps]}") log_sigmas = np.log(sigmas) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - print(f" - sigmas: {sigmas}") if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = np.flip(sigmas).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = (1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0] ** 0.5 sigmas = np.concatenate([sigmas,[sigma_last]]) + self.sigmas = torch.from_numpy(sigmas).to(device=device) # when num_inference_steps == num_train_timesteps, we can end up with @@ -402,8 +400,6 @@ def convert_model_output( model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - print(f" - original: timestep: {timestep}, alpha_t, {self.alpha_t[timestep]}, sigma_t: {self.sigma_t[timestep]}") - print(f" - current: sigma: {sigma}, alpha_t: {alpha_t}, sigma_t: {sigma_t}") x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output @@ -478,24 +474,16 @@ def dpm_solver_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - print(f" ") - print(f" - 1st order update") - print(f" - timestep: {timestep}") - print(f" - step_index: {self.step_index}") + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] - print(f" - sigma_t, sigma_s: {sigma_t}, {sigma_s}") alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) - print(f" - sigma_t, sigma_s: {sigma_t}, {sigma_s}") - print(f" - alpha_t, alpha_s: {alpha_t}, {alpha_s}") - print(f" - lambda_t, lambda_s: {lambda_t}, {lambda_s}") h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - print(f" x_t: {x_t[0,0,:3,:3]}") elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": @@ -553,11 +541,6 @@ def multistep_dpm_solver_second_order_update( lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - print(f" ") - print(f" 2nd order update") - print(f" - sigma_t, sigma_s0, sigma_s1: {sigma_t}, {sigma_s0}, {sigma_s1}") - print(f" - alpha_t, alpha_s0, alpha_s1: {alpha_t},{alpha_s0},{alpha_s1}") - print(f" - lambda_t, lambda_s0, lambda_s1: {lambda_t},{lambda_s0},{lambda_s1}") m0, m1 = model_output_list[-1], model_output_list[-2] @@ -572,7 +555,6 @@ def multistep_dpm_solver_second_order_update( - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) - print(f" - x_t: {x_t[0,0,:3,:3]}") elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample From f687c7dcef1e5377385081c94de2c66a9d8de175 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 02:16:09 +0000 Subject: [PATCH 13/37] style --- .../schedulers/scheduling_dpmsolver_multistep.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 21c8c0d8dc7c..df4d6ecf1558 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -263,10 +263,10 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = (1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0] ** 0.5 - sigmas = np.concatenate([sigmas,[sigma_last]]) - + sigmas = np.concatenate([sigmas, [sigma_last]]) + self.sigmas = torch.from_numpy(sigmas).to(device=device) - + # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) @@ -342,9 +342,8 @@ def _sigma_to_t(self, sigma, log_sigmas): t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t - - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t @@ -635,10 +634,10 @@ def multistep_dpm_solver_third_order_update( """ sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( - self.sigmas[self.step_index +1], + self.sigmas[self.step_index + 1], self.sigmas[self.step_index], - self.sigmas[self.step_index -1], - self.sigmas[self.step_index -2] + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) From 48a9b1e5d4f2d767ce779ca788a0a65022524f8a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 08:44:41 +0000 Subject: [PATCH 14/37] add --- .../schedulers/scheduling_dpmsolver_multistep.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index df4d6ecf1558..a3e1c40ebf46 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -231,19 +231,19 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int32) + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int32) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int32) + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( @@ -679,6 +679,8 @@ def multistep_dpm_solver_third_order_update( def _init_step_index(self, timestep): if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) + print(f" timestep: {timestep}") + print(f" self.timesteps: {self.timesteps}") index_candidates = (self.timesteps == timestep).nonzero() @@ -800,6 +802,7 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + print(f" - timesteps (add_noise): {timesteps}") alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) From b3bf64424d9c193239a323930e9093d58c8cd614 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 15:02:04 +0000 Subject: [PATCH 15/37] fix --- .../schedulers/scheduling_dpmsolver_multistep.py | 12 ++++++------ tests/schedulers/test_scheduler_dpm_multi.py | 2 ++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index a3e1c40ebf46..d679c76cf040 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -679,21 +679,21 @@ def multistep_dpm_solver_third_order_update( def _init_step_index(self, timestep): if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) - print(f" timestep: {timestep}") - print(f" self.timesteps: {self.timesteps}") index_candidates = (self.timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() else: - step_index = index_candidates[0] + step_index = index_candidates[0].item() - self._step_index = step_index.item() + self._step_index = step_index def step( self, diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 86b24af24095..6f3c818457fa 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -59,6 +59,7 @@ def check_over_configs(self, time_step=0, **config): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = new_scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample @@ -91,6 +92,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): # copy over dummy past residual (must be after setting timesteps) new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] + time_step = new_scheduler.timesteps[time_step] output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample From c5a6cdb36efbfa38c68e71f8f72bfe38d7678af1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 18:43:31 +0000 Subject: [PATCH 16/37] fix calculation --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d679c76cf040..9b08902f8fc4 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -262,7 +262,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = (1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0] ** 0.5 + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_last]]) self.sigmas = torch.from_numpy(sigmas).to(device=device) From 3932182205ba350a015088248c98d585aec26f49 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 21:13:21 +0000 Subject: [PATCH 17/37] update add_noise --- .../scheduling_dpmsolver_multistep.py | 47 ++++++++----------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 9b08902f8fc4..3d11faa91f64 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -200,7 +200,6 @@ def __init__( # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 @@ -251,7 +250,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) if self.config.use_karras_sigmas: @@ -263,15 +261,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_last]]) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) - - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] - self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -794,29 +786,30 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - print(f" - timesteps (add_noise): {timesteps}") - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): From 2f785fbb88a5e5963867afbd8404d2b9d6d26053 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 22:00:45 +0000 Subject: [PATCH 18/37] remove the unused timestep args --- .../scheduling_dpmsolver_multistep.py | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 3d11faa91f64..ee2aaac0d9c6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -356,7 +356,7 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) return sigmas def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, model_output: torch.FloatTensor, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is @@ -373,8 +373,6 @@ def convert_model_output( Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -443,8 +441,6 @@ def convert_model_output( def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: @@ -496,8 +492,6 @@ def dpm_solver_first_order_update( def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: @@ -603,8 +597,6 @@ def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ @@ -725,7 +717,6 @@ def step( if self.step_index is None: self._init_step_index(timestep) - prev_timestep = 0 if self.step_index == len(self.timesteps) - 1 else self.timesteps[self.step_index + 1] lower_order_final = ( (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) @@ -733,7 +724,7 @@ def step( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -747,17 +738,15 @@ def step( if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update( - model_output, timestep, prev_timestep, sample, noise=noise + model_output, sample, noise=noise ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[self.step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample, noise=noise + self.model_outputs, sample, noise=noise ) else: - timestep_list = [self.timesteps[self.step_index - 2], self.timesteps[self.step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample + self.model_outputs, sample ) if self.lower_order_nums < self.config.solver_order: From f39cc962ba52e1f4a3ccacb905a4932af659a2f8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 12 Sep 2023 23:36:48 +0000 Subject: [PATCH 19/37] refactor dpm inverse --- .../scheduling_dpmsolver_multistep.py | 16 +- .../scheduling_dpmsolver_multistep_inverse.py | 264 ++++++++++++------ 2 files changed, 185 insertions(+), 95 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ee2aaac0d9c6..7aebf7577361 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -355,9 +355,7 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def convert_model_output( - self, model_output: torch.FloatTensor, sample: torch.FloatTensor - ) -> torch.FloatTensor: + def convert_model_output(self, model_output: torch.FloatTensor, sample: torch.FloatTensor) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an @@ -737,17 +735,11 @@ def step( noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update( - model_output, sample, noise=noise - ) + prev_sample = self.dpm_solver_first_order_update(model_output, sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, sample, noise=noise - ) + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample, noise=noise) else: - prev_sample = self.multistep_dpm_solver_third_order_update( - self.model_outputs, sample - ) + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 34639d38a6a2..d139a38349f3 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -203,8 +203,16 @@ def __init__( self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self._step_index = None self.use_karras_sigmas = use_karras_sigmas + @property + def step_index(self): + """ + TODO: Nice docstring + """ + return self._step_index + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -244,19 +252,20 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + if self.config.use_karras_sigmas: - log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) + sigmas = np.concatenate([sigmas[:1], sigmas]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_first = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([[sigma_first], sigmas]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] - self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -266,6 +275,11 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ] * self.config.solver_order self.lower_order_nums = 0 + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + print(f" - timesteps: {self.timesteps}") + print(f" -sigmas: {self.sigmas}") + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -325,6 +339,12 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" @@ -340,9 +360,7 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) return sigmas # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output - def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor - ) -> torch.FloatTensor: + def convert_model_output(self, model_output: torch.FloatTensor, sample: torch.FloatTensor) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an @@ -358,8 +376,6 @@ def convert_model_output( Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -374,12 +390,14 @@ def convert_model_output( # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -401,10 +419,12 @@ def convert_model_output( else: epsilon = model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( @@ -413,18 +433,18 @@ def convert_model_output( ) if self.config.thresholding: - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: @@ -445,25 +465,40 @@ def dpm_solver_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - elif "sde" in self.config.algorithm_type: - raise NotImplementedError( - f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + print(" 1st order update") + print(f" x_t: {x_t[0,0,:3,:3]}") return x_t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: @@ -484,11 +519,23 @@ def multistep_dpm_solver_second_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) @@ -520,18 +567,46 @@ def multistep_dpm_solver_second_order_update( - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) - elif "sde" in self.config.algorithm_type: - raise NotImplementedError( - f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." - ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + print(f" 2nd order update") + print(f" - x_t :{x_t[0,0,:3,:3]}") return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ @@ -551,16 +626,26 @@ def multistep_dpm_solver_third_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], ) - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 @@ -585,6 +670,27 @@ def multistep_dpm_solver_third_order_update( ) return x_t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step def step( self, model_output: torch.FloatTensor, @@ -604,6 +710,8 @@ def step( The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. @@ -618,24 +726,17 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = ( - self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - ) + if self.step_index is None: + self._init_step_index(timestep) + lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -648,23 +749,18 @@ def step( noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update( - model_output, timestep, prev_timestep, sample, noise=noise - ) + prev_sample = self.dpm_solver_first_order_update(model_output, sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample, noise=noise - ) + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample, noise=noise) else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,) @@ -686,28 +782,30 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): From b28320e56d52c9f90bf0e8780e99be5cf06b7a3d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 01:57:55 +0000 Subject: [PATCH 20/37] fix last sigma + unique timesteps --- .../scheduling_dpmsolver_multistep_inverse.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index d139a38349f3..4073b1a68af8 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -258,14 +258,17 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) - sigmas = np.concatenate([sigmas[:1], sigmas]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_first = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - sigmas = np.concatenate([[sigma_first], sigmas]).astype(np.float32) + sigmas = np.concatenate([sigmas,sigmas[-1:]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -467,11 +470,16 @@ def dpm_solver_first_order_update( """ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + print(f" 1st order update") + print(f" index: {self.step_index +1, self.step_index}") + print(f" - sigmas: {sigma_t},{sigma_s}") alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) - + print(f" - sigmas: {sigma_t},{sigma_s}") + print(f" - alphas: {alpha_t},{alpha_s}") + print(f" - lambdas: {lambda_t},{lambda_s}") h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output From 8431b9ad7d4d65a2452e68499e91411e6c9a3568 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 04:52:55 +0000 Subject: [PATCH 21/37] add sigma_max --- .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 4073b1a68af8..3ed4351b814f 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -260,7 +260,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc timesteps = timesteps.copy().astype(np.int64) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigmas = np.concatenate([sigmas,sigmas[-1:]]).astype(np.float32) + + sigma_max = ((1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) @@ -280,8 +282,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # add an index counter for schedulers that allow duplicated timesteps self._step_index = None - print(f" - timesteps: {self.timesteps}") - print(f" -sigmas: {self.sigmas}") # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: From 3e0826b4c730812a70b7aedf3036fd42395bdefb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 05:00:29 +0000 Subject: [PATCH 22/37] fix test t -> timesteps[t] --- tests/schedulers/test_scheduler_dpm_multi_inverse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/schedulers/test_scheduler_dpm_multi_inverse.py b/tests/schedulers/test_scheduler_dpm_multi_inverse.py index 61a1d82e0f51..2b7f48efe184 100644 --- a/tests/schedulers/test_scheduler_dpm_multi_inverse.py +++ b/tests/schedulers/test_scheduler_dpm_multi_inverse.py @@ -54,6 +54,7 @@ def check_over_configs(self, time_step=0, **config): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample From 3d31065daee4f29c5a93415118cd082394bb425b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 05:51:48 +0000 Subject: [PATCH 23/37] inverse --- .../scheduling_dpmsolver_multistep_inverse.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 3ed4351b814f..dce85cc74a66 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -260,8 +260,10 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc timesteps = timesteps.copy().astype(np.int64) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - - sigma_max = ((1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]) ** 0.5 + + sigma_max = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) @@ -470,16 +472,10 @@ def dpm_solver_first_order_update( """ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] - print(f" 1st order update") - print(f" index: {self.step_index +1, self.step_index}") - print(f" - sigmas: {sigma_t},{sigma_s}") alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) - print(f" - sigmas: {sigma_t},{sigma_s}") - print(f" - alphas: {alpha_t},{alpha_s}") - print(f" - lambdas: {lambda_t},{lambda_s}") h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output @@ -499,8 +495,6 @@ def dpm_solver_first_order_update( - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) - print(" 1st order update") - print(f" x_t: {x_t[0,0,:3,:3]}") return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update @@ -607,8 +601,6 @@ def multistep_dpm_solver_second_order_update( - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) - print(f" 2nd order update") - print(f" - x_t :{x_t[0,0,:3,:3]}") return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update From dd8ec06c040ee2af7bbf9d6819ae05013ff9ec92 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 06:06:03 +0000 Subject: [PATCH 24/37] add --- .../scheduling_dpmsolver_multistep_inverse.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index dce85cc74a66..8d6c134932e9 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -258,13 +258,13 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - - sigma_max = ( - (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] - ) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) + sigma_max = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) From 709095ad91aef6fd4476d72f71b34fb5c1f4d734 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 06:12:35 +0000 Subject: [PATCH 25/37] doc string for step index --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 1e6c64f2cdbb..dbc6eb8054a5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -208,7 +208,7 @@ def __init__( @property def step_index(self): """ - TODO: Nice docstring + The index counter for current timestep. It will increae 1 after each scheduler step. """ return self._step_index diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index f3cbdd00feef..97131c1916c2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -209,7 +209,7 @@ def __init__( @property def step_index(self): """ - TODO: Nice docstring + The index counter for current timestep. It will increae 1 after each scheduler step. """ return self._step_index From af5fcd6451b4a8074a48fe3815ff7acb2608bf1c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 06:19:40 +0000 Subject: [PATCH 26/37] fix --- .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 1 + tests/schedulers/test_scheduler_dpm_multi.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 97131c1916c2..bd498a7badad 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -476,6 +476,7 @@ def dpm_solver_first_order_update( alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 6f3c818457fa..9c32a7203edf 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -229,7 +229,7 @@ def test_full_loop_with_karras_and_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2096) < 1e-3 + assert abs(result_mean.item() - 0.2096) < 1e-2 def test_switch(self): # make sure that iterating over schedulers with same config names gives same results From 91283366736ae483e36db129da7c307ee71b66e4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 06:22:48 +0000 Subject: [PATCH 27/37] 2e-3 --- tests/schedulers/test_scheduler_dpm_multi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 9c32a7203edf..65c374370f56 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -229,7 +229,7 @@ def test_full_loop_with_karras_and_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2096) < 1e-2 + assert abs(result_mean.item() - 0.2096) < 2e-3 def test_switch(self): # make sure that iterating over schedulers with same config names gives same results From 7ff1d84e5e2ba6ee6567671e40e90f55cd232766 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Sep 2023 06:32:09 +0000 Subject: [PATCH 28/37] oppos make change in the wrong file --- tests/schedulers/test_scheduler_dpm_multi.py | 2 +- tests/schedulers/test_scheduler_dpm_multi_inverse.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 65c374370f56..6f3c818457fa 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -229,7 +229,7 @@ def test_full_loop_with_karras_and_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2096) < 2e-3 + assert abs(result_mean.item() - 0.2096) < 1e-3 def test_switch(self): # make sure that iterating over schedulers with same config names gives same results diff --git a/tests/schedulers/test_scheduler_dpm_multi_inverse.py b/tests/schedulers/test_scheduler_dpm_multi_inverse.py index 2b7f48efe184..014c901680e3 100644 --- a/tests/schedulers/test_scheduler_dpm_multi_inverse.py +++ b/tests/schedulers/test_scheduler_dpm_multi_inverse.py @@ -223,7 +223,7 @@ def test_full_loop_with_karras_and_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 1.7833) < 1e-3 + assert abs(result_mean.item() - 1.7833) < 2e-3 def test_switch(self): # make sure that iterating over schedulers with same config names gives same results From 160216effaa61e47c28b9464c8b6a27ac66196b1 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 15 Sep 2023 21:28:41 -1000 Subject: [PATCH 29/37] Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index dbc6eb8054a5..ece368748624 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -355,7 +355,10 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def convert_model_output(self, model_output: torch.FloatTensor, sample: torch.FloatTensor) -> torch.FloatTensor: + def convert_model_output(self, model_output: torch.FloatTensor, *args, sample: torch.FloatTensor = None, **kwargs) -> torch.FloatTensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if timestep is not None: + deprecate("Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`") """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an From 25c04323f26270a31fad6ef79e68e43d6249a94f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 17 Sep 2023 06:08:12 +0000 Subject: [PATCH 30/37] deprecate --- .../scheduling_dpmsolver_multistep.py | 117 ++++++++++++++---- .../scheduling_dpmsolver_multistep_inverse.py | 115 ++++++++++++++--- 2 files changed, 189 insertions(+), 43 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ece368748624..470a562c36b7 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,6 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -355,10 +356,13 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def convert_model_output(self, model_output: torch.FloatTensor, *args, sample: torch.FloatTensor = None, **kwargs) -> torch.FloatTensor: - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if timestep is not None: - deprecate("Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`") + def convert_model_output( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an @@ -381,6 +385,18 @@ def convert_model_output(self, model_output: torch.FloatTensor, *args, sample: t `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: @@ -442,8 +458,10 @@ def convert_model_output(self, model_output: torch.FloatTensor, *args, sample: t def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, noise: Optional[torch.FloatTensor] = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the first-order DPMSolver (equivalent to DDIM). @@ -451,10 +469,6 @@ def dpm_solver_first_order_update( Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -462,6 +476,26 @@ def dpm_solver_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) @@ -493,8 +527,10 @@ def dpm_solver_first_order_update( def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, noise: Optional[torch.FloatTensor] = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the second-order multistep DPMSolver. @@ -502,10 +538,6 @@ def multistep_dpm_solver_second_order_update( Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -513,6 +545,26 @@ def multistep_dpm_solver_second_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], @@ -598,7 +650,9 @@ def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the third-order multistep DPMSolver. @@ -606,10 +660,6 @@ def multistep_dpm_solver_third_order_update( Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by diffusion process. @@ -618,6 +668,27 @@ def multistep_dpm_solver_third_order_update( The sample tensor at the previous timestep. """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], @@ -725,7 +796,7 @@ def step( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, sample) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -738,11 +809,11 @@ def step( noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, sample, noise=noise) + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample, noise=noise) + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) else: - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample) + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index bd498a7badad..7c740234fa40 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -21,6 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -344,6 +345,7 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t @@ -365,7 +367,13 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) return sigmas # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output - def convert_model_output(self, model_output: torch.FloatTensor, sample: torch.FloatTensor) -> torch.FloatTensor: + def convert_model_output( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an @@ -388,6 +396,18 @@ def convert_model_output(self, model_output: torch.FloatTensor, sample: torch.Fl `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: @@ -450,8 +470,10 @@ def convert_model_output(self, model_output: torch.FloatTensor, sample: torch.Fl def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, noise: Optional[torch.FloatTensor] = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the first-order DPMSolver (equivalent to DDIM). @@ -459,10 +481,6 @@ def dpm_solver_first_order_update( Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -470,6 +488,26 @@ def dpm_solver_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) @@ -502,8 +540,10 @@ def dpm_solver_first_order_update( def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, noise: Optional[torch.FloatTensor] = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the second-order multistep DPMSolver. @@ -511,10 +551,6 @@ def multistep_dpm_solver_second_order_update( Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -522,6 +558,26 @@ def multistep_dpm_solver_second_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], @@ -608,7 +664,9 @@ def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the third-order multistep DPMSolver. @@ -616,10 +674,6 @@ def multistep_dpm_solver_third_order_update( Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by diffusion process. @@ -628,6 +682,27 @@ def multistep_dpm_solver_third_order_update( The sample tensor at the previous timestep. """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], @@ -737,7 +812,7 @@ def step( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, sample) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -750,11 +825,11 @@ def step( noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, sample, noise=noise) + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample, noise=noise) + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) else: - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample) + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 From 911ca6ad83ba5b6f1732c85d841b7999fdacbef3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 18 Sep 2023 00:46:31 +0000 Subject: [PATCH 31/37] update dpm singlestep --- .../scheduling_dpmsolver_singlestep.py | 289 ++++++++++++++---- 1 file changed, 222 insertions(+), 67 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 060ec363e842..47f747453210 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import logging +from ..utils import logging, deprecate from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -197,6 +197,7 @@ def __init__( self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) + self._step_index = None def get_order_list(self, num_inference_steps: int) -> List[int]: """ @@ -231,6 +232,13 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: elif order == 1: orders = [1] * steps return orders + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -259,8 +267,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) + sigmas = np.flip(sigmas).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) + self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order @@ -273,6 +287,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.register_to_config(lower_order_final=True) self.order_list = self.get_order_list(num_inference_steps) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: @@ -332,6 +349,13 @@ def _sigma_to_t(self, sigma, log_sigmas): t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t + + # Copied from diffusers.schedulers.scheduler_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: @@ -348,7 +372,11 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) return sigmas def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is @@ -365,8 +393,6 @@ def convert_model_output( Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -374,18 +400,32 @@ def convert_model_output( `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -405,11 +445,13 @@ def convert_model_output( model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: @@ -421,9 +463,9 @@ def convert_model_output( def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the first-order DPMSolver (equivalent to DDIM). @@ -442,9 +484,18 @@ def dpm_solver_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output @@ -455,9 +506,9 @@ def dpm_solver_first_order_update( def singlestep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the @@ -477,11 +528,42 @@ def singlestep_dpm_solver_second_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1] - sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1] + h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m1, (1.0 / r0) * (m0 - m1) @@ -518,9 +600,9 @@ def singlestep_dpm_solver_second_order_update( def singlestep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the @@ -540,16 +622,48 @@ def singlestep_dpm_solver_third_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], ) - alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2] - sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2] + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m2 @@ -591,10 +705,11 @@ def singlestep_dpm_solver_third_order_update( def singlestep_dpm_solver_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - order: int, + *args, + sample: torch.FloatTensor = None, + + order: int = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the singlestep DPMSolver. @@ -615,19 +730,64 @@ def singlestep_dpm_solver_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing `order` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + if order == 1: - return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) + return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) elif order == 2: return self.singlestep_dpm_solver_second_order_update( - model_output_list, timestep_list, prev_timestep, sample + model_output_list, sample=sample ) elif order == 3: return self.singlestep_dpm_solver_third_order_update( - model_output_list, timestep_list, prev_timestep, sample + model_output_list, sample=sample ) else: raise ValueError(f"Order must be 1, 2, 3, got {order}") + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + def step( self, model_output: torch.FloatTensor, @@ -660,21 +820,15 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + if self.step_index is None: + self._init_step_index(timestep) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - order = self.order_list[step_index] + order = self.order_list[self.step_index] # For img2img denoising might start with order>1 which is not possible # In this case make sure that the first two steps are both order=1 @@ -685,9 +839,8 @@ def step( if order == 1: self.sample = sample - timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] prev_sample = self.singlestep_dpm_solver_update( - self.model_outputs, timestep_list, prev_timestep, self.sample, order + self.model_outputs, sample= self.sample, order=order ) if not return_dict: @@ -710,28 +863,30 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise +# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): From 65fe7c30d4e9b0f3fcf12c5481290068cd7473b5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 18 Sep 2023 03:37:37 +0000 Subject: [PATCH 32/37] fix --- .../scheduling_dpmsolver_singlestep.py | 50 +++++++++++-------- tests/schedulers/test_scheduler_dpm_single.py | 31 ++++++++++++ 2 files changed, 60 insertions(+), 21 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 47f747453210..ebf543043aa5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import logging, deprecate +from ..utils import deprecate, logging from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -232,7 +232,7 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: elif order == 1: orders = [1] * steps return orders - + @property def step_index(self): """ @@ -287,7 +287,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.register_to_config(lower_order_final=True) self.order_list = self.get_order_list(num_inference_steps) - + # add an index counter for schedulers that allow duplicated timesteps self._step_index = None @@ -349,8 +349,8 @@ def _sigma_to_t(self, sigma, log_sigmas): t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t - - # Copied from diffusers.schedulers.scheduler_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t @@ -372,9 +372,9 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) return sigmas def convert_model_output( - self, - model_output: torch.FloatTensor, - *args, + self, + model_output: torch.FloatTensor, + *args, sample: torch.FloatTensor = None, **kwargs, ) -> torch.FloatTensor: @@ -491,6 +491,19 @@ def dpm_solver_first_order_update( sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) @@ -661,7 +674,6 @@ def singlestep_dpm_solver_third_order_update( lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 @@ -707,7 +719,6 @@ def singlestep_dpm_solver_update( model_output_list: List[torch.FloatTensor], *args, sample: torch.FloatTensor = None, - order: int = None, **kwargs, ) -> torch.FloatTensor: @@ -755,17 +766,13 @@ def singlestep_dpm_solver_update( "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) - + if order == 1: return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) elif order == 2: - return self.singlestep_dpm_solver_second_order_update( - model_output_list, sample=sample - ) + return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) elif order == 3: - return self.singlestep_dpm_solver_third_order_update( - model_output_list, sample=sample - ) + return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) else: raise ValueError(f"Order must be 1, 2, 3, got {order}") @@ -839,9 +846,10 @@ def step( if order == 1: self.sample = sample - prev_sample = self.singlestep_dpm_solver_update( - self.model_outputs, sample= self.sample, order=order - ) + prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) + + # upon completion increase step index by one + self._step_index += 1 if not return_dict: return (prev_sample,) @@ -863,7 +871,7 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample -# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 66be3d5d00ad..169839e776b1 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -58,6 +58,7 @@ def check_over_configs(self, time_step=0, **config): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample @@ -248,3 +249,33 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + time_step_0 = scheduler.timesteps[0] + time_step_1 = scheduler.timesteps[1] + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) From d53056cd923fef542269b53496664d842a85b7c1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 18 Sep 2023 03:40:50 +0000 Subject: [PATCH 33/37] style --- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index ebf543043aa5..a9d7575a161f 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -503,7 +503,7 @@ def dpm_solver_first_order_update( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) + ) sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) From 13729a9a1539b7031643577752c3b07ef323c5dc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 18 Sep 2023 05:25:18 +0000 Subject: [PATCH 34/37] deis --- .../schedulers/scheduling_deis_multistep.py | 306 +++++++++++++----- tests/schedulers/test_scheduler_deis.py | 1 + 2 files changed, 235 insertions(+), 72 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 95d809575dc4..ebb969823b61 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -22,6 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -186,6 +187,14 @@ def __init__( self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self._step_index = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -228,14 +237,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) + sigmas = np.flip(sigmas).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] - + self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -245,6 +254,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ] * self.config.solver_order self.lower_order_nums = 0 + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -280,8 +292,57 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: return sample + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DEIS algorithm needs. @@ -298,13 +359,26 @@ def convert_model_output( `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -316,7 +390,6 @@ def convert_model_output( x0_pred = self._threshold_sample(x0_pred) if self.config.algorithm_type == "deis": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] return (sample - alpha_t * x0_pred) / sigma_t else: raise NotImplementedError("only support log-rho multistep deis now") @@ -324,9 +397,9 @@ def convert_model_output( def deis_first_order_update( self, model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the first-order DEIS (equivalent to DDIM). @@ -345,9 +418,33 @@ def deis_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep] + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s if self.config.algorithm_type == "deis": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output @@ -358,9 +455,9 @@ def deis_first_order_update( def multistep_deis_second_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the second-order multistep DEIS. @@ -368,10 +465,6 @@ def multistep_deis_second_order_update( Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -379,10 +472,38 @@ def multistep_deis_second_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + m0, m1 = model_output_list[-1], model_output_list[-2] - alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1] - sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1] rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 @@ -403,9 +524,9 @@ def ind_fn(t, b, c): def multistep_deis_third_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the third-order multistep DEIS. @@ -413,10 +534,6 @@ def multistep_deis_third_order_update( Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by diffusion process. @@ -424,15 +541,47 @@ def multistep_deis_third_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2] - sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2] + rho_t, rho_s0, rho_s1, rho_s2 = ( sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1, - simga_s2 / alpha_s2, + sigma_s2 / alpha_s2, ) if self.config.algorithm_type == "deis": @@ -460,6 +609,25 @@ def ind_fn(t, b, c, d): else: raise NotImplementedError("only support log-rho multistep deis now") + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + def step( self, model_output: torch.FloatTensor, @@ -492,42 +660,34 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + if self.step_index is None: + self._init_step_index(timestep) + lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample) + prev_sample = self.deis_first_order_update(model_output, sample=sample) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_deis_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) + prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample) else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_deis_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) + prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,) @@ -548,28 +708,30 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/tests/schedulers/test_scheduler_deis.py b/tests/schedulers/test_scheduler_deis.py index 8b14601bc982..277aaf26e4f2 100644 --- a/tests/schedulers/test_scheduler_deis.py +++ b/tests/schedulers/test_scheduler_deis.py @@ -51,6 +51,7 @@ def check_over_configs(self, time_step=0, **config): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample From b7223a80e0efeb43d811272b9d852b4ed111eac3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 18 Sep 2023 21:20:32 +0000 Subject: [PATCH 35/37] add unipc --- .../schedulers/scheduling_unipc_multistep.py | 246 +++++++++++++----- 1 file changed, 180 insertions(+), 66 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 1525845db35c..7d394c6ebf96 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -22,10 +22,16 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -38,22 +44,34 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) + class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. @@ -181,6 +199,14 @@ def __init__( self.disable_corrector = disable_corrector self.solver_p = solver_p self.last_sample = None + self._step_index = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -223,14 +249,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) + sigmas = np.flip(sigmas).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] - + self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -243,6 +269,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -302,6 +331,13 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" @@ -317,7 +353,11 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) return sigmas def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: r""" Convert the model output to the corresponding type the UniPC algorithm needs. @@ -334,14 +374,28 @@ def convert_model_output( `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + if self.predict_x0: if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -357,11 +411,9 @@ def convert_model_output( if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: @@ -373,9 +425,10 @@ def convert_model_output( def multistep_uni_p_bh_update( self, model_output: torch.FloatTensor, - prev_timestep: int, - sample: torch.FloatTensor, - order: int, + *args, + sample: torch.FloatTensor = None, + order: int = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. @@ -394,10 +447,27 @@ def multistep_uni_p_bh_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) timestep_list = self.timestep_list model_output_list = self.model_outputs - s0, t = self.timestep_list[-1], prev_timestep + s0 = self.timestep_list[-1] m0 = model_output_list[-1] x = sample @@ -405,9 +475,12 @@ def multistep_uni_p_bh_update( x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t - lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = sample.device @@ -415,9 +488,10 @@ def multistep_uni_p_bh_update( rks = [] D1s = [] for i in range(1, order): - si = timestep_list[-(i + 1)] + si = self.step_index - i mi = model_output_list[-(i + 1)] - lambda_si = self.lambda_t[si] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) @@ -481,10 +555,11 @@ def multistep_uni_p_bh_update( def multistep_uni_c_bh_update( self, this_model_output: torch.FloatTensor, - this_timestep: int, - last_sample: torch.FloatTensor, - this_sample: torch.FloatTensor, - order: int, + *args, + last_sample: torch.FloatTensor = None, + this_sample: torch.FloatTensor = None, + order: int = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the UniC (B(h) version). @@ -505,18 +580,43 @@ def multistep_uni_c_bh_update( `torch.FloatTensor`: The corrected sample tensor at the current timestep. """ - timestep_list = self.timestep_list + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + sample = args[3] + else: + raise ValueError(" missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs - s0, t = timestep_list[-1], this_timestep m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output - lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = this_sample.device @@ -524,9 +624,10 @@ def multistep_uni_c_bh_update( rks = [] D1s = [] for i in range(1, order): - si = timestep_list[-(i + 1)] + si = self.step_index - i mi = model_output_list[-(i + 1)] - lambda_si = self.lambda_t[si] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) @@ -589,6 +690,25 @@ def multistep_uni_c_bh_update( x_t = x_t.to(x.dtype) return x_t + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + def step( self, model_output: torch.FloatTensor, @@ -616,37 +736,27 @@ def step( tuple is returned where the first element is the sample tensor. """ - if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() + if self.step_index is None: + self._init_step_index(timestep) use_corrector = ( - step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None ) - model_output_convert = self.convert_model_output(model_output, timestep, sample) + model_output_convert = self.convert_model_output(model_output, sample=sample) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, - this_timestep=timestep, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) - # now prepare to run the predictor - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] @@ -655,7 +765,7 @@ def step( self.timestep_list[-1] = timestep if self.config.lower_order_final: - this_order = min(self.config.solver_order, len(self.timesteps) - step_index) + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) else: this_order = self.config.solver_order @@ -665,7 +775,6 @@ def step( self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, # pass the original non-converted model output, in case solver-p is used - prev_timestep=prev_timestep, sample=sample, order=self.this_order, ) @@ -673,6 +782,9 @@ def step( if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,) @@ -693,28 +805,30 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): From 02d07d4211a1637b1e930c040c5d877a17ceb78a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 19 Sep 2023 01:41:43 +0000 Subject: [PATCH 36/37] fix --- src/diffusers/schedulers/scheduling_unipc_multistep.py | 9 +++------ tests/schedulers/test_scheduler_unipc.py | 9 +-------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 7d394c6ebf96..ee0d00811fb2 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -71,7 +71,6 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) - class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. @@ -464,7 +463,6 @@ def multistep_uni_p_bh_update( "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) - timestep_list = self.timestep_list model_output_list = self.model_outputs s0 = self.timestep_list[-1] @@ -593,7 +591,7 @@ def multistep_uni_c_bh_update( raise ValueError(" missing`this_sample` as a required keyward argument") if order is None: if len(args) > 3: - sample = args[3] + order = args[3] else: raise ValueError(" missing`order` as a required keyward argument") if this_timestep is not None: @@ -603,7 +601,6 @@ def multistep_uni_c_bh_update( "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) - model_output_list = self.model_outputs m0 = model_output_list[-1] @@ -611,7 +608,7 @@ def multistep_uni_c_bh_update( x_t = this_sample model_t = this_model_output - sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) @@ -624,7 +621,7 @@ def multistep_uni_c_bh_update( rks = [] D1s = [] for i in range(1, order): - si = self.step_index - i + si = self.step_index - (i + 1) mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 0495f423e5b9..08482fd06b62 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -52,6 +52,7 @@ def check_over_configs(self, time_step=0, **config): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample @@ -241,11 +242,3 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 - - def test_unique_timesteps(self, **config): - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - - scheduler.set_timesteps(scheduler.config.num_train_timesteps) - assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps From 466cb5304037fc344a41da8e802d9276cf3ffb48 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 19 Sep 2023 02:05:54 +0000 Subject: [PATCH 37/37] flip sigmas --- src/diffusers/schedulers/scheduling_deis_multistep.py | 3 +-- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 3 +-- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 3 +-- src/diffusers/schedulers/scheduling_unipc_multistep.py | 3 +-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index ebb969823b61..c7a94bce88eb 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -234,10 +234,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) - sigmas = np.flip(sigmas).copy() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 470a562c36b7..264ee268ae17 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -254,10 +254,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc log_sigmas = np.log(sigmas) if self.config.use_karras_sigmas: + sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) - sigmas = np.flip(sigmas).copy() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index a9d7575a161f..10f7ab34e0a4 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -264,10 +264,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) - sigmas = np.flip(sigmas).copy() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index ee0d00811fb2..2dcca2ecaece 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -245,10 +245,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) - sigmas = np.flip(sigmas).copy() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)