diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 93975a27fc6e..c99eacfdc49b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -282,17 +282,25 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [None] * self.config.solver_order self.sample = None - if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + if not self.config.lower_order_final and self.num_inference_steps % self.config.solver_order != 0: logger.warn( "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." ) self.register_to_config(lower_order_final=True) - self.order_list = self.get_order_list(num_inference_steps) + self.order_list = self.get_order_list(self.num_inference_steps) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 66be3d5d00ad..350bf698c4d5 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -248,3 +248,11 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps