From 18190001514cfc9233fd9859906e180cd43c810b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 9 Aug 2024 18:41:56 +0530 Subject: [PATCH 01/16] feat: add non-breaking support to serialize metadata in loras. --- src/diffusers/loaders/lora_base.py | 43 ++++- src/diffusers/loaders/lora_pipeline.py | 209 ++++++++++++++++++++++--- src/diffusers/loaders/unet.py | 12 +- src/diffusers/utils/peft_utils.py | 10 +- tests/lora/test_lora_layers_flux.py | 3 +- tests/lora/utils.py | 88 ++++++++++- 6 files changed, 327 insertions(+), 38 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 4b963270427b..1032a9aa7dd6 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -14,6 +14,7 @@ import copy import inspect +import json import os from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -26,6 +27,7 @@ from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( + SAFETENSORS_FILE_EXTENSION, USE_PEFT_BACKEND, _get_model_file, delete_adapter_layers, @@ -44,6 +46,7 @@ from transformers import PreTrainedModel if is_peft_available(): + from peft import LoraConfig from peft.tuners.tuners_utils import BaseTunerLayer if is_accelerate_available(): @@ -252,6 +255,7 @@ def _fetch_state_dict( from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE model_file = None + metadata = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( @@ -280,6 +284,8 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") + with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: + metadata = f.metadata() except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -305,10 +311,14 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = load_state_dict(model_file) + file_extension = os.path.basename(model_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: + with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: + metadata = f.metadata() else: state_dict = pretrained_model_name_or_path_or_dict - return state_dict + return state_dict, metadata @classmethod def _best_guess_weight_name( @@ -717,9 +727,13 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, + metadata=None, ): from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0: + raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.") + if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -727,8 +741,12 @@ def write_lora_layers( if save_function is None: if safe_serialization: - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + def save_function(weights, filename, metadata): + if metadata is None: + metadata = {"format": "pt"} + elif len(metadata) > 0: + metadata.update({"format": "pt"}) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save @@ -742,7 +760,10 @@ def save_function(weights, filename): weight_name = LORA_WEIGHT_NAME save_path = Path(save_directory, weight_name).as_posix() - save_function(state_dict, save_path) + if save_function != torch.save: + save_function(state_dict, save_path, metadata) + else: + save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") @property @@ -750,3 +771,17 @@ def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + + @staticmethod + def pack_metadata(config, prefix): + local_metadata = {} + if config is not None: + if isinstance(config, LoraConfig): + config = config.to_dict() + for key, value in config.items(): + if isinstance(value, set): + config[key] = list(value) + + config_as_string = json.dumps(config, indent=2, sort_keys=True) + local_metadata[prefix] = config_as_string + return local_metadata diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f612cc0c6e53..0ece7eefc9bb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os from typing import Callable, Dict, List, Optional, Union @@ -92,7 +93,9 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas, metadata = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True + ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: @@ -104,6 +107,7 @@ def load_lora_weights( unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, _pipeline=self, + config=metadata, ) self.load_lora_into_text_encoder( state_dict, @@ -114,6 +118,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + config=metadata, ) @classmethod @@ -168,6 +173,7 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_metadata (`bool`) """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -181,6 +187,7 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_metadata = kwargs.pop("return_metadata", False) allow_pickle = False if use_safetensors is None: @@ -192,7 +199,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict, metadata = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -224,10 +231,13 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + if return_metadata: + return state_dict, network_alphas, metadata + else: + return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -245,6 +255,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -258,7 +269,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + config=config, + _pipeline=_pipeline, ) @classmethod @@ -270,6 +285,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -291,6 +307,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -303,6 +320,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -341,7 +361,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -385,6 +407,8 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + unet_lora_config: dict = None, + text_encoder_lora_config: dict = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -401,6 +425,8 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + unet_lora_config: Dict + text_encoder_lora_config: Dict is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -413,19 +439,27 @@ def save_lora_weights( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} + metadata = {} if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") if unet_lora_layers: state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + if unet_lora_config: + unet_metadata = cls.pack_metadata(unet_lora_config, cls.unet_name) + metadata.update(unet_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if text_encoder_lora_config: + te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name) + metadata.update(te_metadata) # Save the model cls.write_lora_layers( state_dict=state_dict, + metadata=metadata, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, @@ -550,9 +584,10 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, + return_metadata=True, **kwargs, ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) @@ -560,7 +595,12 @@ def load_lora_weights( raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( - state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self + state_dict, + network_alphas=network_alphas, + unet=self.unet, + adapter_name=adapter_name, + config=metadata, + _pipeline=self, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -571,6 +611,7 @@ def load_lora_weights( prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -583,6 +624,7 @@ def load_lora_weights( prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -639,6 +681,7 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_metadata (`bool`) """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -652,6 +695,7 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_metadata = kwargs.pop("return_metadata", False) allow_pickle = False if use_safetensors is None: @@ -663,7 +707,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict, metadata = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -695,11 +739,14 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + if return_metadata: + return state_dict, network_alphas, metadata + else: + return state_dict, network_alphas @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -717,6 +764,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -730,7 +778,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + config=config, + _pipeline=_pipeline, ) @classmethod @@ -743,6 +795,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -764,6 +817,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -776,6 +830,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -814,7 +871,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -859,6 +918,9 @@ def save_lora_weights( unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + unet_lora_config: dict = None, + text_encoder_lora_config: dict = None, + text_encoder_2_lora_config: dict = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -878,6 +940,9 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + unet_lora_config (`dict`): + text_encoder_lora_config (`dict`): + text_encoder_2_lora_config (`dict`): is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -890,6 +955,7 @@ def save_lora_weights( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} + metadata = {} if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -898,15 +964,25 @@ def save_lora_weights( if unet_lora_layers: state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) + if unet_lora_config is not None: + unet_metadata = cls.pack_metadata(unet_lora_config, "unet") + metadata.update(unet_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + if text_encoder_lora_config is not None: + te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder") + metadata.update(te_metadata) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if text_encoder_2_lora_config is not None: + te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2") + metadata.update(te2_metadata) cls.write_lora_layers( state_dict=state_dict, + metadata=metadata, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, @@ -1041,6 +1117,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_metadata (`bool`): """ # Load the main state dict first which has the LoRA layers for either of @@ -1054,6 +1131,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_metadata = kwargs.pop("return_metadata", False) allow_pickle = False if use_safetensors is None: @@ -1065,7 +1143,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict, metadata = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1080,7 +1158,11 @@ def lora_state_dict( allow_pickle=allow_pickle, ) - return state_dict + # Otherwise, this would be a breaking change. + if return_metadata: + return state_dict, metadata + else: + return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -1114,7 +1196,9 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, metadata = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True + ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: @@ -1124,6 +1208,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -1136,6 +1221,7 @@ def load_lora_weights( prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -1148,11 +1234,12 @@ def load_lora_weights( prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @classmethod - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1166,6 +1253,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`): """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict @@ -1192,7 +1280,11 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if config is not None and isinstance(config, dict) and len(config) > 0: + config = json.loads(config[cls.transformer_name]) + lora_config_kwargs = get_peft_kwargs( + rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( @@ -1239,6 +1331,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -1260,6 +1353,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1272,6 +1366,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -1310,7 +1407,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -1355,6 +1454,9 @@ def save_lora_weights( transformer_lora_layers: Dict[str, torch.nn.Module] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + transformer_lora_config: dict = None, + text_encoder_lora_config: dict = None, + text_encoder_2_lora_config: dict = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -1374,6 +1476,9 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + transformer_lora_config (`dict`): + text_encoder_lora_config (`dict`): + text_encoder_2_lora_config (`dict`): is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1386,6 +1491,7 @@ def save_lora_weights( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} + metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -1394,16 +1500,26 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_config is not None: + transformer_metadata = cls.pack_metadata(transformer_lora_config, cls.transformer_name) + metadata.update(transformer_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + if text_encoder_lora_config is not None: + te_config = cls.pack_metadata(text_encoder_lora_config, "text_encoder") + metadata.update(te_config) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if text_encoder_2_lora_config is not None: + te2_config = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2") + metadata.update(te2_config) # Save the model cls.write_lora_layers( state_dict=state_dict, + metadata=metadata, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, @@ -1538,6 +1654,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_metadata (`bool`): """ # Load the main state dict first which has the LoRA layers for either of @@ -1551,6 +1668,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_metadata = kwargs.pop("return_metadata", False) allow_pickle = False if use_safetensors is None: @@ -1562,7 +1680,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict, metadata = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1577,7 +1695,11 @@ def lora_state_dict( allow_pickle=allow_pickle, ) - return state_dict + # Otherwise, this would be a breaking change. + if return_metadata: + return state_dict, metadata + else: + return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -1611,7 +1733,9 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, metadata = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True + ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: @@ -1621,6 +1745,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -1633,12 +1758,13 @@ def load_lora_weights( prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1652,6 +1778,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`): """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict @@ -1678,7 +1805,11 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if config is not None and isinstance(config, dict) and len(config) > 0: + config = json.loads(config[cls.transformer_name]) + lora_config_kwargs = get_peft_kwargs( + rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( @@ -1725,6 +1856,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -1746,6 +1878,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1758,6 +1891,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -1796,7 +1932,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -1841,6 +1979,8 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_config: dict = None, + text_encoder_lora_config: dict = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -1857,6 +1997,8 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + transformer_lora_config: Dict + text_encoder_lora_config: Dict is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1869,19 +2011,27 @@ def save_lora_weights( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} + metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_config: + transformer_metadata = cls.pack_metadata(transformer_lora_config, cls.transformer_name) + metadata.update(transformer_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if text_encoder_lora_config: + te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name) + metadata.update(te_metadata) # Save the model cls.write_lora_layers( state_dict=state_dict, + metadata=metadata, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, @@ -2051,6 +2201,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -2072,6 +2223,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2084,6 +2236,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -2122,7 +2277,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 32ace77b6224..3e67c90137a0 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os from collections import defaultdict from contextlib import nullcontext @@ -115,6 +116,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict `default_{i}` where i is the total number of adapters being loaded. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + config: (`dict`, *optional*) Example: @@ -143,6 +145,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict _pipeline = kwargs.pop("_pipeline", None) network_alphas = kwargs.pop("network_alphas", None) allow_pickle = False + config = kwargs.pop("config", None) if use_safetensors is None: use_safetensors = True @@ -208,6 +211,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict unet_identifier_key=self.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, + config=config, _pipeline=_pipeline, ) else: @@ -268,7 +272,7 @@ def _process_custom_diffusion(self, state_dict): return attn_processors - def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): + def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, config=None): # This method does the following things: # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # format. For legacy format no filtering is applied. @@ -316,7 +320,11 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if config is not None and isinstance(config, dict) and len(config) > 0: + config = json.loads(config["unet"]) + print(f"{config=}") + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index ca55192ff7ae..4889aed7f8e3 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -147,11 +147,17 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] + # Try to retrive config. + alpha_retrieved = False + if config is not None: + lora_alpha = config["lora_alpha"] + alpha_retrieved = True + if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] @@ -160,7 +166,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} - if network_alpha_dict is not None and len(network_alpha_dict) > 0: + if not alpha_retrieved and network_alpha_dict is not None and len(network_alpha_dict) > 0: if len(set(network_alpha_dict.values())) > 1: # get the alpha occuring the most number of times lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index c0f0684ac4de..b08fceee3937 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -19,7 +19,7 @@ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device sys.path.append(".") @@ -28,6 +28,7 @@ @require_peft_backend +@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = FluxPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler() diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 283b9f534766..b6f87349c3a6 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -87,7 +87,7 @@ class PeftLoraLoaderMixinTests: transformer_kwargs = None vae_kwargs = None - def get_dummy_components(self, scheduler_cls=None, use_dora=False): + def get_dummy_components(self, scheduler_cls=None, use_dora=False, alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") if self.has_two_text_encoders and self.has_three_text_encoders: @@ -95,6 +95,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls rank = 4 + alpha = rank if alpha is None else alpha torch.manual_seed(0) if self.unet_kwargs is not None: @@ -120,7 +121,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): text_lora_config = LoraConfig( r=rank, - lora_alpha=rank, + lora_alpha=alpha, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False, use_dora=use_dora, @@ -128,7 +129,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): denoiser_lora_config = LoraConfig( r=rank, - lora_alpha=rank, + lora_alpha=alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False, use_dora=use_dora, @@ -1752,3 +1753,84 @@ def set_pad_mode(network, mode="circular"): _, _, inputs = self.get_dummy_inputs() _ = pipe(**inputs).images + + def test_if_lora_alpha_is_correctly_parsed(self): + lora_alpha = 8 + + scheduler_class = FlowMatchEulerDiscreteScheduler if self.uses_flow_matching else DDIMScheduler + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + alpha=lora_alpha, scheduler_cls=scheduler_class + ) + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + + # Inference works? + _ = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + denoiser = pipe.unet if self.unet_kwargs else pipe.transformer + denoiser_state_dict = get_peft_model_state_dict(denoiser) + denoiser_lora_config = denoiser.peft_config["default"] + + text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) + text_encoder_lora_config = pipe.text_encoder.peft_config["default"] + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"] + + saving_kwargs = { + "save_directory": tmpdirname, + "text_encoder_lora_layers": text_encoder_state_dict, + "text_encoder_lora_config": text_encoder_lora_config, + } + if self.unet_kwargs is not None: + saving_kwargs.update( + {"unet_lora_layers": denoiser_state_dict, "unet_lora_config": denoiser_lora_config} + ) + else: + saving_kwargs.update( + {"transformer_lora_layers": denoiser_state_dict, "transformer_lora_config": denoiser_lora_config} + ) + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + saving_kwargs.update( + { + "text_encoder_2_lora_layers": text_encoder_2_state_dict, + "text_encoder_2_lora_config": text_encoder_2_lora_config, + } + ) + + self.pipeline_class.save_lora_weights(**saving_kwargs) + loaded_pipe = self.pipeline_class(**components) + loaded_pipe.load_lora_weights(tmpdirname) + + # Inference works? + _ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images + + denoiser_loaded = pipe.unet if self.unet_kwargs is not None else pipe.transformer + assert ( + denoiser_loaded.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for UNet." + assert ( + loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder." + if self.has_two_text_encoders or self.has_three_text_encoders: + assert ( + loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder 2." From 2b9f77e2d39ab60cee2c3ca859cfd193b707c15b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 9 Aug 2024 19:58:29 +0530 Subject: [PATCH 02/16] fix flux test --- tests/lora/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b6f87349c3a6..fd9c15505438 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1831,6 +1831,7 @@ def test_if_lora_alpha_is_correctly_parsed(self): loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha ), "LoRA alpha not correctly loaded for text encoder." if self.has_two_text_encoders or self.has_three_text_encoders: - assert ( - loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha - ), "LoRA alpha not correctly loaded for text encoder 2." + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert ( + loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder 2." From fb9e86d8065c507992101ffed98ced8564f7c3fc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 9 Aug 2024 20:04:01 +0530 Subject: [PATCH 03/16] Apply suggestions from code review Co-authored-by: Benjamin Bossan --- src/diffusers/loaders/unet.py | 1 - src/diffusers/utils/peft_utils.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 3e67c90137a0..9caaf2440bde 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -322,7 +322,6 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter if config is not None and isinstance(config, dict) and len(config) > 0: config = json.loads(config["unet"]) - print(f"{config=}") lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) if "use_dora" in lora_config_kwargs: diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 4889aed7f8e3..e74c19048dea 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -152,7 +152,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] - # Try to retrive config. + # Try to retrieve config. alpha_retrieved = False if config is not None: lora_alpha = config["lora_alpha"] From 79aff1d82b59c854765f5b52df34e971ccce7a9c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 9 Aug 2024 20:15:47 +0530 Subject: [PATCH 04/16] fix lora_alpha --- src/diffusers/utils/peft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index e74c19048dea..c616aab2162f 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -155,7 +155,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, # Try to retrieve config. alpha_retrieved = False if config is not None: - lora_alpha = config["lora_alpha"] + lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha alpha_retrieved = True if len(set(rank_dict.values())) > 1: From 733e1d92593ebd7b2f7020039df4487ce31f846f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 08:10:29 +0530 Subject: [PATCH 05/16] documentation --- src/diffusers/loaders/lora_base.py | 28 ++++++------- src/diffusers/loaders/lora_pipeline.py | 56 +++++++++++++++++--------- 2 files changed, 50 insertions(+), 34 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 1032a9aa7dd6..4e37c3e0054e 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -719,6 +719,20 @@ def pack_weights(layers, prefix): layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} return layers_state_dict + @staticmethod + def pack_metadata(config, prefix): + local_metadata = {} + if config is not None: + if isinstance(config, LoraConfig): + config = config.to_dict() + for key, value in config.items(): + if isinstance(value, set): + config[key] = list(value) + + config_as_string = json.dumps(config, indent=2, sort_keys=True) + local_metadata[prefix] = config_as_string + return local_metadata + @staticmethod def write_lora_layers( state_dict: Dict[str, torch.Tensor], @@ -771,17 +785,3 @@ def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - - @staticmethod - def pack_metadata(config, prefix): - local_metadata = {} - if config is not None: - if isinstance(config, LoraConfig): - config = config.to_dict() - for key, value in config.items(): - if isinstance(value, set): - config[key] = list(value) - - config_as_string = json.dumps(config, indent=2, sort_keys=True) - local_metadata[prefix] = config_as_string - return local_metadata diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 0ece7eefc9bb..9ada78f193a9 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -173,7 +173,9 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. - return_metadata (`bool`) + return_metadata (`bool`): + If state dict metadata should be returned. Is only supported when the state dict has a safetensors + extension. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -255,7 +257,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`, *optional*): + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -307,7 +309,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`, *optional*): + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -425,8 +427,9 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - unet_lora_config: Dict - text_encoder_lora_config: Dict + unet_lora_config (`dict`, *optional*): LoRA configuration used to train the `unet_lora_layers`. + text_encoder_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -681,7 +684,9 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. - return_metadata (`bool`) + return_metadata (`bool`): + If state dict metadata should be returned. Is only supported when the state dict has a safetensors + extension. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -764,7 +769,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`, *optional*): + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -817,7 +822,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`, *optional*): + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -940,9 +945,11 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - unet_lora_config (`dict`): - text_encoder_lora_config (`dict`): - text_encoder_2_lora_config (`dict`): + unet_lora_config (`dict`, *optional*): LoRA configuration used to train the `unet_lora_layers`. + text_encoder_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_lora_layers`. + text_encoder_2_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_2_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1118,6 +1125,8 @@ def lora_state_dict( subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. return_metadata (`bool`): + If state dict metadata should be returned. Is only supported when the state dict has a safetensors + extension. """ # Load the main state dict first which has the LoRA layers for either of @@ -1353,7 +1362,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`, *optional*): + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1476,9 +1485,12 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - transformer_lora_config (`dict`): - text_encoder_lora_config (`dict`): - text_encoder_2_lora_config (`dict`): + transformer_lora_config (`dict`, *optional*): + LoRA configuration used to train the `transformer_lora_layers`. + text_encoder_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_lora_layers`. + text_encoder_2_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_2_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1654,7 +1666,9 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_metadata (`bool`): + return_metadata (`bool`, *optional*): + If state dict metadata should be returned. Is only supported when the state dict has a safetensors + extension. """ # Load the main state dict first which has the LoRA layers for either of @@ -1878,7 +1892,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`, *optional*): + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1997,8 +2011,10 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - transformer_lora_config: Dict - text_encoder_lora_config: Dict + transformer_lora_config (`dict`, *optional*): + LoRA configuration used to train the `transformer_lora_layers`. + text_encoder_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -2223,7 +2239,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`, *optional*): + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") From 167557b9bf0945e28cdb42880d08f3f6117c833a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 08:13:08 +0530 Subject: [PATCH 06/16] fix-copues --- src/diffusers/loaders/lora_pipeline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9ada78f193a9..64656a8deef7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1124,7 +1124,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_metadata (`bool`): + return_metadata (`bool`, *optional*): If state dict metadata should be returned. Is only supported when the state dict has a safetensors extension. @@ -2011,8 +2011,7 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - transformer_lora_config (`dict`, *optional*): - LoRA configuration used to train the `transformer_lora_layers`. + transformer_lora_config (`dict`, *optional*): LoRA configuration used to train the `transformer_lora_layers`. text_encoder_lora_config (`dict`, *optional*): LoRA configuration used to train the `text_encoder_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): From 6bbf629349e076021e9e81a8eec8c2b8a6f649e3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 08:13:28 +0530 Subject: [PATCH 07/16] style --- src/diffusers/loaders/lora_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 64656a8deef7..c0cc0a105bfd 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2011,7 +2011,8 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - transformer_lora_config (`dict`, *optional*): LoRA configuration used to train the `transformer_lora_layers`. + transformer_lora_config (`dict`, *optional*): + LoRA configuration used to train the `transformer_lora_layers`. text_encoder_lora_config (`dict`, *optional*): LoRA configuration used to train the `text_encoder_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): From 48768b66b666de6b6c012c2816e27bd37cfc345a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 08:34:31 +0530 Subject: [PATCH 08/16] style --- src/diffusers/loaders/lora_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index c0cc0a105bfd..0af2d95c7376 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -427,7 +427,8 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - unet_lora_config (`dict`, *optional*): LoRA configuration used to train the `unet_lora_layers`. + unet_lora_config (`dict`, *optional*): + LoRA configuration used to train the `unet_lora_layers`. text_encoder_lora_config (`dict`, *optional*): LoRA configuration used to train the `text_encoder_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): From 632bf78c04ffd1bce198f29035443b02895e7bc5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 18:50:39 +0530 Subject: [PATCH 09/16] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 0af2d95c7376..7a5cc6e06980 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1263,7 +1263,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`): + config (`dict`): Configuration that was used to train this LoRA. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict @@ -1793,7 +1793,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - config (`dict`): + config (`dict`): Configuration that was used to train this LoRA. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict From 91cfffc54b8ad81ba2809ccae41c4ee70b15f045 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 18:55:24 +0530 Subject: [PATCH 10/16] check, --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7a5cc6e06980..16c9fd11876f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1178,7 +1178,7 @@ def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. From 6fb5987e409cfa3d40675d85a14bea9e145bf3db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 23:06:21 +0530 Subject: [PATCH 11/16] utilize fix-copies better. --- src/diffusers/loaders/lora_pipeline.py | 38 ++++++++++---------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 16c9fd11876f..3a9ecd1390b5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -117,8 +117,8 @@ def load_lora_weights( else self.text_encoder, lora_scale=self.lora_scale, adapter_name=adapter_name, - _pipeline=self, config=metadata, + _pipeline=self, ) @classmethod @@ -527,10 +527,6 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -1054,10 +1050,6 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -1458,10 +1450,11 @@ def load_lora_into_text_encoder( # Unsafe code /> @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, transformer_lora_config: dict = None, @@ -1486,8 +1479,7 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - transformer_lora_config (`dict`, *optional*): - LoRA configuration used to train the `transformer_lora_layers`. + transformer_lora_config (`dict`, *optional*): LoRA configuration used to train the `transformer_lora_layers`. text_encoder_lora_config (`dict`, *optional*): LoRA configuration used to train the `text_encoder_lora_layers`. text_encoder_2_lora_config (`dict`, *optional*): @@ -1508,28 +1500,27 @@ def save_lora_weights( if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( - "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." + "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." ) if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, "transformer")) if transformer_lora_config is not None: - transformer_metadata = cls.pack_metadata(transformer_lora_config, cls.transformer_name) + transformer_metadata = cls.pack_metadata(transformer_lora_config, "transformer") metadata.update(transformer_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) if text_encoder_lora_config is not None: - te_config = cls.pack_metadata(text_encoder_lora_config, "text_encoder") - metadata.update(te_config) + te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder") + metadata.update(te_metadata) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) if text_encoder_2_lora_config is not None: - te2_config = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2") - metadata.update(te2_config) + te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2") + metadata.update(te2_metadata) - # Save the model cls.write_lora_layers( state_dict=state_dict, metadata=metadata, @@ -1540,6 +1531,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], @@ -1583,6 +1575,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): r""" Reverses the effect of @@ -1596,10 +1589,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -2100,6 +2089,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): r""" Reverses the effect of From 1852c3fa3b287b181447fa4112a5cd0f243ddb62 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 23:14:21 +0530 Subject: [PATCH 12/16] style --- src/diffusers/loaders/lora_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3a9ecd1390b5..09e83e748409 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1479,7 +1479,8 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - transformer_lora_config (`dict`, *optional*): LoRA configuration used to train the `transformer_lora_layers`. + transformer_lora_config (`dict`, *optional*): + LoRA configuration used to train the `transformer_lora_layers`. text_encoder_lora_config (`dict`, *optional*): LoRA configuration used to train the `text_encoder_lora_layers`. text_encoder_2_lora_config (`dict`, *optional*): From 712a110e765362c9a5db0af0fe29d8d04a962093 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 10 Aug 2024 23:18:29 +0530 Subject: [PATCH 13/16] style --- src/diffusers/loaders/lora_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 09e83e748409..613b16f0ee18 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -942,7 +942,8 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. - unet_lora_config (`dict`, *optional*): LoRA configuration used to train the `unet_lora_layers`. + unet_lora_config (`dict`, *optional*): + LoRA configuration used to train the `unet_lora_layers`. text_encoder_lora_config (`dict`, *optional*): LoRA configuration used to train the `text_encoder_lora_layers`. text_encoder_2_lora_config (`dict`, *optional*): From f7d30de31a95e0eda3bd2da250fe924b08c4aa3c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 12 Aug 2024 15:25:43 +0530 Subject: [PATCH 14/16] alpha_pattern and rank+pattern --- src/diffusers/utils/peft_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index c616aab2162f..3de466dd1fbf 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -21,9 +21,12 @@ from packaging import version +from . import logging from .import_utils import is_peft_available, is_torch_available +logger = logging.get_logger(__name__) + if is_torch_available(): import torch @@ -154,10 +157,20 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, # Try to retrieve config. alpha_retrieved = False + alpha_pattern = None + rank_pattern = None if config is not None: lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha alpha_retrieved = True + if config.get("alpha_pattern", None): + alpha_pattern = config["alpha_pattern"] + logger.warning("`alpha_pattern` found in the LoRA config. This will be ignored.") + + if config.get("rank_pattern", None): + rank_pattern = config["alpha_pattern"] + logger.warning("`rank_pattern` found in the LoRA config. This will be ignored.") + if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] From e7808e4c35885b5a7fb992028a072a76303d5cf5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 12 Aug 2024 16:24:28 +0530 Subject: [PATCH 15/16] fix --- src/diffusers/utils/peft_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 3de466dd1fbf..32e1fa3caefa 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -157,18 +157,14 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, # Try to retrieve config. alpha_retrieved = False - alpha_pattern = None - rank_pattern = None if config is not None: lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha alpha_retrieved = True - if config.get("alpha_pattern", None): - alpha_pattern = config["alpha_pattern"] + if config.get("alpha_pattern", None) is not None: logger.warning("`alpha_pattern` found in the LoRA config. This will be ignored.") - if config.get("rank_pattern", None): - rank_pattern = config["alpha_pattern"] + if config.get("rank_pattern", None) is not None: logger.warning("`rank_pattern` found in the LoRA config. This will be ignored.") if len(set(rank_dict.values())) > 1: From 178a4596b28a9a2c96e6682736909ecba37caa30 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 12 Aug 2024 18:40:23 +0530 Subject: [PATCH 16/16] add: comment. --- src/diffusers/utils/peft_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 32e1fa3caefa..d4cfd4cc4d83 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -161,6 +161,12 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha alpha_retrieved = True + # We simply ignore the `alpha_pattern` and `rank_pattern` if they are found + # in the `config`. This is because: + # 1. We determine `rank_pattern` from the `rank_dict`. + # 2. When `network_alpha_dict` is passed that means the underlying checkpoint + # is a non-diffusers checkpoint. + # More details: https://github.com/huggingface/diffusers/pull/9143#discussion_r1711491175 if config.get("alpha_pattern", None) is not None: logger.warning("`alpha_pattern` found in the LoRA config. This will be ignored.")