Skip to content

Commit

Permalink
Fix controlnet guess mode euler (#3571)
Browse files Browse the repository at this point in the history
* Fix guess mode controlnet for euler-like schedulers

* make style

* Co-authored-by: Chanchana Sornsoontorn <off.chanchana@gmail.com>

* Add co author Co-authored-by: Chanchana Sornsoontorn <off.chanchana@gmail.com>

* 2nd try
Co-authored-by: Chanchana Sornsoontorn <off.chanchana@gmail.com>
  • Loading branch information
patrickvonplaten committed May 26, 2023
1 parent 66356e7 commit bf16a97
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,14 +956,15 @@ def __call__(
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds

down_block_res_samples, mid_block_res_sample = self.controlnet(
controlnet_latent_model_input,
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1034,14 +1034,15 @@ def __call__(
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds

down_block_res_samples, mid_block_res_sample = self.controlnet(
controlnet_latent_model_input,
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1248,16 +1248,18 @@ def __call__(
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds

down_block_res_samples, mid_block_res_sample = self.controlnet(
controlnet_latent_model_input,
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image,
Expand Down
34 changes: 34 additions & 0 deletions tests/pipelines/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AutoencoderKL,
ControlNetModel,
DDIMScheduler,
EulerDiscreteScheduler,
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
)
Expand Down Expand Up @@ -644,6 +645,39 @@ def test_canny_guess_mode(self):
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_canny_guess_mode_euler(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = ""
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)

output = pipe(
prompt,
image,
generator=generator,
output_type="np",
num_inference_steps=3,
guidance_scale=3.0,
guess_mode=True,
)

image = output.images[0]
assert image.shape == (768, 512, 3)

image_slice = image[-3:, -3:, -1]
expected_slice = np.array([0.1655, 0.1721, 0.1623, 0.1685, 0.1711, 0.1646, 0.1651, 0.1631, 0.1494])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@require_torch_2
def test_stable_diffusion_compile(self):
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
Expand Down

0 comments on commit bf16a97

Please sign in to comment.