@@ -320,7 +320,7 @@ def test_text_to_image_sdxl(self):
320
320
feature_extractor = feature_extractor ,
321
321
torch_dtype = self .dtype ,
322
322
)
323
- pipeline .to ( torch_device )
323
+ pipeline .enable_model_cpu_offload ( )
324
324
pipeline .load_ip_adapter ("h94/IP-Adapter" , subfolder = "sdxl_models" , weight_name = "ip-adapter_sdxl.bin" )
325
325
326
326
inputs = self .get_dummy_inputs ()
@@ -380,7 +380,7 @@ def test_image_to_image_sdxl(self):
380
380
feature_extractor = feature_extractor ,
381
381
torch_dtype = self .dtype ,
382
382
)
383
- pipeline .to ( torch_device )
383
+ pipeline .enable_model_cpu_offload ( )
384
384
pipeline .load_ip_adapter ("h94/IP-Adapter" , subfolder = "sdxl_models" , weight_name = "ip-adapter_sdxl.bin" )
385
385
386
386
inputs = self .get_dummy_inputs (for_image_to_image = True )
@@ -449,7 +449,7 @@ def test_inpainting_sdxl(self):
449
449
feature_extractor = feature_extractor ,
450
450
torch_dtype = self .dtype ,
451
451
)
452
- pipeline .to ( torch_device )
452
+ pipeline .enable_model_cpu_offload ( )
453
453
pipeline .load_ip_adapter ("h94/IP-Adapter" , subfolder = "sdxl_models" , weight_name = "ip-adapter_sdxl.bin" )
454
454
455
455
inputs = self .get_dummy_inputs (for_inpainting = True )
@@ -497,7 +497,7 @@ def test_ip_adapter_single_mask(self):
497
497
image_encoder = image_encoder ,
498
498
torch_dtype = self .dtype ,
499
499
)
500
- pipeline .to ( torch_device )
500
+ pipeline .enable_model_cpu_offload ( )
501
501
pipeline .load_ip_adapter (
502
502
"h94/IP-Adapter" , subfolder = "sdxl_models" , weight_name = "ip-adapter-plus-face_sdxl_vit-h.safetensors"
503
503
)
@@ -525,7 +525,7 @@ def test_ip_adapter_multiple_masks(self):
525
525
image_encoder = image_encoder ,
526
526
torch_dtype = self .dtype ,
527
527
)
528
- pipeline .to ( torch_device )
528
+ pipeline .enable_model_cpu_offload ( )
529
529
pipeline .load_ip_adapter (
530
530
"h94/IP-Adapter" , subfolder = "sdxl_models" , weight_name = ["ip-adapter-plus-face_sdxl_vit-h.safetensors" ] * 2
531
531
)
0 commit comments