-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix loading broken LoRAs that could give NaN #5316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
60f99bc
eff91c2
7def112
0506cc0
0e596ec
22e0c2f
f52f730
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,7 +121,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): | |
|
|
||
| return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) | ||
|
|
||
| def _fuse_lora(self, lora_scale=1.0): | ||
| def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): | ||
| if self.lora_linear_layer is None: | ||
| return | ||
|
|
||
|
|
@@ -135,6 +135,14 @@ def _fuse_lora(self, lora_scale=1.0): | |
| w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank | ||
|
|
||
| fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) | ||
|
|
||
| if safe_fusing and torch.isnan(fused_weight).any().item(): | ||
| raise ValueError( | ||
| "This LoRA weight seems to be broken. " | ||
| f"Encountered NaN values when trying to fuse LoRA weights for {self}." | ||
| "LoRA weights will not be fused." | ||
| ) | ||
|
Comment on lines
+139
to
+144
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Crazy, honestly. |
||
|
|
||
| self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) | ||
|
|
||
| # we can drop the lora layer now | ||
|
|
@@ -672,13 +680,14 @@ def save_function(weights, filename): | |
| save_function(state_dict, os.path.join(save_directory, weight_name)) | ||
| logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") | ||
|
|
||
| def fuse_lora(self, lora_scale=1.0): | ||
| def fuse_lora(self, lora_scale=1.0, safe_fusing=False): | ||
| self.lora_scale = lora_scale | ||
| self._safe_fusing = safe_fusing | ||
| self.apply(self._fuse_lora_apply) | ||
|
|
||
| def _fuse_lora_apply(self, module): | ||
| if hasattr(module, "_fuse_lora"): | ||
| module._fuse_lora(self.lora_scale) | ||
| module._fuse_lora(self.lora_scale, self._safe_fusing) | ||
|
|
||
| def unfuse_lora(self): | ||
| self.apply(self._unfuse_lora_apply) | ||
|
|
@@ -2086,7 +2095,13 @@ def unload_lora_weights(self): | |
| # Safe to call the following regardless of LoRA. | ||
| self._remove_text_encoder_monkey_patch() | ||
|
|
||
| def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): | ||
| def fuse_lora( | ||
| self, | ||
| fuse_unet: bool = True, | ||
| fuse_text_encoder: bool = True, | ||
| lora_scale: float = 1.0, | ||
| safe_fusing: bool = False, | ||
| ): | ||
| r""" | ||
| Fuses the LoRA parameters into the original parameters of the corresponding blocks. | ||
|
|
||
|
|
@@ -2103,6 +2118,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora | |
| LoRA parameters then it won't have any effect. | ||
| lora_scale (`float`, defaults to 1.0): | ||
| Controls how much to influence the outputs with the LoRA parameters. | ||
| safe_fusing (`bool`, defaults to `False`): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we defaulting to False?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe because the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense! Thanks for explaining! |
||
| Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. | ||
| """ | ||
| if fuse_unet or fuse_text_encoder: | ||
| self.num_fused_loras += 1 | ||
|
|
@@ -2112,12 +2129,13 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora | |
| ) | ||
|
|
||
| if fuse_unet: | ||
| self.unet.fuse_lora(lora_scale) | ||
| self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) | ||
|
|
||
| if self.use_peft_backend: | ||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
| def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): | ||
| def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False): | ||
| # TODO(Patrick, Younes): enable "safe" fusing | ||
| for module in text_encoder.modules(): | ||
| if isinstance(module, BaseTunerLayer): | ||
| if lora_scale != 1.0: | ||
|
|
@@ -2129,24 +2147,24 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): | |
| if version.parse(__version__) > version.parse("0.23"): | ||
| deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE) | ||
|
|
||
| def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): | ||
| def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False): | ||
| for _, attn_module in text_encoder_attn_modules(text_encoder): | ||
| if isinstance(attn_module.q_proj, PatchedLoraProjection): | ||
| attn_module.q_proj._fuse_lora(lora_scale) | ||
| attn_module.k_proj._fuse_lora(lora_scale) | ||
| attn_module.v_proj._fuse_lora(lora_scale) | ||
| attn_module.out_proj._fuse_lora(lora_scale) | ||
| attn_module.q_proj._fuse_lora(lora_scale, safe_fusing) | ||
| attn_module.k_proj._fuse_lora(lora_scale, safe_fusing) | ||
| attn_module.v_proj._fuse_lora(lora_scale, safe_fusing) | ||
| attn_module.out_proj._fuse_lora(lora_scale, safe_fusing) | ||
|
|
||
| for _, mlp_module in text_encoder_mlp_modules(text_encoder): | ||
| if isinstance(mlp_module.fc1, PatchedLoraProjection): | ||
| mlp_module.fc1._fuse_lora(lora_scale) | ||
| mlp_module.fc2._fuse_lora(lora_scale) | ||
| mlp_module.fc1._fuse_lora(lora_scale, safe_fusing) | ||
| mlp_module.fc2._fuse_lora(lora_scale, safe_fusing) | ||
|
|
||
| if fuse_text_encoder: | ||
| if hasattr(self, "text_encoder"): | ||
| fuse_text_encoder_lora(self.text_encoder, lora_scale) | ||
| fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing) | ||
| if hasattr(self, "text_encoder_2"): | ||
| fuse_text_encoder_lora(self.text_encoder_2, lora_scale) | ||
| fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing) | ||
|
|
||
| def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): | ||
| r""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1028,6 +1028,47 @@ def test_load_lora_locally_safetensors(self): | |
|
|
||
| sd_pipe.unload_lora_weights() | ||
|
|
||
| def test_lora_fuse_nan(self): | ||
| pipeline_components, lora_components = self.get_dummy_components() | ||
| sd_pipe = StableDiffusionXLPipeline(**pipeline_components) | ||
| sd_pipe = sd_pipe.to(torch_device) | ||
| sd_pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| # Emulate training. | ||
| set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) | ||
| set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) | ||
| set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdirname: | ||
| StableDiffusionXLPipeline.save_lora_weights( | ||
| save_directory=tmpdirname, | ||
| unet_lora_layers=lora_components["unet_lora_layers"], | ||
| text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], | ||
| text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], | ||
| safe_serialization=True, | ||
| ) | ||
| self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) | ||
| sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) | ||
|
|
||
| # corrupt one LoRA weight with `inf` values | ||
| with torch.no_grad(): | ||
| sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Testing the different possibilities here: |
||
| "inf" | ||
| ) | ||
|
|
||
| # with `safe_fusing=True` we should see an Error | ||
| with self.assertRaises(ValueError): | ||
| sd_pipe.fuse_lora(safe_fusing=True) | ||
|
|
||
| # without we should not see an error, but every image will be black | ||
| sd_pipe.fuse_lora(safe_fusing=False) | ||
|
|
||
| out = sd_pipe("test", num_inference_steps=2, output_type="np").images | ||
|
|
||
| assert np.isnan(out).all() | ||
|
|
||
| def test_lora_fusion(self): | ||
| pipeline_components, lora_components = self.get_dummy_components() | ||
| sd_pipe = StableDiffusionXLPipeline(**pipeline_components) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: do you know if usually the
nanhappens inor in
cc @BenjaminBossan - if it happens in the second case we could remove the copy in the PEFT PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know really. I have not tested it, but my intuition is that checking for NaN values can be quite expensive anyways when on GPU. So no matter what we have a time overhead and can't set
safe_fusingas a default.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK sounds great!