Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,24 @@
from torch import nn

from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging, scale_lora_layers
from ..utils import logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
scale_lora_layers(text_encoder, weight=lora_scale)
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale


class LoRALinearLayer(nn.Module):
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -304,7 +304,10 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -429,6 +432,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

return prompt_embeds, negative_prompt_embeds

def run_safety_checker(self, image, device, dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
from ...utils import (
PIL_INTERPOLATION,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -302,7 +309,10 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -427,6 +437,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

return prompt_embeds, negative_prompt_embeds

def run_safety_checker(self, image, device, dtype):
Expand Down
15 changes: 9 additions & 6 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
logging,
replace_example_docstring,
)
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
Expand Down Expand Up @@ -291,7 +287,10 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -416,6 +415,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -315,7 +317,10 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -440,6 +445,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
logging,
replace_example_docstring,
)
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
Expand Down Expand Up @@ -442,7 +438,10 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -567,6 +566,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
is_invisible_watermark_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -314,8 +316,12 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
scale_lora_layers(self.text_encoder_2, lora_scale)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down Expand Up @@ -452,6 +458,11 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
Expand Down
18 changes: 12 additions & 6 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils import logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
Expand Down Expand Up @@ -288,8 +285,12 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
scale_lora_layers(self.text_encoder_2, lora_scale)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down Expand Up @@ -426,6 +427,11 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from ...utils import (
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -326,8 +328,12 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
scale_lora_layers(self.text_encoder_2, lora_scale)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down Expand Up @@ -464,6 +470,11 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput
Expand Down Expand Up @@ -308,7 +308,10 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -433,6 +436,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
logging,
replace_example_docstring,
)
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput
Expand Down Expand Up @@ -301,7 +297,10 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -426,6 +425,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

return prompt_embeds, negative_prompt_embeds

def run_safety_checker(self, image, device, dtype):
Expand Down
Loading