@@ -398,6 +398,179 @@ def test_save_pretrained_raise_not_implemented_exception(self):
398
398
pass
399
399
400
400
401
+ class StableDiffusionMultiControlNetOneModelPipelineFastTests (
402
+ PipelineTesterMixin , PipelineKarrasSchedulerTesterMixin , unittest .TestCase
403
+ ):
404
+ pipeline_class = StableDiffusionControlNetPipeline
405
+ params = TEXT_TO_IMAGE_PARAMS
406
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
407
+ image_params = frozenset ([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
408
+
409
+ def get_dummy_components (self ):
410
+ torch .manual_seed (0 )
411
+ unet = UNet2DConditionModel (
412
+ block_out_channels = (32 , 64 ),
413
+ layers_per_block = 2 ,
414
+ sample_size = 32 ,
415
+ in_channels = 4 ,
416
+ out_channels = 4 ,
417
+ down_block_types = ("DownBlock2D" , "CrossAttnDownBlock2D" ),
418
+ up_block_types = ("CrossAttnUpBlock2D" , "UpBlock2D" ),
419
+ cross_attention_dim = 32 ,
420
+ )
421
+ torch .manual_seed (0 )
422
+
423
+ def init_weights (m ):
424
+ if isinstance (m , torch .nn .Conv2d ):
425
+ torch .nn .init .normal (m .weight )
426
+ m .bias .data .fill_ (1.0 )
427
+
428
+ controlnet = ControlNetModel (
429
+ block_out_channels = (32 , 64 ),
430
+ layers_per_block = 2 ,
431
+ in_channels = 4 ,
432
+ down_block_types = ("DownBlock2D" , "CrossAttnDownBlock2D" ),
433
+ cross_attention_dim = 32 ,
434
+ conditioning_embedding_out_channels = (16 , 32 ),
435
+ )
436
+ controlnet .controlnet_down_blocks .apply (init_weights )
437
+
438
+ torch .manual_seed (0 )
439
+ scheduler = DDIMScheduler (
440
+ beta_start = 0.00085 ,
441
+ beta_end = 0.012 ,
442
+ beta_schedule = "scaled_linear" ,
443
+ clip_sample = False ,
444
+ set_alpha_to_one = False ,
445
+ )
446
+ torch .manual_seed (0 )
447
+ vae = AutoencoderKL (
448
+ block_out_channels = [32 , 64 ],
449
+ in_channels = 3 ,
450
+ out_channels = 3 ,
451
+ down_block_types = ["DownEncoderBlock2D" , "DownEncoderBlock2D" ],
452
+ up_block_types = ["UpDecoderBlock2D" , "UpDecoderBlock2D" ],
453
+ latent_channels = 4 ,
454
+ )
455
+ torch .manual_seed (0 )
456
+ text_encoder_config = CLIPTextConfig (
457
+ bos_token_id = 0 ,
458
+ eos_token_id = 2 ,
459
+ hidden_size = 32 ,
460
+ intermediate_size = 37 ,
461
+ layer_norm_eps = 1e-05 ,
462
+ num_attention_heads = 4 ,
463
+ num_hidden_layers = 5 ,
464
+ pad_token_id = 1 ,
465
+ vocab_size = 1000 ,
466
+ )
467
+ text_encoder = CLIPTextModel (text_encoder_config )
468
+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
469
+
470
+ controlnet = MultiControlNetModel ([controlnet ])
471
+
472
+ components = {
473
+ "unet" : unet ,
474
+ "controlnet" : controlnet ,
475
+ "scheduler" : scheduler ,
476
+ "vae" : vae ,
477
+ "text_encoder" : text_encoder ,
478
+ "tokenizer" : tokenizer ,
479
+ "safety_checker" : None ,
480
+ "feature_extractor" : None ,
481
+ }
482
+ return components
483
+
484
+ def get_dummy_inputs (self , device , seed = 0 ):
485
+ if str (device ).startswith ("mps" ):
486
+ generator = torch .manual_seed (seed )
487
+ else :
488
+ generator = torch .Generator (device = device ).manual_seed (seed )
489
+
490
+ controlnet_embedder_scale_factor = 2
491
+
492
+ images = [
493
+ randn_tensor (
494
+ (1 , 3 , 32 * controlnet_embedder_scale_factor , 32 * controlnet_embedder_scale_factor ),
495
+ generator = generator ,
496
+ device = torch .device (device ),
497
+ ),
498
+ ]
499
+
500
+ inputs = {
501
+ "prompt" : "A painting of a squirrel eating a burger" ,
502
+ "generator" : generator ,
503
+ "num_inference_steps" : 2 ,
504
+ "guidance_scale" : 6.0 ,
505
+ "output_type" : "numpy" ,
506
+ "image" : images ,
507
+ }
508
+
509
+ return inputs
510
+
511
+ def test_control_guidance_switch (self ):
512
+ components = self .get_dummy_components ()
513
+ pipe = self .pipeline_class (** components )
514
+ pipe .to (torch_device )
515
+
516
+ scale = 10.0
517
+ steps = 4
518
+
519
+ inputs = self .get_dummy_inputs (torch_device )
520
+ inputs ["num_inference_steps" ] = steps
521
+ inputs ["controlnet_conditioning_scale" ] = scale
522
+ output_1 = pipe (** inputs )[0 ]
523
+
524
+ inputs = self .get_dummy_inputs (torch_device )
525
+ inputs ["num_inference_steps" ] = steps
526
+ inputs ["controlnet_conditioning_scale" ] = scale
527
+ output_2 = pipe (** inputs , control_guidance_start = 0.1 , control_guidance_end = 0.2 )[0 ]
528
+
529
+ inputs = self .get_dummy_inputs (torch_device )
530
+ inputs ["num_inference_steps" ] = steps
531
+ inputs ["controlnet_conditioning_scale" ] = scale
532
+ output_3 = pipe (
533
+ ** inputs ,
534
+ control_guidance_start = [0.1 ],
535
+ control_guidance_end = [0.2 ],
536
+ )[0 ]
537
+
538
+ inputs = self .get_dummy_inputs (torch_device )
539
+ inputs ["num_inference_steps" ] = steps
540
+ inputs ["controlnet_conditioning_scale" ] = scale
541
+ output_4 = pipe (** inputs , control_guidance_start = 0.4 , control_guidance_end = [0.5 ])[0 ]
542
+
543
+ # make sure that all outputs are different
544
+ assert np .sum (np .abs (output_1 - output_2 )) > 1e-3
545
+ assert np .sum (np .abs (output_1 - output_3 )) > 1e-3
546
+ assert np .sum (np .abs (output_1 - output_4 )) > 1e-3
547
+
548
+ def test_attention_slicing_forward_pass (self ):
549
+ return self ._test_attention_slicing_forward_pass (expected_max_diff = 2e-3 )
550
+
551
+ @unittest .skipIf (
552
+ torch_device != "cuda" or not is_xformers_available (),
553
+ reason = "XFormers attention is only available with CUDA and `xformers` installed" ,
554
+ )
555
+ def test_xformers_attention_forwardGenerator_pass (self ):
556
+ self ._test_xformers_attention_forwardGenerator_pass (expected_max_diff = 2e-3 )
557
+
558
+ def test_inference_batch_single_identical (self ):
559
+ self ._test_inference_batch_single_identical (expected_max_diff = 2e-3 )
560
+
561
+ def test_save_pretrained_raise_not_implemented_exception (self ):
562
+ components = self .get_dummy_components ()
563
+ pipe = self .pipeline_class (** components )
564
+ pipe .to (torch_device )
565
+ pipe .set_progress_bar_config (disable = None )
566
+ with tempfile .TemporaryDirectory () as tmpdir :
567
+ try :
568
+ # save_pretrained is not implemented for Multi-ControlNet
569
+ pipe .save_pretrained (tmpdir )
570
+ except NotImplementedError :
571
+ pass
572
+
573
+
401
574
@slow
402
575
@require_torch_gpu
403
576
class ControlNetPipelineSlowTests (unittest .TestCase ):
0 commit comments