diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index e7ba0ba1f30e..5f11eb9ac71b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -24,7 +24,8 @@ from ..utils import deprecate, is_scipy_available from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - +from ..utils import logging +logger = logging.get_logger(__name__) if is_scipy_available(): import scipy.stats @@ -411,29 +412,34 @@ def set_timesteps( if self.config.use_karras_sigmas: sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - if self.config.beta_schedule != "squaredcos_cap_v2": - timesteps = timesteps.round() + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps) + elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) sigmas = np.exp(lambdas) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - if self.config.beta_schedule != "squaredcos_cap_v2": - timesteps = timesteps.round() + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps) + elif self.config.use_exponential_sigmas: sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps) + elif self.config.use_beta_sigmas: sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps) + elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -544,6 +550,38 @@ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: t = t.reshape(sigma.shape) return t + def _ensure_unique_timesteps(self, timesteps, sigmas, num_inference_steps): + """ + Ensure timesteps are unique and handle duplicates while preserving the correspondence with sigmas. + + Args: + timesteps (`np.ndarray`): + The timestep values that may contain duplicates. + sigmas (`np.ndarray`): + The sigma values corresponding to the timesteps. + num_inference_steps (`int`): + The number of inference steps originally requested. + + Returns: + `Tuple[np.ndarray, np.ndarray]`: + A tuple of (timesteps, sigmas) where timesteps are unique and sigmas are filtered accordingly. + """ + unique_timesteps, unique_indices = np.unique(timesteps, return_index=True) + + if len(unique_timesteps) < len(timesteps): + # Sort by original indices to maintain order + unique_indices_sorted = np.sort(unique_indices) + timesteps = timesteps[unique_indices_sorted] + sigmas = sigmas[unique_indices_sorted] + + if len(timesteps) < num_inference_steps: + logger.warning( + f"Due to the current scheduler configuration, only {len(timesteps)} unique timesteps " + f"could be generated instead of the requested {num_inference_steps}." + ) + + return timesteps, sigmas + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert sigma values to alpha_t and sigma_t values. diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 28c354709dc9..3f1a60f43fc1 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -366,3 +366,31 @@ def test_beta_sigmas(self): def test_exponential_sigmas(self): self.check_over_configs(use_exponential_sigmas=True) + + def test_no_duplicate_timesteps_with_sigma_methods(self): + sigma_configs = [ + {"use_karras_sigmas": True}, + {"use_lu_lambdas": True}, + {"use_exponential_sigmas": True}, + {"use_beta_sigmas": True}, + ] + + for config in sigma_configs: + scheduler = DPMSolverMultistepScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + **config, + ) + scheduler.set_timesteps(20) + + sample = torch.randn(4, 3, 32, 32) + + try: + for t in scheduler.timesteps: + model_output = torch.randn_like(sample) + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + except IndexError as e: + self.fail(f"Index error occurred with config {config}: {e}") + except Exception as e: + self.fail(f"Unexpected error with config {config}: {e}") \ No newline at end of file