@@ -343,6 +343,21 @@ def test_stable_diffusion_attn_processors(self):
343
343
image = sd_pipe (** inputs ).images
344
344
assert image .shape == (1 , 64 , 64 , 3 )
345
345
346
+ @unittest .skipIf (not torch .cuda .is_available () or not is_xformers_available (), reason = "xformers requires cuda" )
347
+ def test_stable_diffusion_set_xformers_attn_processors (self ):
348
+ # disable_full_determinism()
349
+ device = "cuda" # ensure determinism for the device-dependent torch.Generator
350
+ components , _ = self .get_dummy_components ()
351
+ sd_pipe = StableDiffusionPipeline (** components )
352
+ sd_pipe = sd_pipe .to (device )
353
+ sd_pipe .set_progress_bar_config (disable = None )
354
+
355
+ _ , _ , inputs = self .get_dummy_inputs ()
356
+
357
+ # run normal sd pipe
358
+ image = sd_pipe (** inputs ).images
359
+ assert image .shape == (1 , 64 , 64 , 3 )
360
+
346
361
# run lora xformers attention
347
362
attn_processors , _ = create_unet_lora_layers (sd_pipe .unet )
348
363
attn_processors = {
@@ -607,7 +622,7 @@ def test_unload_lora_sd(self):
607
622
orig_image_slice , orig_image_slice_two , atol = 1e-3
608
623
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
609
624
610
- @unittest .skipIf (torch_device != "cuda" , "This test is supposed to run on GPU" )
625
+ @unittest .skipIf (torch_device != "cuda" or not is_xformers_available () , "This test is supposed to run on GPU" )
611
626
def test_lora_unet_attn_processors_with_xformers (self ):
612
627
with tempfile .TemporaryDirectory () as tmpdirname :
613
628
self .create_lora_weight_file (tmpdirname )
@@ -644,7 +659,7 @@ def test_lora_unet_attn_processors_with_xformers(self):
644
659
if isinstance (module , Attention ):
645
660
self .assertIsInstance (module .processor , XFormersAttnProcessor )
646
661
647
- @unittest .skipIf (torch_device != "cuda" , "This test is supposed to run on GPU" )
662
+ @unittest .skipIf (torch_device != "cuda" or not is_xformers_available () , "This test is supposed to run on GPU" )
648
663
def test_lora_save_load_with_xformers (self ):
649
664
pipeline_components , lora_components = self .get_dummy_components ()
650
665
sd_pipe = StableDiffusionPipeline (** pipeline_components )
0 commit comments