Skip to content

Commit dee9074

Browse files
Fix AuraFlow LoRA tests by applying to the right denoiser layers.
Co-authored-by: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com>
1 parent 5091757 commit dee9074

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tests/lora/test_lora_layers_auraflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
7474
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
7575
text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5"
7676
text_encoder_target_modules = ["q", "k", "v", "o"]
77+
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
7778

7879
@property
7980
def output_shape(self):

tests/lora/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class PeftLoraLoaderMixinTests:
104104
vae_kwargs = None
105105

106106
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
107+
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
107108

108109
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
109110
if self.unet_kwargs and self.transformer_kwargs:
@@ -157,7 +158,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False):
157158
denoiser_lora_config = LoraConfig(
158159
r=rank,
159160
lora_alpha=rank,
160-
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
161+
target_modules=self.denoiser_target_modules,
161162
init_lora_weights=False,
162163
use_dora=use_dora,
163164
)
@@ -2040,7 +2041,7 @@ def test_lora_B_bias(self):
20402041
bias_values = {}
20412042
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
20422043
for name, module in denoiser.named_modules():
2043-
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
2044+
if any(k in name for k in self.denoiser_target_modules):
20442045
if module.bias is not None:
20452046
bias_values[name] = module.bias.data.clone()
20462047

0 commit comments

Comments
 (0)