diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 8c25cdff8a07..5e6869cb0f0c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -181,9 +181,6 @@ def __init__( self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": @@ -200,9 +197,26 @@ def __init__( # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self._step_index = None + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + @property + def step_index(self): + """ + TODO: Nice docstring + """ + return self._step_index def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ @@ -221,20 +235,18 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) - ) + timesteps = np.linspace(0, last_timestep - 1, num_inference_steps)[::-1].copy().astype(np.float32) elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // (num_inference_steps + 1) + step_ratio = last_timestep // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.float32) timesteps -= 1 else: raise ValueError( @@ -242,18 +254,15 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + if self.config.use_karras_sigmas: - log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) - - self.sigmas = torch.from_numpy(sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device) @@ -264,6 +273,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ] * self.config.solver_order self.lower_order_nums = 0 + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -371,13 +383,13 @@ def convert_model_output( # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t + sigma = self.sigmas[self.step_index] + x0_pred = sample - sigma * model_output elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = alpha_t * sample - sigma_t * model_output + sigma = self.sigmas[self.step_index] + x0_pred = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" @@ -442,19 +454,24 @@ def dpm_solver_first_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] - h = lambda_t - lambda_s + + def t_fn(_sigma): + return -torch.log(_sigma) + + # YiYi notes: keep these for now so don't get an error, don't need once fully refactored + #alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + h = t_fn(sigma_t) - t_fn(sigma_s) if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + (1 - torch.exp(-2.0 * h)) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": @@ -491,27 +508,34 @@ def multistep_dpm_solver_second_order_update( `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + + def t_fn(_sigma): + return -torch.log(_sigma) + + # YiYi notes: keep these for now so don't get an error, not needed once fully refactored + #t, s0 = prev_timestep, timestep_list[-1] + #alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + + h, h_0 = t_fn(sigma_t) - t_fn(sigma_s0), t_fn(sigma_s0) - t_fn(sigma_s1) r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 - ) + x_t = (sigma_t / sigma_s0) * sample - (torch.exp(-h) - 1.0) * D0 - 0.5 * (torch.exp(-h) - 1.0) * D1 elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (torch.exp(-h) - 1.0) * D0 + + ((torch.exp(-h) - 1.0) / h + 1.0) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations @@ -532,15 +556,15 @@ def multistep_dpm_solver_second_order_update( if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + (1 - torch.exp(-2.0 * h)) * D0 + + 0.5 * (1 - torch.exp(-2.0 * h)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + (1 - torch.exp(-2.0 * h)) * D0 + + ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": @@ -619,6 +643,23 @@ def multistep_dpm_solver_third_order_update( ) return x_t + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + def step( self, model_output: torch.FloatTensor, @@ -654,19 +695,13 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 - ) + if self.step_index is None: + self._init_step_index(timestep) + + prev_timestep = 0 if self.step_index == len(self.timesteps) - 1 else self.timesteps[self.step_index + 1] + lower_order_final = self.step_index == len(self.timesteps) - 1 lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, timestep, sample) @@ -686,12 +721,12 @@ def step( model_output, timestep, prev_timestep, sample, noise=noise ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] + timestep_list = [self.timesteps[self.step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + timestep_list = [self.timesteps[self.step_index - 2], self.timesteps[self.step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) @@ -699,24 +734,37 @@ def step( if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: - sample (`torch.FloatTensor`): - The input sample. + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: A scaled input sample. """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + sample = sample / ((sigma**2 + 1) ** 0.5) + + self.is_scale_input_called = True return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index c9935780b983..86b24af24095 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -264,10 +264,10 @@ def test_fp16_support(self): assert sample.dtype == torch.float16 - def test_unique_timesteps(self, **config): + def test_duplicated_timesteps(self, **config): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(scheduler.config.num_train_timesteps) - assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps + assert len(scheduler.timesteps) == scheduler.num_inference_steps