From 264f7d7195062852a7b0c9a9ab3914154d3853b5 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 24 Jul 2023 18:42:47 +0530 Subject: [PATCH 1/3] Fix an error when DPM Solver Single Step contains duplicate timesteps, based on https://github.com/huggingface/diffusers/pull/2969 --- .../schedulers/scheduling_dpmsolver_singlestep.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 93975a27fc6e..096e10387617 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -282,7 +282,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [None] * self.config.solver_order self.sample = None From 20188c7fc9b7b42880c3d6eab675122c63bf67be Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 24 Jul 2023 18:50:46 +0530 Subject: [PATCH 2/3] Use the updated inference steps --- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 096e10387617..c99eacfdc49b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -294,13 +294,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.model_outputs = [None] * self.config.solver_order self.sample = None - if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + if not self.config.lower_order_final and self.num_inference_steps % self.config.solver_order != 0: logger.warn( "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." ) self.register_to_config(lower_order_final=True) - self.order_list = self.get_order_list(num_inference_steps) + self.order_list = self.get_order_list(self.num_inference_steps) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: From 53bdf0e086279fc64271472d284ad268a2eaaf68 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 24 Jul 2023 19:20:40 +0530 Subject: [PATCH 3/3] Add a test for unique timesteps in DPM Single --- tests/schedulers/test_scheduler_dpm_single.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 66be3d5d00ad..350bf698c4d5 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -248,3 +248,11 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps