Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 72 additions & 24 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
Expand All @@ -472,7 +476,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -891,7 +895,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
Expand All @@ -912,7 +920,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class SD3LoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -1290,7 +1298,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
Expand All @@ -1312,7 +1324,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class FluxLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -1828,7 +1840,11 @@ def fuse_lora(
)

super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
Expand All @@ -1849,7 +1865,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)

super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)

# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False):
Expand Down Expand Up @@ -2548,7 +2564,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
Expand All @@ -2566,7 +2586,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class Mochi1LoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -2852,7 +2872,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -2871,7 +2895,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class LTXVideoLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -3157,7 +3181,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -3176,7 +3204,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class SanaLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -3462,7 +3490,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -3481,7 +3513,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -3770,7 +3802,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -3789,7 +3825,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class Lumina2LoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -4079,7 +4115,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
Expand All @@ -4098,7 +4138,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class WanLoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -4384,7 +4424,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -4403,7 +4447,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class CogView4LoraLoaderMixin(LoraBaseMixin):
Expand Down Expand Up @@ -4689,7 +4733,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
Expand All @@ -4708,7 +4756,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)


class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
Expand Down
Loading