From d13114edcde48808e66c5b3b2ee563c88b18a962 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Wed, 27 Sep 2023 06:19:05 -0400 Subject: [PATCH 1/6] Update Unipc einsum to support 1D and 3D diffusion. --- src/diffusers/schedulers/scheduling_unipc_multistep.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 2b5bd4fd60db..18d95fe514cc 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -534,14 +534,14 @@ def multistep_uni_p_bh_update( if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: - pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: - pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res @@ -670,7 +670,7 @@ def multistep_uni_c_bh_update( if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: - corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 @@ -678,7 +678,7 @@ def multistep_uni_c_bh_update( else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: - corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 From 8c22c8b454edf09d7e404199ebea739191eb0ca1 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Wed, 27 Sep 2023 06:29:06 -0400 Subject: [PATCH 2/6] Add unittest --- tests/schedulers/test_scheduler_unipc.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 08482fd06b62..dbff0ab911a2 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -242,3 +242,29 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + +class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest): + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + width = 8 + + sample = torch.rand((batch_size, num_channels, width)) + + return sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + return sample From cd340b3d926a443650c9d7bf31d3300039e52d10 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Wed, 27 Sep 2023 06:50:00 -0400 Subject: [PATCH 3/6] Update unittest & edge case --- .../schedulers/scheduling_unipc_multistep.py | 6 +-- tests/schedulers/test_scheduler_unipc.py | 45 ++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 18d95fe514cc..d61341cee72d 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -282,13 +282,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype - batch_size, channels, height, width = sample.shape + batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * height * width) + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" @@ -300,7 +300,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - sample = sample.reshape(batch_size, channels, height, width) + sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index dbff0ab911a2..9c847ee40741 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -265,6 +265,49 @@ def dummy_sample_deter(self): sample = torch.arange(num_elems) sample = sample.reshape(num_channels, width, batch_size) sample = sample / num_elems - sample = sample.permute(3, 0, 1, 2) + sample = sample.permute(2, 0, 1) return sample + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler = UniPCMultistepScheduler(**self.get_scheduler_config()) + sample = self.full_loop(scheduler=scheduler) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config) + scheduler = DEISMultistepScheduler.from_config(scheduler.config) + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + scheduler = UniPCMultistepScheduler.from_config(scheduler.config) + + sample = self.full_loop(scheduler=scheduler) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_with_karras(self): + sample = self.full_loop(use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2898) < 1e-3 + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.1944) < 1e-3 From e61af7ad3dfa78b01d3c8aeb83c63c7ae723788b Mon Sep 17 00:00:00 2001 From: Lengyue Date: Thu, 28 Sep 2023 12:47:05 -0400 Subject: [PATCH 4/6] Fix unittest --- src/diffusers/utils/testing_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 2998f7dc429e..3fa1027cdf38 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -69,7 +69,8 @@ ) from e logger.info(f"torch_device overrode to {torch_device}") else: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + # torch_device = "cuda" if torch.cuda.is_available() else "cpu" + torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( version.parse(torch.__version__).base_version ) >= version.parse("1.12") From 1024f14a7867e75b776e76ce585febd995802d6d Mon Sep 17 00:00:00 2001 From: Leng Yue Date: Mon, 2 Oct 2023 02:57:37 -0700 Subject: [PATCH 5/6] Fix testing_utils.py --- src/diffusers/utils/testing_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 3fa1027cdf38..2998f7dc429e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -69,8 +69,7 @@ ) from e logger.info(f"torch_device overrode to {torch_device}") else: - # torch_device = "cuda" if torch.cuda.is_available() else "cpu" - torch_device = "cpu" + torch_device = "cuda" if torch.cuda.is_available() else "cpu" is_torch_higher_equal_than_1_12 = version.parse( version.parse(torch.__version__).base_version ) >= version.parse("1.12") From f311608721d5249f0039b7914b9792b05dea3373 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Mon, 2 Oct 2023 06:06:08 -0400 Subject: [PATCH 6/6] Fix unittest file --- tests/schedulers/test_scheduler_unipc.py | 42 ++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 76153d2c0d47..be41cea95b67 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -270,6 +270,7 @@ def test_full_loop_with_noise(self): assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" + class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest): @property def dummy_sample(self): @@ -281,6 +282,20 @@ def dummy_sample(self): return sample + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + + return sample + @property def dummy_sample_deter(self): batch_size = 4 @@ -337,3 +352,30 @@ def test_full_loop_with_karras_and_v_prediction(self): result_mean = torch.mean(torch.abs(sample)) assert abs(result_mean.item() - 0.1944) < 1e-3 + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + t_start = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + # add noise + noise = self.dummy_noise_deter + timesteps = scheduler.timesteps[t_start * scheduler.order :] + sample = scheduler.add_noise(sample, noise, timesteps[:1]) + + for i, t in enumerate(timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" + assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"