diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py index 38542407e31f..4b8b15368c47 100644 --- a/src/diffusers/callbacks.py +++ b/src/diffusers/callbacks.py @@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s class SDXLCFGCutoffCallback(PipelineCallback): """ - Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or - `cutoff_step_index`), this callback will disable the CFG. + Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG. Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. """ - tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] + tensor_inputs = [ + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ] def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio @@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s callback_kwargs[self.tensor_inputs[0]] = prompt_embeds callback_kwargs[self.tensor_inputs[1]] = add_text_embeds callback_kwargs[self.tensor_inputs[2]] = add_time_ids + + return callback_kwargs + + +class SDXLControlnetCFGCutoffCallback(PipelineCallback): + """ + Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = [ + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + "image", + ] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + add_text_embeds = callback_kwargs[self.tensor_inputs[1]] + add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens + + add_time_ids = callback_kwargs[self.tensor_inputs[2]] + add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector + + # For Controlnet + image = callback_kwargs[self.tensor_inputs[3]] + image = image[-1:] + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids + callback_kwargs[self.tensor_inputs[3]] = image + return callback_kwargs diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0f3a15172843..7a9433e1d357 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline( "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", + "image", ] def __init__( @@ -1540,6 +1541,7 @@ def __call__( ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + image = callback_outputs.pop("image", image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):