diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 42c0c8e42252..aaaea147f7ab 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -34,6 +34,11 @@ jobs: runner: docker-cpu image: diffusers/diffusers-pytorch-cpu report: torch_cpu_models_schedulers + - name: LoRA + framework: lora + runner: docker-cpu + image: diffusers/diffusers-pytorch-cpu + report: torch_cpu_lora - name: Fast Flax CPU tests framework: flax runner: docker-cpu @@ -89,6 +94,14 @@ jobs: --make-reports=tests_${{ matrix.config.report }} \ tests/models tests/schedulers tests/others + - name: Run fast PyTorch LoRA CPU tests + if: ${{ matrix.config.framework == 'lora' }} + run: | + python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "not Flax and not Onnx and not Dependency" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/lora + - name: Run fast Flax TPU tests if: ${{ matrix.config.framework == 'flax' }} run: | @@ -170,4 +183,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: pr_${{ matrix.config.report }}_test_reports - path: reports \ No newline at end of file + path: reports diff --git a/tests/models/test_lora_layers.py b/tests/lora/test_lora_layers.py similarity index 69% rename from tests/models/test_lora_layers.py rename to tests/lora/test_lora_layers.py index ef6ade9af5c1..e54caeb9f0c2 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/lora/test_lora_layers.py @@ -14,6 +14,7 @@ # limitations under the License. import copy import os +import random import tempfile import time import unittest @@ -23,16 +24,22 @@ import torch.nn as nn import torch.nn.functional as F from huggingface_hub.repocard import RepoCard +from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( AutoencoderKL, + ControlNetModel, DDIMScheduler, DiffusionPipeline, EulerDiscreteScheduler, + PNDMScheduler, + StableDiffusionInpaintPipeline, StableDiffusionPipeline, + StableDiffusionXLControlNetPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, + UNet3DConditionModel, ) from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules from diffusers.models.attention_processor import ( @@ -41,9 +48,38 @@ AttnProcessor2_0, LoRAAttnProcessor, LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) -from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, slow, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, slow, torch_device + + +def create_lora_layers(model, mock_weights: bool = True): + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 + + return lora_attn_procs def create_unet_lora_layers(unet: nn.Module): @@ -91,6 +127,39 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): return text_encoder_lora_layers +def create_lora_3d_layers(model, mock_weights: bool = True): + lora_attn_procs = {} + for name in model.attn_processors.keys(): + has_cross_attention = name.endswith("attn2.processor") and not ( + name.startswith("transformer_in") or "temp_attentions" in name.split(".") + ) + cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + elif name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * model.config.attention_head_dim + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 + + return lora_attn_procs + + def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): with torch.no_grad(): for parameter in lora_attn_parameters: @@ -215,6 +284,91 @@ def create_lora_weight_file(self, tmpdirname): ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + @unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda") + def test_stable_diffusion_attn_processors(self): + # disable_full_determinism() + device = "cuda" # ensure determinism for the device-dependent torch.Generator + components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs() + + # run normal sd pipe + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run xformers attention + sd_pipe.enable_xformers_memory_efficient_attention() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run attention slicing + sd_pipe.enable_attention_slicing() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run vae attention slicing + sd_pipe.enable_vae_slicing() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run lora attention + attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) + attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} + sd_pipe.unet.set_attn_processor(attn_processors) + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run lora xformers attention + attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) + attn_processors = { + k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim) + for k, v in attn_processors.items() + } + attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} + sd_pipe.unet.set_attn_processor(attn_processors) + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # enable_full_determinism() + + def test_stable_diffusion_lora(self): + components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward 1 + _, _, inputs = self.get_dummy_inputs() + + output = sd_pipe(**inputs) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + # set lora layers + lora_attn_procs = create_lora_layers(sd_pipe.unet) + sd_pipe.unet.set_attn_processor(lora_attn_procs) + sd_pipe = sd_pipe.to(torch_device) + + # forward 2 + _, _, inputs = self.get_dummy_inputs() + + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) + image = output.images + image_slice_1 = image[0, -3:, -3:, -1] + + # forward 3 + _, _, inputs = self.get_dummy_inputs() + + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) + image = output.images + image_slice_2 = image[0, -3:, -3:, -1] + + assert np.abs(image_slice - image_slice_1).max() < 1e-2 + assert np.abs(image_slice - image_slice_2).max() > 1e-2 + def test_lora_save_load(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**pipeline_components) @@ -499,6 +653,126 @@ def test_lora_save_load_with_xformers(self): self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) +class SDXInpaintLoraMixinTests(unittest.TestCase): + def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True): + # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched + if output_pil: + # Get random floats in [0, 1] as image + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + mask_image = torch.ones_like(image) + # Convert image and mask_image to [0, 255] + image = 255 * image + mask_image = 255 * mask_image + # Convert to PIL image + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((img_res, img_res)) + mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB").resize((img_res, img_res)) + else: + # Get random floats in [0, 1] as image with spatial size (img_res, img_res) + image = floats_tensor((1, 3, img_res, img_res), rng=random.Random(seed)).to(device) + # Convert image to [-1, 1] + init_image = 2.0 * image - 1.0 + mask_image = torch.ones((1, 1, img_res, img_res), device=device) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": init_image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = PNDMScheduler(skip_prk_steps=True) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def test_stable_diffusion_inpaint_lora(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward 1 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + # set lora layers + lora_attn_procs = create_lora_layers(sd_pipe.unet) + sd_pipe.unet.set_attn_processor(lora_attn_procs) + sd_pipe = sd_pipe.to(torch_device) + + # forward 2 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) + image = output.images + image_slice_1 = image[0, -3:, -3:, -1] + + # forward 3 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) + image = output.images + image_slice_2 = image[0, -3:, -3:, -1] + + assert np.abs(image_slice - image_slice_1).max() < 1e-2 + assert np.abs(image_slice - image_slice_2).max() > 1e-2 + + class SDXLLoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -1051,6 +1325,495 @@ def test_save_load_fused_lora_modules(self): ), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth." +class UNet2DConditionLoRAModelTests(unittest.TestCase): + model_class = UNet2DConditionModel + main_input_name = "sample" + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_lora_processors(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + + # make sure we can set a list of attention processors + model.set_attn_processor(lora_attn_procs) + model.to(torch_device) + + # test that attn processors can be set to itself + model.set_attn_processor(model.attn_processors) + + with torch.no_grad(): + sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample1 - sample2).abs().max() < 3e-3 + assert (sample3 - sample4).abs().max() < 3e-3 + + # sample 2 and sample 3 should be different + assert (sample2 - sample3).abs().max() > 1e-4 + + def test_lora_save_load(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=False) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 5e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 5e-4 + + def test_lora_save_load_safetensors(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_safetensors_load_torch(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = create_lora_layers(model, mock_weights=False) + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors") + + def test_lora_save_torch_force_load_safetensors_error(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = create_lora_layers(model, mock_weights=False) + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=False) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + with self.assertRaises(IOError) as e: + new_model.load_attn_procs(tmpdirname, use_safetensors=True) + self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) + + def test_lora_on_off(self, expected_max_diff=1e-3): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + + model.set_default_attn_processor() + + with torch.no_grad(): + new_sample = model(**inputs_dict).sample + + max_diff_new_sample = (sample - new_sample).abs().max() + max_diff_old_sample = (sample - old_sample).abs().max() + + assert max_diff_new_sample < expected_max_diff + assert max_diff_old_sample < expected_max_diff + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_lora_xformers_on_off(self, expected_max_diff=1e-3): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + # default + with torch.no_grad(): + sample = model(**inputs_dict).sample + + model.enable_xformers_memory_efficient_attention() + on_sample = model(**inputs_dict).sample + + model.disable_xformers_memory_efficient_attention() + off_sample = model(**inputs_dict).sample + + max_diff_on_sample = (sample - on_sample).abs().max() + max_diff_off_sample = (sample - off_sample).abs().max() + + assert max_diff_on_sample < expected_max_diff + assert max_diff_off_sample < expected_max_diff + + +class UNet3DConditionModelTests(unittest.TestCase): + model_class = UNet3DConditionModel + main_input_name = "sample" + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 4, 32, 32) + + @property + def output_shape(self): + return (4, 4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ( + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_lora_processors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + lora_attn_procs = create_lora_3d_layers(model) + + # make sure we can set a list of attention processors + model.set_attn_processor(lora_attn_procs) + model.to(torch_device) + + # test that attn processors can be set to itself + model.set_attn_processor(model.attn_processors) + + with torch.no_grad(): + sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample1 - sample2).abs().max() < 3e-3 + assert (sample3 - sample4).abs().max() < 3e-3 + + # sample 2 and sample 3 should be different + assert (sample2 - sample3).abs().max() > 3e-3 + + def test_lora_save_load(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_3d_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=False) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-3 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_load_safetensors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_3d_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 3e-3 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_safetensors_load_torch(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = create_lora_3d_layers(model, mock_weights=False) + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors") + + def test_lora_save_torch_force_load_safetensors_error(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = create_lora_3d_layers(model, mock_weights=False) + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=False) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + with self.assertRaises(IOError) as e: + new_model.load_attn_procs(tmpdirname, use_safetensors=True) + self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) + + def test_lora_on_off(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_3d_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + + model.set_attn_processor(AttnProcessor()) + + with torch.no_grad(): + new_sample = model(**inputs_dict).sample + + assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - old_sample).abs().max() < 3e-3 + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_lora_xformers_on_off(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 4 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + lora_attn_procs = create_lora_3d_layers(model) + model.set_attn_processor(lora_attn_procs) + + # default + with torch.no_grad(): + sample = model(**inputs_dict).sample + + model.enable_xformers_memory_efficient_attention() + on_sample = model(**inputs_dict).sample + + model.disable_xformers_memory_efficient_attention() + off_sample = model(**inputs_dict).sample + + assert (sample - on_sample).abs().max() < 1e-4 + assert (sample - off_sample).abs().max() < 1e-4 + + @slow @require_torch_gpu class LoraIntegrationTests(unittest.TestCase): @@ -1498,6 +2261,29 @@ def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_canny_lora(self): + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") + + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet + ) + pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") + pipe.enable_sequential_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "corgi" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (768, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333]) + assert np.allclose(original_image, expected_image, atol=1e-04) + @nightly def test_sequential_fuse_unfuse(self): pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 8aa2099154a1..0f16e6432728 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -24,7 +24,7 @@ from pytest import mark from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor +from diffusers.models.attention_processor import CustomDiffusionAttnProcessor from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -45,33 +45,6 @@ enable_full_determinism() -def create_lora_layers(model, mock_weights: bool = True): - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - - if mock_weights: - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 - - return lora_attn_procs - - def create_custom_diffusion_layers(model, mock_weights: bool = True): train_kv = True train_q_out = True @@ -527,214 +500,6 @@ def test_model_xattn_padding(self): keeplast_out ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." - def test_lora_processors(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - sample1 = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - - # make sure we can set a list of attention processors - model.set_attn_processor(lora_attn_procs) - model.to(torch_device) - - # test that attn processors can be set to itself - model.set_attn_processor(model.attn_processors) - - with torch.no_grad(): - sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample1 - sample2).abs().max() < 3e-3 - assert (sample3 - sample4).abs().max() < 3e-3 - - # sample 2 and sample 3 should be different - assert (sample2 - sample3).abs().max() > 1e-4 - - def test_lora_save_load(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname) - - with torch.no_grad(): - new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample - new_sample).abs().max() < 5e-4 - - # LoRA and no LoRA should NOT be the same - assert (sample - old_sample).abs().max() > 5e-4 - - def test_lora_save_load_safetensors(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=True) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname) - - with torch.no_grad(): - new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample - new_sample).abs().max() < 1e-4 - - # LoRA and no LoRA should NOT be the same - assert (sample - old_sample).abs().max() > 1e-4 - - def test_lora_save_safetensors_load_torch(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - lora_attn_procs = create_lora_layers(model, mock_weights=False) - model.set_attn_processor(lora_attn_procs) - # Saving as torch, properly reloads with directly filename - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=True) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors") - - def test_lora_save_torch_force_load_safetensors_error(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - lora_attn_procs = create_lora_layers(model, mock_weights=False) - model.set_attn_processor(lora_attn_procs) - # Saving as torch, properly reloads with directly filename - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - with self.assertRaises(IOError) as e: - new_model.load_attn_procs(tmpdirname, use_safetensors=True) - self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) - - def test_lora_on_off(self, expected_max_diff=1e-3): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - - model.set_default_attn_processor() - - with torch.no_grad(): - new_sample = model(**inputs_dict).sample - - max_diff_new_sample = (sample - new_sample).abs().max() - max_diff_old_sample = (sample - old_sample).abs().max() - - assert max_diff_new_sample < expected_max_diff - assert max_diff_old_sample < expected_max_diff - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_lora_xformers_on_off(self, expected_max_diff=1e-3): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - # default - with torch.no_grad(): - sample = model(**inputs_dict).sample - - model.enable_xformers_memory_efficient_attention() - on_sample = model(**inputs_dict).sample - - model.disable_xformers_memory_efficient_attention() - off_sample = model(**inputs_dict).sample - - max_diff_on_sample = (sample - on_sample).abs().max() - max_diff_off_sample = (sample - off_sample).abs().max() - - assert max_diff_on_sample < expected_max_diff - assert max_diff_off_sample < expected_max_diff - def test_custom_diffusion_processors(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index f0d6a8d72571..9efaea8d651b 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -13,15 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile import unittest import numpy as np import torch from diffusers.models import ModelMixin, UNet3DConditionModel -from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device @@ -34,39 +31,6 @@ logger = logging.get_logger(__name__) -def create_lora_layers(model, mock_weights: bool = True): - lora_attn_procs = {} - for name in model.attn_processors.keys(): - has_cross_attention = name.endswith("attn2.processor") and not ( - name.startswith("transformer_in") or "temp_attentions" in name.split(".") - ) - cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - elif name.startswith("transformer_in"): - # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 - hidden_size = 8 * model.config.attention_head_dim - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - - if mock_weights: - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 - - return lora_attn_procs - - @skip_mps class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet3DConditionModel @@ -197,203 +161,6 @@ def test_model_attention_slicing(self): output = model(**inputs_dict) assert output is not None - def test_lora_processors(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - sample1 = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - - # make sure we can set a list of attention processors - model.set_attn_processor(lora_attn_procs) - model.to(torch_device) - - # test that attn processors can be set to itself - model.set_attn_processor(model.attn_processors) - - with torch.no_grad(): - sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample1 - sample2).abs().max() < 3e-3 - assert (sample3 - sample4).abs().max() < 3e-3 - - # sample 2 and sample 3 should be different - assert (sample2 - sample3).abs().max() > 3e-3 - - def test_lora_save_load(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname) - - with torch.no_grad(): - new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample - new_sample).abs().max() < 1e-3 - - # LoRA and no LoRA should NOT be the same - assert (sample - old_sample).abs().max() > 1e-4 - - def test_lora_save_load_safetensors(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=True) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname) - - with torch.no_grad(): - new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample - new_sample).abs().max() < 3e-3 - - # LoRA and no LoRA should NOT be the same - assert (sample - old_sample).abs().max() > 1e-4 - - def test_lora_save_safetensors_load_torch(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - lora_attn_procs = create_lora_layers(model, mock_weights=False) - model.set_attn_processor(lora_attn_procs) - # Saving as torch, properly reloads with directly filename - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors") - - def test_lora_save_torch_force_load_safetensors_error(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - lora_attn_procs = create_lora_layers(model, mock_weights=False) - model.set_attn_processor(lora_attn_procs) - # Saving as torch, properly reloads with directly filename - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - with self.assertRaises(IOError) as e: - new_model.load_attn_procs(tmpdirname, use_safetensors=True) - self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) - - def test_lora_on_off(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - - model.set_attn_processor(AttnProcessor()) - - with torch.no_grad(): - new_sample = model(**inputs_dict).sample - - assert (sample - new_sample).abs().max() < 1e-4 - assert (sample - old_sample).abs().max() < 3e-3 - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_lora_xformers_on_off(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 4 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - # default - with torch.no_grad(): - sample = model(**inputs_dict).sample - - model.enable_xformers_memory_efficient_attention() - on_sample = model(**inputs_dict).sample - - model.disable_xformers_memory_efficient_attention() - off_sample = model(**inputs_dict).sample - - assert (sample - on_sample).abs().max() < 1e-4 - assert (sample - off_sample).abs().max() < 1e-4 - def test_feed_forward_chunking(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 32 @@ -411,6 +178,3 @@ def test_feed_forward_chunking(self): self.assertEqual(output.shape, output_2.shape, "Shape doesn't match") assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 - - -# (todo: sayakpaul) implement SLOW tests. diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index c4b91f0eb79c..4fff88434bc3 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -775,26 +775,3 @@ def test_depth(self): original_image = images[0, -3:, -3:, -1].flatten() expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853]) assert np.allclose(original_image, expected_image, atol=1e-04) - - def test_canny_lora(self): - controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") - - pipe = StableDiffusionXLControlNetPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet - ) - pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") - pipe.enable_sequential_cpu_offload() - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "corgi" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ) - - images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images - - assert images[0].shape == (768, 512, 3) - - original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333]) - assert np.allclose(original_image, expected_image, atol=1e-04) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index e67bfd661cc1..95762e36423c 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -37,7 +37,7 @@ UNet2DConditionModel, logging, ) -from diffusers.models.attention_processor import AttnProcessor, LoRAXFormersAttnProcessor +from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.testing_utils import ( CaptureLogger, enable_full_determinism, @@ -51,8 +51,6 @@ torch_device, ) -from ...models.test_lora_layers import create_unet_lora_layers -from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin @@ -188,40 +186,6 @@ def test_stable_diffusion_ddim(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - def test_stable_diffusion_lora(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - - components = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - # forward 1 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs) - image = output.images - image_slice = image[0, -3:, -3:, -1] - - # set lora layers - lora_attn_procs = create_lora_layers(sd_pipe.unet) - sd_pipe.unet.set_attn_processor(lora_attn_procs) - sd_pipe = sd_pipe.to(torch_device) - - # forward 2 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) - image = output.images - image_slice_1 = image[0, -3:, -3:, -1] - - # forward 3 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) - image = output.images - image_slice_2 = image[0, -3:, -3:, -1] - - assert np.abs(image_slice - image_slice_1).max() < 1e-2 - assert np.abs(image_slice - image_slice_2).max() > 1e-2 - def test_stable_diffusion_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**components) @@ -374,56 +338,6 @@ def test_stable_diffusion_pndm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - @unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda") - def test_stable_diffusion_attn_processors(self): - # disable_full_determinism() - device = "cuda" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - - # run normal sd pipe - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run xformers attention - sd_pipe.enable_xformers_memory_efficient_attention() - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run attention slicing - sd_pipe.enable_attention_slicing() - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run vae attention slicing - sd_pipe.enable_vae_slicing() - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run lora attention - attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) - attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} - sd_pipe.unet.set_attn_processor(attn_processors) - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run lora xformers attention - attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) - attn_processors = { - k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim) - for k, v in attn_processors.items() - } - attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} - sd_pipe.unet.set_attn_processor(attn_processors) - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # enable_full_determinism() - def test_stable_diffusion_no_safety_checker(self): pipe = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index cd5ba4087ab1..c7731d97a878 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -49,7 +49,6 @@ torch_device, ) -from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin @@ -221,40 +220,6 @@ def test_stable_diffusion_inpaint_image_tensor(self): assert out_pil.shape == (1, 64, 64, 3) assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2 - def test_stable_diffusion_inpaint_lora(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - - components = self.get_dummy_components() - sd_pipe = StableDiffusionInpaintPipeline(**components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - # forward 1 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs) - image = output.images - image_slice = image[0, -3:, -3:, -1] - - # set lora layers - lora_attn_procs = create_lora_layers(sd_pipe.unet) - sd_pipe.unet.set_attn_processor(lora_attn_procs) - sd_pipe = sd_pipe.to(torch_device) - - # forward 2 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) - image = output.images - image_slice_1 = image[0, -3:, -3:, -1] - - # forward 3 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) - image = output.images - image_slice_2 = image[0, -3:, -3:, -1] - - assert np.abs(image_slice - image_slice_1).max() < 1e-2 - assert np.abs(image_slice - image_slice_2).max() > 1e-2 - def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) @@ -410,10 +375,6 @@ def test_stable_diffusion_inpaint(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - @unittest.skip("skipped here because area stays unchanged due to mask") - def test_stable_diffusion_inpaint_lora(self): - ... - def test_stable_diffusion_inpaint_2_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components()