From 46b574b0e4aff8e962f0fc4712bafbf9d0ec35fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Tue, 7 May 2024 00:43:09 -0400 Subject: [PATCH 1/3] return layer weight if not found --- src/diffusers/utils/peft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 8ea12e2e3b3f..f9f9fd382711 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -246,7 +246,7 @@ def get_module_weight(weight_for_adapter, module_name): for layer_name, weight_ in weight_for_adapter.items(): if layer_name in module_name: return weight_ - raise RuntimeError(f"No LoRA weight found for module {module_name}.") + return weight_for_adapter[layer_name] # iterate over each adapter, make it active and set the corresponding scaling weight for adapter_name, weight in zip(adapter_names, weights): From 8f819f33d611a66b21697d500e32e84ed6059c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Tue, 7 May 2024 02:15:03 -0400 Subject: [PATCH 2/3] better system and test --- src/diffusers/utils/peft_utils.py | 7 ++++++- tests/lora/test_lora_layers_sdxl.py | 30 +++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index f9f9fd382711..79be751f9982 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -246,7 +246,12 @@ def get_module_weight(weight_for_adapter, module_name): for layer_name, weight_ in weight_for_adapter.items(): if layer_name in module_name: return weight_ - return weight_for_adapter[layer_name] + + parts = module_name.split(".") + key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}" + blocK_weight = weight_for_adapter.get(key, 1.0) + + return blocK_weight # iterate over each adapter, make it active and set the corresponding scaling weight for adapter_name, weight in zip(adapter_names, weights): diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index b46b887d10fb..a8b2d2759f41 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -202,6 +202,36 @@ def test_sdxl_1_0_lora(self): pipe.unload_lora_weights() release_memory(pipe) + def test_sdxl_1_0_blockwise_lora(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, adapter_name="offset") + scales = { + "unet": { + "down": {"block_1": [1.0, 1.0], "block_2": [1.0, 1.0]}, + "mid": 1.0, + "up": {"block_0": [1.0, 1.0, 1.0], "block_1": [1.0, 1.0, 1.0]}, + }, + } + pipe.set_adapters(["offset"], [scales]) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + max_diff = numpy_cosine_similarity_distance(expected, images) + assert max_diff < 1e-4 + + pipe.unload_lora_weights() + release_memory(pipe) + def test_sdxl_lcm_lora(self): pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 From d5e5bc1e9c8c36cca27698ac06f498933b4f5ec7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Tue, 7 May 2024 03:10:28 -0400 Subject: [PATCH 3/3] key example and typo --- src/diffusers/utils/peft_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 79be751f9982..ca55192ff7ae 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -248,10 +248,11 @@ def get_module_weight(weight_for_adapter, module_name): return weight_ parts = module_name.split(".") + # e.g. key = "down_blocks.1.attentions.0" key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}" - blocK_weight = weight_for_adapter.get(key, 1.0) + block_weight = weight_for_adapter.get(key, 1.0) - return blocK_weight + return block_weight # iterate over each adapter, make it active and set the corresponding scaling weight for adapter_name, weight in zip(adapter_names, weights):