Skip to content

Commit 2551b73

Browse files
authored
Fix bug in ControlNetPipelines with MultiControlNetModel of length 1 (#4032)
* Fix bug in ControlNetPipelines with MultiControlNetModel of length 1 * Add tests for varying number of ControlNet models * Fix missing indexing for control_guidance_start and control_guidance_end * Fix code quality * Separate test for MultiControlNet with one model * Revert formatting of earlier test
1 parent 930c8fd commit 2551b73

File tree

4 files changed

+176
-3
lines changed

4 files changed

+176
-3
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ def __call__(
914914
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
915915
for s, e in zip(control_guidance_start, control_guidance_end)
916916
]
917-
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
917+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
918918

919919
# 8. Denoising loop
920920
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ def __call__(
10071007
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
10081008
for s, e in zip(control_guidance_start, control_guidance_end)
10091009
]
1010-
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
1010+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
10111011

10121012
# 8. Denoising loop
10131013
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,7 @@ def __call__(
12421242
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
12431243
for s, e in zip(control_guidance_start, control_guidance_end)
12441244
]
1245-
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
1245+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
12461246

12471247
# 8. Denoising loop
12481248
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

tests/pipelines/controlnet/test_controlnet.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,179 @@ def test_save_pretrained_raise_not_implemented_exception(self):
398398
pass
399399

400400

401+
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
402+
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
403+
):
404+
pipeline_class = StableDiffusionControlNetPipeline
405+
params = TEXT_TO_IMAGE_PARAMS
406+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
407+
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
408+
409+
def get_dummy_components(self):
410+
torch.manual_seed(0)
411+
unet = UNet2DConditionModel(
412+
block_out_channels=(32, 64),
413+
layers_per_block=2,
414+
sample_size=32,
415+
in_channels=4,
416+
out_channels=4,
417+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
418+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
419+
cross_attention_dim=32,
420+
)
421+
torch.manual_seed(0)
422+
423+
def init_weights(m):
424+
if isinstance(m, torch.nn.Conv2d):
425+
torch.nn.init.normal(m.weight)
426+
m.bias.data.fill_(1.0)
427+
428+
controlnet = ControlNetModel(
429+
block_out_channels=(32, 64),
430+
layers_per_block=2,
431+
in_channels=4,
432+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
433+
cross_attention_dim=32,
434+
conditioning_embedding_out_channels=(16, 32),
435+
)
436+
controlnet.controlnet_down_blocks.apply(init_weights)
437+
438+
torch.manual_seed(0)
439+
scheduler = DDIMScheduler(
440+
beta_start=0.00085,
441+
beta_end=0.012,
442+
beta_schedule="scaled_linear",
443+
clip_sample=False,
444+
set_alpha_to_one=False,
445+
)
446+
torch.manual_seed(0)
447+
vae = AutoencoderKL(
448+
block_out_channels=[32, 64],
449+
in_channels=3,
450+
out_channels=3,
451+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
452+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
453+
latent_channels=4,
454+
)
455+
torch.manual_seed(0)
456+
text_encoder_config = CLIPTextConfig(
457+
bos_token_id=0,
458+
eos_token_id=2,
459+
hidden_size=32,
460+
intermediate_size=37,
461+
layer_norm_eps=1e-05,
462+
num_attention_heads=4,
463+
num_hidden_layers=5,
464+
pad_token_id=1,
465+
vocab_size=1000,
466+
)
467+
text_encoder = CLIPTextModel(text_encoder_config)
468+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
469+
470+
controlnet = MultiControlNetModel([controlnet])
471+
472+
components = {
473+
"unet": unet,
474+
"controlnet": controlnet,
475+
"scheduler": scheduler,
476+
"vae": vae,
477+
"text_encoder": text_encoder,
478+
"tokenizer": tokenizer,
479+
"safety_checker": None,
480+
"feature_extractor": None,
481+
}
482+
return components
483+
484+
def get_dummy_inputs(self, device, seed=0):
485+
if str(device).startswith("mps"):
486+
generator = torch.manual_seed(seed)
487+
else:
488+
generator = torch.Generator(device=device).manual_seed(seed)
489+
490+
controlnet_embedder_scale_factor = 2
491+
492+
images = [
493+
randn_tensor(
494+
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
495+
generator=generator,
496+
device=torch.device(device),
497+
),
498+
]
499+
500+
inputs = {
501+
"prompt": "A painting of a squirrel eating a burger",
502+
"generator": generator,
503+
"num_inference_steps": 2,
504+
"guidance_scale": 6.0,
505+
"output_type": "numpy",
506+
"image": images,
507+
}
508+
509+
return inputs
510+
511+
def test_control_guidance_switch(self):
512+
components = self.get_dummy_components()
513+
pipe = self.pipeline_class(**components)
514+
pipe.to(torch_device)
515+
516+
scale = 10.0
517+
steps = 4
518+
519+
inputs = self.get_dummy_inputs(torch_device)
520+
inputs["num_inference_steps"] = steps
521+
inputs["controlnet_conditioning_scale"] = scale
522+
output_1 = pipe(**inputs)[0]
523+
524+
inputs = self.get_dummy_inputs(torch_device)
525+
inputs["num_inference_steps"] = steps
526+
inputs["controlnet_conditioning_scale"] = scale
527+
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
528+
529+
inputs = self.get_dummy_inputs(torch_device)
530+
inputs["num_inference_steps"] = steps
531+
inputs["controlnet_conditioning_scale"] = scale
532+
output_3 = pipe(
533+
**inputs,
534+
control_guidance_start=[0.1],
535+
control_guidance_end=[0.2],
536+
)[0]
537+
538+
inputs = self.get_dummy_inputs(torch_device)
539+
inputs["num_inference_steps"] = steps
540+
inputs["controlnet_conditioning_scale"] = scale
541+
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5])[0]
542+
543+
# make sure that all outputs are different
544+
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
545+
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
546+
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
547+
548+
def test_attention_slicing_forward_pass(self):
549+
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
550+
551+
@unittest.skipIf(
552+
torch_device != "cuda" or not is_xformers_available(),
553+
reason="XFormers attention is only available with CUDA and `xformers` installed",
554+
)
555+
def test_xformers_attention_forwardGenerator_pass(self):
556+
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
557+
558+
def test_inference_batch_single_identical(self):
559+
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
560+
561+
def test_save_pretrained_raise_not_implemented_exception(self):
562+
components = self.get_dummy_components()
563+
pipe = self.pipeline_class(**components)
564+
pipe.to(torch_device)
565+
pipe.set_progress_bar_config(disable=None)
566+
with tempfile.TemporaryDirectory() as tmpdir:
567+
try:
568+
# save_pretrained is not implemented for Multi-ControlNet
569+
pipe.save_pretrained(tmpdir)
570+
except NotImplementedError:
571+
pass
572+
573+
401574
@slow
402575
@require_torch_gpu
403576
class ControlNetPipelineSlowTests(unittest.TestCase):

0 commit comments

Comments
 (0)