From d6dce5ee1c9add6a9be9fae9d158c3c435609862 Mon Sep 17 00:00:00 2001
From: UmerHA <40663591+UmerHA@users.noreply.github.com>
Date: Thu, 14 Dec 2023 19:45:25 +0100
Subject: [PATCH 1/7] Initial commit
---
.../controlnet/pipeline_controlnet_sd_xl.py | 30 ++++++++
.../pipeline_controlnet_xs_sd_xl.py | 77 ++++++++++++++++---
2 files changed, 98 insertions(+), 9 deletions(-)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 0e7920708184..f3150d74bc53 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -907,6 +907,10 @@ def do_classifier_free_guidance(self):
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
@property
def num_timesteps(self):
return self._num_timesteps
@@ -921,6 +925,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -979,6 +984,13 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -1134,6 +1146,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1307,6 +1320,23 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
index 58f0f544a5ac..bd6c8ff33835 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -729,6 +729,39 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.guidance_scale
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.clip_skip
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.do_classifier_free_guidance
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.cross_attention_kwargs
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.denoising_end
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -739,6 +772,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -794,6 +828,13 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -908,6 +949,11 @@ def __call__(
control_guidance_end,
)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -917,10 +963,6 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
@@ -936,7 +978,7 @@ def __call__(
prompt_2,
device,
num_images_per_prompt,
- do_classifier_free_guidance,
+ self.do_classifier_free_guidance,
negative_prompt,
negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -957,7 +999,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=do_classifier_free_guidance,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
)
height, width = image.shape[-2:]
else:
@@ -1015,7 +1057,7 @@ def __call__(
else:
negative_add_time_ids = add_time_ids
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
@@ -1026,6 +1068,23 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
@@ -1036,7 +1095,7 @@ def __call__(
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
@@ -1068,7 +1127,7 @@ def __call__(
).sample
# perform guidance
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
From 89f11c86f016f0734dc1f6a205844c143e8652eb Mon Sep 17 00:00:00 2001
From: UmerHA <40663591+UmerHA@users.noreply.github.com>
Date: Thu, 14 Dec 2023 20:33:46 +0100
Subject: [PATCH 2/7] Removed copy hints, as in original SDXLControlNetPipeline
Removed copy hints, as in original SDXLControlNetPipeline, as the `make fix-copies` seems to have issues with the @property decorator.
---
.../pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
index bd6c8ff33835..df886a109588 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -729,12 +729,10 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.guidance_scale
@property
def guidance_scale(self):
return self._guidance_scale
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.clip_skip
@property
def clip_skip(self):
return self._clip_skip
@@ -742,22 +740,18 @@ def clip_skip(self):
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.do_classifier_free_guidance
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.cross_attention_kwargs
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.denoising_end
@property
def denoising_end(self):
return self._denoising_end
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
@property
def num_timesteps(self):
return self._num_timesteps
From 262ed814d3762cf640232f7ac072cf4439ee3ff0 Mon Sep 17 00:00:00 2001
From: UmerHA <40663591+UmerHA@users.noreply.github.com>
Date: Mon, 4 Mar 2024 23:02:40 +0100
Subject: [PATCH 3/7] Reverted changes to ControlNetXS
---
.../pipeline_controlnet_xs_sd_xl.py | 43 ++++---------------
1 file changed, 9 insertions(+), 34 deletions(-)
diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
index 17d1566a8404..fbfe84410316 100644
--- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
+++ b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
@@ -641,7 +641,6 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
- denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -697,13 +696,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- denoising_end (`float`, *optional*):
- When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
- completed before it is intentionally prematurely terminated. As a result, the returned sample will
- still retain a substantial amount of noise as determined by the discrete timesteps selected by the
- scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
- "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -818,10 +810,6 @@ def __call__(
control_guidance_end,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
- self._denoising_end = denoising_end
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -832,7 +820,11 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
-
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
@@ -847,7 +839,7 @@ def __call__(
prompt_2,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -868,7 +860,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
)
height, width = image.shape[-2:]
else:
@@ -926,7 +918,7 @@ def __call__(
else:
negative_add_time_ids = add_time_ids
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
@@ -937,23 +929,6 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
-
- # 8.1 Apply denoising_end
- if (
- self.denoising_end is not None
- and isinstance(self.denoising_end, float)
- and self.denoising_end > 0
- and self.denoising_end < 1
- ):
- discrete_timestep_cutoff = int(
- round(
- self.scheduler.config.num_train_timesteps
- - (self.denoising_end * self.scheduler.config.num_train_timesteps)
- )
- )
- num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
- timesteps = timesteps[:num_inference_steps]
-
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
@@ -964,7 +939,7 @@ def __call__(
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
From 04a9d8706d244780f708174f4dadd26d659039c0 Mon Sep 17 00:00:00 2001
From: UmerHA <40663591+UmerHA@users.noreply.github.com>
Date: Mon, 4 Mar 2024 23:04:53 +0100
Subject: [PATCH 4/7] Addendum to: Removed changes to ControlNetXS
---
.../controlnetxs/pipeline_controlnet_xs_sd_xl.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
index fbfe84410316..d0186573fa9c 100644
--- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
+++ b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
@@ -810,7 +810,6 @@ def __call__(
control_guidance_end,
)
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -824,7 +823,7 @@ def __call__(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
-
+
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
@@ -971,7 +970,7 @@ def __call__(
).sample
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
From abecbd2f98b6fb1214439bf1f34e4a2dcbb993a4 Mon Sep 17 00:00:00 2001
From: UmerHA <40663591+UmerHA@users.noreply.github.com>
Date: Wed, 6 Mar 2024 15:14:10 +0100
Subject: [PATCH 5/7] Added test+docs for mixture of denoiser
---
docs/source/en/using-diffusers/controlnet.md | 21 ++++
.../controlnet/test_controlnet_sdxl.py | 107 ++++++++++++++++++
2 files changed, 128 insertions(+)
diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md
index 849bb838ac63..fd4ddf218796 100644
--- a/docs/source/en/using-diffusers/controlnet.md
+++ b/docs/source/en/using-diffusers/controlnet.md
@@ -429,6 +429,27 @@ image = pipe(
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
```
+
+
+You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`.
+See [section `Refine image quality` on the `StableDiffusionXL` page](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality) for how to do it.
+Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`.
+
+```
+base = StableDiffusionXLControlNetPipeline(...)
+image = base(
+ prompt=prompt,
+ controlnet_conditioning_scale=0.5,
+ image=canny_image,
+ num_inference_steps=40,
+ denoising_end=0.8,
+ output_type="latent",
+).images
+# rest exactly as with StableDiffusionXLPipeline
+```
+
+
+
## MultiControlNet
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py
index b39147246a74..c82ce6c39cca 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import gc
import unittest
@@ -24,8 +25,10 @@
AutoencoderKL,
ControlNetModel,
EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
LCMScheduler,
StableDiffusionXLControlNetPipeline,
+ StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
@@ -364,6 +367,110 @@ def test_controlnet_sdxl_lcm(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ # copied from test_stable_diffusion_xl.py:test_stable_diffusion_two_xl_mixture_of_denoiser_fast
+ # with `StableDiffusionXLControlNetPipeline` instead of `StableDiffusionXLPipeline`
+ def test_controlnet_sdxl_two_mixture_of_denoiser_fast(self):
+ components = self.get_dummy_components()
+ pipe_1 = StableDiffusionXLControlNetPipeline(**components).to(torch_device)
+ pipe_1.unet.set_default_attn_processor()
+
+ components_without_controlnet = {k: v for k, v in components.items() if k != "controlnet"}
+ pipe_2 = StableDiffusionXLImg2ImgPipeline(**components_without_controlnet).to(torch_device)
+ pipe_2.unet.set_default_attn_processor()
+
+ def assert_run_mixture(
+ num_steps,
+ split,
+ scheduler_cls_orig,
+ expected_tss,
+ num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps,
+ ):
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["num_inference_steps"] = num_steps
+
+ class scheduler_cls(scheduler_cls_orig):
+ pass
+
+ pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
+ pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
+
+ # Let's retrieve the number of timesteps we want to use
+ pipe_1.scheduler.set_timesteps(num_steps)
+ expected_steps = pipe_1.scheduler.timesteps.tolist()
+
+ if pipe_1.scheduler.order == 2:
+ expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
+ expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
+ expected_steps = expected_steps_1 + expected_steps_2
+ else:
+ expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
+ expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
+
+ # now we monkey patch step `done_steps`
+ # list into the step function for testing
+ done_steps = []
+ old_step = copy.copy(scheduler_cls.step)
+
+ def new_step(self, *args, **kwargs):
+ done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t`
+ return old_step(self, *args, **kwargs)
+
+ scheduler_cls.step = new_step
+
+ inputs_1 = {
+ **inputs,
+ **{
+ "denoising_end": 1.0 - (split / num_train_timesteps),
+ "output_type": "latent",
+ },
+ }
+ latents = pipe_1(**inputs_1).images[0]
+
+ assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
+
+ inputs_2 = {
+ **inputs,
+ **{
+ "denoising_start": 1.0 - (split / num_train_timesteps),
+ "image": latents,
+ },
+ }
+ pipe_2(**inputs_2).images[0]
+
+ assert expected_steps_2 == done_steps[len(expected_steps_1) :]
+ assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
+
+ steps = 10
+ for split in [300, 700]:
+ for scheduler_cls_timesteps in [
+ (EulerDiscreteScheduler, [901, 801, 701, 601, 501, 401, 301, 201, 101, 1]),
+ (
+ HeunDiscreteScheduler,
+ [
+ 901.0,
+ 801.0,
+ 801.0,
+ 701.0,
+ 701.0,
+ 601.0,
+ 601.0,
+ 501.0,
+ 501.0,
+ 401.0,
+ 401.0,
+ 301.0,
+ 301.0,
+ 201.0,
+ 201.0,
+ 101.0,
+ 101.0,
+ 1.0,
+ 1.0,
+ ],
+ ),
+ ]:
+ assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1])
+
class StableDiffusionXLMultiControlNetPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
From 1e92ee6c31e1a07d782935c9af46384dd602dce6 Mon Sep 17 00:00:00 2001
From: UmerHA <40663591+UmerHA@users.noreply.github.com>
Date: Wed, 6 Mar 2024 20:25:47 +0100
Subject: [PATCH 6/7] Update docs/source/en/using-diffusers/controlnet.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
---
docs/source/en/using-diffusers/controlnet.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md
index fd4ddf218796..bf92ac58200d 100644
--- a/docs/source/en/using-diffusers/controlnet.md
+++ b/docs/source/en/using-diffusers/controlnet.md
@@ -432,7 +432,7 @@ make_image_grid([original_image, canny_image, image], rows=1, cols=3)
You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`.
-See [section `Refine image quality` on the `StableDiffusionXL` page](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality) for how to do it.
+See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model.
Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`.
```
From 1e033a77fc7d2ef1c64b3e8ee7e84d83d1e7327c Mon Sep 17 00:00:00 2001
From: UmerHA <40663591+UmerHA@users.noreply.github.com>
Date: Wed, 6 Mar 2024 20:25:54 +0100
Subject: [PATCH 7/7] Update docs/source/en/using-diffusers/controlnet.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
---
docs/source/en/using-diffusers/controlnet.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md
index bf92ac58200d..2a1295d14d04 100644
--- a/docs/source/en/using-diffusers/controlnet.md
+++ b/docs/source/en/using-diffusers/controlnet.md
@@ -435,7 +435,7 @@ You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improv
See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model.
Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`.
-```
+```py
base = StableDiffusionXLControlNetPipeline(...)
image = base(
prompt=prompt,