Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.

Expand Down Expand Up @@ -112,6 +113,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi

[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin

## ZImageLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.ZImageLoraLoaderMixin

## KandinskyLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def text_encoder_attn_modules(text_encoder):
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
"ZImageLoraLoaderMixin",
"Flux2LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
Expand Down Expand Up @@ -130,6 +131,7 @@ def text_encoder_attn_modules(text_encoder):
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
ZImageLoraLoaderMixin,
)
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
Expand Down
118 changes: 118 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2351,3 +2351,121 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict


def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
"""
Convert non-diffusers ZImage LoRA state dict to diffusers format.
Handles:
- `diffusion_model.` prefix removal
- `lora_unet_` prefix conversion with key mapping
- `default.` prefix removal
- `.lora_down.weight`/`.lora_up.weight` → `.lora_A.weight`/`.lora_B.weight` conversion with alpha scaling
"""
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_diffusion_model:
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}

has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}

def convert_key(key: str) -> str:
# ZImage has: layers, noise_refiner, context_refiner blocks
# Keys may be like: layers_0_attention_to_q.lora_down.weight

if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""

# Protected n-grams that must keep their internal underscores
protected = {
# pairs for attention
("to", "q"),
("to", "k"),
("to", "v"),
("to", "out"),
# feed_forward
("feed", "forward"),
}

prot_by_len = {}
for ng in protected:
prot_by_len.setdefault(len(ng), set()).add(ng)

parts = base.split("_")
merged = []
i = 0
lengths_desc = sorted(prot_by_len.keys(), reverse=True)

while i < len(parts):
matched = False
for L in lengths_desc:
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
merged.append("_".join(parts[i : i + L]))
i += L
matched = True
break
if not matched:
merged.append(parts[i])
i += 1

converted_base = ".".join(merged)
return converted_base + (("." + suffix) if suffix else "")

state_dict = {convert_key(k): v for k, v in state_dict.items()}

has_default = any("default." in k for k in state_dict)
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}

converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
up_key = ".lora_up.weight"
a_key = ".lora_A.weight"
b_key = ".lora_B.weight"

has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)

if has_non_diffusers_lora_id:

def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up

for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
alpha_key = k.replace(down_key, ".alpha")

down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(down_key, up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up

# Already in diffusers format (lora_A/lora_B), just pop
elif has_diffusers_lora_id:
for k in all_keys:
if a_key in k or b_key in k:
converted_state_dict[k] = state_dict.pop(k)
elif ".alpha" in k:
state_dict.pop(k)

if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
207 changes: 207 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_non_diffusers_z_image_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
Expand Down Expand Up @@ -5085,6 +5086,212 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)


class ZImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`ZImageTransformer2DModel`]. Specific to [`ZImagePipeline`].
"""

_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME

@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
has_default = any("default." in k for k in state_dict)
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
state_dict = _convert_non_diffusers_z_image_lora_to_diffusers(state_dict)

out = (state_dict, metadata) if return_lora_metadata else state_dict
return out

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->ZImageTransformer2DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}

if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata

if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")

cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)


class Flux2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`].
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
}


Expand Down
Loading
Loading