Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):

return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_linear_layer is None:
return

Expand All @@ -135,6 +135,14 @@ def _fuse_lora(self, lora_scale=1.0):
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank

fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: do you know if usually the nan happens in

w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

or in

(lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

cc @BenjaminBossan - if it happens in the second case we could remove the copy in the PEFT PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know really. I have not tested it, but my intuition is that checking for NaN values can be quite expensive anyways when on GPU. So no matter what we have a time overhead and can't set safe_fusing as a default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sounds great!


if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
Comment on lines +139 to +144
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crazy, honestly.


self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
Expand Down Expand Up @@ -672,13 +680,14 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

def fuse_lora(self, lora_scale=1.0):
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
self.lora_scale = lora_scale
self._safe_fusing = safe_fusing
self.apply(self._fuse_lora_apply)

def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale)
module._fuse_lora(self.lora_scale, self._safe_fusing)

def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)
Expand Down Expand Up @@ -2086,7 +2095,13 @@ def unload_lora_weights(self):
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
def fuse_lora(
self,
fuse_unet: bool = True,
fuse_text_encoder: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.

Expand All @@ -2103,6 +2118,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we defaulting to False?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe because the torch.isnan(fused_weight).any().item() adds an overhead to the fusion operation. For a large tensor this might add a considerable slowdown (one needs to benchmark though), so to be on the safe zone I would also default it to False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Thanks for explaining!

Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
"""
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
Expand All @@ -2112,12 +2129,13 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
)

if fuse_unet:
self.unet.fuse_lora(lora_scale)
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)

if self.use_peft_backend:
from peft.tuners.tuners_utils import BaseTunerLayer

def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
# TODO(Patrick, Younes): enable "safe" fusing
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
Expand All @@ -2129,24 +2147,24 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
if version.parse(__version__) > version.parse("0.23"):
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)

def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)

if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder, lora_scale)
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)

def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
Expand Down
20 changes: 18 additions & 2 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_layer is None:
return

Expand All @@ -128,6 +128,14 @@ def _fuse_lora(self, lora_scale=1.0):
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((w_orig.shape))
fused_weight = w_orig + (lora_scale * fusion)

if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)

self.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
Expand Down Expand Up @@ -179,7 +187,7 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_layer is None:
return

Expand All @@ -193,6 +201,14 @@ def _fuse_lora(self, lora_scale=1.0):
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank

fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)

self.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
Expand Down
41 changes: 41 additions & 0 deletions tests/lora/test_lora_layers_old_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,47 @@ def test_load_lora_locally_safetensors(self):

sd_pipe.unload_lora_weights()

def test_lora_fuse_nan(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)

# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)

with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))

# corrupt one LoRA weight with `inf` values
with torch.no_grad():
sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing the different possibilities here:
#5316 (comment)

@BenjaminBossan

"inf"
)

# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
sd_pipe.fuse_lora(safe_fusing=True)

# without we should not see an error, but every image will be black
sd_pipe.fuse_lora(safe_fusing=False)

out = sd_pipe("test", num_inference_steps=2, output_type="np").images

assert np.isnan(out).all()

def test_lora_fusion(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
Expand Down