From f7e7091afda87780e48927c82aa85c6c3402ccdb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 21 Aug 2024 16:51:00 +0530 Subject: [PATCH 1/5] support parsing alpha from a flux lora state dict. --- src/diffusers/loaders/lora_pipeline.py | 30 ++++++++++---- tests/lora/test_lora_layers_flux.py | 57 +++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f612cc0c6e53..4de75ceb6845 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1489,7 +1489,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1577,7 +1576,15 @@ def lora_state_dict( allow_pickle=allow_pickle, ) - return state_dict + # For state dicts like + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA + keys = list(state_dict.keys()) + network_alphas = {} + for k in keys: + if "alpha" in k: + network_alphas[k] = state_dict.pop(k) + + return state_dict, network_alphas def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -1611,7 +1618,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: @@ -1619,6 +1626,7 @@ def load_lora_weights( self.load_lora_into_transformer( state_dict, + network_alphas=network_alphas, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, @@ -1628,7 +1636,7 @@ def load_lora_weights( if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, - network_alphas=None, + network_alphas=network_alphas, text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, @@ -1637,8 +1645,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1647,6 +1654,10 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). transformer (`SD3Transformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): @@ -1678,7 +1689,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if network_alphas is not None and len(network_alphas) >= 1: + prefix = cls.transformer_name + alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index c0f0684ac4de..6e91807114b8 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -12,19 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import sys +import tempfile import unittest +import numpy as np +import safetensors.torch import torch +from peft.utils import get_peft_model_state_dict from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device sys.path.append(".") -from utils import PeftLoraLoaderMixinTests # noqa: E402 +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 @require_peft_backend @@ -90,3 +95,51 @@ def get_dummy_inputs(self, with_generator=True): pipeline_inputs.update({"generator": generator}) return noise, input_ids, pipeline_inputs + + def test_with_alpha_in_state_dict(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == self.output_shape) + + pipe.transformer.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + # modify the state dict to have alpha values following + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors + state_dict_with_alpha = safetensors.torch.load_file( + os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") + ) + alpha_dict = {} + for k, v in state_dict_with_alpha.items(): + # only do for `transformer` and for the k projections -- should be enough to test. + if "transformer" in k and "to_k" in k and "lora_A" in k: + alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) + state_dict_with_alpha.update(alpha_dict) + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + pipe.unload_lora_weights() + pipe.load_lora_weights(state_dict_with_alpha) + images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) From 298354363b064c00be9ba3d36247da07f846dce9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 21 Aug 2024 17:01:47 +0530 Subject: [PATCH 2/5] conditional import. --- tests/lora/test_lora_layers_flux.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 6e91807114b8..a82d37665fc5 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -20,13 +20,15 @@ import numpy as np import safetensors.torch import torch -from peft.utils import get_peft_model_state_dict from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device +from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device +if is_peft_available(): + from peft.utils import get_peft_model_state_dict + sys.path.append(".") from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 From c8bca51c06a477d693fb65bd9862e00a67bba5a8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 21 Aug 2024 17:13:20 +0530 Subject: [PATCH 3/5] fix breaking changes. --- src/diffusers/loaders/lora_pipeline.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4de75ceb6845..3a207bcd76a4 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1492,6 +1492,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + return_alphas: bool = False, **kwargs, ): r""" @@ -1584,7 +1585,10 @@ def lora_state_dict( if "alpha" in k: network_alphas[k] = state_dict.pop(k) - return state_dict, network_alphas + if return_alphas: + return state_dict, network_alphas + else: + return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -1618,7 +1622,9 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs + ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: From 67fc491480ab677c4e9a04f8280f47ade5a2fe28 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 21 Aug 2024 21:09:41 +0530 Subject: [PATCH 4/5] safeguard alpha. --- src/diffusers/loaders/lora_pipeline.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3a207bcd76a4..1b04e806e2a0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1583,7 +1583,15 @@ def lora_state_dict( network_alphas = {} for k in keys: if "alpha" in k: - network_alphas[k] = state_dict.pop(k) + alpha_value = state_dict.get(k) + if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( + alpha_value, float + ): + [k] = state_dict.pop(k) + else: + raise ValueError( + f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." + ) if return_alphas: return state_dict, network_alphas From 27405bce9ec835a6cd727c3fc24689ddb906a993 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 21 Aug 2024 21:18:12 +0530 Subject: [PATCH 5/5] fix --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1b04e806e2a0..d57de182db10 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1587,7 +1587,7 @@ def lora_state_dict( if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( alpha_value, float ): - [k] = state_dict.pop(k) + network_alphas[k] = state_dict.pop(k) else: raise ValueError( f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."