From 6c6cade1a7ad313130093086a801a80422b6e4b1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 12:52:56 +0530 Subject: [PATCH 01/17] migrate lora pipeline tests to pytest --- fix_asserts_lora.py | 114 + tests/lora/test_lora_layers_auraflow.py | 2 +- tests/lora/test_lora_layers_cogvideox.py | 2 +- tests/lora/test_lora_layers_cogview4.py | 2 +- tests/lora/test_lora_layers_flux.py | 522 ++--- tests/lora/test_lora_layers_flux.py.bak | 1041 +++++++++ tests/lora/test_lora_layers_hunyuanvideo.py | 2 +- tests/lora/test_lora_layers_ltx_video.py | 2 +- tests/lora/test_lora_layers_lumina2.py | 2 +- tests/lora/test_lora_layers_mochi.py | 2 +- tests/lora/test_lora_layers_qwenimage.py | 2 +- tests/lora/test_lora_layers_sana.py | 2 +- tests/lora/test_lora_layers_sd.py | 6 +- tests/lora/test_lora_layers_sd3.py | 2 +- tests/lora/test_lora_layers_sdxl.py | 2 +- tests/lora/test_lora_layers_wan.py | 2 +- tests/lora/test_lora_layers_wanvace.py | 29 +- tests/lora/utils.py | 1399 ++++------- tests/lora/utils.py.bak | 2328 +++++++++++++++++++ 19 files changed, 4104 insertions(+), 1359 deletions(-) create mode 100644 fix_asserts_lora.py create mode 100644 tests/lora/test_lora_layers_flux.py.bak create mode 100644 tests/lora/utils.py.bak diff --git a/fix_asserts_lora.py b/fix_asserts_lora.py new file mode 100644 index 000000000000..32259574f3a7 --- /dev/null +++ b/fix_asserts_lora.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +""" +Fix F631-style asserts of the form: + assert (, "message") +…into: + assert , "message" + +Scans recursively under tests/lora/. + +Usage: + python fix_assert_tuple.py [--root tests/lora] [--dry-run] +""" + +import argparse +import ast +from pathlib import Path +from typing import Tuple, List, Optional + + +class AssertTupleFixer(ast.NodeTransformer): + """ + Transform `assert (, )` into `assert , `. + We only rewrite when the assert test is a Tuple with exactly 2 elements. + """ + def __init__(self): + super().__init__() + self.fixed_locs: List[Tuple[int, int]] = [] + + def visit_Assert(self, node: ast.Assert) -> ast.AST: + self.generic_visit(node) + if isinstance(node.test, ast.Tuple) and len(node.test.elts) == 2: + cond, msg = node.test.elts + # Convert only if this *looks* like a real assert-with-message tuple, + # i.e. keep anything as msg (string, f-string, name, call, etc.) + new_node = ast.Assert(test=cond, msg=msg) + ast.copy_location(new_node, node) + ast.fix_missing_locations(new_node) + self.fixed_locs.append((node.lineno, node.col_offset)) + return new_node + return node + + +def fix_file(path: Path, dry_run: bool = False) -> int: + """ + Returns number of fixes applied. + """ + try: + src = path.read_text(encoding="utf-8") + except Exception as e: + print(f"Could not read {path}: {e}") + return 0 + + try: + tree = ast.parse(src, filename=str(path)) + except SyntaxError: + # Skip files that don’t parse (partial edits, etc.) + return 0 + + fixer = AssertTupleFixer() + new_tree = fixer.visit(tree) + fixes = len(fixer.fixed_locs) + if fixes == 0: + return 0 + + try: + new_src = ast.unparse(new_tree) # Python 3.9+ + except Exception as e: + print(f"Failed to unparse {path}: {e}") + return 0 + + if dry_run: + for (lineno, col) in fixer.fixed_locs: + print(f"[DRY-RUN] {path}:{lineno}:{col} -> fixed assert tuple") + return fixes + + # Backup and write + backup = path.with_suffix(path.suffix + ".bak") + try: + if not backup.exists(): + backup.write_text(src, encoding="utf-8") + path.write_text(new_src, encoding="utf-8") + for (lineno, col) in fixer.fixed_locs: + print(f"Fixed {path}:{lineno}:{col}") + except Exception as e: + print(f"Failed to write {path}: {e}") + return 0 + + return fixes + + +def main(): + ap = argparse.ArgumentParser(description="Fix F631-style tuple asserts.") + ap.add_argument("--root", default="tests/lora", help="Root directory to scan") + ap.add_argument("--dry-run", action="store_true", help="Report changes but don't write") + args = ap.parse_args() + + root = Path(args.root) + if not root.exists(): + print(f"{root} does not exist.") + return + + total_files = 0 + total_fixes = 0 + for pyfile in root.rglob("*.py"): + total_files += 1 + total_fixes += fix_file(pyfile, dry_run=args.dry_run) + + print(f"\nScanned {total_files} file(s). Applied {total_fixes} fix(es).") + if args.dry_run: + print("Run again without --dry-run to apply changes.") + + +if __name__ == "__main__": + main() diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 91f63c4b56c4..55d69b5bfa4f 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -40,7 +40,7 @@ @require_peft_backend -class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestAuraFlowLoRA(PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index fa57b4c9c2f9..4d407ad420ca 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -39,7 +39,7 @@ @require_peft_backend -class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestCogVideoXLoRA(PeftLoraLoaderMixinTests): pipeline_class = CogVideoXPipeline scheduler_cls = CogVideoXDPMScheduler scheduler_kwargs = {"timestep_spacing": "trailing"} diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 30eb8fbb6367..de732b85268b 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -47,7 +47,7 @@ def from_pretrained(*args, **kwargs): @require_peft_backend @skip_mps -class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestCogView4LoRA(PeftLoraLoaderMixinTests): pipeline_class = CogView4Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b840d7ac72ce..ff53983ecf52 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -1,17 +1,3 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import copy import gc import os @@ -20,6 +6,7 @@ import unittest import numpy as np +import pytest import safetensors.torch import torch from parameterized import parameterized @@ -46,14 +33,12 @@ if is_peft_available(): from peft.utils import get_peft_model_state_dict - sys.path.append(".") - -from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set @require_peft_backend -class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestFluxLoRA(PeftLoraLoaderMixinTests): pipeline_class = FluxPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -83,10 +68,10 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "scaling_factor": 1.5035, } has_two_text_encoders = True - tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" - tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" - text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" - text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + (tokenizer_cls, tokenizer_id) = (CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2") + (tokenizer_2_cls, tokenizer_2_id) = (AutoTokenizer, "hf-internal-testing/tiny-random-t5") + (text_encoder_cls, text_encoder_id) = (CLIPTextModel, "peft-internal-testing/tiny-clip-text-2") + (text_encoder_2_cls, text_encoder_2_id) = (T5EncoderModel, "hf-internal-testing/tiny-random-t5") @property def output_shape(self): @@ -97,11 +82,9 @@ def get_dummy_inputs(self, with_generator=True): sequence_length = 10 num_channels = 4 sizes = (32, 32) - generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", "num_inference_steps": 4, @@ -112,147 +95,110 @@ def get_dummy_inputs(self, with_generator=True): } if with_generator: pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs + return (noise, input_ids, pipeline_inputs) def test_with_alpha_in_state_dict(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - # modify the state dict to have alpha values following - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors state_dict_with_alpha = safetensors.torch.load_file( os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") ) alpha_dict = {} for k, v in state_dict_with_alpha.items(): - # only do for `transformer` and for the k projections -- should be enough to test. - if "transformer" in k and "to_k" in k and "lora_A" in k: + if "transformer" in k and "to_k" in k and ("lora_A" in k): alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) state_dict_with_alpha.update(alpha_dict) - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" pipe.unload_lora_weights() pipe.load_lora_weights(state_dict_with_alpha) images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images - - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results." ) - self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) + assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001) - def test_lora_expansion_works_for_absent_keys(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): + (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - - # Modify the config to have a layer which won't be present in the second LoRA we will load. + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config.target_modules.add("x_embedder") - pipe.transformer.add_adapter(modified_denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertFalse( - np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), "LoRA should lead to different results.", ) - with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") - - # Modify the state dict to exclude "x_embedder" related LoRA params. lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} - + lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") pipe.set_adapters(["one", "two"]) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images - - self.assertFalse( - np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), "Different LoRAs should lead to different results.", ) - self.assertFalse( - np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), "LoRA should lead to different results.", ) - def test_lora_expansion_works_for_extra_keys(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): + (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - - # Modify the config to have a layer which won't be present in the first LoRA we will load. + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config.target_modules.add("x_embedder") - pipe.transformer.add_adapter(modified_denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertFalse( - np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), "LoRA should lead to different results.", ) - with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe.unload_lora_weights() - # Modify the state dict to exclude "x_embedder" related LoRA params. lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} + lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") - - # Load state dict with `x_embedder`. pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") - pipe.set_adapters(["one", "two"]) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images - - self.assertFalse( - np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), "Different LoRAs should lead to different results.", ) - self.assertFalse( - np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), "LoRA should lead to different results.", ) @@ -273,7 +219,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass -class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestFluxControlLoRA(PeftLoraLoaderMixinTests): pipeline_class = FluxControlPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -304,10 +250,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "scaling_factor": 1.5035, } has_two_text_encoders = True - tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" - tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" - text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" - text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + (tokenizer_cls, tokenizer_id) = (CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2") + (tokenizer_2_cls, tokenizer_2_id) = (AutoTokenizer, "hf-internal-testing/tiny-random-t5") + (text_encoder_cls, text_encoder_id) = (CLIPTextModel, "peft-internal-testing/tiny-clip-text-2") + (text_encoder_2_cls, text_encoder_2_id) = (T5EncoderModel, "hf-internal-testing/tiny-random-t5") @property def output_shape(self): @@ -318,11 +264,9 @@ def get_dummy_inputs(self, with_generator=True): sequence_length = 10 num_channels = 4 sizes = (32, 32) - generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - np.random.seed(0) pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", @@ -335,22 +279,17 @@ def get_dummy_inputs(self, with_generator=True): } if with_generator: pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs + return (noise, input_ids, pipeline_inputs) def test_with_norm_in_state_dict(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.INFO) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: norm_state_dict = {} for name, module in pipe.transformer.named_modules(): @@ -359,70 +298,54 @@ def test_with_norm_in_state_dict(self): norm_state_dict[f"transformer.{name}.weight"] = torch.randn( module.weight.shape, device=module.weight.device, dtype=module.weight.dtype ) - with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(norm_state_dict) lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( + assert ( "The provided state dict contains normalization layers in addition to LoRA layers" in cap_logger.out ) - self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0) - + assert len(pipe.transformer._transformer_norm_layers) > 0 pipe.unload_lora_weights() lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue(pipe.transformer._transformer_norm_layers is None) - self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5)) - self.assertFalse( - np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested" + assert pipe.transformer._transformer_norm_layers is None + assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05) + assert not ( + np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), + f"{norm_layer} is tested", ) - with CaptureLogger(logger) as cap_logger: for key in list(norm_state_dict.keys()): norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) pipe.load_lora_weights(norm_state_dict) - - self.assertTrue( - "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out - ) + assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out def test_lora_parameter_expanded_shapes(self): - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - - # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) - self.assertTrue( - transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + assert transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) - original_transformer_state_dict = pipe.transformer.state_dict() x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) - self.assertTrue( - "x_embedder.weight" in incompatible_keys.missing_keys, - "Could not find x_embedder.weight in the missing keys.", + assert "x_embedder.weight" in incompatible_keys.missing_keys, ( + "Could not find x_embedder.weight in the missing keys." ) transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) pipe.transformer = transformer - - out_features, in_features = pipe.transformer.x_embedder.weight.shape + (out_features, in_features) = pipe.transformer.x_embedder.weight.shape rank = 4 - dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -431,18 +354,13 @@ def test_lora_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - - # Testing opposite direction where the LoRA params are zero-padded. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -454,38 +372,27 @@ def test_lora_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) + assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out def test_normal_lora_with_expanded_lora_raises_error(self): - # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then - # load shape expanded LoRA (such as Control LoRA). - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - - # Change the transformer config to mimic a real use case. + (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer - pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - - out_features, in_features = pipe.transformer.x_embedder.weight.shape + (out_features, in_features) = pipe.transformer.x_embedder.weight.shape rank = 4 - shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -494,102 +401,68 @@ def test_normal_lora_with_expanded_lora_raises_error(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - self.assertTrue(pipe.get_active_adapters() == ["adapter-1"]) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert pipe.get_active_adapters() == ["adapter-1"] + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-2") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) - self.assertTrue(pipe.get_active_adapters() == ["adapter-2"]) - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out + assert pipe.get_active_adapters() == ["adapter-2"] lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) - - # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. - # This should raise a runtime error on input shapes being incompatible. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - # Change the transformer config to mimic a real use case. + assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001) + (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer - pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - - out_features, in_features = pipe.transformer.x_embedder.weight.shape + (out_features, in_features) = pipe.transformer.x_embedder.weight.shape rank = 4 - lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } pipe.load_lora_weights(lora_state_dict, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) - self.assertTrue(pipe.transformer.config.in_channels == in_features) - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features + assert pipe.transformer.config.in_channels == in_features lora_state_dict = { "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, } - - # We should check for input shapes being incompatible here. But because above mentioned issue is - # not a supported use case, and because of the PEFT renaming, we will currently have a shape - # mismatch error. - self.assertRaisesRegex( - RuntimeError, - "size mismatch for x_embedder.lora_A.adapter-2.weight", - pipe.load_lora_weights, - lora_state_dict, - "adapter-2", - ) + with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"): + pipe.load_lora_weights(lora_state_dict, "adapter-2") def test_fuse_expanded_lora_with_regular_lora(self): - # This test checks if it works when a lora with expanded shapes (like control loras) but - # another lora with correct shapes is loaded. The opposite direction isn't supported and is - # tested with it. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - - # Change the transformer config to mimic a real use case. + (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer - pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - - out_features, in_features = pipe.transformer.x_embedder.weight.shape + (out_features, in_features) = pipe.transformer.x_embedder.weight.shape rank = 4 - shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -597,98 +470,74 @@ def test_fuse_expanded_lora_with_regular_lora(self): "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, } pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - - _, _, inputs = self.get_dummy_inputs(with_generator=False) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - pipe.load_lora_weights(lora_state_dict, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3)) - + assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001) + assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001) + assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001) pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3)) + assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001) def test_load_regular_lora(self): - # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded - # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those - # transformers include Flux Fill, Flux Control, etc. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - out_features, in_features = pipe.transformer.x_embedder.weight.shape + (out_features, in_features) = pipe.transformer.x_embedder.weight.shape rank = 4 - in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA. + in_features = in_features // 2 normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.INFO) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) - self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3)) + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 + assert not np.allclose(original_output, lora_output, atol=0.001, rtol=0.001) def test_lora_unload_with_parameter_expanded_shapes(self): - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - + (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - - # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) - self.assertTrue( - transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + assert transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) - - # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. components["transformer"] = transformer pipe = FluxPipeline(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) control_image = inputs.pop("control_image") original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - control_pipe = self.pipeline_class(**components) - out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + (out_features, in_features) = control_pipe.transformer.x_embedder.weight.shape rank = 4 - dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -697,64 +546,49 @@ def test_lora_unload_with_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: control_pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" inputs["control_image"] = control_image lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - + assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") control_pipe.unload_lora_weights(reset_to_overwritten_params=True) - self.assertTrue( - control_pipe.transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + assert control_pipe.transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}" ) loaded_pipe = FluxPipeline.from_pipe(control_pipe) - self.assertTrue( - loaded_pipe.transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}", + assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}" ) inputs.pop("control_image") unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) - self.assertTrue(pipe.transformer.config.in_channels == in_features) + assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001) + assert np.allclose(unloaded_lora_out, original_out, atol=0.0001, rtol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features + assert pipe.transformer.config.in_channels == in_features def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - + (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - - # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) - self.assertTrue( - transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + assert transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) - - # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. components["transformer"] = transformer pipe = FluxPipeline(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) control_image = inputs.pop("control_image") original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - control_pipe = self.pipeline_class(**components) - out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + (out_features, in_features) = control_pipe.transformer.x_embedder.weight.shape rank = 4 - dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -763,26 +597,21 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): } with CaptureLogger(logger) as cap_logger: control_pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" inputs["control_image"] = control_image lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - + assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") control_pipe.unload_lora_weights(reset_to_overwritten_params=False) - self.assertTrue( - control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}" ) no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) - self.assertTrue(pipe.transformer.config.in_channels == in_features * 2) + assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 + assert pipe.transformer.config.in_channels == in_features * 2 @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale(self): @@ -818,15 +647,12 @@ class FluxLoRAIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() - gc.collect() backend_empty_cache(torch_device) - self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) def tearDown(self): super().tearDown() - del self.pipeline gc.collect() backend_empty_cache(torch_device) @@ -835,13 +661,8 @@ def test_flux_the_last_ben(self): self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() - # Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI - # run supports it. We have about 34GB RAM in the CI runner which kills the test when run with - # `enable_model_cpu_offload()`. We repeat this for the other tests, too. self.pipeline = self.pipeline.to(torch_device) - prompt = "jon snow eating pizza with ketchup" - out = self.pipeline( prompt, num_inference_steps=self.num_inference_steps, @@ -851,17 +672,14 @@ def test_flux_the_last_ben(self): ).images out_slice = out[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246]) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 + assert max_diff < 0.001 def test_flux_kohya(self): self.pipeline.load_lora_weights("Norod78/brain-slug-flux") self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() self.pipeline = self.pipeline.to(torch_device) - prompt = "The cat with a brain slug earring" out = self.pipeline( prompt, @@ -870,20 +688,16 @@ def test_flux_kohya(self): output_type="np", generator=torch.manual_seed(self.seed), ).images - out_slice = out[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484]) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 + assert max_diff < 0.001 def test_flux_kohya_with_text_encoder(self): self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() self.pipeline = self.pipeline.to(torch_device) - prompt = "optimus is cleaning the house with broomstick" out = self.pipeline( prompt, @@ -892,19 +706,15 @@ def test_flux_kohya_with_text_encoder(self): output_type="np", generator=torch.manual_seed(self.seed), ).images - out_slice = out[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219]) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 + assert max_diff < 0.001 def test_flux_kohya_embedders_conversion(self): """Test that embedders load without throwing errors""" self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") self.pipeline.unload_lora_weights() - assert True def test_flux_xlabs(self): @@ -912,9 +722,7 @@ def test_flux_xlabs(self): self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() self.pipeline = self.pipeline.to(torch_device) - prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" - out = self.pipeline( prompt, num_inference_steps=self.num_inference_steps, @@ -923,11 +731,9 @@ def test_flux_xlabs(self): generator=torch.manual_seed(self.seed), ).images out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980]) - + expected_slice = np.array([0.3965, 0.418, 0.4434, 0.4082, 0.4375, 0.459, 0.4141, 0.4375, 0.498]) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 + assert max_diff < 0.001 def test_flux_xlabs_load_lora_with_single_blocks(self): self.pipeline.load_lora_weights( @@ -936,9 +742,7 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() self.pipeline.enable_model_cpu_offload() - prompt = "a wizard mouse playing chess" - out = self.pipeline( prompt, num_inference_steps=self.num_inference_steps, @@ -951,8 +755,7 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): [0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625] ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 + assert max_diff < 0.001 @nightly @@ -966,17 +769,14 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() - gc.collect() backend_empty_cache(torch_device) - self.pipeline = FluxControlPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ).to(torch_device) def tearDown(self): super().tearDown() - gc.collect() backend_empty_cache(torch_device) @@ -985,7 +785,6 @@ def test_lora(self, lora_ckpt_id): self.pipeline.load_lora_weights(lora_ckpt_id) self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() - if "Canny" in lora_ckpt_id: control_image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" @@ -994,7 +793,6 @@ def test_lora(self, lora_ckpt_id): control_image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" ) - image = self.pipeline( prompt=self.prompt, control_image=control_image, @@ -1005,16 +803,13 @@ def test_lora(self, lora_ckpt_id): output_type="np", generator=torch.manual_seed(self.seed), ).images - out_slice = image[0, -3:, -3:, -1].flatten() if "Canny" in lora_ckpt_id: expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516]) else: - expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) - + expected_slice = np.array([0.8203, 0.832, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 + assert max_diff < 0.001 @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) def test_lora_with_turbo(self, lora_ckpt_id): @@ -1022,7 +817,6 @@ def test_lora_with_turbo(self, lora_ckpt_id): self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() - if "Canny" in lora_ckpt_id: control_image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" @@ -1031,7 +825,6 @@ def test_lora_with_turbo(self, lora_ckpt_id): control_image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" ) - image = self.pipeline( prompt=self.prompt, control_image=control_image, @@ -1042,13 +835,10 @@ def test_lora_with_turbo(self, lora_ckpt_id): output_type="np", generator=torch.manual_seed(self.seed), ).images - out_slice = image[0, -3:, -3:, -1].flatten() if "Canny" in lora_ckpt_id: expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484]) else: - expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562]) - + expected_slice = np.array([0.668, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562]) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 + assert max_diff < 0.001 diff --git a/tests/lora/test_lora_layers_flux.py.bak b/tests/lora/test_lora_layers_flux.py.bak new file mode 100644 index 000000000000..ee0235266307 --- /dev/null +++ b/tests/lora/test_lora_layers_flux.py.bak @@ -0,0 +1,1041 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import gc +import os +import sys +import tempfile +import unittest + +import numpy as np +import pytest +import safetensors.torch +import torch +from parameterized import parameterized +from PIL import Image +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel +from diffusers.utils import load_image, logging + +from ..testing_utils import ( + CaptureLogger, + backend_empty_cache, + floats_tensor, + is_peft_available, + nightly, + numpy_cosine_similarity_distance, + require_big_accelerator, + require_peft_backend, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_peft_available(): + from peft.utils import get_peft_model_state_dict + +sys.path.append(".") + +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 + + +@require_peft_backend +class TestFluxLoRA(PeftLoraLoaderMixinTests): + pipeline_class = FluxPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = {} + transformer_kwargs = { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + transformer_cls = FluxTransformer2DModel + vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 1, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, + } + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + @property + def output_shape(self): + return (1, 8, 8, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_with_alpha_in_state_dict(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.transformer.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + # modify the state dict to have alpha values following + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors + state_dict_with_alpha = safetensors.torch.load_file( + os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") + ) + alpha_dict = {} + for k, v in state_dict_with_alpha.items(): + # only do for `transformer` and for the k projections -- should be enough to test. + if "transformer" in k and "to_k" in k and "lora_A" in k: + alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) + state_dict_with_alpha.update(alpha_dict) + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + + pipe.unload_lora_weights() + pipe.load_lora_weights(state_dict_with_alpha) + images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images + + assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), "Loading from saved checkpoints should give same results." + + assert not np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3) + + def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + # Modify the config to have a layer which won't be present in the second LoRA we will load. + modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) + modified_denoiser_lora_config.target_modules.add("x_embedder") + + pipe.transformer.add_adapter(modified_denoiser_lora_config) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + assert not( + np.allclose(images_lora, base_pipe_output, atol=1e-3, rtol=1e-3), + "LoRA should lead to different results.", + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") + + # Modify the state dict to exclude "x_embedder" related LoRA params. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} + + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") + pipe.set_adapters(["one", "two"]) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + + assert not( + np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + "Different LoRAs should lead to different results.", + ) + assert not( + np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + "LoRA should lead to different results.", + ) + + def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + # Modify the config to have a layer which won't be present in the first LoRA we will load. + modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) + modified_denoiser_lora_config.target_modules.add("x_embedder") + + pipe.transformer.add_adapter(modified_denoiser_lora_config) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + assert not( + np.allclose(images_lora, base_pipe_output, atol=1e-3, rtol=1e-3), + "LoRA should lead to different results.", + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() + # Modify the state dict to exclude "x_embedder" related LoRA params. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") + + # Load state dict with `x_embedder`. + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") + + pipe.set_adapters(["one", "two"]) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + + assert not( + np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), + "Different LoRAs should lead to different results.", + ) + assert not( + np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), + "LoRA should lead to different results.", + ) + + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + pass + + +class TestFluxControlLoRA(PeftLoraLoaderMixinTests): + pipeline_class = FluxControlPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = {} + transformer_kwargs = { + "patch_size": 1, + "in_channels": 8, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + transformer_cls = FluxTransformer2DModel + vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 1, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, + } + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + @property + def output_shape(self): + return (1, 8, 8, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + np.random.seed(0) + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")), + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_with_norm_in_state_dict(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: + norm_state_dict = {} + for name, module in pipe.transformer.named_modules(): + if norm_layer not in name or not hasattr(module, "weight") or module.weight is None: + continue + norm_state_dict[f"transformer.{name}.weight"] = torch.randn( + module.weight.shape, device=module.weight.device, dtype=module.weight.dtype + ) + + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(norm_state_dict) + lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert "The provided state dict contains normalization layers in addition to LoRA layers" in cap_logger.out + assert len(pipe.transformer._transformer_norm_layers) > 0 + + pipe.unload_lora_weights() + lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert pipe.transformer._transformer_norm_layers is None + assert np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5) + assert not( + np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested" + ) + + with CaptureLogger(logger) as cap_logger: + for key in list(norm_state_dict.keys()): + norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) + pipe.load_lora_weights(norm_state_dict) + + assert ( + "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out + ) + + def test_lora_parameter_expanded_shapes(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + assert transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}" + + + original_transformer_state_dict = pipe.transformer.state_dict() + x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") + incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) + assert ( + "x_embedder.weight" in incompatible_keys.missing_keys, + "Could not find x_embedder.weight in the missing keys." + ) + + transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) + pipe.transformer = transformer + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + + # Testing opposite direction where the LoRA params are zero-padded. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + dummy_lora_A = torch.nn.Linear(1, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out + + def test_normal_lora_with_expanded_lora_raises_error(self): + # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then + # load shape expanded LoRA (such as Control LoRA). + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + components["transformer"] = transformer + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, + "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert pipe.get_active_adapters() == ["adapter-1"] + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-2") + + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out + assert pipe.get_active_adapters() == ["adapter-2"] + + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3) + + # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. + # This should raise a runtime error on input shapes being incompatible. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + components["transformer"] = transformer + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + pipe.load_lora_weights(lora_state_dict, "adapter-1") + + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features + assert pipe.transformer.config.in_channels == in_features + + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, + "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, + } + + # We should check for input shapes being incompatible here. But because above mentioned issue is + # not a supported use case, and because of the PEFT renaming, we will currently have a shape + # mismatch error. + with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"): + pipe.load_lora_weights(lora_state_dict, "adapter-2") + + def test_fuse_expanded_lora_with_regular_lora(self): + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + components["transformer"] = transformer + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, + "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, + } + pipe.load_lora_weights(lora_state_dict, "adapter-1") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + pipe.load_lora_weights(lora_state_dict, "adapter-2") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) + lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) + assert not(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3)) + assert not(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3)) + + pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) + lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3) + + def test_load_regular_lora(self): + # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded + # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those + # transformers include Flux Fill, Flux Control, etc. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA. + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 + assert not np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3) + + def test_lora_unload_with_parameter_expanded_shapes(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + assert ( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. + components["transformer"] = transformer + pipe = FluxPipeline(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + control_image = inputs.pop("control_image") + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + control_pipe = self.pipeline_class(**components) + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + control_pipe.load_lora_weights(lora_state_dict, "adapter-1") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + + inputs["control_image"] = control_image + lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + + control_pipe.unload_lora_weights(reset_to_overwritten_params=True) + assert( + control_pipe.transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + ) + loaded_pipe = FluxPipeline.from_pipe(control_pipe) + assert ( + loaded_pipe.transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}", + ) + inputs.pop("control_image") + unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4) + assert np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4) + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features + assert pipe.transformer.config.in_channels == in_features + + def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + assert ( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. + components["transformer"] = transformer + pipe = FluxPipeline(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + control_image = inputs.pop("control_image") + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + control_pipe = self.pipeline_class(**components) + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + control_pipe.load_lora_weights(lora_state_dict, "adapter-1") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + + inputs["control_image"] = control_image + lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + + control_pipe.unload_lora_weights(reset_to_overwritten_params=False) + assert( + control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + ) + no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4) + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 + assert pipe.transformer.config.in_channels == in_features * 2 + + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + pass + + +@slow +@nightly +@require_torch_accelerator +@require_peft_backend +@require_big_accelerator +class FluxLoRAIntegrationTests(unittest.TestCase): + """internal note: The integration slices were obtained on audace. + + torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the + assertions to pass. + """ + + num_inference_steps = 10 + seed = 0 + + def setUp(self): + super().setUp() + + gc.collect() + backend_empty_cache(torch_device) + + self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + + def tearDown(self): + super().tearDown() + + del self.pipeline + gc.collect() + backend_empty_cache(torch_device) + + def test_flux_the_last_ben(self): + self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + # Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI + # run supports it. We have about 34GB RAM in the CI runner which kills the test when run with + # `enable_model_cpu_offload()`. We repeat this for the other tests, too. + self.pipeline = self.pipeline.to(torch_device) + + prompt = "jon snow eating pizza with ketchup" + + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=4.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 + + def test_flux_kohya(self): + self.pipeline.load_lora_weights("Norod78/brain-slug-flux") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + self.pipeline = self.pipeline.to(torch_device) + + prompt = "The cat with a brain slug earring" + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=4.5, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 + + def test_flux_kohya_with_text_encoder(self): + self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + self.pipeline = self.pipeline.to(torch_device) + + prompt = "optimus is cleaning the house with broomstick" + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=4.5, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 + + def test_flux_kohya_embedders_conversion(self): + """Test that embedders load without throwing errors""" + self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") + self.pipeline.unload_lora_weights() + + assert True + + def test_flux_xlabs(self): + self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + self.pipeline = self.pipeline.to(torch_device) + + prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" + + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=3.5, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 + + def test_flux_xlabs_load_lora_with_single_blocks(self): + self.pipeline.load_lora_weights( + "salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors" + ) + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + self.pipeline.enable_model_cpu_offload() + + prompt = "a wizard mouse playing chess" + + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=3.5, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array( + [0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 + + +@nightly +@require_torch_accelerator +@require_peft_backend +@require_big_accelerator +class FluxControlLoRAIntegrationTests(unittest.TestCase): + num_inference_steps = 10 + seed = 0 + prompt = "A robot made of exotic candies and chocolates of different kinds." + + def setUp(self): + super().setUp() + + gc.collect() + backend_empty_cache(torch_device) + + self.pipeline = FluxControlPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ).to(torch_device) + + def tearDown(self): + super().tearDown() + + gc.collect() + backend_empty_cache(torch_device) + + @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) + def test_lora(self, lora_ckpt_id): + self.pipeline.load_lora_weights(lora_ckpt_id) + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = self.pipeline( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=self.num_inference_steps, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516]) + else: + expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 + + @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) + def test_lora_with_turbo(self, lora_ckpt_id): + self.pipeline.load_lora_weights(lora_ckpt_id) + self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = self.pipeline( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=self.num_inference_steps, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484]) + else: + expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index cfd5d3146a91..0f31eaf57aa7 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -48,7 +48,7 @@ @require_peft_backend @skip_mps -class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests): pipeline_class = HunyuanVideoPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 6ab51a5e513f..b72479de5736 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -34,7 +34,7 @@ @require_peft_backend -class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestLTXVideoLoRA(PeftLoraLoaderMixinTests): pipeline_class = LTXPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 0417b05b33a1..a4ddd5457d3c 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -36,7 +36,7 @@ @require_peft_backend -class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestLumina2LoRA(PeftLoraLoaderMixinTests): pipeline_class = Lumina2Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 7be81273db77..a34a615b257a 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -34,7 +34,7 @@ @require_peft_backend @skip_mps -class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestMochiLoRA(PeftLoraLoaderMixinTests): pipeline_class = MochiPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 51de2f8e20e1..167373211e90 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -34,7 +34,7 @@ @require_peft_backend -class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestQwenImageLoRA(PeftLoraLoaderMixinTests): pipeline_class = QwenImagePipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index 3cdb28de75fb..2323d66e39e2 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -29,7 +29,7 @@ @require_peft_backend -class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestSanaLoRA(PeftLoraLoaderMixinTests): pipeline_class = SanaPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {"shift": 7.0} diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 933bf2336a59..76ac775a9f1c 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -55,7 +55,7 @@ from accelerate.utils import release_memory -class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): +class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests): pipeline_class = StableDiffusionPipeline scheduler_cls = DDIMScheduler scheduler_kwargs = { @@ -666,11 +666,11 @@ def test_load_unload_load_state_dict(self): previous_state_dict = lcm_lora.copy() pipe.load_lora_weights(lcm_lora, adapter_name="lcm") - self.assertDictEqual(lcm_lora, previous_state_dict) + assert lcm_lora == previous_state_dict pipe.unload_lora_weights() pipe.load_lora_weights(lcm_lora, adapter_name="lcm") - self.assertDictEqual(lcm_lora, previous_state_dict) + assert lcm_lora == previous_state_dict release_memory(pipe) diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 228460eaad90..02602ddf6fc2 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -51,7 +51,7 @@ @require_peft_backend -class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestSD3LoRA(PeftLoraLoaderMixinTests): pipeline_class = StableDiffusion3Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index ac1d65abdaa7..405e97cd1b1f 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -59,7 +59,7 @@ from accelerate.utils import release_memory -class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): +class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests): has_two_text_encoders = True pipeline_class = StableDiffusionXLPipeline scheduler_cls = EulerDiscreteScheduler diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 5734509b410f..7066578dc749 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -39,7 +39,7 @@ @require_peft_backend @skip_mps -class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestWanLoRA(PeftLoraLoaderMixinTests): pipeline_class = WanPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index ab1f57bfc9da..60246ad2bcc7 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -47,7 +47,7 @@ @require_peft_backend @skip_mps @is_flaky(max_attempts=10, description="very flaky class") -class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestWanVACELoRA(PeftLoraLoaderMixinTests): pipeline_class = WanVACEPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -163,14 +163,13 @@ def test_layerwise_casting_inference_denoiser(self): super().test_layerwise_casting_inference_denoiser() @require_peft_version_greater("0.13.2") - def test_lora_exclude_modules_wanvace(self): + def test_lora_exclude_modules_wanvace(self, base_pipe_output): exclude_module_name = "vace_blocks.0.proj_out" components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - self.assertTrue(output_no_lora.shape == self.output_shape) + assert base_pipe_output.shape == self.output_shape # only supported for `denoiser` now denoiser_lora_config.target_modules = ["proj_out"] @@ -180,8 +179,8 @@ def test_lora_exclude_modules_wanvace(self): ) # The state dict shouldn't contain the modules to be excluded from LoRA. state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default") - self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model)) - self.assertTrue(any("proj_out" in k for k in state_dict_from_model)) + assert not any(exclude_module_name in k for k in state_dict_from_model) + assert any("proj_out" in k for k in state_dict_from_model) output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdir: @@ -192,23 +191,21 @@ def test_lora_exclude_modules_wanvace(self): # Check in the loaded state dict. loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict)) - self.assertTrue(any("proj_out" in k for k in loaded_state_dict)) + assert not any(exclude_module_name in k for k in loaded_state_dict) + assert any("proj_out" in k for k in loaded_state_dict) # Check in the state dict obtained after loading LoRA. pipe.load_lora_weights(tmpdir) state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0") - self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model)) - self.assertTrue(any("proj_out" in k for k in state_dict_from_model)) + assert not any(exclude_module_name in k for k in state_dict_from_model) + assert any("proj_out" in k for k in state_dict_from_model) output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), - "LoRA should change outputs.", + assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), ( + "LoRA should change outputs." ) - self.assertTrue( - np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), - "Lora outputs should match.", + assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), ( + "Lora outputs should match." ) def test_simple_inference_with_text_denoiser_lora_and_scale(self): diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 3d4344bb86a9..2fe80b4c1bb2 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,17 +1,3 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import inspect import os import re @@ -24,10 +10,7 @@ import torch from parameterized import parameterized -from diffusers import ( - AutoencoderKL, - UNet2DConditionModel, -) +from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.utils import logging from diffusers.utils.import_utils import is_peft_available @@ -54,12 +37,10 @@ def state_dicts_almost_equal(sd1, sd2): sd1 = dict(sorted(sd1.items())) sd2 = dict(sorted(sd2.items())) - models_are_equal = True for ten1, ten2 in zip(sd1.values(), sd2.values()): - if (ten1 - ten2).abs().max() > 1e-3: + if (ten1 - ten2).abs().max() > 0.001: models_are_equal = False - return models_are_equal @@ -75,15 +56,15 @@ def check_if_lora_correctly_set(model) -> bool: def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str): extracted = { - k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.") + k.removeprefix(f"{module_key}."): v for (k, v) in parsed_metadata.items() if k.startswith(f"{module_key}.") } check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"]) def initialize_dummy_state_dict(state_dict): - if not all(v.device.type == "meta" for _, v in state_dict.items()): + if not all((v.device.type == "meta" for (_, v) in state_dict.items())): raise ValueError("`state_dict` has non-meta values.") - return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()} + return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for (k, v) in state_dict.items()} POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] @@ -91,8 +72,6 @@ def initialize_dummy_state_dict(state_dict): def determine_attention_kwargs_name(pipeline_class): call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys() - - # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: if possible_attention_kwargs in call_signature_keys: attention_kwargs_name = possible_attention_kwargs @@ -104,61 +83,48 @@ def determine_attention_kwargs_name(pipeline_class): @require_peft_backend class PeftLoraLoaderMixinTests: pipeline_class = None - scheduler_cls = None scheduler_kwargs = None - has_two_text_encoders = False has_three_text_encoders = False - text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, "" - text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, "" - text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, "" - tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" - tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" - tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" - + (text_encoder_cls, text_encoder_id, text_encoder_subfolder) = (None, None, "") + (text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder) = (None, None, "") + (text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder) = (None, None, "") + (tokenizer_cls, tokenizer_id, tokenizer_subfolder) = (None, None, "") + (tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder) = (None, None, "") + (tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder) = (None, None, "") unet_kwargs = None transformer_cls = None transformer_kwargs = None vae_cls = AutoencoderKL vae_kwargs = None - text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - cached_non_lora_output = None - - def get_base_pipe_output(self): - if self.cached_non_lora_output is None: - self.cached_non_lora_output = self._compute_baseline_output() - return self.cached_non_lora_output + @pytest.fixture(scope="class") + def base_pipe_output(self): + return self._compute_baseline_output() def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") if self.has_two_text_encoders and self.has_three_text_encoders: raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") - scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls rank = 4 lora_alpha = rank if lora_alpha is None else lora_alpha - torch.manual_seed(0) if self.unet_kwargs is not None: unet = UNet2DConditionModel(**self.unet_kwargs) else: transformer = self.transformer_cls(**self.transformer_kwargs) - scheduler = scheduler_cls(**self.scheduler_kwargs) - torch.manual_seed(0) vae = self.vae_cls(**self.vae_kwargs) - text_encoder = self.text_encoder_cls.from_pretrained( self.text_encoder_id, subfolder=self.text_encoder_subfolder ) tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) - if self.text_encoder_2_cls is not None: text_encoder_2 = self.text_encoder_2_cls.from_pretrained( self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder @@ -166,7 +132,6 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No tokenizer_2 = self.tokenizer_2_cls.from_pretrained( self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder ) - if self.text_encoder_3_cls is not None: text_encoder_3 = self.text_encoder_3_cls.from_pretrained( self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder @@ -174,7 +139,6 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No tokenizer_3 = self.tokenizer_3_cls.from_pretrained( self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder ) - text_lora_config = LoraConfig( r=rank, lora_alpha=lora_alpha, @@ -182,7 +146,6 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No init_lora_weights=False, use_dora=use_dora, ) - denoiser_lora_config = LoraConfig( r=rank, lora_alpha=lora_alpha, @@ -190,26 +153,20 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No init_lora_weights=False, use_dora=use_dora, ) - pipeline_components = { "scheduler": scheduler, "vae": vae, "text_encoder": text_encoder, "tokenizer": tokenizer, } - # Denoiser if self.unet_kwargs is not None: pipeline_components.update({"unet": unet}) elif self.transformer_kwargs is not None: pipeline_components.update({"transformer": transformer}) - - # Remaining text encoders. if self.text_encoder_2_cls is not None: pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2}) if self.text_encoder_3_cls is not None: pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3}) - - # Remaining stuff init_params = inspect.signature(self.pipeline_class.__init__).parameters if "safety_checker" in init_params: pipeline_components.update({"safety_checker": None}) @@ -217,8 +174,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No pipeline_components.update({"feature_extractor": None}) if "image_encoder" in init_params: pipeline_components.update({"image_encoder": None}) - - return pipeline_components, text_lora_config, denoiser_lora_config + return (pipeline_components, text_lora_config, denoiser_lora_config) @property def output_shape(self): @@ -229,11 +185,9 @@ def get_dummy_inputs(self, with_generator=True): sequence_length = 10 num_channels = 4 sizes = (32, 32) - generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", "num_inference_steps": 5, @@ -242,18 +196,14 @@ def get_dummy_inputs(self, with_generator=True): } if with_generator: pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs + return (noise, input_ids, pipeline_inputs) def _compute_baseline_output(self): - components, _, _ = self.get_dummy_components(self.scheduler_cls) + (components, _, _) = self.get_dummy_components(self.scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - - # Always ensure the inputs are without the `generator`. Make sure to pass the `generator` - # explicitly. - _, _, inputs = self.get_dummy_inputs(with_generator=False) + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) return pipe(**inputs, generator=torch.manual_seed(0))[0] def _get_lora_state_dicts(self, modules_to_save): @@ -273,327 +223,243 @@ def _get_lora_adapter_metadata(self, modules_to_save): def _get_modules_to_save(self, pipe, has_denoiser=False): modules_to_save = {} lora_loadable_modules = self.pipeline_class._lora_loadable_modules - if ( "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder") - and getattr(pipe.text_encoder, "peft_config", None) is not None + and (getattr(pipe.text_encoder, "peft_config", None) is not None) ): modules_to_save["text_encoder"] = pipe.text_encoder - if ( "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2") - and getattr(pipe.text_encoder_2, "peft_config", None) is not None + and (getattr(pipe.text_encoder_2, "peft_config", None) is not None) ): modules_to_save["text_encoder_2"] = pipe.text_encoder_2 - if has_denoiser: if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): modules_to_save["unet"] = pipe.unet - if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): modules_to_save["transformer"] = pipe.transformer - return modules_to_save def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" if denoiser_lora_config is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." else: denoiser = None - if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - return pipe, denoiser + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + return (pipe, denoiser) - def test_simple_inference(self): + def test_simple_inference(self, base_pipe_output): """ Tests a simple inference and makes sure it works as expected """ - output_no_lora = self.get_base_pipe_output() - assert output_no_lora.shape == self.output_shape + assert base_pipe_output.shape == self.output_shape - def test_simple_inference_with_text_lora(self): + def test_simple_inference_with_text_lora(self, base_pipe_output): """ Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() + (components, text_lora_config, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) + assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" @require_peft_version_greater("0.13.1") def test_low_cpu_mem_usage_with_injection(self): """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.") - self.assertTrue( - "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, - "The LoRA params should be on 'meta' device.", + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." + assert "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, ( + "The LoRA params should be on 'meta' device." ) - te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, - "No param should be on 'meta' device.", + assert "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, ( + "No param should be on 'meta' device." ) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - self.assertTrue( - "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." - ) - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + assert "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." - ) - + assert "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + assert "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, ( + "The LoRA params should be on 'meta' device." ) - self.assertTrue( - "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, - "The LoRA params should be on 'meta' device.", - ) - te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, - "No param should be on 'meta' device.", + assert "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, ( + "No param should be on 'meta' device." ) - - _, _, inputs = self.get_dummy_inputs() + (_, _, inputs) = self.get_dummy_inputs() output_lora = pipe(**inputs)[0] - self.assertTrue(output_lora.shape == self.output_shape) + assert output_lora.shape == self.output_shape @require_peft_version_greater("0.13.1") @require_transformers_version_greater("4.45.2") def test_low_cpu_mem_usage_with_loading(self): """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results." ) - - # Now, check for `low_cpu_mem_usage.` pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", - ) + assert np.allclose( + images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001 + ), "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." - def test_simple_inference_with_text_lora_and_scale(self): + def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): """ Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, _ = self.get_dummy_components() + (components, text_lora_config, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - + assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(output_lora, output_lora_scale, atol=0.001, rtol=0.001), ( + "Lora + scale should change the output" ) - attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", + assert np.allclose(base_pipe_output, output_lora_0_scale, atol=0.001, rtol=0.001), ( + "Lora + 0 scale should lead to same result as no LoRA" ) - def test_simple_inference_with_text_lora_fused(self): + def test_simple_inference_with_text_lora_fused(self, base_pipe_output): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() + (components, text_lora_config, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.fuse_lora() - # Fusing should still keep the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + assert not ( + np.allclose(ouput_fused, base_pipe_output, atol=0.001, rtol=0.001), + "Fused lora should change the output", ) - def test_simple_inference_with_text_lora_unloaded(self): + def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() + (components, text_lora_config, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.unload_lora_weights() - # unloading should remove the LoRA layers - self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") - + assert not (check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertFalse( + assert not ( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2", ) - ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", + assert np.allclose(ouput_unloaded, base_pipe_output, atol=0.001, rtol=0.001), ( + "Fused lora should change the output" ) def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. """ - components, text_lora_config, _ = self.get_dummy_components() + (components, text_lora_config, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_partial_text_lora(self): + def test_simple_inference_with_partial_text_lora(self, base_pipe_output): """ Tests a simple inference with lora attached on the text encoder with different ranks and some adapters removed and makes sure it works as expected """ - components, _, _ = self.get_dummy_components() - # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). + (components, _, _) = self.get_dummy_components() text_lora_config = LoraConfig( r=4, rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, @@ -605,519 +471,388 @@ def test_simple_inference_with_partial_text_lora(self): pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) state_dict = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: - # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` - # supports missing layers (PR#8324). state_dict = { f"text_encoder.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + for (module_name, param) in get_peft_model_state_dict(pipe.text_encoder).items() if "text_model.encoder.layers.4" not in module_name } - if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: state_dict.update( { f"text_encoder_2.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + for (module_name, param) in get_peft_model_state_dict(pipe.text_encoder_2).items() if "text_model.encoder.layers.4" not in module_name } ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - # Unload lora and load it back using the pipe.load_lora_weights machinery + assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" pipe.unload_lora_weights() pipe.load_lora_weights(state_dict) - output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), - "Removing adapters should change the output", + assert not np.allclose(output_partial_lora, output_lora, atol=0.001, rtol=0.001), ( + "Removing adapters should change the output" ) def test_simple_inference_save_pretrained_with_text_lora(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ - components, text_lora_config, _ = self.get_dummy_components() + (components, text_lora_config, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: pipe.save_pretrained(tmpdirname) - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) pipe_from_pretrained.to(torch_device) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), - "Lora not correctly set in text encoder", + assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), ( + "Lora not correctly set in text encoder" ) - if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), - "Lora not correctly set in text encoder 2", + assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), ( + "Lora not correctly set in text encoder 2" ) - images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_save_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results." ) def test_simple_inference_with_text_denoiser_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_text_denoiser_lora_and_scale(self): + def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output): """ Tests a simple inference with lora attached on the text encoder + Unet + scale argument and makes sure it works as expected """ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - + assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(output_lora, output_lora_scale, atol=0.001, rtol=0.001), ( + "Lora + scale should change the output" ) - attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", + assert np.allclose(base_pipe_output, output_lora_0_scale, atol=0.001, rtol=0.001), ( + "Lora + 0 scale should lead to same result as no LoRA" ) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, - "The scaling parameter has not been correctly restored!", + assert pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, ( + "The scaling parameter has not been correctly restored!" ) - def test_simple_inference_with_text_lora_denoiser_fused(self): + def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, denoiser) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) - - # Fusing should still keep the LoRA layers if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + assert not ( + np.allclose(output_fused, base_pipe_output, atol=0.001, rtol=0.001), + "Fused lora should change the output", ) - def test_simple_inference_with_text_denoiser_lora_unloaded(self): + def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, denoiser) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unload_lora_weights() - # unloading should remove the LoRA layers - self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") - self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser") - + assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" + assert not check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertFalse( + assert not ( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2", ) - output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", + assert np.allclose(output_unloaded, base_pipe_output, atol=0.001, rtol=0.001), ( + "Fused lora should change the output" ) def test_simple_inference_with_text_denoiser_lora_unfused( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + self, expected_atol: float = 0.001, expected_rtol: float = 0.001 ): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, denoiser) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + assert pipe.num_fused_loras == 1, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + assert pipe.num_fused_loras == 0, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # unloading should remove the LoRA layers if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" + assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) - - # Fuse and unfuse should lead to the same results - self.assertTrue( - np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" + assert np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), ( + "Fused lora should not change the output" ) - def test_simple_inference_with_text_denoiser_multi_adapter(self): + def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(base_pipe_output, output_adapter_1, atol=0.001, rtol=0.001), "Adapter outputs should be different.", ) - pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(base_pipe_output, output_adapter_2, atol=0.001, rtol=0.001), "Adapter outputs should be different.", ) - pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(base_pipe_output, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter outputs should be different.", ) - - # Fuse and unfuse should lead to the same results - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), "Adapter 1 and 2 should give different results", ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_1, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 1 and mixed adapters should give different results", ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 2 and mixed adapters should give different results", ) - pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( + "output with no lora and output with lora disabled should give same results" ) def test_wrong_adapter_name_raises_error(self): adapter_name = "adapter-1" - - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) - - with self.assertRaises(ValueError) as err_context: + with pytest.raises(ValueError) as err_context: pipe.set_adapters("test") - - self.assertTrue("not in the list of present adapters" in str(err_context.exception)) - - # test this works. + assert "not in the list of present adapters" in str(err_context.value) pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_multiple_wrong_adapter_name_raises_error(self): adapter_name = "adapter-1" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) - scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} logger = logging.get_logger("diffusers.loaders.lora_base") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) - wrong_components = sorted(set(scale_with_wrong_components.keys())) msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " - self.assertTrue(msg in str(cap_logger.out)) - - # test this works. + assert msg in str(cap_logger.out) pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_simple_inference_with_text_denoiser_block_scale(self): + def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder and unet, attaches one adapter and set different weights for different blocks (i.e. block lora) """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" weights_1 = {"text_encoder": 2, "unet": {"down": 5}} pipe.set_adapters("adapter-1", weights_1) output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - weights_2 = {"unet": {"up": 5}} pipe.set_adapters("adapter-1", weights_2) output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse( - np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_weights_1, output_weights_2, atol=0.001, rtol=0.001), "LoRA weights 1 and 2 should give different results", ) - self.assertFalse( - np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(base_pipe_output, output_weights_1, atol=0.001, rtol=0.001), "No adapter and LoRA weights 1 should give different results", ) - self.assertFalse( - np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(base_pipe_output, output_weights_2, atol=0.001, rtol=0.001), "No adapter and LoRA weights 2 should give different results", ) - pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( + "output with no lora and output with lora disabled should give same results" ) - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set different weights for different blocks (i.e. block lora) """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" scales_1 = {"text_encoder": 2, "unet": {"down": 5}} scales_2 = {"unet": {"down": 5, "mid": 5}} - pipe.set_adapters("adapter-1", scales_1) output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters("adapter-2", scales_2) output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # Fuse and unfuse should lead to the same results - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), "Adapter 1 and 2 should give different results", ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_1, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 1 and mixed adapters should give different results", ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 2 and mixed adapters should give different results", ) - pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( + "output with no lora and output with lora disabled should give same results" ) - - # a mismatching number of adapter_names and adapter_weights should raise an error - with self.assertRaises(ValueError): + with pytest.raises(ValueError): pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): @@ -1130,13 +865,11 @@ def updown_options(blocks_with_tf, layers_per_block, value): """ num_val = value list_val = [value] * layers_per_block - node_opts = [None, num_val, list_val] node_opts_foreach_block = [node_opts] * len(blocks_with_tf) - updown_opts = [num_val] for nodes in product(*node_opts_foreach_block): - if all(n is None for n in nodes): + if all((n is None for n in nodes)): continue opt = {} for b, n in zip(blocks_with_tf, nodes): @@ -1150,30 +883,24 @@ def all_possible_dict_opts(unet, value): Generate every possible combination for how a lora weight dict can be. E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ... """ - - down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")] - up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")] - + down_blocks_with_tf = [i for (i, d) in enumerate(unet.down_blocks) if hasattr(d, "attentions")] + up_blocks_with_tf = [i for (i, u) in enumerate(unet.up_blocks) if hasattr(u, "attentions")] layers_per_block = unet.config.layers_per_block - text_encoder_opts = [None, value] text_encoder_2_opts = [None, value] mid_opts = [None, value] down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value) up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value) - opts = [] - for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts): - if all(o is None for o in (t1, t2, d, m, u)): + if all((o is None for o in (t1, t2, d, m, u))): continue opt = {} if t1 is not None: opt["text_encoder"] = t1 if t2 is not None: opt["text_encoder_2"] = t2 - if all(o is None for o in (d, m, u)): - # no unet scaling + if all((o is None for o in (d, m, u))): continue opt["unet"] = {} if d is not None: @@ -1183,194 +910,143 @@ def all_possible_dict_opts(unet, value): if u is not None: opt["unet"]["up"] = u opts.append(opt) - return opts - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components(self.scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): - # test if lora block scales can be set with this scale_dict if not self.has_two_text_encoders and "text_encoder_2" in scale_dict: del scale_dict["text_encoder_2"] + pipe.set_adapters("adapter-1", scale_dict) - pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error - - def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): + def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set/delete them """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), "Adapter 1 and 2 should give different results", ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_1, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 1 and mixed adapters should give different results", ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 2 and mixed adapters should give different results", ) - pipe.delete_adapters("adapter-1") output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert np.allclose(output_deleted_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), ( + "Adapter 1 and 2 should give different results" ) - pipe.delete_adapters("adapter-2") output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_deleted_adapters, atol=0.001, rtol=0.001), ( + "output with no lora and output with lora disabled should give same results" ) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"]) - output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_deleted_adapters, atol=0.001, rtol=0.001), ( + "output with no lora and output with lora disabled should give same results" ) - def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): + def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # Fuse and unfuse should lead to the same results - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), "Adapter 1 and 2 should give different results", ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_1, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 1 and mixed adapters should give different results", ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 2 and mixed adapters should give different results", ) - pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse( - np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=0.001, rtol=0.001), "Weighted adapter and mixed adapter should give different results", ) - pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( + "output with no lora and output with lora disabled should give same results" ) @skip_mps @@ -1380,28 +1056,24 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): strict=False, ) def test_lora_fuse_nan(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - # corrupt one LoRA weight with `inf` values + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." with torch.no_grad(): if self.unet_kwargs: pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float( "inf" ) else: - named_modules = [name for name, _ in pipe.transformer.named_modules()] + named_modules = [name for (name, _) in pipe.transformer.named_modules()] possible_tower_names = [ "transformer_blocks", "blocks", @@ -1416,279 +1088,215 @@ def test_lora_fuse_nan(self): raise ValueError(reason) for tower_name in filtered_tower_names: transformer_tower = getattr(pipe.transformer, tower_name) - has_attn1 = any("attn1" in name for name in named_modules) + has_attn1 = any(("attn1" in name for name in named_modules)) if has_attn1: transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") else: transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") - - # with `safe_fusing=True` we should see an Error - with self.assertRaises(ValueError): + with pytest.raises(ValueError): pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) - - # without we should not see an error, but every image will be black pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) out = pipe(**inputs)[0] - - self.assertTrue(np.isnan(out).all()) + assert np.isnan(out).all() def test_get_adapters(self): """ Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - adapter_names = pipe.get_active_adapters() - self.assertListEqual(adapter_names, ["adapter-1"]) - + assert adapter_names == ["adapter-1"] pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - adapter_names = pipe.get_active_adapters() - self.assertListEqual(adapter_names, ["adapter-2"]) - + assert adapter_names == ["adapter-2"] pipe.set_adapters(["adapter-1", "adapter-2"]) - self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"]) + assert sorted(pipe.get_active_adapters()) == ["adapter-1", "adapter-2"] def test_get_list_adapters(self): """ Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - - # 1. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") dicts_to_be_checked = {"text_encoder": ["adapter-1"]} - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") dicts_to_be_checked.update({"unet": ["adapter-1"]}) else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") dicts_to_be_checked.update({"transformer": ["adapter-1"]}) - - self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) - - # 2. + assert pipe.get_list_adapters() == dicts_to_be_checked dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - - self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) - - # 3. + assert pipe.get_list_adapters() == dicts_to_be_checked pipe.set_adapters(["adapter-1", "adapter-2"]) - dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - if self.unet_kwargs is not None: dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) else: dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - - self.assertDictEqual( - pipe.get_list_adapters(), - dicts_to_be_checked, - ) - - # 4. + assert pipe.get_list_adapters() == dicts_to_be_checked dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) - - self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) + assert pipe.get_list_adapters() == dicts_to_be_checked def test_simple_inference_with_text_lora_denoiser_fused_multi( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + self, expected_atol: float = 0.001, expected_rtol: float = 0.001 ): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet and multi-adapter case """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." denoiser.add_adapter(denoiser_lora_config, "adapter-2") - if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - - # set them to multi-adapter inference mode pipe.set_adapters(["adapter-1", "adapter-2"]) outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters(["adapter-1"]) outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) - self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - - # Fusing should still keep the LoRA layers so output should remain the same + assert pipe.num_fused_loras == 1, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", + assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), ( + "Fused lora should not change the output" ) - pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - + assert pipe.num_fused_loras == 0, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" + assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]) - self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - - # Fusing should still keep the LoRA layers + assert pipe.num_fused_loras == 2, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", + assert np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), ( + "Fused lora should not change the output" ) pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + assert pipe.num_fused_loras == 0, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) - def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): + def test_lora_scale_kwargs_match_fusion( + self, base_pipe_output, expected_atol: float = 0.001, expected_rtol: float = 0.001 + ): attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - for lora_scale in [1.0, 0.8]: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly set in text encoder 2", - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.set_adapters(["adapter-1"]) attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - pipe.fuse_lora( components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"], lora_scale=lora_scale, ) - self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - + assert pipe.num_fused_loras == 1, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", + assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), ( + "Fused lora should not change the output" ) - self.assertFalse( - np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), + assert not ( + np.allclose(base_pipe_output, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), "LoRA should change the output", ) def test_simple_inference_with_dora(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components(use_dora=True) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_dora_lora.shape == self.output_shape) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + assert output_no_dora_lora.shape == self.output_shape + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse( - np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(output_dora_lora, output_no_dora_lora, atol=0.001, rtol=0.001), "DoRA lora should change the output", ) def test_missing_keys_warning(self): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, denoiser_lora_config = self.get_dummy_components() + (components, _, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -1696,35 +1304,25 @@ def test_missing_keys_warning(self): save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) pipe.unload_lora_weights() - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) - - # To make things dynamic since we cannot settle with a single key for all the models where we - # offer PEFT support. missing_key = [k for k in state_dict if "lora_A" in k][0] del state_dict[missing_key] - logger = logging.get_logger("diffusers.utils.peft_utils") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) - - # Since the missing key won't contain the adapter name ("default_0"). - # Also strip out the component prefix (such as "unet." from `missing_key`). component = list({k.split(".")[0] for k in state_dict})[0] - self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")) + assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", "")) def test_unexpected_keys_warning(self): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, denoiser_lora_config = self.get_dummy_components() + (components, _, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -1732,18 +1330,15 @@ def test_unexpected_keys_warning(self): save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) pipe.unload_lora_weights() - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) - unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) - logger = logging.get_logger("diffusers.utils.peft_utils") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) - - self.assertTrue(".diffusers_cat" in cap_logger.out) + assert ".diffusers_cat" in cap_logger.out @unittest.skip("This is failing for now - need to investigate") def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): @@ -1751,20 +1346,16 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) - if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) - - # Just makes sure it works. _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_modify_padding_mode(self): @@ -1773,40 +1364,31 @@ def set_pad_mode(network, mode="circular"): if isinstance(module, torch.nn.Conv2d): module.padding_mode = mode - components, _, _ = self.get_dummy_components() + (components, _, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _pad_mode = "circular" set_pad_mode(pipe.vae, _pad_mode) set_pad_mode(pipe.unet, _pad_mode) - - _, _, inputs = self.get_dummy_inputs() + (_, _, inputs) = self.get_dummy_inputs() _ = pipe(**inputs)[0] - def test_logs_info_when_no_lora_keys_found(self): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, _ = self.get_dummy_components() + def test_logs_info_when_no_lora_keys_found(self, base_pipe_output): + (components, _, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} logger = logging.get_logger("diffusers.loaders.peft") logger.setLevel(logging.WARNING) - with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(no_op_state_dict) out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0] - denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") - self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}")) - self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5)) - - # test only for text encoder + assert cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}") + assert np.allclose(base_pipe_output, out_after_lora_attempt, atol=1e-05, rtol=1e-05) for lora_module in self.pipeline_class._lora_loadable_modules: if "text_encoder" in lora_module: text_encoder = getattr(pipe, lora_module) @@ -1814,101 +1396,76 @@ def test_logs_info_when_no_lora_keys_found(self): prefix = "text_encoder" elif lora_module == "text_encoder_2": prefix = "text_encoder_2" - logger = logging.get_logger("diffusers.loaders.lora_base") logger.setLevel(logging.WARNING) - with CaptureLogger(logger) as cap_logger: self.pipeline_class.load_lora_into_text_encoder( no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix ) + assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - self.assertTrue( - cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - ) - - def test_set_adapters_match_attention_kwargs(self): + def test_set_adapters_match_attention_kwargs(self, base_pipe_output): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) lora_scale = 0.5 attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - self.assertFalse( - np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + assert not ( + np.allclose(base_pipe_output, output_lora_scale, atol=0.001, rtol=0.001), "Lora + scale should change the output", ) - pipe.set_adapters("default", lora_scale) output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(base_pipe_output, output_lora_scale_wo_kwargs, atol=0.001, rtol=0.001), ( + "Lora + scale should change the output" ) - self.assertTrue( - np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), - "Lora + scale should match the output of `set_adapters()`.", + assert np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=0.001, rtol=0.001), ( + "Lora + scale should match the output of `set_adapters()`." ) - with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - self.assertTrue( - not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Lora + scale should change the output" ) - self.assertTrue( - np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results as attention_kwargs.", + assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results as attention_kwargs." ) - self.assertTrue( - np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results as set_adapters().", + assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results as set_adapters()." ) @require_peft_version_greater("0.13.2") def test_lora_B_bias(self): - # Currently, this test is only relevant for Flux Control LoRA as we are not - # aware of any other LoRA checkpoint that has its `lora_B` biases trained. - components, _, denoiser_lora_config = self.get_dummy_components() + (components, _, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - - # keep track of the bias values of the base layers to perform checks later. bias_values = {} denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer for name, module in denoiser.named_modules(): - if any(k in name for k in self.denoiser_target_modules): + if any((k in name for k in self.denoiser_target_modules)): if module.bias is not None: bias_values[name] = module.bias.data.clone() - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - denoiser_lora_config.lora_bias = False if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") @@ -1916,84 +1473,66 @@ def test_lora_B_bias(self): pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.delete_adapters("adapter-1") - denoiser_lora_config.lora_bias = True if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + assert not np.allclose(original_output, lora_bias_false_output, atol=0.001, rtol=0.001) + assert not np.allclose(original_output, lora_bias_true_output, atol=0.001, rtol=0.001) + assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=0.001, rtol=0.001) def test_correct_lora_configs_with_different_ranks(self): - components, _, denoiser_lora_config = self.get_dummy_components() + (components, _, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - if self.unet_kwargs is not None: pipe.unet.delete_adapters("adapter-1") else: pipe.transformer.delete_adapters("adapter-1") - denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer for name, _ in denoiser.named_modules(): - if "to_k" in name and "attn" in name and "lora" not in name: + if "to_k" in name and "attn" in name and ("lora" not in name): module_name_to_rank_update = name.replace(".base_layer.", ".") break - - # change the rank_pattern updated_rank = denoiser_lora_config.r * 2 denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern - - self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank}) - + assert updated_rank_pattern == {module_name_to_rank_update: updated_rank} lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - + assert not np.allclose(original_output, lora_output_same_rank, atol=0.001, rtol=0.001) + assert not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=0.001, rtol=0.001) if self.unet_kwargs is not None: pipe.unet.delete_adapters("adapter-1") else: pipe.transformer.delete_adapters("adapter-1") - - # similarly change the alpha_pattern updated_alpha = denoiser_lora_config.lora_alpha * 2 denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue( - pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} - ) + assert pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue( - pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} - ) - + assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == { + module_name_to_rank_update: updated_alpha + } lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) - self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + assert not np.allclose(original_output, lora_output_diff_alpha, atol=0.001, rtol=0.001) + assert not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=0.001, rtol=0.001) def test_layerwise_casting_inference_denoiser(self): from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS @@ -2007,7 +1546,7 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): continue dtype_to_check = storage_dtype - if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check): + if "lora" in name or any((re.search(pattern, name) for pattern in patterns_to_check)): dtype_to_check = compute_dtype if getattr(submodule, "weight", None) is not None: self.assertEqual(submodule.weight.dtype, dtype_to_check) @@ -2015,33 +1554,27 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): self.assertEqual(submodule.bias.dtype, dtype_to_check) def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - + (pipe, denoiser) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) if storage_dtype is not None: denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) check_linear_dtype(denoiser, storage_dtype, compute_dtype) - return pipe - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) pipe_fp32 = initialize_pipeline(storage_dtype=None) pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] @require_peft_version_greater("0.14.0") def test_layerwise_casting_peft_input_autocast_denoiser(self): - r""" + """ A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`). @@ -2054,7 +1587,6 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. """ - from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from diffusers.hooks.layerwise_casting import ( _PEFT_AUTOCAST_DISABLE_HOOK, @@ -2066,58 +1598,48 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): compute_dtype = torch.float32 def check_module(denoiser): - # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser) for name, module in denoiser.named_modules(): if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS): continue dtype_to_check = storage_dtype - if any(re.search(pattern, name) for pattern in patterns_to_check): + if any((re.search(pattern, name) for pattern in patterns_to_check)): dtype_to_check = compute_dtype if getattr(module, "weight", None) is not None: self.assertEqual(module.weight.dtype, dtype_to_check) if getattr(module, "bias", None) is not None: self.assertEqual(module.bias.dtype, dtype_to_check) if isinstance(module, BaseTunerLayer): - self.assertTrue(getattr(module, "_diffusers_hook", None) is not None) - self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None) + assert getattr(module, "_diffusers_hook", None is not None) + assert module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None - # 1. Test forward with add_adapter - components, _, denoiser_lora_config = self.get_dummy_components() + (components, _, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None: patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns) - apply_layerwise_casting( denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check ) check_module(denoiser) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] - - # 2. Test forward with load_lora_weights with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - components, _, _ = self.get_dummy_components() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + (components, _, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet apply_layerwise_casting( denoiser, @@ -2126,67 +1648,58 @@ def check_module(denoiser): skip_modules_pattern=patterns_to_check, ) check_module(denoiser) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] @parameterized.expand([4, 8, 16]) def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components) - - pipe, _ = self.add_adapters_to_pipeline( + (pipe, _) = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) - with tempfile.TemporaryDirectory() as tmpdir: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) pipe.unload_lora_weights() - out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) if len(out) == 3: - _, _, parsed_metadata = out + (_, _, parsed_metadata) = out elif len(out) == 2: - _, parsed_metadata = out - + (_, parsed_metadata) = out denoiser_key = ( f"{self.pipeline_class.transformer_name}" if self.transformer_kwargs is not None else f"{self.pipeline_class.unet_name}" ) - self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) + assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) check_module_lora_metadata( parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key ) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: text_encoder_key = self.pipeline_class.text_encoder_name - self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) + assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) check_module_lora_metadata( parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key ) - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: text_encoder_2_key = "text_encoder_2" - self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) + assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) check_module_lora_metadata( parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key ) @parameterized.expand([4, 8, 16]) def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components).to(torch_device) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdir: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -2194,109 +1707,84 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) pipe.unload_lora_weights() pipe.load_lora_weights(tmpdir) - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." + assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), ( + "Lora outputs should match." ) def test_lora_unload_add_adapter(self): """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + (pipe, _) = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # unload and then add. pipe.unload_lora_weights() - pipe, _ = self.add_adapters_to_pipeline( + (pipe, _) = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_inference_load_delete_load_adapters(self): - "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + def test_inference_load_delete_load_adapters(self, base_pipe_output): + """Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works.""" + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - - # First, delete adapter and compare. + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe.delete_adapters(pipe.get_active_adapters()[0]) output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3)) - self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3)) - - # Then load adapter and compare. + assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001) + assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001) pipe.load_lora_weights(tmpdirname) output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) + assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001) def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook onload_device = torch_device offload_device = torch.device("cpu") - - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - - components, _, _ = self.get_dummy_components() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + (components, _, _) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) check_if_lora_correctly_set(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - # Test group offloading with load_lora_weights + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) denoiser.enable_group_offload( onload_device=onload_device, offload_device=offload_device, @@ -2304,66 +1792,53 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): num_blocks_per_group=1, use_stream=use_stream, ) - # Place other model-level components on `torch_device`. for _, component in pipe.components.items(): if isinstance(component, torch.nn.Module): component.to(torch_device) group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) - self.assertTrue(group_offload_hook_1 is not None) + assert group_offload_hook_1 is not None output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # Test group offloading after removing the lora pipe.unload_lora_weights() group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) - self.assertTrue(group_offload_hook_2 is not None) + assert group_offload_hook_2 is not None output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 - - # Add the lora again and check if group offloading works pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) check_if_lora_correctly_set(denoiser) group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) - self.assertTrue(group_offload_hook_3 is not None) + assert group_offload_hook_3 is not None output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3)) + assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001) @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) @require_torch_accelerator def test_group_offloading_inference_denoiser(self, offload_type, use_stream): for cls in inspect.getmro(self.__class__): if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: - # Skip this test if it is overwritten by child class. We need to do this because parameterized - # materializes the test methods on invocation which cannot be overridden. return self._test_group_offloading_inference_denoiser(offload_type, use_stream) @require_torch_accelerator def test_lora_loading_model_cpu_offload(self): - components, _, denoiser_lora_config = self.get_dummy_components() - _, _, inputs = self.get_dummy_inputs(with_generator=False) + (components, _, denoiser_lora_config) = self.get_dummy_components() + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) - # reinitialize the pipeline to mimic the inference workflow. - components, _, denoiser_lora_config = self.get_dummy_components() + (components, _, denoiser_lora_config) = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.enable_model_cpu_offload(device=torch_device) pipe.load_lora_weights(tmpdirname) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3)) + assert np.allclose(output_lora, output_lora_loaded, atol=0.001, rtol=0.001) diff --git a/tests/lora/utils.py.bak b/tests/lora/utils.py.bak new file mode 100644 index 000000000000..077eb202c47c --- /dev/null +++ b/tests/lora/utils.py.bak @@ -0,0 +1,2328 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +import re +import tempfile +import unittest +from itertools import product + +import numpy as np +import pytest +import torch +from parameterized import parameterized + +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, +) +from diffusers.utils import logging +from diffusers.utils.import_utils import is_peft_available + +from ..testing_utils import ( + CaptureLogger, + check_if_dicts_are_equal, + floats_tensor, + is_torch_version, + require_peft_backend, + require_peft_version_greater, + require_torch_accelerator, + require_transformers_version_greater, + skip_mps, + torch_device, +) + + +if is_peft_available(): + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils import get_peft_model_state_dict + + +def state_dicts_almost_equal(sd1, sd2): + sd1 = dict(sorted(sd1.items())) + sd2 = dict(sorted(sd2.items())) + + models_are_equal = True + for ten1, ten2 in zip(sd1.values(), sd2.values()): + if (ten1 - ten2).abs().max() > 1e-3: + models_are_equal = False + + return models_are_equal + + +def check_if_lora_correctly_set(model) -> bool: + """ + Checks if the LoRA layers are correctly set with peft + """ + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + +def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str): + extracted = { + k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.") + } + check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"]) + + +def initialize_dummy_state_dict(state_dict): + if not all(v.device.type == "meta" for _, v in state_dict.items()): + raise ValueError("`state_dict` has non-meta values.") + return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()} + + +POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] + + +def determine_attention_kwargs_name(pipeline_class): + call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys() + + # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release + for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: + if possible_attention_kwargs in call_signature_keys: + attention_kwargs_name = possible_attention_kwargs + break + assert attention_kwargs_name is not None + return attention_kwargs_name + + +@require_peft_backend +class PeftLoraLoaderMixinTests: + pipeline_class = None + + scheduler_cls = None + scheduler_kwargs = None + + has_two_text_encoders = False + has_three_text_encoders = False + text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, "" + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, "" + text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, "" + tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" + tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" + + unet_kwargs = None + transformer_cls = None + transformer_kwargs = None + vae_cls = AutoencoderKL + vae_kwargs = None + + text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + + @pytest.fixture(scope="class") + def base_pipe_output(self): + return self._compute_baseline_output() + + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): + if self.unet_kwargs and self.transformer_kwargs: + raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") + if self.has_two_text_encoders and self.has_three_text_encoders: + raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") + + scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls + rank = 4 + lora_alpha = rank if lora_alpha is None else lora_alpha + + torch.manual_seed(0) + if self.unet_kwargs is not None: + unet = UNet2DConditionModel(**self.unet_kwargs) + else: + transformer = self.transformer_cls(**self.transformer_kwargs) + + scheduler = scheduler_cls(**self.scheduler_kwargs) + + torch.manual_seed(0) + vae = self.vae_cls(**self.vae_kwargs) + + text_encoder = self.text_encoder_cls.from_pretrained( + self.text_encoder_id, subfolder=self.text_encoder_subfolder + ) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) + + if self.text_encoder_2_cls is not None: + text_encoder_2 = self.text_encoder_2_cls.from_pretrained( + self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder + ) + tokenizer_2 = self.tokenizer_2_cls.from_pretrained( + self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder + ) + + if self.text_encoder_3_cls is not None: + text_encoder_3 = self.text_encoder_3_cls.from_pretrained( + self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder + ) + tokenizer_3 = self.tokenizer_3_cls.from_pretrained( + self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder + ) + + text_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=self.text_encoder_target_modules, + init_lora_weights=False, + use_dora=use_dora, + ) + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=self.denoiser_target_modules, + init_lora_weights=False, + use_dora=use_dora, + ) + + pipeline_components = { + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + # Denoiser + if self.unet_kwargs is not None: + pipeline_components.update({"unet": unet}) + elif self.transformer_kwargs is not None: + pipeline_components.update({"transformer": transformer}) + + # Remaining text encoders. + if self.text_encoder_2_cls is not None: + pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2}) + if self.text_encoder_3_cls is not None: + pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3}) + + # Remaining stuff + init_params = inspect.signature(self.pipeline_class.__init__).parameters + if "safety_checker" in init_params: + pipeline_components.update({"safety_checker": None}) + if "feature_extractor" in init_params: + pipeline_components.update({"feature_extractor": None}) + if "image_encoder" in init_params: + pipeline_components.update({"image_encoder": None}) + + return pipeline_components, text_lora_config, denoiser_lora_config + + @property + def output_shape(self): + raise NotImplementedError + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 5, + "guidance_scale": 6.0, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def _compute_baseline_output(self): + components, _, _ = self.get_dummy_components(self.scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Always ensure the inputs are without the `generator`. Make sure to pass the `generator` + # explicitly. + _, _, inputs = self.get_dummy_inputs(with_generator=False) + return pipe(**inputs, generator=torch.manual_seed(0))[0] + + def _get_lora_state_dicts(self, modules_to_save): + state_dicts = {} + for module_name, module in modules_to_save.items(): + if module is not None: + state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) + return state_dicts + + def _get_lora_adapter_metadata(self, modules_to_save): + metadatas = {} + for module_name, module in modules_to_save.items(): + if module is not None: + metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict() + return metadatas + + def _get_modules_to_save(self, pipe, has_denoiser=False): + modules_to_save = {} + lora_loadable_modules = self.pipeline_class._lora_loadable_modules + + if ( + "text_encoder" in lora_loadable_modules + and hasattr(pipe, "text_encoder") + and getattr(pipe.text_encoder, "peft_config", None) is not None + ): + modules_to_save["text_encoder"] = pipe.text_encoder + + if ( + "text_encoder_2" in lora_loadable_modules + and hasattr(pipe, "text_encoder_2") + and getattr(pipe.text_encoder_2, "peft_config", None) is not None + ): + modules_to_save["text_encoder_2"] = pipe.text_encoder_2 + + if has_denoiser: + if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): + modules_to_save["unet"] = pipe.unet + + if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): + modules_to_save["transformer"] = pipe.transformer + + return modules_to_save + + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): + if text_lora_config is not None: + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + + if denoiser_lora_config is not None: + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + else: + denoiser = None + + if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + + return pipe, denoiser + + def test_simple_inference(self, base_pipe_output): + """ + Tests a simple inference and makes sure it works as expected + """ + assert base_pipe_output.shape == self.output_shape + + def test_simple_inference_with_text_lora(self, base_pipe_output): + """ + Tests a simple inference with lora attached on the text encoder + and makes sure it works as expected + """ + components, text_lora_config, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" + + @require_peft_version_greater("0.13.1") + def test_low_cpu_mem_usage_with_injection(self): + """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." + assert ( + "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, + "The LoRA params should be on 'meta' device.", + ) + + te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) + set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) + assert ( + "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, + "No param should be on 'meta' device.", + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + assert ( + "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." + ) + + denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) + set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) + assert ( + "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." + ) + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + assert ( + "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, + "The LoRA params should be on 'meta' device.", + ) + + te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) + set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) + assert ( + "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, + "No param should be on 'meta' device.", + ) + + _, _, inputs = self.get_dummy_inputs() + output_lora = pipe(**inputs)[0] + assert output_lora.shape == self.output_shape + + @require_peft_version_greater("0.13.1") + @require_transformers_version_greater("4.45.2") + def test_low_cpu_mem_usage_with_loading(self): + """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) + + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + # Now, check for `low_cpu_mem_usage.` + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) + + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + + images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", + ) + + def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): + """ + Tests a simple inference with lora attached on the text encoder + scale argument + and makes sure it works as expected + """ + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) + components, text_lora_config, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) + + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} + output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + + assert ( + not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} + output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + + assert ( + np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), + "Lora + 0 scale should lead to same result as no LoRA", + ) + + def test_simple_inference_with_text_lora_fused(self, base_pipe_output): + """ + Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model + and makes sure it works as expected + """ + components, text_lora_config, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + pipe.fuse_lora() + # Fusing should still keep the LoRA layers + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not( + np.allclose(ouput_fused, base_pipe_output, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + ) + + def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): + """ + Tests a simple inference with lora attached to text encoder, then unloads the lora weights + and makes sure it works as expected + """ + components, text_lora_config, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + pipe.unload_lora_weights() + # unloading should remove the LoRA layers + assert not(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert not( + check_if_lora_correctly_set(pipe.text_encoder_2), + "Lora not correctly unloaded in text encoder 2", + ) + + ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + np.allclose(ouput_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), + "Fused lora should change the output", + ) + + def test_simple_inference_with_text_lora_save_load(self): + """ + Tests a simple usecase where users could use saving utilities for LoRA. + """ + components, text_lora_config, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + def test_simple_inference_with_partial_text_lora(self, base_pipe_output): + """ + Tests a simple inference with lora attached on the text encoder + with different ranks and some adapters removed + and makes sure it works as expected + """ + components, _, _ = self.get_dummy_components() + # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). + text_lora_config = LoraConfig( + r=4, + rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, + lora_alpha=4, + target_modules=self.text_encoder_target_modules, + init_lora_weights=False, + use_dora=False, + ) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + state_dict = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` + # supports missing layers (PR#8324). + state_dict = { + f"text_encoder.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + if "text_model.encoder.layers.4" not in module_name + } + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + state_dict.update( + { + f"text_encoder_2.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + if "text_model.encoder.layers.4" not in module_name + } + ) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) + + # Unload lora and load it back using the pipe.load_lora_weights machinery + pipe.unload_lora_weights() + pipe.load_lora_weights(state_dict) + + output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), + "Removing adapters should change the output", + ) + + def test_simple_inference_save_pretrained_with_text_lora(self): + """ + Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained + """ + components, text_lora_config, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(torch_device) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + assert ( + check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), + "Lora not correctly set in text encoder", + ) + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert ( + check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), + "Lora not correctly set in text encoder 2", + ) + + images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + def test_simple_inference_with_text_denoiser_lora_save_load(self): + """ + Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output): + """ + Tests a simple inference with lora attached on the text encoder + Unet + scale argument + and makes sure it works as expected + """ + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) + + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} + output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + + assert ( + not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} + output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + + assert ( + np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), + "Lora + 0 scale should lead to same result as no LoRA", + ) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + assert ( + pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, + "The scaling parameter has not been correctly restored!", + ) + + def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output): + """ + Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model + and makes sure it works as expected - with unet + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) + + # Fusing should still keep the LoRA layers + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser" + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not( + np.allclose(output_fused, base_pipe_output, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + ) + + def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output): + """ + Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights + and makes sure it works as expected + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + pipe.unload_lora_weights() + # unloading should remove the LoRA layers + assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" + assert not check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser" + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert not( + check_if_lora_correctly_set(pipe.text_encoder_2), + "Lora not correctly unloaded in text encoder 2", + ) + + output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + np.allclose(output_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), + "Fused lora should change the output", + ) + + def test_simple_inference_with_text_denoiser_lora_unfused( + self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + ): + """ + Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights + and makes sure it works as expected + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) + assert pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" + output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + assert pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" + output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # unloading should remove the LoRA layers + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" + + assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" + ) + + # Fuse and unfuse should lead to the same results + assert ( + np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), + "Fused lora should not change the output", + ) + + def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_output): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + multiple adapters and set them + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.set_adapters("adapter-1") + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not( + np.allclose(base_pipe_output, output_adapter_1, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) + + pipe.set_adapters("adapter-2") + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not( + np.allclose(base_pipe_output, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) + + pipe.set_adapters(["adapter-1", "adapter-2"]) + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not( + np.allclose(base_pipe_output, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) + + # Fuse and unfuse should lead to the same results + assert not( + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + assert not( + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 1 and mixed adapters should give different results", + ) + + assert not( + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 2 and mixed adapters should give different results", + ) + + pipe.disable_lora() + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + def test_wrong_adapter_name_raises_error(self): + adapter_name = "adapter-1" + + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name + ) + + with pytest.raises(ValueError) as err_context: + pipe.set_adapters("test") + + assert "not in the list of present adapters" in str(err_context.value) + + # test this works. + pipe.set_adapters(adapter_name) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + + def test_multiple_wrong_adapter_name_raises_error(self): + adapter_name = "adapter-1" + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name + ) + + scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} + logger = logging.get_logger("diffusers.loaders.lora_base") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) + + wrong_components = sorted(set(scale_with_wrong_components.keys())) + msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " + assert msg in str(cap_logger.out) + + # test this works. + pipe.set_adapters(adapter_name) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + + def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + one adapter and set different weights for different blocks (i.e. block lora) + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + weights_1 = {"text_encoder": 2, "unet": {"down": 5}} + pipe.set_adapters("adapter-1", weights_1) + output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + weights_2 = {"unet": {"up": 5}} + pipe.set_adapters("adapter-1", weights_2) + output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not( + np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), + "LoRA weights 1 and 2 should give different results", + ) + assert not( + np.allclose(base_pipe_output, output_weights_1, atol=1e-3, rtol=1e-3), + "No adapter and LoRA weights 1 should give different results", + ) + assert not( + np.allclose(base_pipe_output, output_weights_2, atol=1e-3, rtol=1e-3), + "No adapter and LoRA weights 2 should give different results", + ) + + pipe.disable_lora() + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + multiple adapters and set different weights for different blocks (i.e. block lora) + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + scales_1 = {"text_encoder": 2, "unet": {"down": 5}} + scales_2 = {"unet": {"down": 5, "mid": 5}} + + pipe.set_adapters("adapter-1", scales_1) + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters("adapter-2", scales_2) + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # Fuse and unfuse should lead to the same results + assert not( + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + assert not( + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 1 and mixed adapters should give different results", + ) + + assert not( + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 2 and mixed adapters should give different results", + ) + + pipe.disable_lora() + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + # a mismatching number of adapter_names and adapter_weights should raise an error + with pytest.raises(ValueError): + pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) + + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" + + def updown_options(blocks_with_tf, layers_per_block, value): + """ + Generate every possible combination for how a lora weight dict for the up/down part can be. + E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ... + """ + num_val = value + list_val = [value] * layers_per_block + + node_opts = [None, num_val, list_val] + node_opts_foreach_block = [node_opts] * len(blocks_with_tf) + + updown_opts = [num_val] + for nodes in product(*node_opts_foreach_block): + if all(n is None for n in nodes): + continue + opt = {} + for b, n in zip(blocks_with_tf, nodes): + if n is not None: + opt["block_" + str(b)] = n + updown_opts.append(opt) + return updown_opts + + def all_possible_dict_opts(unet, value): + """ + Generate every possible combination for how a lora weight dict can be. + E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ... + """ + + down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")] + up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")] + + layers_per_block = unet.config.layers_per_block + + text_encoder_opts = [None, value] + text_encoder_2_opts = [None, value] + mid_opts = [None, value] + down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value) + up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value) + + opts = [] + + for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts): + if all(o is None for o in (t1, t2, d, m, u)): + continue + opt = {} + if t1 is not None: + opt["text_encoder"] = t1 + if t2 is not None: + opt["text_encoder_2"] = t2 + if all(o is None for o in (d, m, u)): + # no unet scaling + continue + opt["unet"] = {} + if d is not None: + opt["unet"]["down"] = d + if m is not None: + opt["unet"]["mid"] = m + if u is not None: + opt["unet"]["up"] = u + opts.append(opt) + + return opts + + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + + if self.has_two_text_encoders or self.has_three_text_encoders: + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + + for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): + # test if lora block scales can be set with this scale_dict + if not self.has_two_text_encoders and "text_encoder_2" in scale_dict: + del scale_dict["text_encoder_2"] + + pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error + + def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + multiple adapters and set/delete them + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + if self.has_two_text_encoders or self.has_three_text_encoders: + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.set_adapters("adapter-1") + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters("adapter-2") + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters(["adapter-1", "adapter-2"]) + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not( + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + assert not( + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 1 and mixed adapters should give different results", + ) + + assert not( + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 2 and mixed adapters should give different results", + ) + + pipe.delete_adapters("adapter-1") + output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + pipe.delete_adapters("adapter-2") + output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + pipe.set_adapters(["adapter-1", "adapter-2"]) + pipe.delete_adapters(["adapter-1", "adapter-2"]) + + output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + multiple adapters and set them + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + if self.has_two_text_encoders or self.has_three_text_encoders: + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.set_adapters("adapter-1") + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters("adapter-2") + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters(["adapter-1", "adapter-2"]) + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # Fuse and unfuse should lead to the same results + assert not( + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + assert not( + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 1 and mixed adapters should give different results", + ) + + assert not( + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 2 and mixed adapters should give different results", + ) + + pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) + output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not( + np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Weighted adapter and mixed adapter should give different results", + ) + + pipe.disable_lora() + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + @skip_mps + @pytest.mark.xfail( + condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=False, + ) + def test_lora_fuse_nan(self): + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + if self.unet_kwargs: + pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float( + "inf" + ) + else: + named_modules = [name for name, _ in pipe.transformer.named_modules()] + possible_tower_names = [ + "transformer_blocks", + "blocks", + "joint_transformer_blocks", + "single_transformer_blocks", + ] + filtered_tower_names = [ + tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) + ] + if len(filtered_tower_names) == 0: + reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}." + raise ValueError(reason) + for tower_name in filtered_tower_names: + transformer_tower = getattr(pipe.transformer, tower_name) + has_attn1 = any("attn1" in name for name in named_modules) + if has_attn1: + transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") + else: + transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with pytest.raises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + out = pipe(**inputs)[0] + + assert np.isnan(out).all() + + def test_get_adapters(self): + """ + Tests a simple usecase where we attach multiple adapters and check if the results + are the expected results + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + + adapter_names = pipe.get_active_adapters() + assert adapter_names == ["adapter-1"] + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + + adapter_names = pipe.get_active_adapters() + assert adapter_names == ["adapter-2"] + + pipe.set_adapters(["adapter-1", "adapter-2"]) + assert sorted(pipe.get_active_adapters()) == ["adapter-1", "adapter-2"] + + def test_get_list_adapters(self): + """ + Tests a simple usecase where we attach multiple adapters and check if the results + are the expected results + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # 1. + dicts_to_be_checked = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + dicts_to_be_checked = {"text_encoder": ["adapter-1"]} + + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + dicts_to_be_checked.update({"unet": ["adapter-1"]}) + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + dicts_to_be_checked.update({"transformer": ["adapter-1"]}) + + assert pipe.get_list_adapters() == dicts_to_be_checked + + # 2. + dicts_to_be_checked = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") + dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) + + assert pipe.get_list_adapters() == dicts_to_be_checked + + # 3. + pipe.set_adapters(["adapter-1", "adapter-2"]) + + dicts_to_be_checked = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + + if self.unet_kwargs is not None: + dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) + else: + dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) + + assert pipe.get_list_adapters() == dicts_to_be_checked + + # 4. + dicts_to_be_checked = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") + dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") + dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) + + assert pipe.get_list_adapters() == dicts_to_be_checked + + def test_simple_inference_with_text_lora_denoiser_fused_multi( + self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + ): + """ + Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model + and makes sure it works as expected - with unet and multi-adapter case + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + + if self.has_two_text_encoders or self.has_three_text_encoders: + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + + # set them to multi-adapter inference mode + pipe.set_adapters(["adapter-1", "adapter-2"]) + outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters(["adapter-1"]) + outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) + assert pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" + + # Fusing should still keep the LoRA layers so output should remain the same + outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), + "Fused lora should not change the output", + ) + + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + assert pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" + + assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" + ) + + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]) + assert pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" + + # Fusing should still keep the LoRA layers + output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), + "Fused lora should not change the output", + ) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + assert pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" + + def test_lora_scale_kwargs_match_fusion(self, base_pipe_output, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) + + for lora_scale in [1.0, 0.8]: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + assert ( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + if self.has_two_text_encoders or self.has_three_text_encoders: + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + assert ( + check_if_lora_correctly_set(pipe.text_encoder_2), + "Lora not correctly set in text encoder 2", + ) + + pipe.set_adapters(["adapter-1"]) + attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} + outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + + pipe.fuse_lora( + components=self.pipeline_class._lora_loadable_modules, + adapter_names=["adapter-1"], + lora_scale=lora_scale, + ) + assert pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" + + outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), + "Fused lora should not change the output", + ) + assert not( + np.allclose(base_pipe_output, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), + "LoRA should change the output", + ) + + def test_simple_inference_with_dora(self): + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert output_no_dora_lora.shape == self.output_shape + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not( + np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), + "DoRA lora should change the output", + ) + + def test_missing_keys_warning(self): + # Skip text encoder check for now as that is handled with `transformers`. + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + pipe.unload_lora_weights() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + + # To make things dynamic since we cannot settle with a single key for all the models where we + # offer PEFT support. + missing_key = [k for k in state_dict if "lora_A" in k][0] + del state_dict[missing_key] + + logger = logging.get_logger("diffusers.utils.peft_utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(state_dict) + + # Since the missing key won't contain the adapter name ("default_0"). + # Also strip out the component prefix (such as "unet." from `missing_key`). + component = list({k.split(".")[0] for k in state_dict})[0] + assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", "")) + + def test_unexpected_keys_warning(self): + # Skip text encoder check for now as that is handled with `transformers`. + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + pipe.unload_lora_weights() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + + unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" + state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) + + logger = logging.get_logger("diffusers.utils.peft_utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(state_dict) + + assert ".diffusers_cat" in cap_logger.out + + @unittest.skip("This is failing for now - need to investigate") + def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): + """ + Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights + and makes sure it works as expected + """ + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) + + if self.has_two_text_encoders or self.has_three_text_encoders: + pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) + + # Just makes sure it works. + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + + def test_modify_padding_mode(self): + def set_pad_mode(network, mode="circular"): + for _, module in network.named_modules(): + if isinstance(module, torch.nn.Conv2d): + module.padding_mode = mode + + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _pad_mode = "circular" + set_pad_mode(pipe.vae, _pad_mode) + set_pad_mode(pipe.unet, _pad_mode) + + _, _, inputs = self.get_dummy_inputs() + _ = pipe(**inputs)[0] + + def test_logs_info_when_no_lora_keys_found(self, base_pipe_output): + # Skip text encoder check for now as that is handled with `transformers`. + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} + logger = logging.get_logger("diffusers.loaders.peft") + logger.setLevel(logging.WARNING) + + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(no_op_state_dict) + out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0] + + denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") + assert cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}") + assert np.allclose(base_pipe_output, out_after_lora_attempt, atol=1e-5, rtol=1e-5) + + # test only for text encoder + for lora_module in self.pipeline_class._lora_loadable_modules: + if "text_encoder" in lora_module: + text_encoder = getattr(pipe, lora_module) + if lora_module == "text_encoder": + prefix = "text_encoder" + elif lora_module == "text_encoder_2": + prefix = "text_encoder_2" + + logger = logging.get_logger("diffusers.loaders.lora_base") + logger.setLevel(logging.WARNING) + + with CaptureLogger(logger) as cap_logger: + self.pipeline_class.load_lora_into_text_encoder( + no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix + ) + + assert ( + cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") + ) + + def test_set_adapters_match_attention_kwargs(self, base_pipe_output): + """Test to check if outputs after `set_adapters()` and attention kwargs match.""" + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + lora_scale = 0.5 + attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} + output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + assert not( + np.allclose(base_pipe_output, output_lora_scale, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + + pipe.set_adapters("default", lora_scale) + output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( + not np.allclose(base_pipe_output, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + assert ( + np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), + "Lora + scale should match the output of `set_adapters()`.", + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + + output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + assert ( + not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + assert ( + np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results as attention_kwargs.", + ) + assert ( + np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results as set_adapters().", + ) + + @require_peft_version_greater("0.13.2") + def test_lora_B_bias(self): + # Currently, this test is only relevant for Flux Control LoRA as we are not + # aware of any other LoRA checkpoint that has its `lora_B` biases trained. + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # keep track of the bias values of the base layers to perform checks later. + bias_values = {} + denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer + for name, module in denoiser.named_modules(): + if any(k in name for k in self.denoiser_target_modules): + if module.bias is not None: + bias_values[name] = module.bias.data.clone() + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + denoiser_lora_config.lora_bias = False + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.delete_adapters("adapter-1") + + denoiser_lora_config.lora_bias = True + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert not np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3) + assert not np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) + + def test_correct_lora_configs_with_different_ranks(self): + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + + if self.unet_kwargs is not None: + pipe.unet.delete_adapters("adapter-1") + else: + pipe.transformer.delete_adapters("adapter-1") + + denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer + for name, _ in denoiser.named_modules(): + if "to_k" in name and "attn" in name and "lora" not in name: + module_name_to_rank_update = name.replace(".base_layer.", ".") + break + + # change the rank_pattern + updated_rank = denoiser_lora_config.r * 2 + denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} + + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern + + assert updated_rank_pattern == {module_name_to_rank_update: updated_rank} + + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3) + + if self.unet_kwargs is not None: + pipe.unet.delete_adapters("adapter-1") + else: + pipe.transformer.delete_adapters("adapter-1") + + # similarly change the alpha_pattern + updated_alpha = denoiser_lora_config.lora_alpha * 2 + denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + assert ( + pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} + ) + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + assert ( + pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} + ) + + lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3) + + def test_layerwise_casting_inference_denoiser(self): + from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS + from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN + + def check_linear_dtype(module, storage_dtype, compute_dtype): + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN + if getattr(module, "_skip_layerwise_casting_patterns", None) is not None: + patterns_to_check += tuple(module._skip_layerwise_casting_patterns) + for name, submodule in module.named_modules(): + if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): + continue + dtype_to_check = storage_dtype + if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check): + dtype_to_check = compute_dtype + if getattr(submodule, "weight", None) is not None: + self.assertEqual(submodule.weight.dtype, dtype_to_check) + if getattr(submodule, "bias", None) is not None: + self.assertEqual(submodule.bias.dtype, dtype_to_check) + + def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + if storage_dtype is not None: + denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + check_linear_dtype(denoiser, storage_dtype, compute_dtype) + + return pipe + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe_fp32 = initialize_pipeline(storage_dtype=None) + pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] + + pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) + pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] + + pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] + + @require_peft_version_greater("0.14.0") + def test_layerwise_casting_peft_input_autocast_denoiser(self): + r""" + A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This + is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise + cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`). + In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0, + this test will fail with the following error: + + ``` + RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float + ``` + + See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. + """ + + from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS + from diffusers.hooks.layerwise_casting import ( + _PEFT_AUTOCAST_DISABLE_HOOK, + DEFAULT_SKIP_MODULES_PATTERN, + apply_layerwise_casting, + ) + + storage_dtype = torch.float8_e4m3fn + compute_dtype = torch.float32 + + def check_module(denoiser): + # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser) + for name, module in denoiser.named_modules(): + if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS): + continue + dtype_to_check = storage_dtype + if any(re.search(pattern, name) for pattern in patterns_to_check): + dtype_to_check = compute_dtype + if getattr(module, "weight", None) is not None: + self.assertEqual(module.weight.dtype, dtype_to_check) + if getattr(module, "bias", None) is not None: + self.assertEqual(module.bias.dtype, dtype_to_check) + if isinstance(module, BaseTunerLayer): + assert getattr(module, "_diffusers_hook", None is not None) + assert module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None + + # 1. Test forward with add_adapter + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN + if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None: + patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns) + + apply_layerwise_casting( + denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check + ) + check_module(denoiser) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe(**inputs, generator=torch.manual_seed(0))[0] + + # 2. Test forward with load_lora_weights + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + apply_layerwise_casting( + denoiser, + storage_dtype=storage_dtype, + compute_dtype=compute_dtype, + skip_modules_pattern=patterns_to_check, + ) + check_module(denoiser) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe(**inputs, generator=torch.manual_seed(0))[0] + + @parameterized.expand([4, 8, 16]) + def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) + pipe = self.pipeline_class(**components) + + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + + out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) + if len(out) == 3: + _, _, parsed_metadata = out + elif len(out) == 2: + _, parsed_metadata = out + + denoiser_key = ( + f"{self.pipeline_class.transformer_name}" + if self.transformer_kwargs is not None + else f"{self.pipeline_class.unet_name}" + ) + assert any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + ) + + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + text_encoder_key = self.pipeline_class.text_encoder_name + assert any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key + ) + + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_key = "text_encoder_2" + assert any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key + ) + + @parameterized.expand([4, 8, 16]) + def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdir) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert ( + np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." + ) + + def test_lora_unload_add_adapter(self): + """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # unload and then add. + pipe.unload_lora_weights() + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + + def test_inference_load_delete_load_adapters(self, base_pipe_output): + "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + if self.has_two_text_encoders or self.has_three_text_encoders: + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + + + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + # First, delete adapter and compare. + pipe.delete_adapters(pipe.get_active_adapters()[0]) + output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3) + assert np.allclose(base_pipe_output, output_no_adapter, atol=1e-3, rtol=1e-3) + + # Then load adapter and compare. + pipe.load_lora_weights(tmpdirname) + output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3) + + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): + from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook + + onload_device = torch_device + offload_device = torch.device("cpu") + + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + # Test group offloading with load_lora_weights + denoiser.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=1, + use_stream=use_stream, + ) + # Place other model-level components on `torch_device`. + for _, component in pipe.components.items(): + if isinstance(component, torch.nn.Module): + component.to(torch_device) + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_1 is not None + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # Test group offloading after removing the lora + pipe.unload_lora_weights() + group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_2 is not None + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 + + # Add the lora again and check if group offloading works + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_3 is not None + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + assert np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3) + + @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) + @require_torch_accelerator + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + for cls in inspect.getmro(self.__class__): + if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + return + self._test_group_offloading_inference_denoiser(offload_type, use_stream) + + @require_torch_accelerator + def test_lora_loading_model_cpu_offload(self): + components, _, denoiser_lora_config = self.get_dummy_components() + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + # reinitialize the pipeline to mimic the inference workflow. + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.enable_model_cpu_offload(device=torch_device) + pipe.load_lora_weights(tmpdirname) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3) From 9e92f6bb633769669144b29df258406419d45d43 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 12:53:37 +0530 Subject: [PATCH 02/17] up --- fix_asserts_lora.py | 114 -- tests/lora/test_lora_layers_flux.py.bak | 1041 ---------- tests/lora/utils.py.bak | 2328 ----------------------- 3 files changed, 3483 deletions(-) delete mode 100644 fix_asserts_lora.py delete mode 100644 tests/lora/test_lora_layers_flux.py.bak delete mode 100644 tests/lora/utils.py.bak diff --git a/fix_asserts_lora.py b/fix_asserts_lora.py deleted file mode 100644 index 32259574f3a7..000000000000 --- a/fix_asserts_lora.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -""" -Fix F631-style asserts of the form: - assert (, "message") -…into: - assert , "message" - -Scans recursively under tests/lora/. - -Usage: - python fix_assert_tuple.py [--root tests/lora] [--dry-run] -""" - -import argparse -import ast -from pathlib import Path -from typing import Tuple, List, Optional - - -class AssertTupleFixer(ast.NodeTransformer): - """ - Transform `assert (, )` into `assert , `. - We only rewrite when the assert test is a Tuple with exactly 2 elements. - """ - def __init__(self): - super().__init__() - self.fixed_locs: List[Tuple[int, int]] = [] - - def visit_Assert(self, node: ast.Assert) -> ast.AST: - self.generic_visit(node) - if isinstance(node.test, ast.Tuple) and len(node.test.elts) == 2: - cond, msg = node.test.elts - # Convert only if this *looks* like a real assert-with-message tuple, - # i.e. keep anything as msg (string, f-string, name, call, etc.) - new_node = ast.Assert(test=cond, msg=msg) - ast.copy_location(new_node, node) - ast.fix_missing_locations(new_node) - self.fixed_locs.append((node.lineno, node.col_offset)) - return new_node - return node - - -def fix_file(path: Path, dry_run: bool = False) -> int: - """ - Returns number of fixes applied. - """ - try: - src = path.read_text(encoding="utf-8") - except Exception as e: - print(f"Could not read {path}: {e}") - return 0 - - try: - tree = ast.parse(src, filename=str(path)) - except SyntaxError: - # Skip files that don’t parse (partial edits, etc.) - return 0 - - fixer = AssertTupleFixer() - new_tree = fixer.visit(tree) - fixes = len(fixer.fixed_locs) - if fixes == 0: - return 0 - - try: - new_src = ast.unparse(new_tree) # Python 3.9+ - except Exception as e: - print(f"Failed to unparse {path}: {e}") - return 0 - - if dry_run: - for (lineno, col) in fixer.fixed_locs: - print(f"[DRY-RUN] {path}:{lineno}:{col} -> fixed assert tuple") - return fixes - - # Backup and write - backup = path.with_suffix(path.suffix + ".bak") - try: - if not backup.exists(): - backup.write_text(src, encoding="utf-8") - path.write_text(new_src, encoding="utf-8") - for (lineno, col) in fixer.fixed_locs: - print(f"Fixed {path}:{lineno}:{col}") - except Exception as e: - print(f"Failed to write {path}: {e}") - return 0 - - return fixes - - -def main(): - ap = argparse.ArgumentParser(description="Fix F631-style tuple asserts.") - ap.add_argument("--root", default="tests/lora", help="Root directory to scan") - ap.add_argument("--dry-run", action="store_true", help="Report changes but don't write") - args = ap.parse_args() - - root = Path(args.root) - if not root.exists(): - print(f"{root} does not exist.") - return - - total_files = 0 - total_fixes = 0 - for pyfile in root.rglob("*.py"): - total_files += 1 - total_fixes += fix_file(pyfile, dry_run=args.dry_run) - - print(f"\nScanned {total_files} file(s). Applied {total_fixes} fix(es).") - if args.dry_run: - print("Run again without --dry-run to apply changes.") - - -if __name__ == "__main__": - main() diff --git a/tests/lora/test_lora_layers_flux.py.bak b/tests/lora/test_lora_layers_flux.py.bak deleted file mode 100644 index ee0235266307..000000000000 --- a/tests/lora/test_lora_layers_flux.py.bak +++ /dev/null @@ -1,1041 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -import gc -import os -import sys -import tempfile -import unittest - -import numpy as np -import pytest -import safetensors.torch -import torch -from parameterized import parameterized -from PIL import Image -from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel - -from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel -from diffusers.utils import load_image, logging - -from ..testing_utils import ( - CaptureLogger, - backend_empty_cache, - floats_tensor, - is_peft_available, - nightly, - numpy_cosine_similarity_distance, - require_big_accelerator, - require_peft_backend, - require_torch_accelerator, - slow, - torch_device, -) - - -if is_peft_available(): - from peft.utils import get_peft_model_state_dict - -sys.path.append(".") - -from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 - - -@require_peft_backend -class TestFluxLoRA(PeftLoraLoaderMixinTests): - pipeline_class = FluxPipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler - scheduler_kwargs = {} - transformer_kwargs = { - "patch_size": 1, - "in_channels": 4, - "num_layers": 1, - "num_single_layers": 1, - "attention_head_dim": 16, - "num_attention_heads": 2, - "joint_attention_dim": 32, - "pooled_projection_dim": 32, - "axes_dims_rope": [4, 4, 8], - } - transformer_cls = FluxTransformer2DModel - vae_kwargs = { - "sample_size": 32, - "in_channels": 3, - "out_channels": 3, - "block_out_channels": (4,), - "layers_per_block": 1, - "latent_channels": 1, - "norm_num_groups": 1, - "use_quant_conv": False, - "use_post_quant_conv": False, - "shift_factor": 0.0609, - "scaling_factor": 1.5035, - } - has_two_text_encoders = True - tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" - tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" - text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" - text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" - - @property - def output_shape(self): - return (1, 8, 8, 3) - - def get_dummy_inputs(self, with_generator=True): - batch_size = 1 - sequence_length = 10 - num_channels = 4 - sizes = (32, 32) - - generator = torch.manual_seed(0) - noise = floats_tensor((batch_size, num_channels) + sizes) - input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - - pipeline_inputs = { - "prompt": "A painting of a squirrel eating a burger", - "num_inference_steps": 4, - "guidance_scale": 0.0, - "height": 8, - "width": 8, - "output_type": "np", - } - if with_generator: - pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs - - def test_with_alpha_in_state_dict(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe.transformer.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" - - images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - # modify the state dict to have alpha values following - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors - state_dict_with_alpha = safetensors.torch.load_file( - os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") - ) - alpha_dict = {} - for k, v in state_dict_with_alpha.items(): - # only do for `transformer` and for the k projections -- should be enough to test. - if "transformer" in k and "to_k" in k and "lora_A" in k: - alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) - state_dict_with_alpha.update(alpha_dict) - - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - - pipe.unload_lora_weights() - pipe.load_lora_weights(state_dict_with_alpha) - images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images - - assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), "Loading from saved checkpoints should give same results." - - assert not np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3) - - def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - # Modify the config to have a layer which won't be present in the second LoRA we will load. - modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) - modified_denoiser_lora_config.target_modules.add("x_embedder") - - pipe.transformer.add_adapter(modified_denoiser_lora_config) - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" - - images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - assert not( - np.allclose(images_lora, base_pipe_output, atol=1e-3, rtol=1e-3), - "LoRA should lead to different results.", - ) - - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") - - # Modify the state dict to exclude "x_embedder" related LoRA params. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} - - pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") - pipe.set_adapters(["one", "two"]) - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" - images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images - - assert not( - np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), - "Different LoRAs should lead to different results.", - ) - assert not( - np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), - "LoRA should lead to different results.", - ) - - def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - # Modify the config to have a layer which won't be present in the first LoRA we will load. - modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) - modified_denoiser_lora_config.target_modules.add("x_embedder") - - pipe.transformer.add_adapter(modified_denoiser_lora_config) - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" - - images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - assert not( - np.allclose(images_lora, base_pipe_output, atol=1e-3, rtol=1e-3), - "LoRA should lead to different results.", - ) - - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.unload_lora_weights() - # Modify the state dict to exclude "x_embedder" related LoRA params. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} - pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") - - # Load state dict with `x_embedder`. - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") - - pipe.set_adapters(["one", "two"]) - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" - images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images - - assert not( - np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), - "Different LoRAs should lead to different results.", - ) - assert not( - np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), - "LoRA should lead to different results.", - ) - - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - - @unittest.skip("Not supported in Flux.") - def test_modify_padding_mode(self): - pass - - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): - pass - - -class TestFluxControlLoRA(PeftLoraLoaderMixinTests): - pipeline_class = FluxControlPipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler - scheduler_kwargs = {} - transformer_kwargs = { - "patch_size": 1, - "in_channels": 8, - "out_channels": 4, - "num_layers": 1, - "num_single_layers": 1, - "attention_head_dim": 16, - "num_attention_heads": 2, - "joint_attention_dim": 32, - "pooled_projection_dim": 32, - "axes_dims_rope": [4, 4, 8], - } - transformer_cls = FluxTransformer2DModel - vae_kwargs = { - "sample_size": 32, - "in_channels": 3, - "out_channels": 3, - "block_out_channels": (4,), - "layers_per_block": 1, - "latent_channels": 1, - "norm_num_groups": 1, - "use_quant_conv": False, - "use_post_quant_conv": False, - "shift_factor": 0.0609, - "scaling_factor": 1.5035, - } - has_two_text_encoders = True - tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" - tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" - text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" - text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" - - @property - def output_shape(self): - return (1, 8, 8, 3) - - def get_dummy_inputs(self, with_generator=True): - batch_size = 1 - sequence_length = 10 - num_channels = 4 - sizes = (32, 32) - - generator = torch.manual_seed(0) - noise = floats_tensor((batch_size, num_channels) + sizes) - input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - - np.random.seed(0) - pipeline_inputs = { - "prompt": "A painting of a squirrel eating a burger", - "control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")), - "num_inference_steps": 4, - "guidance_scale": 0.0, - "height": 8, - "width": 8, - "output_type": "np", - } - if with_generator: - pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs - - def test_with_norm_in_state_dict(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.INFO) - - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: - norm_state_dict = {} - for name, module in pipe.transformer.named_modules(): - if norm_layer not in name or not hasattr(module, "weight") or module.weight is None: - continue - norm_state_dict[f"transformer.{name}.weight"] = torch.randn( - module.weight.shape, device=module.weight.device, dtype=module.weight.dtype - ) - - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(norm_state_dict) - lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert "The provided state dict contains normalization layers in addition to LoRA layers" in cap_logger.out - assert len(pipe.transformer._transformer_norm_layers) > 0 - - pipe.unload_lora_weights() - lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert pipe.transformer._transformer_norm_layers is None - assert np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5) - assert not( - np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested" - ) - - with CaptureLogger(logger) as cap_logger: - for key in list(norm_state_dict.keys()): - norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) - pipe.load_lora_weights(norm_state_dict) - - assert ( - "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out - ) - - def test_lora_parameter_expanded_shapes(self): - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.DEBUG) - - # Change the transformer config to mimic a real use case. - num_channels_without_control = 4 - transformer = FluxTransformer2DModel.from_config( - components["transformer"].config, in_channels=num_channels_without_control - ).to(torch_device) - assert transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}" - - - original_transformer_state_dict = pipe.transformer.state_dict() - x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") - incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) - assert ( - "x_embedder.weight" in incompatible_keys.missing_keys, - "Could not find x_embedder.weight in the missing keys." - ) - - transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) - pipe.transformer = transformer - - out_features, in_features = pipe.transformer.x_embedder.weight.shape - rank = 4 - - dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) - dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, - "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, - } - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(lora_state_dict, "adapter-1") - - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - - lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features - assert pipe.transformer.config.in_channels == 2 * in_features - assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") - - # Testing opposite direction where the LoRA params are zero-padded. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - dummy_lora_A = torch.nn.Linear(1, rank, bias=False) - dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, - "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, - } - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(lora_state_dict, "adapter-1") - - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - - lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features - assert pipe.transformer.config.in_channels == 2 * in_features - assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out - - def test_normal_lora_with_expanded_lora_raises_error(self): - # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then - # load shape expanded LoRA (such as Control LoRA). - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - - # Change the transformer config to mimic a real use case. - num_channels_without_control = 4 - transformer = FluxTransformer2DModel.from_config( - components["transformer"].config, in_channels=num_channels_without_control - ).to(torch_device) - components["transformer"] = transformer - - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.DEBUG) - - out_features, in_features = pipe.transformer.x_embedder.weight.shape - rank = 4 - - shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) - shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, - "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, - } - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(lora_state_dict, "adapter-1") - - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - assert pipe.get_active_adapters() == ["adapter-1"] - assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features - assert pipe.transformer.config.in_channels == 2 * in_features - assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) - normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, - "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, - } - - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(lora_state_dict, "adapter-2") - - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out - assert pipe.get_active_adapters() == ["adapter-2"] - - lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3) - - # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. - # This should raise a runtime error on input shapes being incompatible. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - # Change the transformer config to mimic a real use case. - num_channels_without_control = 4 - transformer = FluxTransformer2DModel.from_config( - components["transformer"].config, in_channels=num_channels_without_control - ).to(torch_device) - components["transformer"] = transformer - - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.DEBUG) - - out_features, in_features = pipe.transformer.x_embedder.weight.shape - rank = 4 - - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, - "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, - } - pipe.load_lora_weights(lora_state_dict, "adapter-1") - - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features - assert pipe.transformer.config.in_channels == in_features - - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, - "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, - } - - # We should check for input shapes being incompatible here. But because above mentioned issue is - # not a supported use case, and because of the PEFT renaming, we will currently have a shape - # mismatch error. - with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"): - pipe.load_lora_weights(lora_state_dict, "adapter-2") - - def test_fuse_expanded_lora_with_regular_lora(self): - # This test checks if it works when a lora with expanded shapes (like control loras) but - # another lora with correct shapes is loaded. The opposite direction isn't supported and is - # tested with it. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - - # Change the transformer config to mimic a real use case. - num_channels_without_control = 4 - transformer = FluxTransformer2DModel.from_config( - components["transformer"].config, in_channels=num_channels_without_control - ).to(torch_device) - components["transformer"] = transformer - - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.DEBUG) - - out_features, in_features = pipe.transformer.x_embedder.weight.shape - rank = 4 - - shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) - shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, - "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, - } - pipe.load_lora_weights(lora_state_dict, "adapter-1") - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) - normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, - "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, - } - - pipe.load_lora_weights(lora_state_dict, "adapter-2") - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - - lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) - lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) - assert not(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3)) - assert not(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3)) - - pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) - lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3) - - def test_load_regular_lora(self): - # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded - # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those - # transformers include Flux Fill, Flux Control, etc. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - out_features, in_features = pipe.transformer.x_embedder.weight.shape - rank = 4 - in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA. - normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) - normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, - "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, - } - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.INFO) - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(lora_state_dict, "adapter-1") - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - - lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out - assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 - assert not np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3) - - def test_lora_unload_with_parameter_expanded_shapes(self): - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.DEBUG) - - # Change the transformer config to mimic a real use case. - num_channels_without_control = 4 - transformer = FluxTransformer2DModel.from_config( - components["transformer"].config, in_channels=num_channels_without_control - ).to(torch_device) - assert ( - transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", - ) - - # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. - components["transformer"] = transformer - pipe = FluxPipeline(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - control_image = inputs.pop("control_image") - original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - control_pipe = self.pipeline_class(**components) - out_features, in_features = control_pipe.transformer.x_embedder.weight.shape - rank = 4 - - dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) - dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, - "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, - } - with CaptureLogger(logger) as cap_logger: - control_pipe.load_lora_weights(lora_state_dict, "adapter-1") - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - - inputs["control_image"] = control_image - lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4) - assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features - assert pipe.transformer.config.in_channels == 2 * in_features - assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") - - control_pipe.unload_lora_weights(reset_to_overwritten_params=True) - assert( - control_pipe.transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", - ) - loaded_pipe = FluxPipeline.from_pipe(control_pipe) - assert ( - loaded_pipe.transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}", - ) - inputs.pop("control_image") - unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4) - assert np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4) - assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features - assert pipe.transformer.config.in_channels == in_features - - def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.DEBUG) - - # Change the transformer config to mimic a real use case. - num_channels_without_control = 4 - transformer = FluxTransformer2DModel.from_config( - components["transformer"].config, in_channels=num_channels_without_control - ).to(torch_device) - assert ( - transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", - ) - - # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. - components["transformer"] = transformer - pipe = FluxPipeline(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - control_image = inputs.pop("control_image") - original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - control_pipe = self.pipeline_class(**components) - out_features, in_features = control_pipe.transformer.x_embedder.weight.shape - rank = 4 - - dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) - dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, - "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, - } - with CaptureLogger(logger) as cap_logger: - control_pipe.load_lora_weights(lora_state_dict, "adapter-1") - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - - inputs["control_image"] = control_image - lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features - assert pipe.transformer.config.in_channels == 2 * in_features - assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") - - control_pipe.unload_lora_weights(reset_to_overwritten_params=False) - assert( - control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", - ) - no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4) - assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 - assert pipe.transformer.config.in_channels == in_features * 2 - - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - - @unittest.skip("Not supported in Flux.") - def test_modify_padding_mode(self): - pass - - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): - pass - - -@slow -@nightly -@require_torch_accelerator -@require_peft_backend -@require_big_accelerator -class FluxLoRAIntegrationTests(unittest.TestCase): - """internal note: The integration slices were obtained on audace. - - torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the - assertions to pass. - """ - - num_inference_steps = 10 - seed = 0 - - def setUp(self): - super().setUp() - - gc.collect() - backend_empty_cache(torch_device) - - self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) - - def tearDown(self): - super().tearDown() - - del self.pipeline - gc.collect() - backend_empty_cache(torch_device) - - def test_flux_the_last_ben(self): - self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - # Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI - # run supports it. We have about 34GB RAM in the CI runner which kills the test when run with - # `enable_model_cpu_offload()`. We repeat this for the other tests, too. - self.pipeline = self.pipeline.to(torch_device) - - prompt = "jon snow eating pizza with ketchup" - - out = self.pipeline( - prompt, - num_inference_steps=self.num_inference_steps, - guidance_scale=4.0, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246]) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 - - def test_flux_kohya(self): - self.pipeline.load_lora_weights("Norod78/brain-slug-flux") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) - - prompt = "The cat with a brain slug earring" - out = self.pipeline( - prompt, - num_inference_steps=self.num_inference_steps, - guidance_scale=4.5, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - - out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484]) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 - - def test_flux_kohya_with_text_encoder(self): - self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) - - prompt = "optimus is cleaning the house with broomstick" - out = self.pipeline( - prompt, - num_inference_steps=self.num_inference_steps, - guidance_scale=4.5, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - - out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219]) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 - - def test_flux_kohya_embedders_conversion(self): - """Test that embedders load without throwing errors""" - self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") - self.pipeline.unload_lora_weights() - - assert True - - def test_flux_xlabs(self): - self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) - - prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" - - out = self.pipeline( - prompt, - num_inference_steps=self.num_inference_steps, - guidance_scale=3.5, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980]) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 - - def test_flux_xlabs_load_lora_with_single_blocks(self): - self.pipeline.load_lora_weights( - "salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors" - ) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline.enable_model_cpu_offload() - - prompt = "a wizard mouse playing chess" - - out = self.pipeline( - prompt, - num_inference_steps=self.num_inference_steps, - guidance_scale=3.5, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array( - [0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625] - ) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 - - -@nightly -@require_torch_accelerator -@require_peft_backend -@require_big_accelerator -class FluxControlLoRAIntegrationTests(unittest.TestCase): - num_inference_steps = 10 - seed = 0 - prompt = "A robot made of exotic candies and chocolates of different kinds." - - def setUp(self): - super().setUp() - - gc.collect() - backend_empty_cache(torch_device) - - self.pipeline = FluxControlPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 - ).to(torch_device) - - def tearDown(self): - super().tearDown() - - gc.collect() - backend_empty_cache(torch_device) - - @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) - def test_lora(self, lora_ckpt_id): - self.pipeline.load_lora_weights(lora_ckpt_id) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - - if "Canny" in lora_ckpt_id: - control_image = load_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" - ) - else: - control_image = load_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" - ) - - image = self.pipeline( - prompt=self.prompt, - control_image=control_image, - height=1024, - width=1024, - num_inference_steps=self.num_inference_steps, - guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - - out_slice = image[0, -3:, -3:, -1].flatten() - if "Canny" in lora_ckpt_id: - expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516]) - else: - expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 - - @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) - def test_lora_with_turbo(self, lora_ckpt_id): - self.pipeline.load_lora_weights(lora_ckpt_id) - self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - - if "Canny" in lora_ckpt_id: - control_image = load_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" - ) - else: - control_image = load_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" - ) - - image = self.pipeline( - prompt=self.prompt, - control_image=control_image, - height=1024, - width=1024, - num_inference_steps=self.num_inference_steps, - guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - - out_slice = image[0, -3:, -3:, -1].flatten() - if "Canny" in lora_ckpt_id: - expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484]) - else: - expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562]) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 diff --git a/tests/lora/utils.py.bak b/tests/lora/utils.py.bak deleted file mode 100644 index 077eb202c47c..000000000000 --- a/tests/lora/utils.py.bak +++ /dev/null @@ -1,2328 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import os -import re -import tempfile -import unittest -from itertools import product - -import numpy as np -import pytest -import torch -from parameterized import parameterized - -from diffusers import ( - AutoencoderKL, - UNet2DConditionModel, -) -from diffusers.utils import logging -from diffusers.utils.import_utils import is_peft_available - -from ..testing_utils import ( - CaptureLogger, - check_if_dicts_are_equal, - floats_tensor, - is_torch_version, - require_peft_backend, - require_peft_version_greater, - require_torch_accelerator, - require_transformers_version_greater, - skip_mps, - torch_device, -) - - -if is_peft_available(): - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - from peft.tuners.tuners_utils import BaseTunerLayer - from peft.utils import get_peft_model_state_dict - - -def state_dicts_almost_equal(sd1, sd2): - sd1 = dict(sorted(sd1.items())) - sd2 = dict(sorted(sd2.items())) - - models_are_equal = True - for ten1, ten2 in zip(sd1.values(), sd2.values()): - if (ten1 - ten2).abs().max() > 1e-3: - models_are_equal = False - - return models_are_equal - - -def check_if_lora_correctly_set(model) -> bool: - """ - Checks if the LoRA layers are correctly set with peft - """ - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - return True - return False - - -def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str): - extracted = { - k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.") - } - check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"]) - - -def initialize_dummy_state_dict(state_dict): - if not all(v.device.type == "meta" for _, v in state_dict.items()): - raise ValueError("`state_dict` has non-meta values.") - return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()} - - -POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] - - -def determine_attention_kwargs_name(pipeline_class): - call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys() - - # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release - for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: - if possible_attention_kwargs in call_signature_keys: - attention_kwargs_name = possible_attention_kwargs - break - assert attention_kwargs_name is not None - return attention_kwargs_name - - -@require_peft_backend -class PeftLoraLoaderMixinTests: - pipeline_class = None - - scheduler_cls = None - scheduler_kwargs = None - - has_two_text_encoders = False - has_three_text_encoders = False - text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, "" - text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, "" - text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, "" - tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" - tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" - tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" - - unet_kwargs = None - transformer_cls = None - transformer_kwargs = None - vae_cls = AutoencoderKL - vae_kwargs = None - - text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] - denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - - @pytest.fixture(scope="class") - def base_pipe_output(self): - return self._compute_baseline_output() - - def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): - if self.unet_kwargs and self.transformer_kwargs: - raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") - if self.has_two_text_encoders and self.has_three_text_encoders: - raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") - - scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls - rank = 4 - lora_alpha = rank if lora_alpha is None else lora_alpha - - torch.manual_seed(0) - if self.unet_kwargs is not None: - unet = UNet2DConditionModel(**self.unet_kwargs) - else: - transformer = self.transformer_cls(**self.transformer_kwargs) - - scheduler = scheduler_cls(**self.scheduler_kwargs) - - torch.manual_seed(0) - vae = self.vae_cls(**self.vae_kwargs) - - text_encoder = self.text_encoder_cls.from_pretrained( - self.text_encoder_id, subfolder=self.text_encoder_subfolder - ) - tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) - - if self.text_encoder_2_cls is not None: - text_encoder_2 = self.text_encoder_2_cls.from_pretrained( - self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder - ) - tokenizer_2 = self.tokenizer_2_cls.from_pretrained( - self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder - ) - - if self.text_encoder_3_cls is not None: - text_encoder_3 = self.text_encoder_3_cls.from_pretrained( - self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder - ) - tokenizer_3 = self.tokenizer_3_cls.from_pretrained( - self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder - ) - - text_lora_config = LoraConfig( - r=rank, - lora_alpha=lora_alpha, - target_modules=self.text_encoder_target_modules, - init_lora_weights=False, - use_dora=use_dora, - ) - - denoiser_lora_config = LoraConfig( - r=rank, - lora_alpha=lora_alpha, - target_modules=self.denoiser_target_modules, - init_lora_weights=False, - use_dora=use_dora, - ) - - pipeline_components = { - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - } - # Denoiser - if self.unet_kwargs is not None: - pipeline_components.update({"unet": unet}) - elif self.transformer_kwargs is not None: - pipeline_components.update({"transformer": transformer}) - - # Remaining text encoders. - if self.text_encoder_2_cls is not None: - pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2}) - if self.text_encoder_3_cls is not None: - pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3}) - - # Remaining stuff - init_params = inspect.signature(self.pipeline_class.__init__).parameters - if "safety_checker" in init_params: - pipeline_components.update({"safety_checker": None}) - if "feature_extractor" in init_params: - pipeline_components.update({"feature_extractor": None}) - if "image_encoder" in init_params: - pipeline_components.update({"image_encoder": None}) - - return pipeline_components, text_lora_config, denoiser_lora_config - - @property - def output_shape(self): - raise NotImplementedError - - def get_dummy_inputs(self, with_generator=True): - batch_size = 1 - sequence_length = 10 - num_channels = 4 - sizes = (32, 32) - - generator = torch.manual_seed(0) - noise = floats_tensor((batch_size, num_channels) + sizes) - input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - - pipeline_inputs = { - "prompt": "A painting of a squirrel eating a burger", - "num_inference_steps": 5, - "guidance_scale": 6.0, - "output_type": "np", - } - if with_generator: - pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs - - def _compute_baseline_output(self): - components, _, _ = self.get_dummy_components(self.scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # Always ensure the inputs are without the `generator`. Make sure to pass the `generator` - # explicitly. - _, _, inputs = self.get_dummy_inputs(with_generator=False) - return pipe(**inputs, generator=torch.manual_seed(0))[0] - - def _get_lora_state_dicts(self, modules_to_save): - state_dicts = {} - for module_name, module in modules_to_save.items(): - if module is not None: - state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) - return state_dicts - - def _get_lora_adapter_metadata(self, modules_to_save): - metadatas = {} - for module_name, module in modules_to_save.items(): - if module is not None: - metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict() - return metadatas - - def _get_modules_to_save(self, pipe, has_denoiser=False): - modules_to_save = {} - lora_loadable_modules = self.pipeline_class._lora_loadable_modules - - if ( - "text_encoder" in lora_loadable_modules - and hasattr(pipe, "text_encoder") - and getattr(pipe.text_encoder, "peft_config", None) is not None - ): - modules_to_save["text_encoder"] = pipe.text_encoder - - if ( - "text_encoder_2" in lora_loadable_modules - and hasattr(pipe, "text_encoder_2") - and getattr(pipe.text_encoder_2, "peft_config", None) is not None - ): - modules_to_save["text_encoder_2"] = pipe.text_encoder_2 - - if has_denoiser: - if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): - modules_to_save["unet"] = pipe.unet - - if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): - modules_to_save["transformer"] = pipe.transformer - - return modules_to_save - - def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): - if text_lora_config is not None: - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - - if denoiser_lora_config is not None: - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - else: - denoiser = None - - if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - - return pipe, denoiser - - def test_simple_inference(self, base_pipe_output): - """ - Tests a simple inference and makes sure it works as expected - """ - assert base_pipe_output.shape == self.output_shape - - def test_simple_inference_with_text_lora(self, base_pipe_output): - """ - Tests a simple inference with lora attached on the text encoder - and makes sure it works as expected - """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" - - @require_peft_version_greater("0.13.1") - def test_low_cpu_mem_usage_with_injection(self): - """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." - assert ( - "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, - "The LoRA params should be on 'meta' device.", - ) - - te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) - set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) - assert ( - "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, - "No param should be on 'meta' device.", - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - assert ( - "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." - ) - - denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) - set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) - assert ( - "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." - ) - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - assert ( - "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, - "The LoRA params should be on 'meta' device.", - ) - - te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) - set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) - assert ( - "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, - "No param should be on 'meta' device.", - ) - - _, _, inputs = self.get_dummy_inputs() - output_lora = pipe(**inputs)[0] - assert output_lora.shape == self.output_shape - - @require_peft_version_greater("0.13.1") - @require_transformers_version_greater("4.45.2") - def test_low_cpu_mem_usage_with_loading(self): - """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) - - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - # Now, check for `low_cpu_mem_usage.` - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) - - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - - images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", - ) - - def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): - """ - Tests a simple inference with lora attached on the text encoder + scale argument - and makes sure it works as expected - """ - attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} - output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - assert ( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - - attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} - output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - assert ( - np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", - ) - - def test_simple_inference_with_text_lora_fused(self, base_pipe_output): - """ - Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model - and makes sure it works as expected - """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - pipe.fuse_lora() - # Fusing should still keep the LoRA layers - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not( - np.allclose(ouput_fused, base_pipe_output, atol=1e-3, rtol=1e-3), "Fused lora should change the output" - ) - - def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): - """ - Tests a simple inference with lora attached to text encoder, then unloads the lora weights - and makes sure it works as expected - """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - pipe.unload_lora_weights() - # unloading should remove the LoRA layers - assert not(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert not( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) - - ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - np.allclose(ouput_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", - ) - - def test_simple_inference_with_text_lora_save_load(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA. - """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - def test_simple_inference_with_partial_text_lora(self, base_pipe_output): - """ - Tests a simple inference with lora attached on the text encoder - with different ranks and some adapters removed - and makes sure it works as expected - """ - components, _, _ = self.get_dummy_components() - # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). - text_lora_config = LoraConfig( - r=4, - rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, - lora_alpha=4, - target_modules=self.text_encoder_target_modules, - init_lora_weights=False, - use_dora=False, - ) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - state_dict = {} - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` - # supports missing layers (PR#8324). - state_dict = { - f"text_encoder.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() - if "text_model.encoder.layers.4" not in module_name - } - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - state_dict.update( - { - f"text_encoder_2.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() - if "text_model.encoder.layers.4" not in module_name - } - ) - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - # Unload lora and load it back using the pipe.load_lora_weights machinery - pipe.unload_lora_weights() - pipe.load_lora_weights(state_dict) - - output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), - "Removing adapters should change the output", - ) - - def test_simple_inference_save_pretrained_with_text_lora(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained - """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) - - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - assert ( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), - "Lora not correctly set in text encoder", - ) - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert ( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), - "Lora not correctly set in text encoder 2", - ) - - images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - def test_simple_inference_with_text_denoiser_lora_save_load(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output): - """ - Tests a simple inference with lora attached on the text encoder + Unet + scale argument - and makes sure it works as expected - """ - attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} - output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - assert ( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - - attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} - output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - assert ( - np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", - ) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - assert ( - pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, - "The scaling parameter has not been correctly restored!", - ) - - def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output): - """ - Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model - and makes sure it works as expected - with unet - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) - - # Fusing should still keep the LoRA layers - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser" - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not( - np.allclose(output_fused, base_pipe_output, atol=1e-3, rtol=1e-3), "Fused lora should change the output" - ) - - def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output): - """ - Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights - and makes sure it works as expected - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - pipe.unload_lora_weights() - # unloading should remove the LoRA layers - assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" - assert not check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser" - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert not( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) - - output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - np.allclose(output_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", - ) - - def test_simple_inference_with_text_denoiser_lora_unfused( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 - ): - """ - Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights - and makes sure it works as expected - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) - assert pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" - output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - assert pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" - output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # unloading should remove the LoRA layers - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" - - assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) - - # Fuse and unfuse should lead to the same results - assert ( - np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", - ) - - def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_output): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches - multiple adapters and set them - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - pipe.set_adapters("adapter-1") - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not( - np.allclose(base_pipe_output, output_adapter_1, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", - ) - - pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not( - np.allclose(base_pipe_output, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", - ) - - pipe.set_adapters(["adapter-1", "adapter-2"]) - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not( - np.allclose(base_pipe_output, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", - ) - - # Fuse and unfuse should lead to the same results - assert not( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", - ) - - assert not( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", - ) - - assert not( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", - ) - - pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) - - def test_wrong_adapter_name_raises_error(self): - adapter_name = "adapter-1" - - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name - ) - - with pytest.raises(ValueError) as err_context: - pipe.set_adapters("test") - - assert "not in the list of present adapters" in str(err_context.value) - - # test this works. - pipe.set_adapters(adapter_name) - _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - def test_multiple_wrong_adapter_name_raises_error(self): - adapter_name = "adapter-1" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name - ) - - scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} - logger = logging.get_logger("diffusers.loaders.lora_base") - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: - pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) - - wrong_components = sorted(set(scale_with_wrong_components.keys())) - msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " - assert msg in str(cap_logger.out) - - # test this works. - pipe.set_adapters(adapter_name) - _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches - one adapter and set different weights for different blocks (i.e. block lora) - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - weights_1 = {"text_encoder": 2, "unet": {"down": 5}} - pipe.set_adapters("adapter-1", weights_1) - output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - weights_2 = {"unet": {"up": 5}} - pipe.set_adapters("adapter-1", weights_2) - output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not( - np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), - "LoRA weights 1 and 2 should give different results", - ) - assert not( - np.allclose(base_pipe_output, output_weights_1, atol=1e-3, rtol=1e-3), - "No adapter and LoRA weights 1 should give different results", - ) - assert not( - np.allclose(base_pipe_output, output_weights_2, atol=1e-3, rtol=1e-3), - "No adapter and LoRA weights 2 should give different results", - ) - - pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) - - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches - multiple adapters and set different weights for different blocks (i.e. block lora) - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - scales_1 = {"text_encoder": 2, "unet": {"down": 5}} - scales_2 = {"unet": {"down": 5, "mid": 5}} - - pipe.set_adapters("adapter-1", scales_1) - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters("adapter-2", scales_2) - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # Fuse and unfuse should lead to the same results - assert not( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", - ) - - assert not( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", - ) - - assert not( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", - ) - - pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) - - # a mismatching number of adapter_names and adapter_weights should raise an error - with pytest.raises(ValueError): - pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) - - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" - - def updown_options(blocks_with_tf, layers_per_block, value): - """ - Generate every possible combination for how a lora weight dict for the up/down part can be. - E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ... - """ - num_val = value - list_val = [value] * layers_per_block - - node_opts = [None, num_val, list_val] - node_opts_foreach_block = [node_opts] * len(blocks_with_tf) - - updown_opts = [num_val] - for nodes in product(*node_opts_foreach_block): - if all(n is None for n in nodes): - continue - opt = {} - for b, n in zip(blocks_with_tf, nodes): - if n is not None: - opt["block_" + str(b)] = n - updown_opts.append(opt) - return updown_opts - - def all_possible_dict_opts(unet, value): - """ - Generate every possible combination for how a lora weight dict can be. - E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ... - """ - - down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")] - up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")] - - layers_per_block = unet.config.layers_per_block - - text_encoder_opts = [None, value] - text_encoder_2_opts = [None, value] - mid_opts = [None, value] - down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value) - up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value) - - opts = [] - - for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts): - if all(o is None for o in (t1, t2, d, m, u)): - continue - opt = {} - if t1 is not None: - opt["text_encoder"] = t1 - if t2 is not None: - opt["text_encoder_2"] = t2 - if all(o is None for o in (d, m, u)): - # no unet scaling - continue - opt["unet"] = {} - if d is not None: - opt["unet"]["down"] = d - if m is not None: - opt["unet"]["mid"] = m - if u is not None: - opt["unet"]["up"] = u - opts.append(opt) - - return opts - - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - - for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): - # test if lora block scales can be set with this scale_dict - if not self.has_two_text_encoders and "text_encoder_2" in scale_dict: - del scale_dict["text_encoder_2"] - - pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error - - def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches - multiple adapters and set/delete them - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - pipe.set_adapters("adapter-1") - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters(["adapter-1", "adapter-2"]) - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", - ) - - assert not( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", - ) - - assert not( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", - ) - - pipe.delete_adapters("adapter-1") - output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", - ) - - pipe.delete_adapters("adapter-2") - output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - pipe.set_adapters(["adapter-1", "adapter-2"]) - pipe.delete_adapters(["adapter-1", "adapter-2"]) - - output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) - - def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches - multiple adapters and set them - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - pipe.set_adapters("adapter-1") - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters(["adapter-1", "adapter-2"]) - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # Fuse and unfuse should lead to the same results - assert not( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", - ) - - assert not( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", - ) - - assert not( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", - ) - - pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) - output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not( - np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Weighted adapter and mixed adapter should give different results", - ) - - pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) - - @skip_mps - @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), - reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", - strict=False, - ) - def test_lora_fuse_nan(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - # corrupt one LoRA weight with `inf` values - with torch.no_grad(): - if self.unet_kwargs: - pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float( - "inf" - ) - else: - named_modules = [name for name, _ in pipe.transformer.named_modules()] - possible_tower_names = [ - "transformer_blocks", - "blocks", - "joint_transformer_blocks", - "single_transformer_blocks", - ] - filtered_tower_names = [ - tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) - ] - if len(filtered_tower_names) == 0: - reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}." - raise ValueError(reason) - for tower_name in filtered_tower_names: - transformer_tower = getattr(pipe.transformer, tower_name) - has_attn1 = any("attn1" in name for name in named_modules) - if has_attn1: - transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") - else: - transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") - - # with `safe_fusing=True` we should see an Error - with pytest.raises(ValueError): - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) - - # without we should not see an error, but every image will be black - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) - out = pipe(**inputs)[0] - - assert np.isnan(out).all() - - def test_get_adapters(self): - """ - Tests a simple usecase where we attach multiple adapters and check if the results - are the expected results - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - - adapter_names = pipe.get_active_adapters() - assert adapter_names == ["adapter-1"] - - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - - adapter_names = pipe.get_active_adapters() - assert adapter_names == ["adapter-2"] - - pipe.set_adapters(["adapter-1", "adapter-2"]) - assert sorted(pipe.get_active_adapters()) == ["adapter-1", "adapter-2"] - - def test_get_list_adapters(self): - """ - Tests a simple usecase where we attach multiple adapters and check if the results - are the expected results - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # 1. - dicts_to_be_checked = {} - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - dicts_to_be_checked = {"text_encoder": ["adapter-1"]} - - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - dicts_to_be_checked.update({"unet": ["adapter-1"]}) - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - dicts_to_be_checked.update({"transformer": ["adapter-1"]}) - - assert pipe.get_list_adapters() == dicts_to_be_checked - - # 2. - dicts_to_be_checked = {} - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - - assert pipe.get_list_adapters() == dicts_to_be_checked - - # 3. - pipe.set_adapters(["adapter-1", "adapter-2"]) - - dicts_to_be_checked = {} - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - - if self.unet_kwargs is not None: - dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) - else: - dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - - assert pipe.get_list_adapters() == dicts_to_be_checked - - # 4. - dicts_to_be_checked = {} - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") - dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") - dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) - - assert pipe.get_list_adapters() == dicts_to_be_checked - - def test_simple_inference_with_text_lora_denoiser_fused_multi( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 - ): - """ - Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model - and makes sure it works as expected - with unet and multi-adapter case - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - - # set them to multi-adapter inference mode - pipe.set_adapters(["adapter-1", "adapter-2"]) - outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters(["adapter-1"]) - outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) - assert pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" - - # Fusing should still keep the LoRA layers so output should remain the same - outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", - ) - - pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - assert pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" - - assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) - - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]) - assert pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" - - # Fusing should still keep the LoRA layers - output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", - ) - pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - assert pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" - - def test_lora_scale_kwargs_match_fusion(self, base_pipe_output, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): - attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - - for lora_scale in [1.0, 0.8]: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - assert ( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - assert ( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly set in text encoder 2", - ) - - pipe.set_adapters(["adapter-1"]) - attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} - outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - pipe.fuse_lora( - components=self.pipeline_class._lora_loadable_modules, - adapter_names=["adapter-1"], - lora_scale=lora_scale, - ) - assert pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}" - - outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", - ) - assert not( - np.allclose(base_pipe_output, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), - "LoRA should change the output", - ) - - def test_simple_inference_with_dora(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert output_no_dora_lora.shape == self.output_shape - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not( - np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), - "DoRA lora should change the output", - ) - - def test_missing_keys_warning(self): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - pipe.unload_lora_weights() - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) - - # To make things dynamic since we cannot settle with a single key for all the models where we - # offer PEFT support. - missing_key = [k for k in state_dict if "lora_A" in k][0] - del state_dict[missing_key] - - logger = logging.get_logger("diffusers.utils.peft_utils") - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(state_dict) - - # Since the missing key won't contain the adapter name ("default_0"). - # Also strip out the component prefix (such as "unet." from `missing_key`). - component = list({k.split(".")[0] for k in state_dict})[0] - assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", "")) - - def test_unexpected_keys_warning(self): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - pipe.unload_lora_weights() - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) - - unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" - state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) - - logger = logging.get_logger("diffusers.utils.peft_utils") - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(state_dict) - - assert ".diffusers_cat" in cap_logger.out - - @unittest.skip("This is failing for now - need to investigate") - def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): - """ - Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights - and makes sure it works as expected - """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) - - if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) - - # Just makes sure it works. - _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - def test_modify_padding_mode(self): - def set_pad_mode(network, mode="circular"): - for _, module in network.named_modules(): - if isinstance(module, torch.nn.Conv2d): - module.padding_mode = mode - - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _pad_mode = "circular" - set_pad_mode(pipe.vae, _pad_mode) - set_pad_mode(pipe.unet, _pad_mode) - - _, _, inputs = self.get_dummy_inputs() - _ = pipe(**inputs)[0] - - def test_logs_info_when_no_lora_keys_found(self, base_pipe_output): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} - logger = logging.get_logger("diffusers.loaders.peft") - logger.setLevel(logging.WARNING) - - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(no_op_state_dict) - out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0] - - denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") - assert cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}") - assert np.allclose(base_pipe_output, out_after_lora_attempt, atol=1e-5, rtol=1e-5) - - # test only for text encoder - for lora_module in self.pipeline_class._lora_loadable_modules: - if "text_encoder" in lora_module: - text_encoder = getattr(pipe, lora_module) - if lora_module == "text_encoder": - prefix = "text_encoder" - elif lora_module == "text_encoder_2": - prefix = "text_encoder_2" - - logger = logging.get_logger("diffusers.loaders.lora_base") - logger.setLevel(logging.WARNING) - - with CaptureLogger(logger) as cap_logger: - self.pipeline_class.load_lora_into_text_encoder( - no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix - ) - - assert ( - cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - ) - - def test_set_adapters_match_attention_kwargs(self, base_pipe_output): - """Test to check if outputs after `set_adapters()` and attention kwargs match.""" - attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - lora_scale = 0.5 - attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} - output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert not( - np.allclose(base_pipe_output, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - - pipe.set_adapters("default", lora_scale) - output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert ( - not np.allclose(base_pipe_output, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - assert ( - np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), - "Lora + scale should match the output of `set_adapters()`.", - ) - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - - output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert ( - not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - assert ( - np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results as attention_kwargs.", - ) - assert ( - np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results as set_adapters().", - ) - - @require_peft_version_greater("0.13.2") - def test_lora_B_bias(self): - # Currently, this test is only relevant for Flux Control LoRA as we are not - # aware of any other LoRA checkpoint that has its `lora_B` biases trained. - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # keep track of the bias values of the base layers to perform checks later. - bias_values = {} - denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer - for name, module in denoiser.named_modules(): - if any(k in name for k in self.denoiser_target_modules): - if module.bias is not None: - bias_values[name] = module.bias.data.clone() - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - denoiser_lora_config.lora_bias = False - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.delete_adapters("adapter-1") - - denoiser_lora_config.lora_bias = True - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert not np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3) - assert not np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) - assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) - - def test_correct_lora_configs_with_different_ranks(self): - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - - lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - - if self.unet_kwargs is not None: - pipe.unet.delete_adapters("adapter-1") - else: - pipe.transformer.delete_adapters("adapter-1") - - denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer - for name, _ in denoiser.named_modules(): - if "to_k" in name and "attn" in name and "lora" not in name: - module_name_to_rank_update = name.replace(".base_layer.", ".") - break - - # change the rank_pattern - updated_rank = denoiser_lora_config.r * 2 - denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} - - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern - - assert updated_rank_pattern == {module_name_to_rank_update: updated_rank} - - lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3) - assert not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3) - - if self.unet_kwargs is not None: - pipe.unet.delete_adapters("adapter-1") - else: - pipe.transformer.delete_adapters("adapter-1") - - # similarly change the alpha_pattern - updated_alpha = denoiser_lora_config.lora_alpha * 2 - denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - assert ( - pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} - ) - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - assert ( - pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} - ) - - lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3) - assert not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3) - - def test_layerwise_casting_inference_denoiser(self): - from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS - from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN - - def check_linear_dtype(module, storage_dtype, compute_dtype): - patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN - if getattr(module, "_skip_layerwise_casting_patterns", None) is not None: - patterns_to_check += tuple(module._skip_layerwise_casting_patterns) - for name, submodule in module.named_modules(): - if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): - continue - dtype_to_check = storage_dtype - if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check): - dtype_to_check = compute_dtype - if getattr(submodule, "weight", None) is not None: - self.assertEqual(submodule.weight.dtype, dtype_to_check) - if getattr(submodule, "bias", None) is not None: - self.assertEqual(submodule.bias.dtype, dtype_to_check) - - def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - if storage_dtype is not None: - denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) - check_linear_dtype(denoiser, storage_dtype, compute_dtype) - - return pipe - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe_fp32 = initialize_pipeline(storage_dtype=None) - pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] - - pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) - pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] - - pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) - pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] - - @require_peft_version_greater("0.14.0") - def test_layerwise_casting_peft_input_autocast_denoiser(self): - r""" - A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This - is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise - cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`). - In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0, - this test will fail with the following error: - - ``` - RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float - ``` - - See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. - """ - - from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS - from diffusers.hooks.layerwise_casting import ( - _PEFT_AUTOCAST_DISABLE_HOOK, - DEFAULT_SKIP_MODULES_PATTERN, - apply_layerwise_casting, - ) - - storage_dtype = torch.float8_e4m3fn - compute_dtype = torch.float32 - - def check_module(denoiser): - # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser) - for name, module in denoiser.named_modules(): - if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS): - continue - dtype_to_check = storage_dtype - if any(re.search(pattern, name) for pattern in patterns_to_check): - dtype_to_check = compute_dtype - if getattr(module, "weight", None) is not None: - self.assertEqual(module.weight.dtype, dtype_to_check) - if getattr(module, "bias", None) is not None: - self.assertEqual(module.bias.dtype, dtype_to_check) - if isinstance(module, BaseTunerLayer): - assert getattr(module, "_diffusers_hook", None is not None) - assert module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None - - # 1. Test forward with add_adapter - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN - if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None: - patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns) - - apply_layerwise_casting( - denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check - ) - check_module(denoiser) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe(**inputs, generator=torch.manual_seed(0))[0] - - # 2. Test forward with load_lora_weights - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - apply_layerwise_casting( - denoiser, - storage_dtype=storage_dtype, - compute_dtype=compute_dtype, - skip_modules_pattern=patterns_to_check, - ) - check_module(denoiser) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe(**inputs, generator=torch.manual_seed(0))[0] - - @parameterized.expand([4, 8, 16]) - def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) - pipe = self.pipeline_class(**components) - - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config - ) - - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() - - out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) - if len(out) == 3: - _, _, parsed_metadata = out - elif len(out) == 2: - _, parsed_metadata = out - - denoiser_key = ( - f"{self.pipeline_class.transformer_name}" - if self.transformer_kwargs is not None - else f"{self.pipeline_class.unet_name}" - ) - assert any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key - ) - - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - text_encoder_key = self.pipeline_class.text_encoder_name - assert any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key - ) - - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - text_encoder_2_key = "text_encoder_2" - assert any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key - ) - - @parameterized.expand([4, 8, 16]) - def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) - pipe = self.pipeline_class(**components).to(torch_device) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config - ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() - pipe.load_lora_weights(tmpdir) - - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert ( - np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." - ) - - def test_lora_unload_add_adapter(self): - """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components).to(torch_device) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config - ) - _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # unload and then add. - pipe.unload_lora_weights() - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config - ) - _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - def test_inference_load_delete_load_adapters(self, base_pipe_output): - "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config) - assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - - - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - # First, delete adapter and compare. - pipe.delete_adapters(pipe.get_active_adapters()[0]) - output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3) - assert np.allclose(base_pipe_output, output_no_adapter, atol=1e-3, rtol=1e-3) - - # Then load adapter and compare. - pipe.load_lora_weights(tmpdirname) - output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3) - - def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): - from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook - - onload_device = torch_device - offload_device = torch.device("cpu") - - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - # Test group offloading with load_lora_weights - denoiser.enable_group_offload( - onload_device=onload_device, - offload_device=offload_device, - offload_type=offload_type, - num_blocks_per_group=1, - use_stream=use_stream, - ) - # Place other model-level components on `torch_device`. - for _, component in pipe.components.items(): - if isinstance(component, torch.nn.Module): - component.to(torch_device) - group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_1 is not None - output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # Test group offloading after removing the lora - pipe.unload_lora_weights() - group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_2 is not None - output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 - - # Add the lora again and check if group offloading works - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_3 is not None - output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - assert np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3) - - @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) - @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream): - for cls in inspect.getmro(self.__class__): - if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: - # Skip this test if it is overwritten by child class. We need to do this because parameterized - # materializes the test methods on invocation which cannot be overridden. - return - self._test_group_offloading_inference_denoiser(offload_type, use_stream) - - @require_torch_accelerator - def test_lora_loading_model_cpu_offload(self): - components, _, denoiser_lora_config = self.get_dummy_components() - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - # reinitialize the pipeline to mimic the inference workflow. - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.enable_model_cpu_offload(device=torch_device) - pipe.load_lora_weights(tmpdirname) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - - output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3) From d61bb38fb4b21acc62a6b06a0367ceb30434ff45 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 13:14:05 +0530 Subject: [PATCH 03/17] up --- tests/lora/test_lora_layers_flux.py | 204 ++++++++++++++++++++++------ 1 file changed, 162 insertions(+), 42 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index ff53983ecf52..f75a7b3777c1 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import gc import os @@ -68,10 +82,10 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): "scaling_factor": 1.5035, } has_two_text_encoders = True - (tokenizer_cls, tokenizer_id) = (CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2") - (tokenizer_2_cls, tokenizer_2_id) = (AutoTokenizer, "hf-internal-testing/tiny-random-t5") - (text_encoder_cls, text_encoder_id) = (CLIPTextModel, "peft-internal-testing/tiny-clip-text-2") - (text_encoder_2_cls, text_encoder_2_id) = (T5EncoderModel, "hf-internal-testing/tiny-random-t5") + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" @property def output_shape(self): @@ -82,9 +96,11 @@ def get_dummy_inputs(self, with_generator=True): sequence_length = 10 num_channels = 4 sizes = (32, 32) + generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", "num_inference_steps": 4, @@ -95,23 +111,31 @@ def get_dummy_inputs(self, with_generator=True): } if with_generator: pipeline_inputs.update({"generator": generator}) - return (noise, input_ids, pipeline_inputs) + + return noise, input_ids, pipeline_inputs def test_with_alpha_in_state_dict(self): - (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe.transformer.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + # modify the state dict to have alpha values following + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors state_dict_with_alpha = safetensors.torch.load_file( os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") ) @@ -120,8 +144,10 @@ def test_with_alpha_in_state_dict(self): if "transformer" in k and "to_k" in k and ("lora_A" in k): alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) state_dict_with_alpha.update(alpha_dict) + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + pipe.unload_lora_weights() pipe.load_lora_weights(state_dict_with_alpha) images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images @@ -131,15 +157,19 @@ def test_with_alpha_in_state_dict(self): assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001) def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): - (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + # Modify the config to have a layer which won't be present in the second LoRA we will load. modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config.target_modules.add("x_embedder") + pipe.transformer.add_adapter(modified_denoiser_lora_config) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images assert not ( np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), @@ -148,14 +178,18 @@ def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") pipe.set_adapters(["one", "two"]) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images assert not ( np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), @@ -167,15 +201,17 @@ def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): ) def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): - (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config.target_modules.add("x_embedder") pipe.transformer.add_adapter(modified_denoiser_lora_config) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images assert not ( np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), @@ -185,13 +221,16 @@ def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") + pipe.set_adapters(["one", "two"]) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images assert not ( np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), @@ -250,10 +289,10 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): "scaling_factor": 1.5035, } has_two_text_encoders = True - (tokenizer_cls, tokenizer_id) = (CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2") - (tokenizer_2_cls, tokenizer_2_id) = (AutoTokenizer, "hf-internal-testing/tiny-random-t5") - (text_encoder_cls, text_encoder_id) = (CLIPTextModel, "peft-internal-testing/tiny-clip-text-2") - (text_encoder_2_cls, text_encoder_2_id) = (T5EncoderModel, "hf-internal-testing/tiny-random-t5") + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" @property def output_shape(self): @@ -264,9 +303,11 @@ def get_dummy_inputs(self, with_generator=True): sequence_length = 10 num_channels = 4 sizes = (32, 32) + generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + np.random.seed(0) pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", @@ -279,17 +320,22 @@ def get_dummy_inputs(self, with_generator=True): } if with_generator: pipeline_inputs.update({"generator": generator}) - return (noise, input_ids, pipeline_inputs) + + return noise, input_ids, pipeline_inputs def test_with_norm_in_state_dict(self): - (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.INFO) + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: norm_state_dict = {} for name, module in pipe.transformer.named_modules(): @@ -298,14 +344,17 @@ def test_with_norm_in_state_dict(self): norm_state_dict[f"transformer.{name}.weight"] = torch.randn( module.weight.shape, device=module.weight.device, dtype=module.weight.dtype ) + with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(norm_state_dict) lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( "The provided state dict contains normalization layers in addition to LoRA layers" in cap_logger.out ) assert len(pipe.transformer._transformer_norm_layers) > 0 + pipe.unload_lora_weights() lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] assert pipe.transformer._transformer_norm_layers is None @@ -314,6 +363,7 @@ def test_with_norm_in_state_dict(self): np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), f"{norm_layer} is tested", ) + with CaptureLogger(logger) as cap_logger: for key in list(norm_state_dict.keys()): norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) @@ -321,14 +371,17 @@ def test_with_norm_in_state_dict(self): assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out def test_lora_parameter_expanded_shapes(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control @@ -336,15 +389,17 @@ def test_lora_parameter_expanded_shapes(self): assert transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) + original_transformer_state_dict = pipe.transformer.state_dict() x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) assert "x_embedder.weight" in incompatible_keys.missing_keys, ( "Could not find x_embedder.weight in the missing keys." ) + transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) pipe.transformer = transformer - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) @@ -355,12 +410,15 @@ def test_lora_parameter_expanded_shapes(self): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features assert pipe.transformer.config.in_channels == 2 * in_features assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Testing opposite direction where the LoRA params are zero-padded. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -373,6 +431,7 @@ def test_lora_parameter_expanded_shapes(self): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features @@ -380,19 +439,27 @@ def test_lora_parameter_expanded_shapes(self): assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out def test_normal_lora_with_expanded_lora_raises_error(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then + # load shape expanded LoRA (such as Control LoRA). + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer + pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -416,23 +483,32 @@ def test_normal_lora_with_expanded_lora_raises_error(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-2") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out assert pipe.get_active_adapters() == ["adapter-2"] + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001) - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. + # This should raise a runtime error on input shapes being incompatible. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer + pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, @@ -442,27 +518,40 @@ def test_normal_lora_with_expanded_lora_raises_error(self): assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features assert pipe.transformer.config.in_channels == in_features + lora_state_dict = { "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, } + # We should check for input shapes being incompatible here. But because above mentioned issue is + # not a supported use case, and because of the PEFT renaming, we will currently have a shape + # mismatch error. with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"): pipe.load_lora_weights(lora_state_dict, "adapter-2") def test_fuse_expanded_lora_with_regular_lora(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer + pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -471,34 +560,42 @@ def test_fuse_expanded_lora_with_regular_lora(self): } pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } + pipe.load_lora_weights(lora_state_dict, "adapter-2") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001) assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001) assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001) + pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001) - def test_load_regular_lora(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + def test_load_regular_lora(self, base_pipe_output): + # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded + # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those + # transformers include Flux Fill, Flux Control, etc. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 in_features = in_features // 2 normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) @@ -512,15 +609,19 @@ def test_load_regular_lora(self): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 - assert not np.allclose(original_output, lora_output, atol=0.001, rtol=0.001) + assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001) def test_lora_unload_with_parameter_expanded_shapes(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control @@ -528,16 +629,21 @@ def test_lora_unload_with_parameter_expanded_shapes(self): assert transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. components["transformer"] = transformer pipe = FluxPipeline(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) control_image = inputs.pop("control_image") original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + control_pipe = self.pipeline_class(**components) - (out_features, in_features) = control_pipe.transformer.x_embedder.weight.shape + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape rank = 4 + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -547,20 +653,24 @@ def test_lora_unload_with_parameter_expanded_shapes(self): with CaptureLogger(logger) as cap_logger: control_pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + inputs["control_image"] = control_image lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features assert pipe.transformer.config.in_channels == 2 * in_features assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + control_pipe.unload_lora_weights(reset_to_overwritten_params=True) assert control_pipe.transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}" ) + loaded_pipe = FluxPipeline.from_pipe(control_pipe) assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}" ) + inputs.pop("control_image") unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001) @@ -569,9 +679,11 @@ def test_lora_unload_with_parameter_expanded_shapes(self): assert pipe.transformer.config.in_channels == in_features def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) + num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control @@ -579,16 +691,21 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): assert transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. components["transformer"] = transformer pipe = FluxPipeline(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) control_image = inputs.pop("control_image") original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + control_pipe = self.pipeline_class(**components) - (out_features, in_features) = control_pipe.transformer.x_embedder.weight.shape + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape rank = 4 + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -598,16 +715,19 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): with CaptureLogger(logger) as cap_logger: control_pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + inputs["control_image"] = control_image lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features assert pipe.transformer.config.in_channels == 2 * in_features assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + control_pipe.unload_lora_weights(reset_to_overwritten_params=False) assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}" ) + no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 From 7b4bcce602b1edc019ae42388025c7aad41ac21c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 14:10:31 +0530 Subject: [PATCH 04/17] up --- tests/lora/utils.py | 543 +++++++++++++++++++++++++++++++++----------- 1 file changed, 406 insertions(+), 137 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 2fe80b4c1bb2..0b9e1e015296 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,8 +1,21 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect import os import re import tempfile -import unittest from itertools import product import numpy as np @@ -37,10 +50,12 @@ def state_dicts_almost_equal(sd1, sd2): sd1 = dict(sorted(sd1.items())) sd2 = dict(sorted(sd2.items())) + models_are_equal = True for ten1, ten2 in zip(sd1.values(), sd2.values()): - if (ten1 - ten2).abs().max() > 0.001: + if (ten1 - ten2).abs().max() > 1e-3: models_are_equal = False + return models_are_equal @@ -56,15 +71,15 @@ def check_if_lora_correctly_set(model) -> bool: def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str): extracted = { - k.removeprefix(f"{module_key}."): v for (k, v) in parsed_metadata.items() if k.startswith(f"{module_key}.") + k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.") } check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"]) def initialize_dummy_state_dict(state_dict): - if not all((v.device.type == "meta" for (_, v) in state_dict.items())): + if not all((v.device.type == "meta" for _, v in state_dict.items())): raise ValueError("`state_dict` has non-meta values.") - return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for (k, v) in state_dict.items()} + return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()} POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] @@ -72,6 +87,8 @@ def initialize_dummy_state_dict(state_dict): def determine_attention_kwargs_name(pipeline_class): call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys() + + # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: if possible_attention_kwargs in call_signature_keys: attention_kwargs_name = possible_attention_kwargs @@ -83,21 +100,25 @@ def determine_attention_kwargs_name(pipeline_class): @require_peft_backend class PeftLoraLoaderMixinTests: pipeline_class = None + scheduler_cls = None scheduler_kwargs = None + has_two_text_encoders = False has_three_text_encoders = False - (text_encoder_cls, text_encoder_id, text_encoder_subfolder) = (None, None, "") - (text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder) = (None, None, "") - (text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder) = (None, None, "") - (tokenizer_cls, tokenizer_id, tokenizer_subfolder) = (None, None, "") - (tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder) = (None, None, "") - (tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder) = (None, None, "") + text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, "" + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, "" + text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, "" + tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" + tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" + unet_kwargs = None transformer_cls = None transformer_kwargs = None vae_cls = AutoencoderKL vae_kwargs = None + text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] @@ -110,21 +131,26 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") if self.has_two_text_encoders and self.has_three_text_encoders: raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") + scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls rank = 4 lora_alpha = rank if lora_alpha is None else lora_alpha + torch.manual_seed(0) if self.unet_kwargs is not None: unet = UNet2DConditionModel(**self.unet_kwargs) else: transformer = self.transformer_cls(**self.transformer_kwargs) + scheduler = scheduler_cls(**self.scheduler_kwargs) + torch.manual_seed(0) vae = self.vae_cls(**self.vae_kwargs) text_encoder = self.text_encoder_cls.from_pretrained( self.text_encoder_id, subfolder=self.text_encoder_subfolder ) tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) + if self.text_encoder_2_cls is not None: text_encoder_2 = self.text_encoder_2_cls.from_pretrained( self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder @@ -132,6 +158,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No tokenizer_2 = self.tokenizer_2_cls.from_pretrained( self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder ) + if self.text_encoder_3_cls is not None: text_encoder_3 = self.text_encoder_3_cls.from_pretrained( self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder @@ -139,6 +166,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No tokenizer_3 = self.tokenizer_3_cls.from_pretrained( self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder ) + text_lora_config = LoraConfig( r=rank, lora_alpha=lora_alpha, @@ -146,6 +174,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No init_lora_weights=False, use_dora=use_dora, ) + denoiser_lora_config = LoraConfig( r=rank, lora_alpha=lora_alpha, @@ -159,14 +188,19 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No "text_encoder": text_encoder, "tokenizer": tokenizer, } + # Denoiser if self.unet_kwargs is not None: pipeline_components.update({"unet": unet}) elif self.transformer_kwargs is not None: pipeline_components.update({"transformer": transformer}) + + # Remaining text encoders. if self.text_encoder_2_cls is not None: pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2}) if self.text_encoder_3_cls is not None: pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3}) + + # Remaining stuff init_params = inspect.signature(self.pipeline_class.__init__).parameters if "safety_checker" in init_params: pipeline_components.update({"safety_checker": None}) @@ -174,7 +208,8 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No pipeline_components.update({"feature_extractor": None}) if "image_encoder" in init_params: pipeline_components.update({"image_encoder": None}) - return (pipeline_components, text_lora_config, denoiser_lora_config) + + return pipeline_components, text_lora_config, denoiser_lora_config @property def output_shape(self): @@ -185,6 +220,7 @@ def get_dummy_inputs(self, with_generator=True): sequence_length = 10 num_channels = 4 sizes = (32, 32) + generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) @@ -196,14 +232,18 @@ def get_dummy_inputs(self, with_generator=True): } if with_generator: pipeline_inputs.update({"generator": generator}) + return (noise, input_ids, pipeline_inputs) def _compute_baseline_output(self): - (components, _, _) = self.get_dummy_components(self.scheduler_cls) + components, _, _ = self.get_dummy_components(self.scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + # Always ensure the inputs are without the `generator`. Make sure to pass the `generator` + # explicitly. + _, _, inputs = self.get_dummy_inputs(with_generator=False) return pipe(**inputs, generator=torch.manual_seed(0))[0] def _get_lora_state_dicts(self, modules_to_save): @@ -223,23 +263,27 @@ def _get_lora_adapter_metadata(self, modules_to_save): def _get_modules_to_save(self, pipe, has_denoiser=False): modules_to_save = {} lora_loadable_modules = self.pipeline_class._lora_loadable_modules + if ( "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder") and (getattr(pipe.text_encoder, "peft_config", None) is not None) ): modules_to_save["text_encoder"] = pipe.text_encoder + if ( "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2") - and (getattr(pipe.text_encoder_2, "peft_config", None) is not None) + and getattr(pipe.text_encoder_2, "peft_config", None) is not None ): modules_to_save["text_encoder_2"] = pipe.text_encoder_2 + if has_denoiser: if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): modules_to_save["unet"] = pipe.unet if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): modules_to_save["transformer"] = pipe.transformer + return modules_to_save def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): @@ -257,7 +301,7 @@ def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_co if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - return (pipe, denoiser) + return pipe, denoiser def test_simple_inference(self, base_pipe_output): """ @@ -270,40 +314,47 @@ def test_simple_inference_with_text_lora(self, base_pipe_output): Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ - (components, text_lora_config, _) = self.get_dummy_components() + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" @require_peft_version_greater("0.13.1") def test_low_cpu_mem_usage_with_injection(self): """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." assert "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, ( "The LoRA params should be on 'meta' device." ) + te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) assert "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, ( "No param should be on 'meta' device." ) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." assert "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." + denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) assert "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." + if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) @@ -311,12 +362,14 @@ def test_low_cpu_mem_usage_with_injection(self): assert "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, ( "The LoRA params should be on 'meta' device." ) + te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) assert "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, ( "No param should be on 'meta' device." ) - (_, _, inputs) = self.get_dummy_inputs() + + _, _, inputs = self.get_dummy_inputs() output_lora = pipe(**inputs)[0] assert output_lora.shape == self.output_shape @@ -324,32 +377,39 @@ def test_low_cpu_mem_usage_with_injection(self): @require_transformers_version_greater("4.45.2") def test_low_cpu_mem_usage_with_loading(self): """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( "Loading from saved checkpoints should give same results." ) + pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose( images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001 @@ -361,19 +421,23 @@ def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): and makes sure it works as expected """ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - (components, text_lora_config, _) = self.get_dummy_components() + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] assert not np.allclose(output_lora, output_lora_scale, atol=0.001, rtol=0.001), ( "Lora + scale should change the output" ) + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] assert np.allclose(base_pipe_output, output_lora_0_scale, atol=0.001, rtol=0.001), ( @@ -385,17 +449,21 @@ def test_simple_inference_with_text_lora_fused(self, base_pipe_output): Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ - (components, text_lora_config, _) = self.get_dummy_components() + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + pipe.fuse_lora() assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( np.allclose(ouput_fused, base_pipe_output, atol=0.001, rtol=0.001), @@ -407,20 +475,24 @@ def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): Tests a simple inference with lora attached to text encoder, then unloads the lora weights and makes sure it works as expected """ - (components, text_lora_config, _) = self.get_dummy_components() + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + pipe.unload_lora_weights() assert not (check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") + if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: assert not ( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2", ) + ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(ouput_unloaded, base_pipe_output, atol=0.001, rtol=0.001), ( "Fused lora should change the output" @@ -430,13 +502,15 @@ def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. """ - (components, text_lora_config, _) = self.get_dummy_components() + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -444,10 +518,13 @@ def test_simple_inference_with_text_lora_save_load(self): save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( "Loading from saved checkpoints should give same results." @@ -459,7 +536,7 @@ def test_simple_inference_with_partial_text_lora(self, base_pipe_output): with different ranks and some adapters removed and makes sure it works as expected """ - (components, _, _) = self.get_dummy_components() + components, _, _ = self.get_dummy_components() text_lora_config = LoraConfig( r=4, rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, @@ -471,8 +548,9 @@ def test_simple_inference_with_partial_text_lora(self, base_pipe_output): pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) state_dict = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: state_dict = { @@ -491,6 +569,7 @@ def test_simple_inference_with_partial_text_lora(self, base_pipe_output): ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" + pipe.unload_lora_weights() pipe.load_lora_weights(state_dict) output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -502,17 +581,20 @@ def test_simple_inference_save_pretrained_with_text_lora(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ - (components, text_lora_config, _) = self.get_dummy_components() + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdirname: pipe.save_pretrained(tmpdirname) pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) pipe_from_pretrained.to(torch_device) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), ( "Lora not correctly set in text encoder" @@ -522,6 +604,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self): assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), ( "Lora not correctly set in text encoder 2" ) + images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(images_lora, images_lora_save_pretrained, atol=0.001, rtol=0.001), ( "Loading from saved checkpoints should give same results." @@ -531,13 +614,15 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -547,8 +632,10 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( "Loading from saved checkpoints should give same results." @@ -560,19 +647,23 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_outp and makes sure it works as expected """ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] assert not np.allclose(output_lora, output_lora_scale, atol=0.001, rtol=0.001), ( "Lora + scale should change the output" ) + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] assert np.allclose(base_pipe_output, output_lora_0_scale, atol=0.001, rtol=0.001), ( @@ -588,19 +679,23 @@ def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output): Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, denoiser) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( np.allclose(output_fused, base_pipe_output, atol=0.001, rtol=0.001), @@ -612,21 +707,25 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_outpu Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, denoiser) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + pipe.unload_lora_weights() assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" assert not check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser" + if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: assert not ( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2", ) + output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_unloaded, base_pipe_output, atol=0.001, rtol=0.001), ( "Fused lora should change the output" @@ -639,21 +738,25 @@ def test_simple_inference_with_text_denoiser_lora_unfused( Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, denoiser) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) assert pipe.num_fused_loras == 1, ( f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" ) + output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) assert pipe.num_fused_loras == 0, ( f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" ) + output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] if "text_encoder" in self.pipeline_class._lora_loadable_modules: assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" @@ -670,42 +773,50 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_outpu Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( np.allclose(base_pipe_output, output_adapter_1, atol=0.001, rtol=0.001), "Adapter outputs should be different.", ) + pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( np.allclose(base_pipe_output, output_adapter_2, atol=0.001, rtol=0.001), "Adapter outputs should be different.", ) + pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( np.allclose(base_pipe_output, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter outputs should be different.", ) + assert not ( np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), "Adapter 1 and 2 should give different results", @@ -718,6 +829,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_outpu np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 2 and mixed adapters should give different results", ) + pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( @@ -726,30 +838,36 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_outpu def test_wrong_adapter_name_raises_error(self): adapter_name = "adapter-1" - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline( + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) with pytest.raises(ValueError) as err_context: pipe.set_adapters("test") assert "not in the list of present adapters" in str(err_context.value) + pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_multiple_wrong_adapter_name_raises_error(self): adapter_name = "adapter-1" - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline( + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) + scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} logger = logging.get_logger("diffusers.loaders.lora_base") logger.setLevel(30) @@ -766,20 +884,25 @@ def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output) Tests a simple inference with lora attached to text encoder and unet, attaches one adapter and set different weights for different blocks (i.e. block lora) """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + weights_1 = {"text_encoder": 2, "unet": {"down": 5}} pipe.set_adapters("adapter-1", weights_1) output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -798,6 +921,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output) np.allclose(base_pipe_output, output_weights_2, atol=0.001, rtol=0.001), "No adapter and LoRA weights 2 should give different results", ) + pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( @@ -809,24 +933,28 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set different weights for different blocks (i.e. block lora) """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + scales_1 = {"text_encoder": 2, "unet": {"down": 5}} scales_2 = {"unet": {"down": 5, "mid": 5}} pipe.set_adapters("adapter-1", scales_1) @@ -847,6 +975,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 2 and mixed adapters should give different results", ) + pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( @@ -912,53 +1041,65 @@ def all_possible_dict_opts(unet, value): opts.append(opt) return opts - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components(self.scheduler_cls) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") + if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): + # test if lora block scales can be set with this scale_dict if not self.has_two_text_encoders and "text_encoder_2" in scale_dict: del scale_dict["text_encoder_2"] - pipe.set_adapters("adapter-1", scale_dict) + pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set/delete them """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not ( np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), "Adapter 1 and 2 should give different results", @@ -971,23 +1112,28 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 2 and mixed adapters should give different results", ) + pipe.delete_adapters("adapter-1") output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_deleted_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), ( "Adapter 1 and 2 should give different results" ) + pipe.delete_adapters("adapter-2") output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(base_pipe_output, output_deleted_adapters, atol=0.001, rtol=0.001), ( "output with no lora and output with lora disabled should give same results" ) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"]) output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1000,29 +1146,36 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_p Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( @@ -1037,12 +1190,14 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_p np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), "Adapter 2 and mixed adapters should give different results", ) + pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=0.001, rtol=0.001), "Weighted adapter and mixed adapter should give different results", ) + pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( @@ -1056,17 +1211,21 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_p strict=False, ) def test_lora_fuse_nan(self): - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + with torch.no_grad(): if self.unet_kwargs: pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float( @@ -1095,6 +1254,7 @@ def test_lora_fuse_nan(self): transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") with pytest.raises(ValueError): pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) out = pipe(**inputs)[0] assert np.isnan(out).all() @@ -1104,20 +1264,26 @@ def test_get_adapters(self): Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") + adapter_names = pipe.get_active_adapters() assert adapter_names == ["adapter-1"] + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2") + adapter_names = pipe.get_active_adapters() assert adapter_names == ["adapter-2"] + pipe.set_adapters(["adapter-1", "adapter-2"]) assert sorted(pipe.get_active_adapters()) == ["adapter-1", "adapter-2"] @@ -1126,10 +1292,12 @@ def test_get_list_adapters(self): Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + + # 1. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1140,7 +1308,10 @@ def test_get_list_adapters(self): else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") dicts_to_be_checked.update({"transformer": ["adapter-1"]}) + assert pipe.get_list_adapters() == dicts_to_be_checked + + # 2. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") @@ -1152,6 +1323,8 @@ def test_get_list_adapters(self): pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) assert pipe.get_list_adapters() == dicts_to_be_checked + + # 3. pipe.set_adapters(["adapter-1", "adapter-2"]) dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1161,6 +1334,8 @@ def test_get_list_adapters(self): else: dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) assert pipe.get_list_adapters() == dicts_to_be_checked + + # 4. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} @@ -1179,41 +1354,52 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet and multi-adapter case """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." denoiser.add_adapter(denoiser_lora_config, "adapter-2") + if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + + # set them to multi-adapter inference mode pipe.set_adapters(["adapter-1", "adapter-2"]) outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.set_adapters(["adapter-1"]) outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) assert pipe.num_fused_loras == 1, ( f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" ) + + # Fusing should still keep the LoRA layers so output should remain the same outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), ( "Fused lora should not change the output" ) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) assert pipe.num_fused_loras == 0, ( f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" ) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" @@ -1224,10 +1410,12 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( assert pipe.num_fused_loras == 2, ( f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" ) + output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), ( "Fused lora should not change the output" ) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) assert pipe.num_fused_loras == 0, ( f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" @@ -1238,25 +1426,30 @@ def test_lora_scale_kwargs_match_fusion( ): attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) for lora_scale in [1.0, 0.8]: - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + pipe.set_adapters(["adapter-1"]) attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + pipe.fuse_lora( components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"], @@ -1265,6 +1458,7 @@ def test_lora_scale_kwargs_match_fusion( assert pipe.num_fused_loras == 1, ( f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" ) + outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), ( "Fused lora should not change the output" @@ -1275,14 +1469,17 @@ def test_lora_scale_kwargs_match_fusion( ) def test_simple_inference_with_dora(self): - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components(use_dora=True) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] assert output_no_dora_lora.shape == self.output_shape - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( np.allclose(output_dora_lora, output_no_dora_lora, atol=0.001, rtol=0.001), @@ -1290,13 +1487,15 @@ def test_simple_inference_with_dora(self): ) def test_missing_keys_warning(self): - (components, _, denoiser_lora_config) = self.get_dummy_components() + components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -1306,23 +1505,27 @@ def test_missing_keys_warning(self): pipe.unload_lora_weights() assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + missing_key = [k for k in state_dict if "lora_A" in k][0] del state_dict[missing_key] logger = logging.get_logger("diffusers.utils.peft_utils") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) + component = list({k.split(".")[0] for k in state_dict})[0] assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", "")) def test_unexpected_keys_warning(self): - (components, _, denoiser_lora_config) = self.get_dummy_components() + components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -1332,6 +1535,7 @@ def test_unexpected_keys_warning(self): pipe.unload_lora_weights() assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) logger = logging.get_logger("diffusers.utils.peft_utils") @@ -1340,20 +1544,23 @@ def test_unexpected_keys_warning(self): pipe.load_lora_weights(state_dict) assert ".diffusers_cat" in cap_logger.out - @unittest.skip("This is failing for now - need to investigate") + @pytest.mark.skip("This is failing for now - need to investigate") def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1364,31 +1571,36 @@ def set_pad_mode(network, mode="circular"): if isinstance(module, torch.nn.Conv2d): module.padding_mode = mode - (components, _, _) = self.get_dummy_components() + components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + _pad_mode = "circular" set_pad_mode(pipe.vae, _pad_mode) set_pad_mode(pipe.unet, _pad_mode) - (_, _, inputs) = self.get_dummy_inputs() + _, _, inputs = self.get_dummy_inputs() _ = pipe(**inputs)[0] def test_logs_info_when_no_lora_keys_found(self, base_pipe_output): - (components, _, _) = self.get_dummy_components() + components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} logger = logging.get_logger("diffusers.loaders.peft") logger.setLevel(logging.WARNING) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(no_op_state_dict) + out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0] denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") assert cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}") assert np.allclose(base_pipe_output, out_after_lora_attempt, atol=1e-05, rtol=1e-05) + for lora_module in self.pipeline_class._lora_loadable_modules: if "text_encoder" in lora_module: text_encoder = getattr(pipe, lora_module) @@ -1407,12 +1619,15 @@ def test_logs_info_when_no_lora_keys_found(self, base_pipe_output): def test_set_adapters_match_attention_kwargs(self, base_pipe_output): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + lora_scale = 0.5 attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] @@ -1420,6 +1635,7 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output): np.allclose(base_pipe_output, output_lora_scale, atol=0.001, rtol=0.001), "Lora + scale should change the output", ) + pipe.set_adapters("default", lora_scale) output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(base_pipe_output, output_lora_scale_wo_kwargs, atol=0.001, rtol=0.001), ( @@ -1428,6 +1644,7 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output): assert np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=0.001, rtol=0.001), ( "Lora + scale should match the output of `set_adapters()`." ) + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -1435,12 +1652,14 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output): save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( "Lora + scale should change the output" @@ -1453,25 +1672,28 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output): ) @require_peft_version_greater("0.13.2") - def test_lora_B_bias(self): - (components, _, denoiser_lora_config) = self.get_dummy_components() + def test_lora_B_bias(self, base_pipe_output): + components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + bias_values = {} denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer for name, module in denoiser.named_modules(): if any((k in name for k in self.denoiser_target_modules)): if module.bias is not None: bias_values[name] = module.bias.data.clone() - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + denoiser_lora_config.lora_bias = False if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.delete_adapters("adapter-1") denoiser_lora_config.lora_bias = True if self.unet_kwargs is not None: @@ -1479,22 +1701,25 @@ def test_lora_B_bias(self): else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(original_output, lora_bias_false_output, atol=0.001, rtol=0.001) - assert not np.allclose(original_output, lora_bias_true_output, atol=0.001, rtol=0.001) + + assert not np.allclose(base_pipe_output, lora_bias_false_output, atol=0.001, rtol=0.001) + assert not np.allclose(base_pipe_output, lora_bias_true_output, atol=0.001, rtol=0.001) assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=0.001, rtol=0.001) - def test_correct_lora_configs_with_different_ranks(self): - (components, _, denoiser_lora_config) = self.get_dummy_components() + def test_correct_lora_configs_with_different_ranks(self, base_pipe_output): + components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + if self.unet_kwargs is not None: pipe.unet.delete_adapters("adapter-1") else: @@ -1506,6 +1731,7 @@ def test_correct_lora_configs_with_different_ranks(self): break updated_rank = denoiser_lora_config.r * 2 denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern @@ -1513,15 +1739,19 @@ def test_correct_lora_configs_with_different_ranks(self): pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern assert updated_rank_pattern == {module_name_to_rank_update: updated_rank} + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(original_output, lora_output_same_rank, atol=0.001, rtol=0.001) + assert not np.allclose(base_pipe_output, lora_output_same_rank, atol=0.001, rtol=0.001) assert not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=0.001, rtol=0.001) + if self.unet_kwargs is not None: pipe.unet.delete_adapters("adapter-1") else: pipe.transformer.delete_adapters("adapter-1") + updated_alpha = denoiser_lora_config.lora_alpha * 2 denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") assert pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} @@ -1531,7 +1761,7 @@ def test_correct_lora_configs_with_different_ranks(self): module_name_to_rank_update: updated_alpha } lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(original_output, lora_output_diff_alpha, atol=0.001, rtol=0.001) + assert not np.allclose(base_pipe_output, lora_output_diff_alpha, atol=0.001, rtol=0.001) assert not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=0.001, rtol=0.001) def test_layerwise_casting_inference_denoiser(self): @@ -1554,17 +1784,20 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): self.assertEqual(submodule.bias.dtype, dtype_to_check) def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) - (pipe, denoiser) = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + if storage_dtype is not None: denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) check_linear_dtype(denoiser, storage_dtype, compute_dtype) return pipe - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe_fp32 = initialize_pipeline(storage_dtype=None) pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) @@ -1612,13 +1845,15 @@ def check_module(denoiser): assert getattr(module, "_diffusers_hook", None is not None) assert module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None - (components, _, denoiser_lora_config) = self.get_dummy_components() + components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None: patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns) @@ -1626,8 +1861,10 @@ def check_module(denoiser): denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check ) check_module(denoiser) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -1635,7 +1872,8 @@ def check_module(denoiser): save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - (components, _, _) = self.get_dummy_components() + + components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) @@ -1648,16 +1886,18 @@ def check_module(denoiser): skip_modules_pattern=patterns_to_check, ) check_module(denoiser) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] @parameterized.expand([4, 8, 16]) def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components(lora_alpha=lora_alpha) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components) - (pipe, _) = self.add_adapters_to_pipeline( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) + with tempfile.TemporaryDirectory() as tmpdir: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -1675,15 +1915,18 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): else f"{self.pipeline_class.unet_name}" ) assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) + check_module_lora_metadata( parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key ) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: text_encoder_key = self.pipeline_class.text_encoder_name assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) check_module_lora_metadata( parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: text_encoder_2_key = "text_encoder_2" assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) @@ -1693,17 +1936,21 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): @parameterized.expand([4, 8, 16]) def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components(lora_alpha=lora_alpha) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components).to(torch_device) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline( + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdir: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) pipe.unload_lora_weights() pipe.load_lora_weights(tmpdir) @@ -1714,47 +1961,58 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): def test_lora_unload_add_adapter(self): """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - (pipe, _) = self.add_adapters_to_pipeline( + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.unload_lora_weights() - (pipe, _) = self.add_adapters_to_pipeline( + + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_inference_load_delete_load_adapters(self, base_pipe_output): """Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works.""" - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config) assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config) assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.delete_adapters(pipe.get_active_adapters()[0]) output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001) assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001) + pipe.load_lora_weights(tmpdirname) output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001) @@ -1764,13 +2022,15 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): onload_device = torch_device offload_device = torch.device("cpu") - (components, text_lora_config, denoiser_lora_config) = self.get_dummy_components() + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -1778,13 +2038,14 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - (components, _, _) = self.get_dummy_components() + + components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) check_if_lora_correctly_set(denoiser) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) denoiser.enable_group_offload( onload_device=onload_device, offload_device=offload_device, @@ -1795,17 +2056,21 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): for _, component in pipe.components.items(): if isinstance(component, torch.nn.Module): component.to(torch_device) + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) assert group_offload_hook_1 is not None + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.unload_lora_weights() group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) assert group_offload_hook_2 is not None + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) check_if_lora_correctly_set(denoiser) group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) assert group_offload_hook_3 is not None + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001) @@ -1819,26 +2084,30 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream): @require_torch_accelerator def test_lora_loading_model_cpu_offload(self): - (components, _, denoiser_lora_config) = self.get_dummy_components() - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + components, _, denoiser_lora_config = self.get_dummy_components() + _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) - (components, _, denoiser_lora_config) = self.get_dummy_components() + components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.enable_model_cpu_offload(device=torch_device) pipe.load_lora_weights(tmpdirname) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_lora, output_lora_loaded, atol=0.001, rtol=0.001) From ec866f5de82c3ffafdbdb1bb1e861f5326ddb0a8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 14:25:54 +0530 Subject: [PATCH 05/17] tempfile is now a fixture. --- tests/lora/test_lora_layers_cogview4.py | 10 +- tests/lora/test_lora_layers_flux.py | 72 ++-- tests/lora/test_lora_layers_wanvace.py | 53 ++- tests/lora/utils.py | 468 +++++++++++------------- 4 files changed, 281 insertions(+), 322 deletions(-) diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index de732b85268b..3a39c44a37a6 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -13,7 +13,6 @@ # limitations under the License. import sys -import tempfile import unittest import numpy as np @@ -119,7 +118,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - def test_simple_inference_save_pretrained(self): + def test_simple_inference_save_pretrained(self, tmpdirname): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ @@ -131,11 +130,10 @@ def test_simple_inference_save_pretrained(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) + pipe.save_pretrained(tmpdirname) - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(torch_device) images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index f75a7b3777c1..7c230308ae45 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -16,7 +16,6 @@ import gc import os import sys -import tempfile import unittest import numpy as np @@ -114,7 +113,7 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_with_alpha_in_state_dict(self): + def test_with_alpha_in_state_dict(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -126,24 +125,23 @@ def test_with_alpha_in_state_dict(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - # modify the state dict to have alpha values following - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors - state_dict_with_alpha = safetensors.torch.load_file( - os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") - ) - alpha_dict = {} - for k, v in state_dict_with_alpha.items(): - if "transformer" in k and "to_k" in k and ("lora_A" in k): - alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) - state_dict_with_alpha.update(alpha_dict) + # modify the state dict to have alpha values following + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors + state_dict_with_alpha = safetensors.torch.load_file( + os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") + ) + alpha_dict = {} + for k, v in state_dict_with_alpha.items(): + if "transformer" in k and "to_k" in k and ("lora_A" in k): + alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) + state_dict_with_alpha.update(alpha_dict) images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" @@ -156,7 +154,7 @@ def test_with_alpha_in_state_dict(self): ) assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001) - def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): + def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -175,16 +173,15 @@ def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), "LoRA should lead to different results.", ) - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") pipe.set_adapters(["one", "two"]) @@ -200,7 +197,7 @@ def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): "LoRA should lead to different results.", ) - def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): + def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -217,16 +214,15 @@ def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), "LoRA should lead to different results.", ) - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - pipe.unload_lora_weights() - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} - pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + pipe.unload_lora_weights() + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") pipe.set_adapters(["one", "two"]) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index 60246ad2bcc7..9c319b995277 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -14,7 +14,6 @@ import os import sys -import tempfile import unittest import numpy as np @@ -163,7 +162,7 @@ def test_layerwise_casting_inference_denoiser(self): super().test_layerwise_casting_inference_denoiser() @require_peft_version_greater("0.13.2") - def test_lora_exclude_modules_wanvace(self, base_pipe_output): + def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname): exclude_module_name = "vace_blocks.0.proj_out" components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) @@ -183,30 +182,26 @@ def test_lora_exclude_modules_wanvace(self, base_pipe_output): assert any("proj_out" in k for k in state_dict_from_model) output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts) - pipe.unload_lora_weights() - - # Check in the loaded state dict. - loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - assert not any(exclude_module_name in k for k in loaded_state_dict) - assert any("proj_out" in k for k in loaded_state_dict) - - # Check in the state dict obtained after loading LoRA. - pipe.load_lora_weights(tmpdir) - state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0") - assert not any(exclude_module_name in k for k in state_dict_from_model) - assert any("proj_out" in k for k in state_dict_from_model) - - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), ( - "LoRA should change outputs." - ) - assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), ( - "Lora outputs should match." - ) - - def test_simple_inference_with_text_denoiser_lora_and_scale(self): - super().test_simple_inference_with_text_denoiser_lora_and_scale() + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) + pipe.unload_lora_weights() + + # Check in the loaded state dict. + loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + assert not any(exclude_module_name in k for k in loaded_state_dict) + assert any("proj_out" in k for k in loaded_state_dict) + + # Check in the state dict obtained after loading LoRA. + pipe.load_lora_weights(tmpdirname) + state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0") + assert not any(exclude_module_name in k for k in state_dict_from_model) + assert any("proj_out" in k for k in state_dict_from_model) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), ( + "LoRA should change outputs." + ) + assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), ( + "Lora outputs should match." + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 0b9e1e015296..bc879f769195 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -15,7 +15,6 @@ import inspect import os import re -import tempfile from itertools import product import numpy as np @@ -122,10 +121,18 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + @property + def output_shape(self): + raise NotImplementedError + @pytest.fixture(scope="class") def base_pipe_output(self): return self._compute_baseline_output() + @pytest.fixture(scope="function") + def tmpdirname(self, tmp_path_factory): + return tmp_path_factory.mktemp("tmp") + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") @@ -211,10 +218,6 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No return pipeline_components, text_lora_config, denoiser_lora_config - @property - def output_shape(self): - raise NotImplementedError - def get_dummy_inputs(self, with_generator=True): batch_size = 1 sequence_length = 10 @@ -235,6 +238,23 @@ def get_dummy_inputs(self, with_generator=True): return (noise, input_ids, pipeline_inputs) + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): + if text_lora_config is not None: + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + if denoiser_lora_config is not None: + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + else: + denoiser = None + if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + return pipe, denoiser + def _compute_baseline_output(self): components, _, _ = self.get_dummy_components(self.scheduler_cls) pipe = self.pipeline_class(**components) @@ -286,23 +306,6 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): return modules_to_save - def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): - if text_lora_config is not None: - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - if denoiser_lora_config is not None: - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - else: - denoiser = None - if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - return pipe, denoiser - def test_simple_inference(self, base_pipe_output): """ Tests a simple inference and makes sure it works as expected @@ -375,7 +378,7 @@ def test_low_cpu_mem_usage_with_injection(self): @require_peft_version_greater("0.13.1") @require_transformers_version_greater("4.45.2") - def test_low_cpu_mem_usage_with_loading(self): + def test_low_cpu_mem_usage_with_loading(self, tmpdirname): """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -386,34 +389,31 @@ def test_low_cpu_mem_usage_with_loading(self): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( - "Loading from saved checkpoints should give same results." - ) + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results." + ) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose( - images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001 - ), "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." + images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." + ) def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): """ @@ -498,7 +498,7 @@ def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): "Fused lora should change the output" ) - def test_simple_inference_with_text_lora_save_load(self): + def test_simple_inference_with_text_lora_save_load(self, tmpdirname): """ Tests a simple usecase where users could use saving utilities for LoRA. """ @@ -511,16 +511,13 @@ def test_simple_inference_with_text_lora_save_load(self): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + modules_to_save = self._get_modules_to_save(pipe) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" @@ -577,7 +574,7 @@ def test_simple_inference_with_partial_text_lora(self, base_pipe_output): "Removing adapters should change the output" ) - def test_simple_inference_save_pretrained_with_text_lora(self): + def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ @@ -590,10 +587,9 @@ def test_simple_inference_save_pretrained_with_text_lora(self): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) + pipe.save_pretrained(tmpdirname) + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(torch_device) if "text_encoder" in self.pipeline_class._lora_loadable_modules: assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), ( @@ -610,7 +606,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self): "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_text_denoiser_lora_save_load(self): + def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname): """ Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ @@ -623,15 +619,12 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" @@ -1486,7 +1479,7 @@ def test_simple_inference_with_dora(self): "DoRA lora should change the output", ) - def test_missing_keys_warning(self): + def test_missing_keys_warning(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1496,15 +1489,12 @@ def test_missing_keys_warning(self): denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - pipe.unload_lora_weights() - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + pipe.unload_lora_weights() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) missing_key = [k for k in state_dict if "lora_A" in k][0] del state_dict[missing_key] @@ -1516,7 +1506,7 @@ def test_missing_keys_warning(self): component = list({k.split(".")[0] for k in state_dict})[0] assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", "")) - def test_unexpected_keys_warning(self): + def test_unexpected_keys_warning(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1526,15 +1516,12 @@ def test_unexpected_keys_warning(self): denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - pipe.unload_lora_weights() - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + pipe.unload_lora_weights() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) @@ -1616,7 +1603,7 @@ def test_logs_info_when_no_lora_keys_found(self, base_pipe_output): ) assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - def test_set_adapters_match_attention_kwargs(self, base_pipe_output): + def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, denoiser_lora_config = self.get_dummy_components() @@ -1645,31 +1632,28 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output): "Lora + scale should match the output of `set_adapters()`." ) - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( - "Lora + scale should change the output" - ) - assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( - "Loading from saved checkpoints should give same results as attention_kwargs." - ) - assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( - "Loading from saved checkpoints should give same results as set_adapters()." - ) + output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Lora + scale should change the output" + ) + assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results as attention_kwargs." + ) + assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results as set_adapters()." + ) @require_peft_version_greater("0.13.2") def test_lora_B_bias(self, base_pipe_output): @@ -1806,7 +1790,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] @require_peft_version_greater("0.14.0") - def test_layerwise_casting_peft_input_autocast_denoiser(self): + def test_layerwise_casting_peft_input_autocast_denoiser(self, tmpdirname): """ A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise @@ -1865,77 +1849,73 @@ def check_module(denoiser): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - apply_layerwise_casting( - denoiser, - storage_dtype=storage_dtype, - compute_dtype=compute_dtype, - skip_modules_pattern=patterns_to_check, - ) - check_module(denoiser) + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + apply_layerwise_casting( + denoiser, + storage_dtype=storage_dtype, + compute_dtype=compute_dtype, + skip_modules_pattern=patterns_to_check, + ) + check_module(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe(**inputs, generator=torch.manual_seed(0))[0] + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe(**inputs, generator=torch.manual_seed(0))[0] @parameterized.expand([4, 8, 16]) - def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): + def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname): components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components) pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() - out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) - if len(out) == 3: - (_, _, parsed_metadata) = out - elif len(out) == 2: - (_, parsed_metadata) = out - denoiser_key = ( - f"{self.pipeline_class.transformer_name}" - if self.transformer_kwargs is not None - else f"{self.pipeline_class.unet_name}" - ) - assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + out = pipe.lora_state_dict(tmpdirname, return_lora_metadata=True) + if len(out) == 3: + (_, _, parsed_metadata) = out + elif len(out) == 2: + (_, parsed_metadata) = out + denoiser_key = ( + f"{self.pipeline_class.transformer_name}" + if self.transformer_kwargs is not None + else f"{self.pipeline_class.unet_name}" + ) + assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) + + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + ) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + text_encoder_key = self.pipeline_class.text_encoder_name + assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key ) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - text_encoder_key = self.pipeline_class.text_encoder_name - assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key - ) - - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - text_encoder_2_key = "text_encoder_2" - assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_key = "text_encoder_2" + assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key + ) @parameterized.expand([4, 8, 16]) - def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): + def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname): components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components).to(torch_device) @@ -1946,18 +1926,15 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() - pipe.load_lora_weights(tmpdir) - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), ( - "Lora outputs should match." - ) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdirname) + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), "Lora outputs should match." def test_lora_unload_add_adapter(self): """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" @@ -1977,7 +1954,7 @@ def test_lora_unload_add_adapter(self): ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_inference_load_delete_load_adapters(self, base_pipe_output): + def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname): """Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works.""" components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -2002,22 +1979,21 @@ def test_inference_load_delete_load_adapters(self, base_pipe_output): output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.delete_adapters(pipe.get_active_adapters()[0]) - output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001) - assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001) + pipe.delete_adapters(pipe.get_active_adapters()[0]) + output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001) + assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001) - pipe.load_lora_weights(tmpdirname) - output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001) + pipe.load_lora_weights(tmpdirname) + output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001) - def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook onload_device = torch_device @@ -2031,59 +2007,56 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - denoiser.enable_group_offload( - onload_device=onload_device, - offload_device=offload_device, - offload_type=offload_type, - num_blocks_per_group=1, - use_stream=use_stream, - ) - for _, component in pipe.components.items(): - if isinstance(component, torch.nn.Module): - component.to(torch_device) + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + denoiser.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=1, + use_stream=use_stream, + ) + for _, component in pipe.components.items(): + if isinstance(component, torch.nn.Module): + component.to(torch_device) - group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_1 is not None + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_1 is not None - output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.unload_lora_weights() - group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_2 is not None + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.unload_lora_weights() + group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_2 is not None - output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_3 is not None + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_3 is not None - output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001) + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001) @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): for cls in inspect.getmro(self.__class__): if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: return - self._test_group_offloading_inference_denoiser(offload_type, use_stream) + self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) @require_torch_accelerator - def test_lora_loading_model_cpu_offload(self): + def test_lora_loading_model_cpu_offload(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe = self.pipeline_class(**components) @@ -2096,18 +2069,15 @@ def test_lora_loading_model_cpu_offload(self): output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.enable_model_cpu_offload(device=torch_device) - pipe.load_lora_weights(tmpdirname) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.enable_model_cpu_offload(device=torch_device) + pipe.load_lora_weights(tmpdirname) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_lora, output_lora_loaded, atol=0.001, rtol=0.001) From 949cc1c326ec738c954d59f3a40e2a41fc14f853 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 14:54:23 +0530 Subject: [PATCH 06/17] up --- tests/lora/utils.py | 148 ++++++++++++++++++++++---------------------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index bc879f769195..db314d692306 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -243,20 +243,23 @@ def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_co if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + if denoiser_lora_config is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." else: denoiser = None + if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + return pipe, denoiser def _compute_baseline_output(self): - components, _, _ = self.get_dummy_components(self.scheduler_cls) + components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -326,7 +329,7 @@ def test_simple_inference_with_text_lora(self, base_pipe_output): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" @require_peft_version_greater("0.13.1") def test_low_cpu_mem_usage_with_injection(self): @@ -401,7 +404,7 @@ def test_low_cpu_mem_usage_with_loading(self, tmpdirname): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( "Loading from saved checkpoints should give same results." ) @@ -411,7 +414,7 @@ def test_low_cpu_mem_usage_with_loading(self, tmpdirname): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + assert np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." ) @@ -430,17 +433,17 @@ def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert not np.allclose(output_lora, output_lora_scale, atol=0.001, rtol=0.001), ( + assert not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), ( "Lora + scale should change the output" ) attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert np.allclose(base_pipe_output, output_lora_0_scale, atol=0.001, rtol=0.001), ( + assert np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), ( "Lora + 0 scale should lead to same result as no LoRA" ) @@ -466,7 +469,7 @@ def test_simple_inference_with_text_lora_fused(self, base_pipe_output): ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(ouput_fused, base_pipe_output, atol=0.001, rtol=0.001), + np.allclose(ouput_fused, base_pipe_output, atol=1e-3, rtol=1e-3), "Fused lora should change the output", ) @@ -484,18 +487,15 @@ def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.unload_lora_weights() - assert not (check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") + assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert not ( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) + assert not check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2" ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(ouput_unloaded, base_pipe_output, atol=0.001, rtol=0.001), ( - "Fused lora should change the output" + assert np.allclose(ouput_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), ( + "Unloading lora should match the base pipe output" ) def test_simple_inference_with_text_lora_save_load(self, tmpdirname): @@ -523,7 +523,7 @@ def test_simple_inference_with_text_lora_save_load(self, tmpdirname): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( "Loading from saved checkpoints should give same results." ) @@ -565,12 +565,12 @@ def test_simple_inference_with_partial_text_lora(self, base_pipe_output): } ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" pipe.unload_lora_weights() pipe.load_lora_weights(state_dict) output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_partial_lora, output_lora, atol=0.001, rtol=0.001), ( + assert not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), ( "Removing adapters should change the output" ) @@ -602,7 +602,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname): ) images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(images_lora, images_lora_save_pretrained, atol=0.001, rtol=0.001), ( + assert np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), ( "Loading from saved checkpoints should give same results." ) @@ -630,7 +630,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( "Loading from saved checkpoints should give same results." ) @@ -649,17 +649,17 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_outp pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_lora, base_pipe_output, atol=0.001, rtol=0.001), "Lora should change the output" + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert not np.allclose(output_lora, output_lora_scale, atol=0.001, rtol=0.001), ( + assert not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), ( "Lora + scale should change the output" ) attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert np.allclose(base_pipe_output, output_lora_0_scale, atol=0.001, rtol=0.001), ( + assert np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), ( "Lora + 0 scale should lead to same result as no LoRA" ) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -691,7 +691,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output): output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(output_fused, base_pipe_output, atol=0.001, rtol=0.001), + np.allclose(output_fused, base_pipe_output, atol=1e-3, rtol=1e-3), "Fused lora should change the output", ) @@ -720,12 +720,12 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_outpu ) output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_unloaded, base_pipe_output, atol=0.001, rtol=0.001), ( + assert np.allclose(output_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), ( "Fused lora should change the output" ) def test_simple_inference_with_text_denoiser_lora_unfused( - self, expected_atol: float = 0.001, expected_rtol: float = 0.001 + self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights @@ -792,40 +792,40 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_outpu pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(base_pipe_output, output_adapter_1, atol=0.001, rtol=0.001), + np.allclose(base_pipe_output, output_adapter_1, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.", ) pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(base_pipe_output, output_adapter_2, atol=0.001, rtol=0.001), + np.allclose(base_pipe_output, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.", ) pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(base_pipe_output, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(base_pipe_output, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.", ) assert not ( - np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter 1 and 2 should give different results", ) assert not ( - np.allclose(output_adapter_1, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter 1 and mixed adapters should give different results", ) assert not ( - np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter 2 and mixed adapters should give different results", ) pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( + assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), ( "output with no lora and output with lora disabled should give same results" ) @@ -903,21 +903,21 @@ def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output) pipe.set_adapters("adapter-1", weights_2) output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(output_weights_1, output_weights_2, atol=0.001, rtol=0.001), + np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), "LoRA weights 1 and 2 should give different results", ) assert not ( - np.allclose(base_pipe_output, output_weights_1, atol=0.001, rtol=0.001), + np.allclose(base_pipe_output, output_weights_1, atol=1e-3, rtol=1e-3), "No adapter and LoRA weights 1 should give different results", ) assert not ( - np.allclose(base_pipe_output, output_weights_2, atol=0.001, rtol=0.001), + np.allclose(base_pipe_output, output_weights_2, atol=1e-3, rtol=1e-3), "No adapter and LoRA weights 2 should give different results", ) pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( + assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), ( "output with no lora and output with lora disabled should give same results" ) @@ -957,21 +957,21 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter 1 and 2 should give different results", ) assert not ( - np.allclose(output_adapter_1, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter 1 and mixed adapters should give different results", ) assert not ( - np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter 2 and mixed adapters should give different results", ) pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( + assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), ( "output with no lora and output with lora disabled should give same results" ) with pytest.raises(ValueError): @@ -1094,27 +1094,27 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter 1 and 2 should give different results", ) assert not ( - np.allclose(output_adapter_1, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter 1 and mixed adapters should give different results", ) assert not ( - np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter 2 and mixed adapters should give different results", ) pipe.delete_adapters("adapter-1") output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_deleted_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), ( + assert np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( "Adapter 1 and 2 should give different results" ) pipe.delete_adapters("adapter-2") output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(base_pipe_output, output_deleted_adapters, atol=0.001, rtol=0.001), ( + assert np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), ( "output with no lora and output with lora disabled should give same results" ) @@ -1130,7 +1130,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"]) output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(base_pipe_output, output_deleted_adapters, atol=0.001, rtol=0.001), ( + assert np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), ( "output with no lora and output with lora disabled should give same results" ) @@ -1172,28 +1172,28 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_p pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(output_adapter_1, output_adapter_2, atol=0.001, rtol=0.001), + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter 1 and 2 should give different results", ) assert not ( - np.allclose(output_adapter_1, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter 1 and mixed adapters should give different results", ) assert not ( - np.allclose(output_adapter_2, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter 2 and mixed adapters should give different results", ) pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=0.001, rtol=0.001), + np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Weighted adapter and mixed adapter should give different results", ) pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(base_pipe_output, output_disabled, atol=0.001, rtol=0.001), ( + assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), ( "output with no lora and output with lora disabled should give same results" ) @@ -1341,7 +1341,7 @@ def test_get_list_adapters(self): assert pipe.get_list_adapters() == dicts_to_be_checked def test_simple_inference_with_text_lora_denoiser_fused_multi( - self, expected_atol: float = 0.001, expected_rtol: float = 0.001 + self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model @@ -1415,7 +1415,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( ) def test_lora_scale_kwargs_match_fusion( - self, base_pipe_output, expected_atol: float = 0.001, expected_rtol: float = 0.001 + self, base_pipe_output, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) for lora_scale in [1.0, 0.8]: @@ -1475,7 +1475,7 @@ def test_simple_inference_with_dora(self): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not ( - np.allclose(output_dora_lora, output_no_dora_lora, atol=0.001, rtol=0.001), + np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), "DoRA lora should change the output", ) @@ -1619,16 +1619,16 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname) attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] assert not ( - np.allclose(base_pipe_output, output_lora_scale, atol=0.001, rtol=0.001), + np.allclose(base_pipe_output, output_lora_scale, atol=1e-3, rtol=1e-3), "Lora + scale should change the output", ) pipe.set_adapters("default", lora_scale) output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(base_pipe_output, output_lora_scale_wo_kwargs, atol=0.001, rtol=0.001), ( + assert not np.allclose(base_pipe_output, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), ( "Lora + scale should change the output" ) - assert np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=0.001, rtol=0.001), ( + assert np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), ( "Lora + scale should match the output of `set_adapters()`." ) @@ -1645,13 +1645,13 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname) assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( "Lora + scale should change the output" ) - assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( "Loading from saved checkpoints should give same results as attention_kwargs." ) - assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( "Loading from saved checkpoints should give same results as set_adapters()." ) @@ -1686,9 +1686,9 @@ def test_lora_B_bias(self, base_pipe_output): pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(base_pipe_output, lora_bias_false_output, atol=0.001, rtol=0.001) - assert not np.allclose(base_pipe_output, lora_bias_true_output, atol=0.001, rtol=0.001) - assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=0.001, rtol=0.001) + assert not np.allclose(base_pipe_output, lora_bias_false_output, atol=1e-3, rtol=1e-3) + assert not np.allclose(base_pipe_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) def test_correct_lora_configs_with_different_ranks(self, base_pipe_output): components, _, denoiser_lora_config = self.get_dummy_components() @@ -1725,8 +1725,8 @@ def test_correct_lora_configs_with_different_ranks(self, base_pipe_output): assert updated_rank_pattern == {module_name_to_rank_update: updated_rank} lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(base_pipe_output, lora_output_same_rank, atol=0.001, rtol=0.001) - assert not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=0.001, rtol=0.001) + assert not np.allclose(base_pipe_output, lora_output_same_rank, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3) if self.unet_kwargs is not None: pipe.unet.delete_adapters("adapter-1") @@ -1745,8 +1745,8 @@ def test_correct_lora_configs_with_different_ranks(self, base_pipe_output): module_name_to_rank_update: updated_alpha } lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(base_pipe_output, lora_output_diff_alpha, atol=0.001, rtol=0.001) - assert not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=0.001, rtol=0.001) + assert not np.allclose(base_pipe_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3) def test_layerwise_casting_inference_denoiser(self): from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS @@ -1934,7 +1934,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname) pipe.unload_lora_weights() pipe.load_lora_weights(tmpdirname) output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), "Lora outputs should match." + assert np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." def test_lora_unload_add_adapter(self): """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" @@ -1986,12 +1986,12 @@ def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname) pipe.delete_adapters(pipe.get_active_adapters()[0]) output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001) - assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001) + assert not np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3) + assert np.allclose(base_pipe_output, output_no_adapter, atol=1e-3, rtol=1e-3) pipe.load_lora_weights(tmpdirname) output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001) + assert np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3) def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook @@ -2045,7 +2045,7 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tm assert group_offload_hook_3 is not None output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001) + assert np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3) @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) @require_torch_accelerator @@ -2080,4 +2080,4 @@ def test_lora_loading_model_cpu_offload(self, tmpdirname): assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_lora, output_lora_loaded, atol=0.001, rtol=0.001) + assert np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3) From cba82591e80bf4fadc814876d7640f2774bc49c2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 15:56:37 +0530 Subject: [PATCH 07/17] up --- tests/lora/test_lora_layers_auraflow.py | 32 ++-- tests/lora/test_lora_layers_cogvideox.py | 43 ++++-- tests/lora/test_lora_layers_cogview4.py | 40 +++-- tests/lora/test_lora_layers_flux.py | 149 ++++-------------- tests/lora/test_lora_layers_hunyuanvideo.py | 28 ++-- tests/lora/test_lora_layers_ltx_video.py | 26 ++-- tests/lora/test_lora_layers_lumina2.py | 25 +-- tests/lora/test_lora_layers_mochi.py | 29 ++-- tests/lora/test_lora_layers_qwenimage.py | 26 ++-- tests/lora/test_lora_layers_sana.py | 26 ++-- tests/lora/test_lora_layers_sd3.py | 13 +- tests/lora/test_lora_layers_wan.py | 26 ++-- tests/lora/test_lora_layers_wanvace.py | 26 ++-- tests/lora/utils.py | 161 +++++++++----------- tests/testing_utils.py | 27 ++-- 15 files changed, 342 insertions(+), 335 deletions(-) diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 55d69b5bfa4f..650301fa4574 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -13,16 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -import unittest +import pytest import torch from transformers import AutoTokenizer, UMT5EncoderModel -from diffusers import ( - AuraFlowPipeline, - AuraFlowTransformer2DModel, - FlowMatchEulerDiscreteScheduler, -) +from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler from ..testing_utils import ( floats_tensor, @@ -103,34 +99,42 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in AuraFlow.") + pytest.mark.skip("Not supported in AuraFlow.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in AuraFlow.") + pytest.mark.skip("Not supported in AuraFlow.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in AuraFlow.") + pytest.mark.skip("Not supported in AuraFlow.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index 4d407ad420ca..27dc81f7635a 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -13,10 +13,9 @@ # limitations under the License. import sys -import unittest +import pytest import torch -from parameterized import parameterized from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( @@ -128,45 +127,61 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_lora_scale_kwargs_match_fusion(self): super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) - @parameterized.expand([("block_level", True), ("leaf_level", False)]) + @pytest.mark.parametrize( + "offload_type, use_stream", + [ + ("block_level", True), + ("leaf_level", False), + ("leaf_level", True), + ], + ) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 - super()._test_group_offloading_inference_denoiser(offload_type, use_stream) + super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) + + pytest.mark.skip("Not supported in CogVideoX.") - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in CogVideoX.") + pytest.mark.skip("Not supported in CogVideoX.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in CogVideoX.") + pytest.mark.skip("Not supported in CogVideoX.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora_save_load(self): pass - @unittest.skip("Not supported in CogVideoX.") + pytest.mark.skip("Not supported in CogVideoX.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 3a39c44a37a6..799e321987f5 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -13,11 +13,10 @@ # limitations under the License. import sys -import unittest import numpy as np +import pytest import torch -from parameterized import parameterized from transformers import AutoTokenizer, GlmModel from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler @@ -142,41 +141,56 @@ def test_simple_inference_save_pretrained(self, tmpdirname): "Loading from saved checkpoints should give same results.", ) - @parameterized.expand([("block_level", True), ("leaf_level", False)]) + @pytest.mark.parametrize( + "offload_type, use_stream", + [ + ("block_level", True), + ("leaf_level", False), + ("leaf_level", True), + ], + ) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 - super()._test_group_offloading_inference_denoiser(offload_type, use_stream) + super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) + + pytest.mark.skip("Not supported in CogView4.") - @unittest.skip("Not supported in CogView4.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in CogView4.") + pytest.mark.skip("Not supported in CogView4.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in CogView4.") + pytest.mark.skip("Not supported in CogView4.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 7c230308ae45..8b9e6ec472b0 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -22,12 +22,11 @@ import pytest import safetensors.torch import torch -from parameterized import parameterized from PIL import Image from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel -from diffusers.utils import load_image, logging +from diffusers.utils import logging from ..testing_utils import ( CaptureLogger, @@ -169,9 +168,8 @@ def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - assert not ( - np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), - "LoRA should lead to different results.", + assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), ( + "LoRA should lead to different results." ) denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) @@ -188,13 +186,11 @@ def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images - assert not ( - np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), - "Different LoRAs should lead to different results.", + assert not np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), ( + "Different LoRAs should lead to different results." ) - assert not ( - np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), - "LoRA should lead to different results.", + assert not np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), ( + "LoRA should lead to different results." ) def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname): @@ -210,10 +206,10 @@ def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - assert not ( - np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), - "LoRA should lead to different results.", + assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), ( + "LoRA should lead to different results." ) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) @@ -228,28 +224,30 @@ def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images - assert not ( - np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), - "Different LoRAs should lead to different results.", + assert not np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), ( + "Different LoRAs should lead to different results." ) - assert not ( - np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), - "LoRA should lead to different results.", + assert not np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), ( + "LoRA should lead to different results." ) - @unittest.skip("Not supported in Flux.") + pytest.mark.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Flux.") + pytest.mark.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Flux.") + pytest.mark.skip("Not supported in Flux.") + def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in Flux.") + pytest.mark.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass @@ -355,9 +353,8 @@ def test_with_norm_in_state_dict(self): lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] assert pipe.transformer._transformer_norm_layers is None assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05) - assert not ( - np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), - f"{norm_layer} is tested", + assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), ( + f"{norm_layer} is tested" ) with CaptureLogger(logger) as cap_logger: @@ -729,19 +726,23 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 assert pipe.transformer.config.in_channels == in_features * 2 - @unittest.skip("Not supported in Flux.") + pytest.mark.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Flux.") + pytest.mark.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Flux.") + pytest.mark.skip("Not supported in Flux.") + def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in Flux.") + pytest.mark.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass @@ -872,89 +873,3 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - - -@nightly -@require_torch_accelerator -@require_peft_backend -@require_big_accelerator -class FluxControlLoRAIntegrationTests(unittest.TestCase): - num_inference_steps = 10 - seed = 0 - prompt = "A robot made of exotic candies and chocolates of different kinds." - - def setUp(self): - super().setUp() - gc.collect() - backend_empty_cache(torch_device) - self.pipeline = FluxControlPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 - ).to(torch_device) - - def tearDown(self): - super().tearDown() - gc.collect() - backend_empty_cache(torch_device) - - @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) - def test_lora(self, lora_ckpt_id): - self.pipeline.load_lora_weights(lora_ckpt_id) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - if "Canny" in lora_ckpt_id: - control_image = load_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" - ) - else: - control_image = load_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" - ) - image = self.pipeline( - prompt=self.prompt, - control_image=control_image, - height=1024, - width=1024, - num_inference_steps=self.num_inference_steps, - guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - out_slice = image[0, -3:, -3:, -1].flatten() - if "Canny" in lora_ckpt_id: - expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516]) - else: - expected_slice = np.array([0.8203, 0.832, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - assert max_diff < 0.001 - - @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) - def test_lora_with_turbo(self, lora_ckpt_id): - self.pipeline.load_lora_weights(lora_ckpt_id) - self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - if "Canny" in lora_ckpt_id: - control_image = load_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" - ) - else: - control_image = load_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" - ) - image = self.pipeline( - prompt=self.prompt, - control_image=control_image, - height=1024, - width=1024, - num_inference_steps=self.num_inference_steps, - guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, - output_type="np", - generator=torch.manual_seed(self.seed), - ).images - out_slice = image[0, -3:, -3:, -1].flatten() - if "Canny" in lora_ckpt_id: - expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484]) - else: - expected_slice = np.array([0.668, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562]) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - assert max_diff < 0.001 diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index 0f31eaf57aa7..ed06f4343e14 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import pytest import torch from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast @@ -156,39 +157,48 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) # TODO(aryan): Fix the following test - @unittest.skip("This test fails with an error I haven't been able to debug yet.") + pytest.mark.skip("This test fails with an error I haven't been able to debug yet.") + def test_simple_inference_save_pretrained(self): pass - @unittest.skip("Not supported in HunyuanVideo.") + pytest.mark.skip("Not supported in HunyuanVideo.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in HunyuanVideo.") + pytest.mark.skip("Not supported in HunyuanVideo.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in HunyuanVideo.") + pytest.mark.skip("Not supported in HunyuanVideo.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index b72479de5736..db5ade6f673c 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -13,8 +13,8 @@ # limitations under the License. import sys -import unittest +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -114,34 +114,42 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - @unittest.skip("Not supported in LTXVideo.") + pytest.mark.skip("Not supported in LTXVideo.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in LTXVideo.") + pytest.mark.skip("Not supported in LTXVideo.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in LTXVideo.") + pytest.mark.skip("Not supported in LTXVideo.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index a4ddd5457d3c..9d681b619f0f 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -13,7 +13,6 @@ # limitations under the License. import sys -import unittest import numpy as np import pytest @@ -101,35 +100,43 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in Lumina2.") + pytest.mark.skip("Not supported in Lumina2.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Lumina2.") + pytest.mark.skip("Not supported in Lumina2.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Lumina2.") + pytest.mark.skip("Not supported in Lumina2.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index a34a615b257a..eddf59a696d5 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -13,8 +13,8 @@ # limitations under the License. import sys -import unittest +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -105,38 +105,47 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - @unittest.skip("Not supported in Mochi.") + pytest.mark.skip("Not supported in Mochi.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Mochi.") + pytest.mark.skip("Not supported in Mochi.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Mochi.") + pytest.mark.skip("Not supported in Mochi.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") + def test_simple_inference_with_text_lora_save_load(self): pass - @unittest.skip("Not supported in CogVideoX.") + pytest.mark.skip("Not supported in CogVideoX.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 167373211e90..470c2212c2d7 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -import unittest +import pytest import torch from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer @@ -96,34 +96,42 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in Qwen Image.") + pytest.mark.skip("Not supported in Qwen Image.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Qwen Image.") + pytest.mark.skip("Not supported in Qwen Image.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Qwen Image.") + pytest.mark.skip("Not supported in Qwen Image.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index 2323d66e39e2..0f2a3cbe9e05 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -import unittest +import pytest import torch from transformers import Gemma2Model, GemmaTokenizer @@ -105,34 +105,42 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in SANA.") + pytest.mark.skip("Not supported in SANA.") + def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in SANA.") + pytest.mark.skip("Not supported in SANA.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in SANA.") + pytest.mark.skip("Not supported in SANA.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + pytest.mark.skip("Text encoder LoRA is not supported in SANA.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + pytest.mark.skip("Text encoder LoRA is not supported in SANA.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + pytest.mark.skip("Text encoder LoRA is not supported in SANA.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + pytest.mark.skip("Text encoder LoRA is not supported in SANA.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + pytest.mark.skip("Text encoder LoRA is not supported in SANA.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 02602ddf6fc2..7bea30445dcd 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -113,19 +114,23 @@ def test_sd3_lora(self): lora_filename = "lora_peft_format.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - @unittest.skip("Not supported in SD3.") + pytest.mark.skip("Not supported in SD3.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in SD3.") + pytest.mark.skip("Not supported in SD3.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass - @unittest.skip("Not supported in SD3.") + pytest.mark.skip("Not supported in SD3.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in SD3.") + pytest.mark.skip("Not supported in SD3.") + def test_modify_padding_mode(self): pass diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 7066578dc749..18c671aa2f83 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -13,8 +13,8 @@ # limitations under the License. import sys -import unittest +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -110,34 +110,42 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - @unittest.skip("Not supported in Wan.") + pytest.mark.skip("Not supported in Wan.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Wan.") + pytest.mark.skip("Not supported in Wan.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Wan.") + pytest.mark.skip("Not supported in Wan.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index 9c319b995277..1c9493068832 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -14,9 +14,9 @@ import os import sys -import unittest import numpy as np +import pytest import safetensors.torch import torch from PIL import Image @@ -126,35 +126,43 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - @unittest.skip("Not supported in Wan VACE.") + pytest.mark.skip("Not supported in Wan VACE.") + def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Wan VACE.") + pytest.mark.skip("Not supported in Wan VACE.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Wan VACE.") + pytest.mark.skip("Not supported in Wan VACE.") + def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") + def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") + def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") + def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") + def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") + def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index db314d692306..40183cd9a0c9 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -20,7 +20,6 @@ import numpy as np import pytest import torch -from parameterized import parameterized from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.utils import logging @@ -243,19 +242,19 @@ def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_co if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - + if denoiser_lora_config is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." else: denoiser = None - + if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - + return pipe, denoiser def _compute_baseline_output(self): @@ -468,9 +467,8 @@ def test_simple_inference_with_text_lora_fused(self, base_pipe_output): assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(ouput_fused, base_pipe_output, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", + assert not np.allclose(ouput_fused, base_pipe_output, atol=1e-3, rtol=1e-3), ( + "Fused lora should change the output" ) def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): @@ -491,7 +489,9 @@ def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert not check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2" + assert not check_if_lora_correctly_set(pipe.text_encoder_2), ( + "Lora not correctly unloaded in text encoder 2" + ) ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(ouput_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), ( @@ -690,9 +690,8 @@ def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output): assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(output_fused, base_pipe_output, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", + assert not np.allclose(output_fused, base_pipe_output, atol=1e-3, rtol=1e-3), ( + "Fused lora should change the output" ) def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output): @@ -714,9 +713,8 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_outpu if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - assert not ( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", + assert not check_if_lora_correctly_set(pipe.text_encoder_2), ( + "Lora not correctly unloaded in text encoder 2" ) output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -791,36 +789,30 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_outpu pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(base_pipe_output, output_adapter_1, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", + assert not np.allclose(base_pipe_output, output_adapter_1, atol=1e-3, rtol=1e-3), ( + "Adapter outputs should be different." ) pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(base_pipe_output, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", + assert not np.allclose(base_pipe_output, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter outputs should be different." ) pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(base_pipe_output, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", + assert not np.allclose(base_pipe_output, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter outputs should be different." ) - assert not ( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) - assert not ( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and mixed adapters should give different results" ) - assert not ( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", + assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 2 and mixed adapters should give different results" ) pipe.disable_lora() @@ -902,17 +894,15 @@ def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output) weights_2 = {"unet": {"up": 5}} pipe.set_adapters("adapter-1", weights_2) output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), - "LoRA weights 1 and 2 should give different results", + + assert not np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), ( + "LoRA weights 1 and 2 should give different results" ) - assert not ( - np.allclose(base_pipe_output, output_weights_1, atol=1e-3, rtol=1e-3), - "No adapter and LoRA weights 1 should give different results", + assert not np.allclose(base_pipe_output, output_weights_1, atol=1e-3, rtol=1e-3), ( + "No adapter and LoRA weights 1 should give different results" ) - assert not ( - np.allclose(base_pipe_output, output_weights_2, atol=1e-3, rtol=1e-3), - "No adapter and LoRA weights 2 should give different results", + assert not np.allclose(base_pipe_output, output_weights_2, atol=1e-3, rtol=1e-3), ( + "No adapter and LoRA weights 2 should give different results" ) pipe.disable_lora() @@ -952,21 +942,21 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base scales_2 = {"unet": {"down": 5, "mid": 5}} pipe.set_adapters("adapter-1", scales_1) output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.set_adapters("adapter-2", scales_2) output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + + assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) - assert not ( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and mixed adapters should give different results" ) - assert not ( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", + assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 2 and mixed adapters should give different results" ) pipe.disable_lora() @@ -1093,17 +1083,14 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) - assert not ( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and mixed adapters should give different results" ) - assert not ( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", + assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 2 and mixed adapters should give different results" ) pipe.delete_adapters("adapter-1") @@ -1171,24 +1158,20 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_p pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) - assert not ( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and mixed adapters should give different results" ) - assert not ( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", + assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 2 and mixed adapters should give different results" ) pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Weighted adapter and mixed adapter should give different results", + assert not np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Weighted adapter and mixed adapter should give different results" ) pipe.disable_lora() @@ -1456,9 +1439,8 @@ def test_lora_scale_kwargs_match_fusion( assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), ( "Fused lora should not change the output" ) - assert not ( - np.allclose(base_pipe_output, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), - "LoRA should change the output", + assert not np.allclose(base_pipe_output, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), ( + "LoRA should change the output" ) def test_simple_inference_with_dora(self): @@ -1474,9 +1456,8 @@ def test_simple_inference_with_dora(self): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not ( - np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), - "DoRA lora should change the output", + assert not np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), ( + "DoRA lora should change the output" ) def test_missing_keys_warning(self, tmpdirname): @@ -1504,7 +1485,7 @@ def test_missing_keys_warning(self, tmpdirname): pipe.load_lora_weights(state_dict) component = list({k.split(".")[0] for k in state_dict})[0] - assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", "")) + assert missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "") def test_unexpected_keys_warning(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components() @@ -1618,9 +1599,8 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname) lora_scale = 0.5 attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert not ( - np.allclose(base_pipe_output, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(base_pipe_output, output_lora_scale, atol=1e-3, rtol=1e-3), ( + "Lora + scale should change the output" ) pipe.set_adapters("default", lora_scale) @@ -1763,9 +1743,9 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): if "lora" in name or any((re.search(pattern, name) for pattern in patterns_to_check)): dtype_to_check = compute_dtype if getattr(submodule, "weight", None) is not None: - self.assertEqual(submodule.weight.dtype, dtype_to_check) + assert submodule.weight.dtype == dtype_to_check if getattr(submodule, "bias", None) is not None: - self.assertEqual(submodule.bias.dtype, dtype_to_check) + assert submodule.bias.dtype == dtype_to_check def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): components, text_lora_config, denoiser_lora_config = self.get_dummy_components() @@ -1871,7 +1851,7 @@ def check_module(denoiser): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] - @parameterized.expand([4, 8, 16]) + @pytest.mark.parametrize("lora_alpha", [4, 8, 16]) def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname): components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components) @@ -1914,7 +1894,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname) parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key ) - @parameterized.expand([4, 8, 16]) + @pytest.mark.parametrize("lora_alpha", [4, 8, 16]) def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname): components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components).to(torch_device) @@ -2047,7 +2027,14 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tm output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3) - @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) + @pytest.mark.parametrize( + "offload_type, use_stream", + [ + ("block_level", True), + ("leaf_level", False), + ("leaf_level", True), + ], + ) @require_torch_accelerator def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): for cls in inspect.getmro(self.__class__): @@ -2055,7 +2042,7 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmp return self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) - @require_torch_accelerator + @pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch") def test_lora_loading_model_cpu_offload(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 7f849219c16f..73d60459156c 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -24,6 +24,7 @@ import numpy as np import PIL.Image import PIL.ImageOps +import pytest import requests from numpy.linalg import norm from packaging import version @@ -275,7 +276,7 @@ def nightly(test_case): Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. """ - return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) + return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case) def is_torch_compile(test_case): @@ -350,9 +351,9 @@ def decorator(test_case): # These decorators are for accelerator-specific behaviours that are not GPU-specific def require_torch_accelerator(test_case): """Decorator marking a test that requires an accelerator backend and PyTorch.""" - return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")( - test_case - ) + return pytest.mark.skipif( + not (is_torch_available() and torch_device != "cpu"), reason="test requires accelerator+PyTorch" + )(test_case) def require_torch_multi_gpu(test_case): @@ -441,9 +442,9 @@ def require_big_accelerator(test_case): device_properties = torch.cuda.get_device_properties(0) total_memory = device_properties.total_memory / (1024**3) - return unittest.skipUnless( - total_memory >= BIG_GPU_MEMORY, - f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory", + return pytest.mark.skipif( + not total_memory >= BIG_GPU_MEMORY, + reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory", )(test_case) @@ -509,7 +510,7 @@ def require_peft_backend(test_case): Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and transformers. """ - return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) + return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case) def require_timm(test_case): @@ -550,8 +551,8 @@ def decorator(test_case): correct_peft_version = is_peft_available() and version.parse( version.parse(importlib.metadata.version("peft")).base_version ) > version.parse(peft_version) - return unittest.skipUnless( - correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}" + return pytest.mark.skipif( + not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}" )(test_case) return decorator @@ -567,9 +568,9 @@ def decorator(test_case): correct_transformers_version = is_transformers_available() and version.parse( version.parse(importlib.metadata.version("transformers")).base_version ) > version.parse(transformers_version) - return unittest.skipUnless( - correct_transformers_version, - f"test requires transformers with the version greater than {transformers_version}", + return pytest.mark.skipif( + not correct_transformers_version, + reason=f"test requires transformers with the version greater than {transformers_version}", )(test_case) return decorator From 610842af1ad67ca98338a98ca641583339fe7db3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 16:14:36 +0530 Subject: [PATCH 08/17] up --- tests/lora/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 40183cd9a0c9..7d33415d7312 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1802,9 +1802,9 @@ def check_module(denoiser): if any((re.search(pattern, name) for pattern in patterns_to_check)): dtype_to_check = compute_dtype if getattr(module, "weight", None) is not None: - self.assertEqual(module.weight.dtype, dtype_to_check) + assert module.weight.dtype == dtype_to_check if getattr(module, "bias", None) is not None: - self.assertEqual(module.bias.dtype, dtype_to_check) + assert module.bias.dtype == dtype_to_check if isinstance(module, BaseTunerLayer): assert getattr(module, "_diffusers_hook", None is not None) assert module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None From 565d674cc42dcf4b51b48886e85b615e93ef9b4b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 16:30:58 +0530 Subject: [PATCH 09/17] change flux lora integration tests to use pytest --- tests/lora/test_lora_layers_flux.py | 203 +++++++++++++++++++++------- tests/testing_utils.py | 2 +- 2 files changed, 154 insertions(+), 51 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 8b9e6ec472b0..b7518d701abc 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -16,7 +16,6 @@ import gc import os import sys -import unittest import numpy as np import pytest @@ -26,7 +25,7 @@ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel -from diffusers.utils import logging +from diffusers.utils import load_image, logging from ..testing_utils import ( CaptureLogger, @@ -752,7 +751,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -class FluxLoRAIntegrationTests(unittest.TestCase): +class TestFluxLoRAIntegration: """internal note: The integration slices were obtained on audace. torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the @@ -762,25 +761,25 @@ class FluxLoRAIntegrationTests(unittest.TestCase): num_inference_steps = 10 seed = 0 - def setUp(self): - super().setUp() + @pytest.fixture(scope="function") + def pipeline(self, torch_device): gc.collect() backend_empty_cache(torch_device) - self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) - - def tearDown(self): - super().tearDown() - del self.pipeline - gc.collect() - backend_empty_cache(torch_device) - - def test_flux_the_last_ben(self): - self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + try: + yield pipe + finally: + del pipe + gc.collect() + backend_empty_cache(torch_device) + + def test_flux_the_last_ben(self, pipeline): + pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "jon snow eating pizza with ketchup" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.0, @@ -792,13 +791,13 @@ def test_flux_the_last_ben(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - def test_flux_kohya(self): - self.pipeline.load_lora_weights("Norod78/brain-slug-flux") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) + def test_flux_kohya(self, pipeline): + pipeline.load_lora_weights("Norod78/brain-slug-flux") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "The cat with a brain slug earring" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.5, @@ -810,13 +809,13 @@ def test_flux_kohya(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - def test_flux_kohya_with_text_encoder(self): - self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) + def test_flux_kohya_with_text_encoder(self, pipeline): + pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "optimus is cleaning the house with broomstick" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.5, @@ -828,19 +827,18 @@ def test_flux_kohya_with_text_encoder(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - def test_flux_kohya_embedders_conversion(self): + def test_flux_kohya_embedders_conversion(self, pipeline): """Test that embedders load without throwing errors""" - self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") - self.pipeline.unload_lora_weights() - assert True - - def test_flux_xlabs(self): - self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) + pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") + pipeline.unload_lora_weights() + + def test_flux_xlabs(self, pipeline): + pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=3.5, @@ -852,15 +850,13 @@ def test_flux_xlabs(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - def test_flux_xlabs_load_lora_with_single_blocks(self): - self.pipeline.load_lora_weights( - "salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors" - ) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline.enable_model_cpu_offload() + def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline): + pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline.enable_model_cpu_offload() prompt = "a wizard mouse playing chess" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=3.5, @@ -873,3 +869,110 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 + + +@nightly +@require_torch_accelerator +@require_peft_backend +@require_big_accelerator +class TestFluxControlLoRAIntegration: + num_inference_steps = 10 + seed = 0 + prompt = "A robot made of exotic candies and chocolates of different kinds." + + @pytest.fixture(scope="function") + def pipeline(self, torch_device): + gc.collect() + backend_empty_cache(torch_device) + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + try: + yield pipe + finally: + del pipe + gc.collect() + backend_empty_cache(torch_device) + + @pytest.mark.parametrize( + "lora_ckpt_id", + [ + "black-forest-labs/FLUX.1-Canny-dev-lora", + "black-forest-labs/FLUX.1-Depth-dev-lora", + ], + ) + def test_lora(self, pipeline, lora_ckpt_id): + pipeline.load_lora_weights(lora_ckpt_id) + pipeline.fuse_lora() + pipeline.unload_lora_weights() + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = pipeline( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=self.num_inference_steps, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516]) + else: + expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 + + @pytest.mark.parametrize( + "lora_ckpt_id", + [ + "black-forest-labs/FLUX.1-Canny-dev-lora", + "black-forest-labs/FLUX.1-Depth-dev-lora", + ], + ) + def test_lora_with_turbo(self, pipeline, lora_ckpt_id): + pipeline.load_lora_weights(lora_ckpt_id) + pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = self.pipeline( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=self.num_inference_steps, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484]) + else: + expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 73d60459156c..988834acf546 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -266,7 +266,7 @@ def slow(test_case): Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ - return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case) def nightly(test_case): From 1737b710a23fbf46df599a4a2c98c3dbfec44799 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 16:45:04 +0530 Subject: [PATCH 10/17] up --- tests/lora/test_lora_layers_cogview4.py | 5 +- tests/lora/test_lora_layers_flux.py | 2 +- tests/lora/test_lora_layers_hunyuanvideo.py | 41 ++++---- tests/lora/test_lora_layers_lumina2.py | 6 +- tests/lora/test_lora_layers_sd.py | 100 +++++++------------- tests/lora/test_lora_layers_sd3.py | 11 +-- tests/lora/test_lora_layers_sdxl.py | 32 ++----- 7 files changed, 72 insertions(+), 125 deletions(-) diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 799e321987f5..fa614c56cdd0 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -136,9 +136,8 @@ def test_simple_inference_save_pretrained(self, tmpdirname): images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), ( + "Loading from saved checkpoints should give same results." ) @pytest.mark.parametrize( diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b7518d701abc..8db06a801c67 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -119,7 +119,7 @@ def test_with_alpha_in_state_dict(self, tmpdirname): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index ed06f4343e14..3439fef15c28 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -14,7 +14,6 @@ import gc import sys -import unittest import numpy as np import pytest @@ -38,7 +37,6 @@ require_peft_backend, require_torch_accelerator, skip_mps, - torch_device, ) @@ -207,7 +205,7 @@ def test_simple_inference_with_text_lora_save_load(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -class HunyuanVideoLoRAIntegrationTests(unittest.TestCase): +class TestHunyuanVideoLoRAIntegration: """internal note: The integration slices were obtained on DGX. torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the @@ -217,9 +215,8 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase): num_inference_steps = 10 seed = 0 - def setUp(self): - super().setUp() - + @pytest.fixture(scope="function") + def pipeline(self, torch_device): gc.collect() backend_empty_cache(torch_device) @@ -227,27 +224,27 @@ def setUp(self): transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ) - self.pipeline = HunyuanVideoPipeline.from_pretrained( - model_id, transformer=transformer, torch_dtype=torch.float16 - ).to(torch_device) - - def tearDown(self): - super().tearDown() - - gc.collect() - backend_empty_cache(torch_device) - - def test_original_format_cseti(self): - self.pipeline.load_lora_weights( + pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to( + torch_device + ) + try: + yield pipe + finally: + del pipe + gc.collect() + backend_empty_cache(torch_device) + + def test_original_format_cseti(self, pipeline): + pipeline.load_lora_weights( "Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors" ) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline.vae.enable_tiling() + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline.vae.enable_tiling() prompt = "CSETIARCANE. A cat walks on the grass, realistic" - out = self.pipeline( + out = pipeline( prompt=prompt, height=320, width=512, diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 9d681b619f0f..6ce70d53a07f 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -155,11 +155,11 @@ def test_lora_fuse_nan(self): if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." # corrupt one LoRA weight with `inf` values with torch.no_grad(): @@ -173,4 +173,4 @@ def test_lora_fuse_nan(self): pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) out = pipe(**inputs)[0] - self.assertTrue(np.isnan(out).all()) + assert np.isnan(out).all() diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 76ac775a9f1c..a5e640c0b736 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -14,9 +14,9 @@ # limitations under the License. import gc import sys -import unittest import numpy as np +import pytest import torch import torch.nn as nn from huggingface_hub import hf_hub_download @@ -91,16 +91,6 @@ class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests): def output_shape(self): return (1, 64, 64, 3) - def setUp(self): - super().setUp() - gc.collect() - backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() - gc.collect() - backend_empty_cache(torch_device) - # Keeping this test here makes sense because it doesn't look any integration # (value assertions on logits). @slow @@ -114,15 +104,8 @@ def test_integration_move_lora_cpu(self): pipe.load_lora_weights(lora_id, adapter_name="adapter-2") pipe = pipe.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", - ) - - self.assertTrue( - check_if_lora_correctly_set(pipe.unet), - "Lora not correctly set in unet", - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet" # We will offload the first adapter in CPU and check if the offloading # has been performed correctly @@ -130,35 +113,35 @@ def test_integration_move_lora_cpu(self): for name, module in pipe.unet.named_modules(): if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device == torch.device("cpu")) + assert module.weight.device == torch.device("cpu") elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device != torch.device("cpu")) + assert module.weight.device != torch.device("cpu") for name, module in pipe.text_encoder.named_modules(): if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device == torch.device("cpu")) + assert module.weight.device == torch.device("cpu") elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device != torch.device("cpu")) + assert module.weight.device != torch.device("cpu") pipe.set_lora_device(["adapter-1"], 0) for n, m in pipe.unet.named_modules(): if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)): - self.assertTrue(m.weight.device != torch.device("cpu")) + assert m.weight.device != torch.device("cpu") for n, m in pipe.text_encoder.named_modules(): if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)): - self.assertTrue(m.weight.device != torch.device("cpu")) + assert m.weight.device != torch.device("cpu") pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device) for n, m in pipe.unet.named_modules(): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): - self.assertTrue(m.weight.device != torch.device("cpu")) + assert m.weight.device != torch.device("cpu") for n, m in pipe.text_encoder.named_modules(): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): - self.assertTrue(m.weight.device != torch.device("cpu")) + assert m.weight.device != torch.device("cpu") @slow @require_torch_accelerator @@ -181,15 +164,9 @@ def test_integration_move_lora_dora_cpu(self): pipe.unet.add_adapter(unet_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - self.assertTrue( - check_if_lora_correctly_set(pipe.unet), - "Lora not correctly set in unet", - ) + assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet" for name, param in pipe.unet.named_parameters(): if "lora_" in name: @@ -225,17 +202,14 @@ def test_integration_set_lora_device_different_target_layers(self): pipe.unet.add_adapter(config1, adapter_name="adapter-1") pipe = pipe.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe.unet), - "Lora not correctly set in unet", - ) + assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet" # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")} modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")} - self.assertNotEqual(modules_adapter_0, modules_adapter_1) - self.assertTrue(modules_adapter_0 - modules_adapter_1) - self.assertTrue(modules_adapter_1 - modules_adapter_0) + assert modules_adapter_0 != modules_adapter_1 + assert modules_adapter_0 - modules_adapter_1 + assert modules_adapter_1 - modules_adapter_0 # setting both separately works pipe.set_lora_device(["adapter-0"], "cpu") @@ -243,32 +217,30 @@ def test_integration_set_lora_device_different_target_layers(self): for name, module in pipe.unet.named_modules(): if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device == torch.device("cpu")) + assert module.weight.device == torch.device("cpu") elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device == torch.device("cpu")) + assert module.weight.device == torch.device("cpu") # setting both at once also works pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device) for name, module in pipe.unet.named_modules(): if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device != torch.device("cpu")) + assert module.weight.device != torch.device("cpu") elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device != torch.device("cpu")) + assert module.weight.device != torch.device("cpu") @slow @nightly @require_torch_accelerator @require_peft_backend -class LoraIntegrationTests(unittest.TestCase): - def setUp(self): - super().setUp() +class TestSDLoraIntegration: + @pytest.fixture(autouse=True) + def _gc_and_cache_cleanup(self, torch_device): gc.collect() backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() + yield gc.collect() backend_empty_cache(torch_device) @@ -280,10 +252,7 @@ def test_integration_logits_with_scale(self): pipe.load_lora_weights(lora_id) pipe = pipe.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" prompt = "a red sks dog" @@ -312,10 +281,7 @@ def test_integration_logits_no_scale(self): pipe.load_lora_weights(lora_id) pipe = pipe.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" prompt = "a red sks dog" @@ -587,8 +553,8 @@ def test_unload_kohya_lora(self): ).images unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() - self.assertFalse(np.allclose(initial_images, lora_images)) - self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + assert not np.allclose(initial_images, lora_images) + assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3) release_memory(pipe) @@ -625,8 +591,8 @@ def test_load_unload_load_kohya_lora(self): ).images unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() - self.assertFalse(np.allclose(initial_images, lora_images)) - self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + assert not np.allclose(initial_images, lora_images) + assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3) # make sure we can load a LoRA again after unloading and they don't have # any undesired effects. @@ -637,7 +603,7 @@ def test_load_unload_load_kohya_lora(self): ).images lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten() - self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) + assert np.allclose(lora_images, lora_images_again, atol=1e-3) release_memory(pipe) def test_not_empty_state_dict(self): @@ -651,7 +617,7 @@ def test_not_empty_state_dict(self): lcm_lora = load_file(cached_file) pipe.load_lora_weights(lcm_lora, adapter_name="lcm") - self.assertTrue(lcm_lora != {}) + assert lcm_lora != {} release_memory(pipe) def test_load_unload_load_state_dict(self): diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 7bea30445dcd..d60d09841120 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -14,7 +14,6 @@ # limitations under the License. import gc import sys -import unittest import numpy as np import pytest @@ -143,17 +142,15 @@ def test_multiple_wrong_adapter_name_raises_error(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -class SD3LoraIntegrationTests(unittest.TestCase): +class TestSD3LoraIntegration: pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def _gc_and_cache_cleanup(self, torch_device): gc.collect() backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() + yield gc.collect() backend_empty_cache(torch_device) diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 405e97cd1b1f..6ee81dac3244 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -17,9 +17,9 @@ import importlib import sys import time -import unittest import numpy as np +import pytest import torch from packaging import version from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer @@ -104,16 +104,6 @@ class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests): def output_shape(self): return (1, 64, 64, 3) - def setUp(self): - super().setUp() - gc.collect() - backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() - gc.collect() - backend_empty_cache(torch_device) - @is_flaky def test_multiple_wrong_adapter_name_raises_error(self): super().test_multiple_wrong_adapter_name_raises_error() @@ -157,14 +147,12 @@ def test_lora_scale_kwargs_match_fusion(self): @nightly @require_torch_accelerator @require_peft_backend -class LoraSDXLIntegrationTests(unittest.TestCase): - def setUp(self): - super().setUp() +class TestLoraSDXLIntegration: + @pytest.fixture(autouse=True) + def _gc_and_cache_cleanup(self, torch_device): gc.collect() backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() + yield gc.collect() backend_empty_cache(torch_device) @@ -383,7 +371,7 @@ def test_sdxl_1_0_lora_fusion_efficiency(self): end_time = time.time() elapsed_time_fusion = end_time - start_time - self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion) + assert elapsed_time_fusion < elapsed_time_non_fusion release_memory(pipe) @@ -439,14 +427,14 @@ def remap_key(key, sd): for key, value in text_encoder_1_sd.items(): key = remap_key(key, fused_te_state_dict) - self.assertTrue(torch.allclose(fused_te_state_dict[key], value)) + assert torch.allclose(fused_te_state_dict[key], value) for key, value in text_encoder_2_sd.items(): key = remap_key(key, fused_te_2_state_dict) - self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value)) + assert torch.allclose(fused_te_2_state_dict[key], value) for key, value in unet_state_dict.items(): - self.assertTrue(torch.allclose(unet_state_dict[key], value)) + assert torch.allclose(unet_state_dict[key], value) pipe.fuse_lora() pipe.unload_lora_weights() @@ -589,7 +577,7 @@ def test_integration_logits_multi_adapter(self): pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy") pipe = pipe.to(torch_device) - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet" prompt = "toy_face of a hacker with a hoodie" From c4bcf72084f4400a98b666e5bcd1a0103ab8bb34 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 16:56:31 +0530 Subject: [PATCH 11/17] up --- tests/lora/test_lora_layers_auraflow.py | 24 +++++---------- tests/lora/test_lora_layers_cogvideox.py | 33 +++++++-------------- tests/lora/test_lora_layers_cogview4.py | 30 ++++++------------- tests/lora/test_lora_layers_flux.py | 24 +++++---------- tests/lora/test_lora_layers_hunyuanvideo.py | 27 ++++++----------- tests/lora/test_lora_layers_ltx_video.py | 24 +++++---------- tests/lora/test_lora_layers_lumina2.py | 24 +++++---------- tests/lora/test_lora_layers_mochi.py | 27 ++++++----------- tests/lora/test_lora_layers_qwenimage.py | 24 +++++---------- tests/lora/test_lora_layers_sana.py | 24 +++++---------- tests/lora/test_lora_layers_sd3.py | 12 +++----- tests/lora/test_lora_layers_wan.py | 24 +++++---------- tests/lora/test_lora_layers_wanvace.py | 27 +++++------------ tests/lora/utils.py | 6 +--- 14 files changed, 106 insertions(+), 224 deletions(-) diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 650301fa4574..e3bbbfb632d9 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -99,42 +99,34 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - pytest.mark.skip("Not supported in AuraFlow.") - + @pytest.mark.skip("Not supported in AuraFlow.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in AuraFlow.") - + @pytest.mark.skip("Not supported in AuraFlow.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in AuraFlow.") - + @pytest.mark.skip("Not supported in AuraFlow.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") - + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") - + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") - + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") - + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") - + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index 27dc81f7635a..d2557a5337ae 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -129,11 +129,7 @@ def test_lora_scale_kwargs_match_fusion(self): @pytest.mark.parametrize( "offload_type, use_stream", - [ - ("block_level", True), - ("leaf_level", False), - ("leaf_level", True), - ], + [("block_level", True), ("leaf_level", False), ("leaf_level", True)], ) @require_torch_accelerator def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): @@ -141,47 +137,38 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmp # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) - pytest.mark.skip("Not supported in CogVideoX.") - + @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in CogVideoX.") - + @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in CogVideoX.") - + @pytest.mark.skip("Not supported in CogVideoX.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_save_load(self): pass - pytest.mark.skip("Not supported in CogVideoX.") - + @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index fa614c56cdd0..363da9f26515 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -142,11 +142,7 @@ def test_simple_inference_save_pretrained(self, tmpdirname): @pytest.mark.parametrize( "offload_type, use_stream", - [ - ("block_level", True), - ("leaf_level", False), - ("leaf_level", True), - ], + [("block_level", True), ("leaf_level", False), ("leaf_level", True)], ) @require_torch_accelerator def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): @@ -154,42 +150,34 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmp # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) - pytest.mark.skip("Not supported in CogView4.") - + @pytest.mark.skip("Not supported in CogView4.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in CogView4.") - + @pytest.mark.skip("Not supported in CogView4.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in CogView4.") - + @pytest.mark.skip("Not supported in CogView4.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") - + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 8db06a801c67..556ec00391a8 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -230,23 +230,19 @@ def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname) "LoRA should lead to different results." ) - pytest.mark.skip("Not supported in Flux.") - + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in Flux.") - + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in Flux.") - + @pytest.mark.skip("Not supported in Flux.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Not supported in Flux.") - + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass @@ -725,23 +721,19 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 assert pipe.transformer.config.in_channels == in_features * 2 - pytest.mark.skip("Not supported in Flux.") - + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in Flux.") - + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in Flux.") - + @pytest.mark.skip("Not supported in Flux.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Not supported in Flux.") - + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index 3439fef15c28..52ee3cd9f752 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -155,48 +155,39 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) # TODO(aryan): Fix the following test - pytest.mark.skip("This test fails with an error I haven't been able to debug yet.") - + @pytest.mark.skip("This test fails with an error I haven't been able to debug yet.") def test_simple_inference_save_pretrained(self): pass - pytest.mark.skip("Not supported in HunyuanVideo.") - + @pytest.mark.skip("Not supported in HunyuanVideo.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in HunyuanVideo.") - + @pytest.mark.skip("Not supported in HunyuanVideo.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in HunyuanVideo.") - + @pytest.mark.skip("Not supported in HunyuanVideo.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index db5ade6f673c..37bad941bf3b 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -114,42 +114,34 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - pytest.mark.skip("Not supported in LTXVideo.") - + @pytest.mark.skip("Not supported in LTXVideo.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in LTXVideo.") - + @pytest.mark.skip("Not supported in LTXVideo.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in LTXVideo.") - + @pytest.mark.skip("Not supported in LTXVideo.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") - + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 6ce70d53a07f..c0ee9c34e42a 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -100,43 +100,35 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - pytest.mark.skip("Not supported in Lumina2.") - + @pytest.mark.skip("Not supported in Lumina2.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in Lumina2.") - + @pytest.mark.skip("Not supported in Lumina2.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in Lumina2.") - + @pytest.mark.skip("Not supported in Lumina2.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index eddf59a696d5..f9da672732c7 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -105,47 +105,38 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - pytest.mark.skip("Not supported in Mochi.") - + @pytest.mark.skip("Not supported in Mochi.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in Mochi.") - + @pytest.mark.skip("Not supported in Mochi.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in Mochi.") - + @pytest.mark.skip("Not supported in Mochi.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora_save_load(self): pass - pytest.mark.skip("Not supported in CogVideoX.") - + @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 470c2212c2d7..c24464653072 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -96,42 +96,34 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - pytest.mark.skip("Not supported in Qwen Image.") - + @pytest.mark.skip("Not supported in Qwen Image.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in Qwen Image.") - + @pytest.mark.skip("Not supported in Qwen Image.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in Qwen Image.") - + @pytest.mark.skip("Not supported in Qwen Image.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index 0f2a3cbe9e05..5977aeb9a53c 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -105,42 +105,34 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - pytest.mark.skip("Not supported in SANA.") - + @pytest.mark.skip("Not supported in SANA.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Not supported in SANA.") - + @pytest.mark.skip("Not supported in SANA.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in SANA.") - + @pytest.mark.skip("Not supported in SANA.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in SANA.") - + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in SANA.") - + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in SANA.") - + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in SANA.") - + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in SANA.") - + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index d60d09841120..a44f6887f41a 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -113,23 +113,19 @@ def test_sd3_lora(self): lora_filename = "lora_peft_format.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pytest.mark.skip("Not supported in SD3.") - + @pytest.mark.skip("Not supported in SD3.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in SD3.") - + @pytest.mark.skip("Not supported in SD3.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass - pytest.mark.skip("Not supported in SD3.") - + @pytest.mark.skip("Not supported in SD3.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in SD3.") - + @pytest.mark.skip("Not supported in SD3.") def test_modify_padding_mode(self): pass diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 18c671aa2f83..3393521471e7 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -110,42 +110,34 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - pytest.mark.skip("Not supported in Wan.") - + @pytest.mark.skip("Not supported in Wan.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in Wan.") - + @pytest.mark.skip("Not supported in Wan.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in Wan.") - + @pytest.mark.skip("Not supported in Wan.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index 1c9493068832..95a1638d21d8 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -126,49 +126,38 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - pytest.mark.skip("Not supported in Wan VACE.") - + @pytest.mark.skip("Not supported in Wan VACE.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - pytest.mark.skip("Not supported in Wan VACE.") - + @pytest.mark.skip("Not supported in Wan VACE.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - pytest.mark.skip("Not supported in Wan VACE.") - + @pytest.mark.skip("Not supported in Wan VACE.") def test_modify_padding_mode(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_partial_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_text_lora(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_text_lora_and_scale(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_text_lora_fused(self): pass - pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") - + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_text_lora_save_load(self): pass - def test_layerwise_casting_inference_denoiser(self): - super().test_layerwise_casting_inference_denoiser() - @require_peft_version_greater("0.13.2") def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname): exclude_module_name = "vace_blocks.0.proj_out" diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 7d33415d7312..8a91a976897e 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2029,11 +2029,7 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tm @pytest.mark.parametrize( "offload_type, use_stream", - [ - ("block_level", True), - ("leaf_level", False), - ("leaf_level", True), - ], + [("block_level", True), ("leaf_level", False), ("leaf_level", True)], ) @require_torch_accelerator def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): From dae161ed260955a906760dc3a8d71b8b04a3cc5b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 17:39:55 +0530 Subject: [PATCH 12/17] up --- tests/lora/test_lora_layers_cogvideox.py | 6 ++++-- tests/lora/test_lora_layers_flux.py | 6 +++--- tests/lora/test_lora_layers_hunyuanvideo.py | 3 ++- tests/lora/test_lora_layers_sdxl.py | 6 ++++-- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index d2557a5337ae..4ba9c9516da8 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -124,8 +124,10 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - def test_lora_scale_kwargs_match_fusion(self): - super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) + def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): + super().test_lora_scale_kwargs_match_fusion( + base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3 + ) @pytest.mark.parametrize( "offload_type, use_stream", diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 556ec00391a8..1589aa4082a3 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -754,7 +754,7 @@ class TestFluxLoRAIntegration: seed = 0 @pytest.fixture(scope="function") - def pipeline(self, torch_device): + def pipeline(self): gc.collect() backend_empty_cache(torch_device) pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) @@ -873,10 +873,10 @@ class TestFluxControlLoRAIntegration: prompt = "A robot made of exotic candies and chocolates of different kinds." @pytest.fixture(scope="function") - def pipeline(self, torch_device): + def pipeline(self): gc.collect() backend_empty_cache(torch_device) - pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) try: yield pipe finally: diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index 52ee3cd9f752..8ee7db8de71d 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -37,6 +37,7 @@ require_peft_backend, require_torch_accelerator, skip_mps, + torch_device, ) @@ -207,7 +208,7 @@ class TestHunyuanVideoLoRAIntegration: seed = 0 @pytest.fixture(scope="function") - def pipeline(self, torch_device): + def pipeline(self): gc.collect() backend_empty_cache(torch_device) diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 6ee81dac3244..7b53464c88e1 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -132,7 +132,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): expected_atol=expected_atol, expected_rtol=expected_rtol ) - def test_lora_scale_kwargs_match_fusion(self): + def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -140,7 +140,9 @@ def test_lora_scale_kwargs_match_fusion(self): expected_atol = 1e-3 expected_rtol = 1e-3 - super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol) + super().test_lora_scale_kwargs_match_fusion( + base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol + ) @slow From bdc95379997cdd4cd5f720d0f73fca4b85ec46f7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 20:01:26 +0530 Subject: [PATCH 13/17] more fixtures. --- tests/lora/test_lora_layers_cogview4.py | 25 -- tests/lora/test_lora_layers_hunyuanvideo.py | 5 - tests/lora/utils.py | 312 ++++++-------------- 3 files changed, 90 insertions(+), 252 deletions(-) diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 363da9f26515..5631d278d54b 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -14,7 +14,6 @@ import sys -import numpy as np import pytest import torch from transformers import AutoTokenizer, GlmModel @@ -26,7 +25,6 @@ require_peft_backend, require_torch_accelerator, skip_mps, - torch_device, ) @@ -117,29 +115,6 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - def test_simple_inference_save_pretrained(self, tmpdirname): - """ - Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained - """ - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.save_pretrained(tmpdirname) - - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) - - images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - - assert np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), ( - "Loading from saved checkpoints should give same results." - ) - @pytest.mark.parametrize( "offload_type, use_stream", [("block_level", True), ("leaf_level", False), ("leaf_level", True)], diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index 8ee7db8de71d..37d6a61e81be 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -155,11 +155,6 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - # TODO(aryan): Fix the following test - @pytest.mark.skip("This test fails with an error I haven't been able to debug yet.") - def test_simple_inference_save_pretrained(self): - pass - @pytest.mark.skip("Not supported in HunyuanVideo.") def test_simple_inference_with_text_denoiser_block_scale(self): pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 8a91a976897e..bfb242c74df4 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -132,6 +132,14 @@ def base_pipe_output(self): def tmpdirname(self, tmp_path_factory): return tmp_path_factory.mktemp("tmp") + @pytest.fixture(scope="function") + def pipe(self): + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + return pipe + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") @@ -314,16 +322,12 @@ def test_simple_inference(self, base_pipe_output): """ assert base_pipe_output.shape == self.output_shape - def test_simple_inference_with_text_lora(self, base_pipe_output): + def test_simple_inference_with_text_lora(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -331,12 +335,9 @@ def test_simple_inference_with_text_lora(self, base_pipe_output): assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" @require_peft_version_greater("0.13.1") - def test_low_cpu_mem_usage_with_injection(self): + def test_low_cpu_mem_usage_with_injection(self, pipe): """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() if "text_encoder" in self.pipeline_class._lora_loadable_modules: inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) @@ -380,13 +381,9 @@ def test_low_cpu_mem_usage_with_injection(self): @require_peft_version_greater("0.13.1") @require_transformers_version_greater("4.45.2") - def test_low_cpu_mem_usage_with_loading(self, tmpdirname): + def test_low_cpu_mem_usage_with_loading(self, tmpdirname, pipe): """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -417,17 +414,13 @@ def test_low_cpu_mem_usage_with_loading(self, tmpdirname): "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." ) - def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): + def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -446,16 +439,12 @@ def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): "Lora + 0 scale should lead to same result as no LoRA" ) - def test_simple_inference_with_text_lora_fused(self, base_pipe_output): + def test_simple_inference_with_text_lora_fused(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -471,16 +460,12 @@ def test_simple_inference_with_text_lora_fused(self, base_pipe_output): "Fused lora should change the output" ) - def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): + def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -498,14 +483,11 @@ def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output): "Unloading lora should match the base pipe output" ) - def test_simple_inference_with_text_lora_save_load(self, tmpdirname): + def test_simple_inference_with_text_lora_save_load(self, tmpdirname, pipe): """ Tests a simple usecase where users could use saving utilities for LoRA. """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -527,13 +509,12 @@ def test_simple_inference_with_text_lora_save_load(self, tmpdirname): "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_partial_text_lora(self, base_pipe_output): + def test_simple_inference_with_partial_text_lora(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached on the text encoder with different ranks and some adapters removed and makes sure it works as expected """ - components, _, _ = self.get_dummy_components() text_lora_config = LoraConfig( r=4, rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, @@ -542,9 +523,6 @@ def test_simple_inference_with_partial_text_lora(self, base_pipe_output): init_lora_weights=False, use_dora=False, ) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -574,15 +552,11 @@ def test_simple_inference_with_partial_text_lora(self, base_pipe_output): "Removing adapters should change the output" ) - def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname): + def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname, pipe): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -606,15 +580,11 @@ def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname): "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname): + def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname, pipe): """ Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -634,17 +604,13 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname): "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output): + def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached on the text encoder + Unet + scale argument and makes sure it works as expected """ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -667,16 +633,12 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_outp "The scaling parameter has not been correctly restored!" ) - def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output): + def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -694,16 +656,12 @@ def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output): "Fused lora should change the output" ) - def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output): + def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -723,17 +681,13 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_outpu ) def test_simple_inference_with_text_denoiser_lora_unfused( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + self, pipe, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -840,15 +794,10 @@ def test_wrong_adapter_name_raises_error(self): pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_multiple_wrong_adapter_name_raises_error(self): + def test_multiple_wrong_adapter_name_raises_error(self, pipe): adapter_name = "adapter-1" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) @@ -864,16 +813,12 @@ def test_multiple_wrong_adapter_name_raises_error(self): pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output): + def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, attaches one adapter and set different weights for different blocks (i.e. block lora) """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -911,17 +856,14 @@ def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output) "output with no lora and output with lora disabled should give same results" ) - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output): + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set different weights for different blocks (i.e. block lora) """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") @@ -967,7 +909,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base with pytest.raises(ValueError): pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self, pipe): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" def updown_options(blocks_with_tf, layers_per_block, value): @@ -1024,10 +966,7 @@ def all_possible_dict_opts(unet, value): opts.append(opt) return opts - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1045,16 +984,12 @@ def all_possible_dict_opts(unet, value): del scale_dict["text_encoder_2"] pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error - def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output): + def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set/delete them """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1121,16 +1056,12 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, "output with no lora and output with lora disabled should give same results" ) - def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output): + def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1186,12 +1117,8 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_p reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", strict=False, ) - def test_lora_fuse_nan(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_lora_fuse_nan(self, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1235,17 +1162,12 @@ def test_lora_fuse_nan(self): out = pipe(**inputs)[0] assert np.isnan(out).all() - def test_get_adapters(self): + def test_get_adapters(self, pipe): """ Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet @@ -1263,15 +1185,12 @@ def test_get_adapters(self): pipe.set_adapters(["adapter-1", "adapter-2"]) assert sorted(pipe.get_active_adapters()) == ["adapter-1", "adapter-2"] - def test_get_list_adapters(self): + def test_get_list_adapters(self, pipe): """ Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() # 1. dicts_to_be_checked = {} @@ -1324,16 +1243,16 @@ def test_get_list_adapters(self): assert pipe.get_list_adapters() == dicts_to_be_checked def test_simple_inference_with_text_lora_denoiser_fused_multi( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + self, + pipe, + expected_atol: float = 1e-3, + expected_rtol: float = 1e-3, ): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet and multi-adapter case """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1443,12 +1362,8 @@ def test_lora_scale_kwargs_match_fusion( "LoRA should change the output" ) - def test_simple_inference_with_dora(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_simple_inference_with_dora(self, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1460,12 +1375,8 @@ def test_simple_inference_with_dora(self): "DoRA lora should change the output" ) - def test_missing_keys_warning(self, tmpdirname): - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_missing_keys_warning(self, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." @@ -1487,11 +1398,8 @@ def test_missing_keys_warning(self, tmpdirname): component = list({k.split(".")[0] for k in state_dict})[0] assert missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "") - def test_unexpected_keys_warning(self, tmpdirname): - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_unexpected_keys_warning(self, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) @@ -1513,16 +1421,12 @@ def test_unexpected_keys_warning(self, tmpdirname): assert ".diffusers_cat" in cap_logger.out @pytest.mark.skip("This is failing for now - need to investigate") - def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): + def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self, pipe): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -1533,29 +1437,19 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_modify_padding_mode(self): + def test_modify_padding_mode(self, pipe): def set_pad_mode(network, mode="circular"): for _, module in network.named_modules(): if isinstance(module, torch.nn.Conv2d): module.padding_mode = mode - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _pad_mode = "circular" set_pad_mode(pipe.vae, _pad_mode) set_pad_mode(pipe.unet, _pad_mode) _, _, inputs = self.get_dummy_inputs() _ = pipe(**inputs)[0] - def test_logs_info_when_no_lora_keys_found(self, base_pipe_output): - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_logs_info_when_no_lora_keys_found(self, base_pipe_output, pipe): _, _, inputs = self.get_dummy_inputs(with_generator=False) no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} @@ -1584,16 +1478,11 @@ def test_logs_info_when_no_lora_keys_found(self, base_pipe_output): ) assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname): + def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname, pipe): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) lora_scale = 0.5 @@ -1636,12 +1525,8 @@ def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname) ) @require_peft_version_greater("0.13.2") - def test_lora_B_bias(self, base_pipe_output): - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_lora_B_bias(self, base_pipe_output, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() bias_values = {} denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer for name, module in denoiser.named_modules(): @@ -1670,12 +1555,8 @@ def test_lora_B_bias(self, base_pipe_output): assert not np.allclose(base_pipe_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) - def test_correct_lora_configs_with_different_ranks(self, base_pipe_output): - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_correct_lora_configs_with_different_ranks(self, base_pipe_output, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if self.unet_kwargs is not None: @@ -1852,9 +1733,8 @@ def check_module(denoiser): pipe(**inputs, generator=torch.manual_seed(0))[0] @pytest.mark.parametrize("lora_alpha", [4, 8, 16]) - def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) - pipe = self.pipeline_class(**components) + def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) @@ -1895,10 +1775,8 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname) ) @pytest.mark.parametrize("lora_alpha", [4, 8, 16]) - def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) - pipe = self.pipeline_class(**components).to(torch_device) - + def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline( @@ -1916,11 +1794,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname) output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." - def test_lora_unload_add_adapter(self): + def test_lora_unload_add_adapter(self, pipe): """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components).to(torch_device) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline( @@ -1934,13 +1810,9 @@ def test_lora_unload_add_adapter(self): ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname): + def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname, pipe): """Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1973,15 +1845,12 @@ def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname) output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3) - def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe): from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook onload_device = torch_device offload_device = torch.device("cpu") - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, _, denoiser_lora_config = self.get_dummy_components() denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) @@ -1995,6 +1864,7 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tm components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) check_if_lora_correctly_set(denoiser) @@ -2032,19 +1902,16 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tm [("block_level", True), ("leaf_level", False), ("leaf_level", True)], ) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe): for cls in inspect.getmro(self.__class__): if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: return - self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) + self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe) @pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch") - def test_lora_loading_model_cpu_offload(self, tmpdirname): - components, _, denoiser_lora_config = self.get_dummy_components() + def test_lora_loading_model_cpu_offload(self, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) @@ -2055,6 +1922,7 @@ def test_lora_loading_model_cpu_offload(self, tmpdirname): modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.enable_model_cpu_offload(device=torch_device) From 128535cfcdf66d6233db9b127531e505dc8d7318 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 20:03:50 +0530 Subject: [PATCH 14/17] up --- tests/lora/test_lora_layers_flux.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 1589aa4082a3..c8a38f5b75c9 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -757,7 +757,9 @@ class TestFluxLoRAIntegration: def pipeline(self): gc.collect() backend_empty_cache(torch_device) - pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to( + torch_device + ) try: yield pipe finally: @@ -876,7 +878,9 @@ class TestFluxControlLoRAIntegration: def pipeline(self): gc.collect() backend_empty_cache(torch_device) - pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to( + torch_device + ) try: yield pipe finally: From f8f27891c65fc67df447168260a9769bce934ffd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 20:14:45 +0530 Subject: [PATCH 15/17] up --- tests/lora/test_lora_layers_cogvideox.py | 6 ++-- tests/lora/test_lora_layers_cogview4.py | 6 ++-- tests/lora/test_lora_layers_flux.py | 40 ++++++------------------ tests/lora/test_lora_layers_lumina2.py | 7 ++--- tests/lora/test_lora_layers_wanvace.py | 6 ++-- 5 files changed, 20 insertions(+), 45 deletions(-) diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index 4ba9c9516da8..ed3d0f0d8de5 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -131,13 +131,13 @@ def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): @pytest.mark.parametrize( "offload_type, use_stream", - [("block_level", True), ("leaf_level", False), ("leaf_level", True)], + [("block_level", True), ("leaf_level", False)], ) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe): # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 - super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) + super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe) @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale(self): diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 5631d278d54b..8b7daf9c01a6 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -117,13 +117,13 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): @pytest.mark.parametrize( "offload_type, use_stream", - [("block_level", True), ("leaf_level", False), ("leaf_level", True)], + [("block_level", True), ("leaf_level", False)], ) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe): # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 - super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) + super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe) @pytest.mark.skip("Not supported in CogView4.") def test_simple_inference_with_text_denoiser_block_scale(self): diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index c8a38f5b75c9..3defa9ea9678 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -111,11 +111,8 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_with_alpha_in_state_dict(self, tmpdirname): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_with_alpha_in_state_dict(self, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.transformer.add_adapter(denoiser_lora_config) @@ -152,11 +149,8 @@ def test_with_alpha_in_state_dict(self, tmpdirname): ) assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001) - def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, inputs = self.get_dummy_inputs(with_generator=False) # Modify the config to have a layer which won't be present in the second LoRA we will load. @@ -192,11 +186,8 @@ def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname "LoRA should lead to different results." ) - def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, inputs = self.get_dummy_inputs(with_generator=False) modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) @@ -312,12 +303,7 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_with_norm_in_state_dict(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_with_norm_in_state_dict(self, pipe): _, _, inputs = self.get_dummy_inputs(with_generator=False) logger = logging.get_logger("diffusers.loaders.lora_pipeline") @@ -346,6 +332,7 @@ def test_with_norm_in_state_dict(self): pipe.unload_lora_weights() lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert pipe.transformer._transformer_norm_layers is None assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05) assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), ( @@ -358,11 +345,8 @@ def test_with_norm_in_state_dict(self): pipe.load_lora_weights(norm_state_dict) assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out - def test_lora_parameter_expanded_shapes(self): + def test_lora_parameter_expanded_shapes(self, pipe): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -573,14 +557,10 @@ def test_fuse_expanded_lora_with_regular_lora(self): lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001) - def test_load_regular_lora(self, base_pipe_output): + def test_load_regular_lora(self, base_pipe_output, pipe): # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those # transformers include Flux Fill, Flux Control, etc. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) out_features, in_features = pipe.transformer.x_embedder.weight.shape diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index c0ee9c34e42a..b0f6ab6039f0 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -138,11 +138,8 @@ def test_simple_inference_with_text_lora_save_load(self): reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", strict=False, ) - def test_lora_fuse_nan(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_lora_fuse_nan(self, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index 95a1638d21d8..ecc5c16365c3 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -31,7 +31,6 @@ require_peft_backend, require_peft_version_greater, skip_mps, - torch_device, ) @@ -159,10 +158,9 @@ def test_simple_inference_with_text_lora_save_load(self): pass @require_peft_version_greater("0.13.2") - def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname): + def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname, pipe): exclude_module_name = "vace_blocks.0.proj_out" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components).to(torch_device) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) assert base_pipe_output.shape == self.output_shape From 4f5e9a665e924d64df04dee15bd02ca99e3fbe73 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 20:49:50 +0530 Subject: [PATCH 16/17] up --- tests/lora/test_lora_layers_cogvideox.py | 12 ++++++------ tests/lora/test_lora_layers_cogview4.py | 8 ++++---- tests/lora/test_lora_layers_hunyuanvideo.py | 8 ++++---- tests/lora/test_lora_layers_ltx_video.py | 8 ++++---- tests/lora/test_lora_layers_mochi.py | 8 ++++---- tests/lora/test_lora_layers_sdxl.py | 12 ++++++------ tests/lora/test_lora_layers_wan.py | 8 ++++---- tests/lora/test_lora_layers_wanvace.py | 8 ++++---- 8 files changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index ed3d0f0d8de5..f5d85648345a 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -118,15 +118,15 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3, pipe=pipe) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): + def test_lora_scale_kwargs_match_fusion(self, pipe, base_pipe_output): super().test_lora_scale_kwargs_match_fusion( - base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3 + pipe=pipe, base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3 ) @pytest.mark.parametrize( diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 8b7daf9c01a6..d3902730678d 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -109,11 +109,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) @pytest.mark.parametrize( "offload_type, use_stream", diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index 37d6a61e81be..1d3e3dbf6a38 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -149,11 +149,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) @pytest.mark.skip("Not supported in HunyuanVideo.") def test_simple_inference_with_text_denoiser_block_scale(self): diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 37bad941bf3b..2ffc39ef2b41 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -108,11 +108,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) @pytest.mark.skip("Not supported in LTXVideo.") def test_simple_inference_with_text_denoiser_block_scale(self): diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index f9da672732c7..9b81e220b28f 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -99,11 +99,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) @pytest.mark.skip("Not supported in Mochi.") def test_simple_inference_with_text_denoiser_block_scale(self): diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 7b53464c88e1..fd10d4eeda4e 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -108,7 +108,7 @@ def output_shape(self): def test_multiple_wrong_adapter_name_raises_error(self): super().test_multiple_wrong_adapter_name_raises_error() - def test_simple_inference_with_text_denoiser_lora_unfused(self): + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -117,10 +117,10 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): expected_rtol = 1e-3 super().test_simple_inference_with_text_denoiser_lora_unfused( - expected_atol=expected_atol, expected_rtol=expected_rtol + pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol ) - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -129,10 +129,10 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): expected_rtol = 1e-3 super().test_simple_inference_with_text_lora_denoiser_fused_multi( - expected_atol=expected_atol, expected_rtol=expected_rtol + pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol ) - def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): + def test_lora_scale_kwargs_match_fusion(self, base_pipe_output, pipe): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -141,7 +141,7 @@ def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): expected_rtol = 1e-3 super().test_lora_scale_kwargs_match_fusion( - base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol + pipe=pipe, base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol ) diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 3393521471e7..2dfe91d6d578 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -104,11 +104,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) @pytest.mark.skip("Not supported in Wan.") def test_simple_inference_with_text_denoiser_block_scale(self): diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index ecc5c16365c3..48017120ed83 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -119,11 +119,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) @pytest.mark.skip("Not supported in Wan VACE.") def test_simple_inference_with_text_denoiser_block_scale(self): From 0d3da485a0ef247566113a2e368bbe6a9d550401 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 21:00:05 +0530 Subject: [PATCH 17/17] up --- tests/lora/test_lora_layers_cogvideox.py | 4 ++-- tests/lora/test_lora_layers_sdxl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index f5d85648345a..ad2943b6816a 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -124,9 +124,9 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - def test_lora_scale_kwargs_match_fusion(self, pipe, base_pipe_output): + def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): super().test_lora_scale_kwargs_match_fusion( - pipe=pipe, base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3 + base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3 ) @pytest.mark.parametrize( diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index fd10d4eeda4e..e1bc6e8ecb73 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -132,7 +132,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol ) - def test_lora_scale_kwargs_match_fusion(self, base_pipe_output, pipe): + def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -141,7 +141,7 @@ def test_lora_scale_kwargs_match_fusion(self, base_pipe_output, pipe): expected_rtol = 1e-3 super().test_lora_scale_kwargs_match_fusion( - pipe=pipe, base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol + base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol )