diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md index 849bb838ac63..2a1295d14d04 100644 --- a/docs/source/en/using-diffusers/controlnet.md +++ b/docs/source/en/using-diffusers/controlnet.md @@ -429,6 +429,27 @@ image = pipe( make_image_grid([original_image, canny_image, image], rows=1, cols=3) ``` + + +You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`. +See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model. +Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`. + +```py +base = StableDiffusionXLControlNetPipeline(...) +image = base( + prompt=prompt, + controlnet_conditioning_scale=0.5, + image=canny_image, + num_inference_steps=40, + denoising_end=0.8, + output_type="latent", +).images +# rest exactly as with StableDiffusionXLPipeline +``` + + + ## MultiControlNet diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0b611350a6f1..6bb7f5b6fdac 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -916,6 +916,10 @@ def do_classifier_free_guidance(self): def cross_attention_kwargs(self): return self._cross_attention_kwargs + @property + def denoising_end(self): + return self._denoising_end + @property def num_timesteps(self): return self._num_timesteps @@ -930,6 +934,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, @@ -989,6 +994,13 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -1151,6 +1163,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1325,6 +1338,23 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index b39147246a74..c82ce6c39cca 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import gc import unittest @@ -24,8 +25,10 @@ AutoencoderKL, ControlNetModel, EulerDiscreteScheduler, + HeunDiscreteScheduler, LCMScheduler, StableDiffusionXLControlNetPipeline, + StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel, ) from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D @@ -364,6 +367,110 @@ def test_controlnet_sdxl_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + # copied from test_stable_diffusion_xl.py:test_stable_diffusion_two_xl_mixture_of_denoiser_fast + # with `StableDiffusionXLControlNetPipeline` instead of `StableDiffusionXLPipeline` + def test_controlnet_sdxl_two_mixture_of_denoiser_fast(self): + components = self.get_dummy_components() + pipe_1 = StableDiffusionXLControlNetPipeline(**components).to(torch_device) + pipe_1.unet.set_default_attn_processor() + + components_without_controlnet = {k: v for k, v in components.items() if k != "controlnet"} + pipe_2 = StableDiffusionXLImg2ImgPipeline(**components_without_controlnet).to(torch_device) + pipe_2.unet.set_default_attn_processor() + + def assert_run_mixture( + num_steps, + split, + scheduler_cls_orig, + expected_tss, + num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps, + ): + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = num_steps + + class scheduler_cls(scheduler_cls_orig): + pass + + pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) + pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) + + # Let's retrieve the number of timesteps we want to use + pipe_1.scheduler.set_timesteps(num_steps) + expected_steps = pipe_1.scheduler.timesteps.tolist() + + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) + expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss)) + expected_steps = expected_steps_1 + expected_steps_2 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) + expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss)) + + # now we monkey patch step `done_steps` + # list into the step function for testing + done_steps = [] + old_step = copy.copy(scheduler_cls.step) + + def new_step(self, *args, **kwargs): + done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t` + return old_step(self, *args, **kwargs) + + scheduler_cls.step = new_step + + inputs_1 = { + **inputs, + **{ + "denoising_end": 1.0 - (split / num_train_timesteps), + "output_type": "latent", + }, + } + latents = pipe_1(**inputs_1).images[0] + + assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" + + inputs_2 = { + **inputs, + **{ + "denoising_start": 1.0 - (split / num_train_timesteps), + "image": latents, + }, + } + pipe_2(**inputs_2).images[0] + + assert expected_steps_2 == done_steps[len(expected_steps_1) :] + assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" + + steps = 10 + for split in [300, 700]: + for scheduler_cls_timesteps in [ + (EulerDiscreteScheduler, [901, 801, 701, 601, 501, 401, 301, 201, 101, 1]), + ( + HeunDiscreteScheduler, + [ + 901.0, + 801.0, + 801.0, + 701.0, + 701.0, + 601.0, + 601.0, + 501.0, + 501.0, + 401.0, + 401.0, + 301.0, + 301.0, + 201.0, + 201.0, + 101.0, + 101.0, + 1.0, + 1.0, + ], + ), + ]: + assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1]) + class StableDiffusionXLMultiControlNetPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase