Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 53 additions & 27 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,9 +1674,19 @@ def load_lora_weights(
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
}

transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
transformer = getattr(self, self.transformer_name, None)
text_encoder = getattr(self, self.text_encoder_name, None)

if transformer is None and text_encoder is None:
logger.warning(
"No loadable LoRA components (transformer, text_encoder) found in this pipeline. "
"Skipping LoRA weight loading. This can happen when calling `load_lora_weights` on a "
"modular sub-pipeline that does not contain the expected components."
)
return

has_param_with_expanded_shape = False
if len(transformer_lora_state_dict) > 0:
if transformer is not None and len(transformer_lora_state_dict) > 0:
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
transformer, transformer_lora_state_dict, transformer_norm_state_dict
)
Expand All @@ -1687,43 +1697,50 @@ def load_lora_weights(
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
"To get a comprehensive list of parameter names that were modified, enable debug logging."
)
if len(transformer_lora_state_dict) > 0:
if transformer is not None and len(transformer_lora_state_dict) > 0:
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer=transformer, lora_state_dict=transformer_lora_state_dict
)
for k in transformer_lora_state_dict:
state_dict.update({k: transformer_lora_state_dict[k]})

self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if transformer is not None:
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
elif len(transformer_lora_state_dict) > 0:
logger.warning(
"LoRA weights contain transformer parameters but the pipeline does not have a transformer component. "
"Skipping transformer LoRA loading."
)

if len(transformer_norm_state_dict) > 0:
if transformer is not None and len(transformer_norm_state_dict) > 0:
transformer._transformer_norm_layers = self._load_norm_into_transformer(
transformer_norm_state_dict,
transformer=transformer,
discard_original_layers=False,
)

self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if text_encoder is not None:
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=text_encoder,
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
def load_lora_into_transformer(
Expand Down Expand Up @@ -5724,9 +5741,18 @@ def load_lora_weights(
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")

transformer = getattr(self, self.transformer_name, None)
if transformer is None:
logger.warning(
f"The `{self.transformer_name}` component is not available in this pipeline. "
"Skipping LoRA weight loading. This can happen when calling `load_lora_weights` on a "
"modular sub-pipeline that does not contain the transformer component."
)
return

self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
transformer=transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
Expand Down
25 changes: 25 additions & 0 deletions tests/lora/test_lora_layers_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,28 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in Flux2.")
def test_modify_padding_mode(self):
pass

def test_load_lora_weights_warns_when_transformer_missing(self):
"""Regression test for https://github.com/huggingface/diffusers/issues/13487"""
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)

# Simulate a modular sub-pipeline that does not have the transformer component.
pipe.transformer = None

# Create a dummy LoRA state dict with valid keys.
dummy_lora_state_dict = {
"transformer.lora_A.weight": torch.randn(4, 4),
"transformer.lora_B.weight": torch.randn(4, 4),
}

from diffusers.utils import logging

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
from ..testing_utils import CaptureLogger

with CaptureLogger(logger) as cl:
pipe.load_lora_weights(dummy_lora_state_dict)

self.assertIn("not available", cl.out)
Loading