From 7f5ece9f094058a91f5921a0d51ac834c53c6bcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Thu, 29 Aug 2024 04:50:13 -0400 Subject: [PATCH 1/2] initial proposal --- src/diffusers/callbacks.py | 59 ++++++++++++++++++- .../controlnet/pipeline_controlnet_sd_xl.py | 2 + 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py index 38542407e31f..10d8c44d85af 100644 --- a/src/diffusers/callbacks.py +++ b/src/diffusers/callbacks.py @@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s class SDXLCFGCutoffCallback(PipelineCallback): """ - Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or - `cutoff_step_index`), this callback will disable the CFG. + Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` + or `cutoff_step_index`), this callback will disable the CFG. Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. """ - tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] + tensor_inputs = [ + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ] def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio @@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s callback_kwargs[self.tensor_inputs[0]] = prompt_embeds callback_kwargs[self.tensor_inputs[1]] = add_text_embeds callback_kwargs[self.tensor_inputs[2]] = add_time_ids + + return callback_kwargs + + +class SDXLControlnetCFGCutoffCallback(PipelineCallback): + """ + Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` + or `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = [ + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + "image", + ] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + add_text_embeds = callback_kwargs[self.tensor_inputs[1]] + add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens + + add_time_ids = callback_kwargs[self.tensor_inputs[2]] + add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector + + # For Controlnet + image = callback_kwargs[self.tensor_inputs[3]] + image = image[-1:] + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids + callback_kwargs[self.tensor_inputs[3]] = image + return callback_kwargs diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index fdebcdf83641..e2a49a4311d6 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline( "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", + "image", ] def __init__( @@ -1532,6 +1533,7 @@ def __call__( ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + image = callback_outputs.pop("image", image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): From adeb1cc87b5c00ef795e445828ba1f790f8cc579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Thu, 29 Aug 2024 05:11:12 -0400 Subject: [PATCH 2/2] style --- src/diffusers/callbacks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py index 10d8c44d85af..4b8b15368c47 100644 --- a/src/diffusers/callbacks.py +++ b/src/diffusers/callbacks.py @@ -97,8 +97,8 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s class SDXLCFGCutoffCallback(PipelineCallback): """ - Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` - or `cutoff_step_index`), this callback will disable the CFG. + Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG. Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. """ @@ -139,8 +139,8 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s class SDXLControlnetCFGCutoffCallback(PipelineCallback): """ - Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` - or `cutoff_step_index`), this callback will disable the CFG. + Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG. Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. """