From d6dce5ee1c9add6a9be9fae9d158c3c435609862 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 14 Dec 2023 19:45:25 +0100 Subject: [PATCH 1/7] Initial commit --- .../controlnet/pipeline_controlnet_sd_xl.py | 30 ++++++++ .../pipeline_controlnet_xs_sd_xl.py | 77 ++++++++++++++++--- 2 files changed, 98 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0e7920708184..f3150d74bc53 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -907,6 +907,10 @@ def do_classifier_free_guidance(self): def cross_attention_kwargs(self): return self._cross_attention_kwargs + @property + def denoising_end(self): + return self._denoising_end + @property def num_timesteps(self): return self._num_timesteps @@ -921,6 +925,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, @@ -979,6 +984,13 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -1134,6 +1146,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1307,6 +1320,23 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 58f0f544a5ac..bd6c8ff33835 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -729,6 +729,39 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.guidance_scale + @property + def guidance_scale(self): + return self._guidance_scale + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.clip_skip + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.do_classifier_free_guidance + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.cross_attention_kwargs + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.denoising_end + @property + def denoising_end(self): + return self._denoising_end + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -739,6 +772,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, @@ -794,6 +828,13 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -908,6 +949,11 @@ def __call__( control_guidance_end, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -917,10 +963,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( @@ -936,7 +978,7 @@ def __call__( prompt_2, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, @@ -957,7 +999,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, ) height, width = image.shape[-2:] else: @@ -1015,7 +1057,7 @@ def __call__( else: negative_add_time_ids = add_time_ids - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) @@ -1026,6 +1068,23 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") @@ -1036,7 +1095,7 @@ def __call__( if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -1068,7 +1127,7 @@ def __call__( ).sample # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) From 89f11c86f016f0734dc1f6a205844c143e8652eb Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 14 Dec 2023 20:33:46 +0100 Subject: [PATCH 2/7] Removed copy hints, as in original SDXLControlNetPipeline Removed copy hints, as in original SDXLControlNetPipeline, as the `make fix-copies` seems to have issues with the @property decorator. --- .../pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index bd6c8ff33835..df886a109588 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -729,12 +729,10 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.guidance_scale @property def guidance_scale(self): return self._guidance_scale - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.clip_skip @property def clip_skip(self): return self._clip_skip @@ -742,22 +740,18 @@ def clip_skip(self): # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.do_classifier_free_guidance @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.cross_attention_kwargs @property def cross_attention_kwargs(self): return self._cross_attention_kwargs - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.denoising_end @property def denoising_end(self): return self._denoising_end - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps @property def num_timesteps(self): return self._num_timesteps From 262ed814d3762cf640232f7ac072cf4439ee3ff0 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 4 Mar 2024 23:02:40 +0100 Subject: [PATCH 3/7] Reverted changes to ControlNetXS --- .../pipeline_controlnet_xs_sd_xl.py | 43 ++++--------------- 1 file changed, 9 insertions(+), 34 deletions(-) diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py index 17d1566a8404..fbfe84410316 100644 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py +++ b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py @@ -641,7 +641,6 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, - denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, @@ -697,13 +696,6 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - denoising_end (`float`, *optional*): - When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be - completed before it is intentionally prematurely terminated. As a result, the returned sample will - still retain a substantial amount of noise as determined by the discrete timesteps selected by the - scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a - "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image - Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -818,10 +810,6 @@ def __call__( control_guidance_end, ) - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._cross_attention_kwargs = cross_attention_kwargs - self._denoising_end = denoising_end # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -832,7 +820,11 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None @@ -847,7 +839,7 @@ def __call__( prompt_2, device, num_images_per_prompt, - self.do_classifier_free_guidance, + do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, @@ -868,7 +860,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, + do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] else: @@ -926,7 +918,7 @@ def __call__( else: negative_add_time_ids = add_time_ids - if self.do_classifier_free_guidance: + if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) @@ -937,23 +929,6 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - - # 8.1 Apply denoising_end - if ( - self.denoising_end is not None - and isinstance(self.denoising_end, float) - and self.denoising_end > 0 - and self.denoising_end < 1 - ): - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) - ) - ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") @@ -964,7 +939,7 @@ def __call__( if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} From 04a9d8706d244780f708174f4dadd26d659039c0 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 4 Mar 2024 23:04:53 +0100 Subject: [PATCH 4/7] Addendum to: Removed changes to ControlNetXS --- .../controlnetxs/pipeline_controlnet_xs_sd_xl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py index fbfe84410316..d0186573fa9c 100644 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py +++ b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py @@ -810,7 +810,6 @@ def __call__( control_guidance_end, ) - # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -824,7 +823,7 @@ def __call__( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - + # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None @@ -971,7 +970,7 @@ def __call__( ).sample # perform guidance - if self.do_classifier_free_guidance: + if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) From abecbd2f98b6fb1214439bf1f34e4a2dcbb993a4 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:14:10 +0100 Subject: [PATCH 5/7] Added test+docs for mixture of denoiser --- docs/source/en/using-diffusers/controlnet.md | 21 ++++ .../controlnet/test_controlnet_sdxl.py | 107 ++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md index 849bb838ac63..fd4ddf218796 100644 --- a/docs/source/en/using-diffusers/controlnet.md +++ b/docs/source/en/using-diffusers/controlnet.md @@ -429,6 +429,27 @@ image = pipe( make_image_grid([original_image, canny_image, image], rows=1, cols=3) ``` + + +You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`. +See [section `Refine image quality` on the `StableDiffusionXL` page](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality) for how to do it. +Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`. + +``` +base = StableDiffusionXLControlNetPipeline(...) +image = base( + prompt=prompt, + controlnet_conditioning_scale=0.5, + image=canny_image, + num_inference_steps=40, + denoising_end=0.8, + output_type="latent", +).images +# rest exactly as with StableDiffusionXLPipeline +``` + + + ## MultiControlNet diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index b39147246a74..c82ce6c39cca 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import gc import unittest @@ -24,8 +25,10 @@ AutoencoderKL, ControlNetModel, EulerDiscreteScheduler, + HeunDiscreteScheduler, LCMScheduler, StableDiffusionXLControlNetPipeline, + StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel, ) from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D @@ -364,6 +367,110 @@ def test_controlnet_sdxl_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + # copied from test_stable_diffusion_xl.py:test_stable_diffusion_two_xl_mixture_of_denoiser_fast + # with `StableDiffusionXLControlNetPipeline` instead of `StableDiffusionXLPipeline` + def test_controlnet_sdxl_two_mixture_of_denoiser_fast(self): + components = self.get_dummy_components() + pipe_1 = StableDiffusionXLControlNetPipeline(**components).to(torch_device) + pipe_1.unet.set_default_attn_processor() + + components_without_controlnet = {k: v for k, v in components.items() if k != "controlnet"} + pipe_2 = StableDiffusionXLImg2ImgPipeline(**components_without_controlnet).to(torch_device) + pipe_2.unet.set_default_attn_processor() + + def assert_run_mixture( + num_steps, + split, + scheduler_cls_orig, + expected_tss, + num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps, + ): + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = num_steps + + class scheduler_cls(scheduler_cls_orig): + pass + + pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) + pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) + + # Let's retrieve the number of timesteps we want to use + pipe_1.scheduler.set_timesteps(num_steps) + expected_steps = pipe_1.scheduler.timesteps.tolist() + + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) + expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss)) + expected_steps = expected_steps_1 + expected_steps_2 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) + expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss)) + + # now we monkey patch step `done_steps` + # list into the step function for testing + done_steps = [] + old_step = copy.copy(scheduler_cls.step) + + def new_step(self, *args, **kwargs): + done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t` + return old_step(self, *args, **kwargs) + + scheduler_cls.step = new_step + + inputs_1 = { + **inputs, + **{ + "denoising_end": 1.0 - (split / num_train_timesteps), + "output_type": "latent", + }, + } + latents = pipe_1(**inputs_1).images[0] + + assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" + + inputs_2 = { + **inputs, + **{ + "denoising_start": 1.0 - (split / num_train_timesteps), + "image": latents, + }, + } + pipe_2(**inputs_2).images[0] + + assert expected_steps_2 == done_steps[len(expected_steps_1) :] + assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" + + steps = 10 + for split in [300, 700]: + for scheduler_cls_timesteps in [ + (EulerDiscreteScheduler, [901, 801, 701, 601, 501, 401, 301, 201, 101, 1]), + ( + HeunDiscreteScheduler, + [ + 901.0, + 801.0, + 801.0, + 701.0, + 701.0, + 601.0, + 601.0, + 501.0, + 501.0, + 401.0, + 401.0, + 301.0, + 301.0, + 201.0, + 201.0, + 101.0, + 101.0, + 1.0, + 1.0, + ], + ), + ]: + assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1]) + class StableDiffusionXLMultiControlNetPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase From 1e92ee6c31e1a07d782935c9af46384dd602dce6 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 6 Mar 2024 20:25:47 +0100 Subject: [PATCH 6/7] Update docs/source/en/using-diffusers/controlnet.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/controlnet.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md index fd4ddf218796..bf92ac58200d 100644 --- a/docs/source/en/using-diffusers/controlnet.md +++ b/docs/source/en/using-diffusers/controlnet.md @@ -432,7 +432,7 @@ make_image_grid([original_image, canny_image, image], rows=1, cols=3) You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`. -See [section `Refine image quality` on the `StableDiffusionXL` page](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality) for how to do it. +See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model. Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`. ``` From 1e033a77fc7d2ef1c64b3e8ee7e84d83d1e7327c Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 6 Mar 2024 20:25:54 +0100 Subject: [PATCH 7/7] Update docs/source/en/using-diffusers/controlnet.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/controlnet.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md index bf92ac58200d..2a1295d14d04 100644 --- a/docs/source/en/using-diffusers/controlnet.md +++ b/docs/source/en/using-diffusers/controlnet.md @@ -435,7 +435,7 @@ You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improv See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model. Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`. -``` +```py base = StableDiffusionXLControlNetPipeline(...) image = base( prompt=prompt,