diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 0c584777affc..50b6448ecdca 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -661,8 +661,20 @@ def set_adapters( adapter_names: Union[List[str], str], adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, ): - adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + if isinstance(adapter_weights, dict): + components_passed = set(adapter_weights.keys()) + lora_components = set(self._lora_loadable_modules) + + invalid_components = sorted(components_passed - lora_components) + if invalid_components: + logger.warning( + f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. " + f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging " + "to the invalid components will be removed and ignored." + ) + adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components} + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names adapter_weights = copy.deepcopy(adapter_weights) # Expand weights into a list, one entry per adapter @@ -697,12 +709,6 @@ def set_adapters( for adapter_name, weights in zip(adapter_names, adapter_weights): if isinstance(weights, dict): component_adapter_weights = weights.pop(component, None) - - if component_adapter_weights is not None and not hasattr(self, component): - logger.warning( - f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}." - ) - if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]: logger.warning( ( diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index f176de4e3651..dc2695452c2f 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -155,3 +155,7 @@ def test_simple_inference_with_text_lora_fused(self): @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_save_load(self): pass + + @unittest.skip("Not supported in CogVideoX.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + pass diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 0a9c4166fe87..06bbcc62a0d5 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -262,6 +262,10 @@ def test_lora_expansion_works_for_extra_keys(self): "LoRA should lead to different results.", ) + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass @@ -270,6 +274,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + pass + class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = FluxControlPipeline @@ -783,6 +791,10 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) self.assertTrue(pipe.transformer.config.in_channels == in_features * 2) + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass @@ -791,6 +803,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + pass + @slow @nightly diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 2c350582050d..671f1277f99f 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -136,3 +136,7 @@ def test_simple_inference_with_text_lora_fused(self): @unittest.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora_save_load(self): pass + + @unittest.skip("Not supported in CogVideoX.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + pass diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index a789221e79a0..a04285465951 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -30,6 +30,7 @@ from diffusers.utils import load_image from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import ( + is_flaky, nightly, numpy_cosine_similarity_distance, require_big_gpu_with_torch_cuda, @@ -128,6 +129,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass + @is_flaky + def test_multiple_wrong_adapter_name_raises_error(self): + super().test_multiple_wrong_adapter_name_raises_error() + @nightly @require_torch_gpu diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 30238c74873b..76d6dc48602b 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -37,6 +37,7 @@ from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import ( CaptureLogger, + is_flaky, load_image, nightly, numpy_cosine_similarity_distance, @@ -111,6 +112,10 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() + @is_flaky + def test_multiple_wrong_adapter_name_raises_error(self): + super().test_multiple_wrong_adapter_name_raises_error() + @slow @nightly diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b56d72920748..a94198efaa64 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1135,6 +1135,43 @@ def test_wrong_adapter_name_raises_error(self): pipe.set_adapters("adapter-1") _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + def test_multiple_wrong_adapter_name_raises_error(self): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} + logger = logging.get_logger("diffusers.loaders.lora_base") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components) + + wrong_components = sorted(set(scale_with_wrong_components.keys())) + msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " + self.assertTrue(msg in str(cap_logger.out)) + + # test this works. + pipe.set_adapters("adapter-1") + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + def test_simple_inference_with_text_denoiser_block_scale(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches