diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md
index 849bb838ac63..2a1295d14d04 100644
--- a/docs/source/en/using-diffusers/controlnet.md
+++ b/docs/source/en/using-diffusers/controlnet.md
@@ -429,6 +429,27 @@ image = pipe(
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
```
+
+
+You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`.
+See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model.
+Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`.
+
+```py
+base = StableDiffusionXLControlNetPipeline(...)
+image = base(
+ prompt=prompt,
+ controlnet_conditioning_scale=0.5,
+ image=canny_image,
+ num_inference_steps=40,
+ denoising_end=0.8,
+ output_type="latent",
+).images
+# rest exactly as with StableDiffusionXLPipeline
+```
+
+
+
## MultiControlNet
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 0b611350a6f1..6bb7f5b6fdac 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -916,6 +916,10 @@ def do_classifier_free_guidance(self):
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
@property
def num_timesteps(self):
return self._num_timesteps
@@ -930,6 +934,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -989,6 +994,13 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -1151,6 +1163,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1325,6 +1338,23 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py
index b39147246a74..c82ce6c39cca 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import gc
import unittest
@@ -24,8 +25,10 @@
AutoencoderKL,
ControlNetModel,
EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
LCMScheduler,
StableDiffusionXLControlNetPipeline,
+ StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
@@ -364,6 +367,110 @@ def test_controlnet_sdxl_lcm(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ # copied from test_stable_diffusion_xl.py:test_stable_diffusion_two_xl_mixture_of_denoiser_fast
+ # with `StableDiffusionXLControlNetPipeline` instead of `StableDiffusionXLPipeline`
+ def test_controlnet_sdxl_two_mixture_of_denoiser_fast(self):
+ components = self.get_dummy_components()
+ pipe_1 = StableDiffusionXLControlNetPipeline(**components).to(torch_device)
+ pipe_1.unet.set_default_attn_processor()
+
+ components_without_controlnet = {k: v for k, v in components.items() if k != "controlnet"}
+ pipe_2 = StableDiffusionXLImg2ImgPipeline(**components_without_controlnet).to(torch_device)
+ pipe_2.unet.set_default_attn_processor()
+
+ def assert_run_mixture(
+ num_steps,
+ split,
+ scheduler_cls_orig,
+ expected_tss,
+ num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps,
+ ):
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["num_inference_steps"] = num_steps
+
+ class scheduler_cls(scheduler_cls_orig):
+ pass
+
+ pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
+ pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
+
+ # Let's retrieve the number of timesteps we want to use
+ pipe_1.scheduler.set_timesteps(num_steps)
+ expected_steps = pipe_1.scheduler.timesteps.tolist()
+
+ if pipe_1.scheduler.order == 2:
+ expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
+ expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
+ expected_steps = expected_steps_1 + expected_steps_2
+ else:
+ expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
+ expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
+
+ # now we monkey patch step `done_steps`
+ # list into the step function for testing
+ done_steps = []
+ old_step = copy.copy(scheduler_cls.step)
+
+ def new_step(self, *args, **kwargs):
+ done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t`
+ return old_step(self, *args, **kwargs)
+
+ scheduler_cls.step = new_step
+
+ inputs_1 = {
+ **inputs,
+ **{
+ "denoising_end": 1.0 - (split / num_train_timesteps),
+ "output_type": "latent",
+ },
+ }
+ latents = pipe_1(**inputs_1).images[0]
+
+ assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
+
+ inputs_2 = {
+ **inputs,
+ **{
+ "denoising_start": 1.0 - (split / num_train_timesteps),
+ "image": latents,
+ },
+ }
+ pipe_2(**inputs_2).images[0]
+
+ assert expected_steps_2 == done_steps[len(expected_steps_1) :]
+ assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
+
+ steps = 10
+ for split in [300, 700]:
+ for scheduler_cls_timesteps in [
+ (EulerDiscreteScheduler, [901, 801, 701, 601, 501, 401, 301, 201, 101, 1]),
+ (
+ HeunDiscreteScheduler,
+ [
+ 901.0,
+ 801.0,
+ 801.0,
+ 701.0,
+ 701.0,
+ 601.0,
+ 601.0,
+ 501.0,
+ 501.0,
+ 401.0,
+ 401.0,
+ 301.0,
+ 301.0,
+ 201.0,
+ 201.0,
+ 101.0,
+ 101.0,
+ 1.0,
+ 1.0,
+ ],
+ ),
+ ]:
+ assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1])
+
class StableDiffusionXLMultiControlNetPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase