From 7803187be5ac0f8c6d55f3298c2edfe4d34903f8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 18 Aug 2023 09:12:12 +0530 Subject: [PATCH 1/3] enable lora for sdxl controlnets too. --- .../controlnet/pipeline_controlnet_sd_xl.py | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c4a8777c5ed0..9866875425f7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -14,6 +14,7 @@ import inspect +import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -1169,3 +1170,76 @@ def __call__( return (image,) return StableDiffusionXLPipelineOutput(images=image) + + # Overrride to properly handle the loading and unloading of the additional text encoder. + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + @classmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + state_dict = {} + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + state_dict.update(pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) From 6f4c452676bca6c7c5218d76b8fbc90370c5c128 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 18 Aug 2023 10:23:48 +0530 Subject: [PATCH 2/3] add: tests --- .../controlnet/test_controlnet_sdxl.py | 83 ++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 0c7ee4972f01..6f423776e66b 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import unittest import numpy as np @@ -27,9 +28,9 @@ UNet2DConditionModel, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel -from diffusers.utils import randn_tensor, torch_device +from diffusers.utils import load_image, randn_tensor, torch_device from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow from ..pipeline_params import ( IMAGE_TO_IMAGE_IMAGE_PARAMS, @@ -678,3 +679,81 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + +@slow +@require_torch_gpu +class ControlNetSDXLPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") + + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet + ) + pipe.enable_sequential_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (768, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.4185, 0.4127, 0.4089, 0.4046, 0.4115, 0.4096, 0.4081, 0.4112, 0.3913]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + def test_depth(self): + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0") + + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet + ) + pipe.enable_sequential_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Stormtrooper's lecture" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (512, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + def test_canny_lora(self): + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") + + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet + ) + pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") + pipe.enable_sequential_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "corgi" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (768, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.3806, 0.3895, 0.4046, 0.3799, 0.3872, 0.4052, 0.3852, 0.3965, 0.4057]) + assert np.allclose(original_image, expected_image, atol=1e-04) From 55d9d8f157f470168f66368c5042f3887b3f0e0b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 18 Aug 2023 10:39:20 +0530 Subject: [PATCH 3/3] fix: assertion values. --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 6f423776e66b..7906467e3918 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -755,5 +755,5 @@ def test_canny_lora(self): assert images[0].shape == (768, 512, 3) original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.3806, 0.3895, 0.4046, 0.3799, 0.3872, 0.4052, 0.3852, 0.3965, 0.4057]) + expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333]) assert np.allclose(original_image, expected_image, atol=1e-04)