diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 4b963270427b..4e37c3e0054e 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -14,6 +14,7 @@ import copy import inspect +import json import os from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -26,6 +27,7 @@ from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( + SAFETENSORS_FILE_EXTENSION, USE_PEFT_BACKEND, _get_model_file, delete_adapter_layers, @@ -44,6 +46,7 @@ from transformers import PreTrainedModel if is_peft_available(): + from peft import LoraConfig from peft.tuners.tuners_utils import BaseTunerLayer if is_accelerate_available(): @@ -252,6 +255,7 @@ def _fetch_state_dict( from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE model_file = None + metadata = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( @@ -280,6 +284,8 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") + with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: + metadata = f.metadata() except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -305,10 +311,14 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = load_state_dict(model_file) + file_extension = os.path.basename(model_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: + with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: + metadata = f.metadata() else: state_dict = pretrained_model_name_or_path_or_dict - return state_dict + return state_dict, metadata @classmethod def _best_guess_weight_name( @@ -709,6 +719,20 @@ def pack_weights(layers, prefix): layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} return layers_state_dict + @staticmethod + def pack_metadata(config, prefix): + local_metadata = {} + if config is not None: + if isinstance(config, LoraConfig): + config = config.to_dict() + for key, value in config.items(): + if isinstance(value, set): + config[key] = list(value) + + config_as_string = json.dumps(config, indent=2, sort_keys=True) + local_metadata[prefix] = config_as_string + return local_metadata + @staticmethod def write_lora_layers( state_dict: Dict[str, torch.Tensor], @@ -717,9 +741,13 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, + metadata=None, ): from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0: + raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.") + if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -727,8 +755,12 @@ def write_lora_layers( if save_function is None: if safe_serialization: - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + def save_function(weights, filename, metadata): + if metadata is None: + metadata = {"format": "pt"} + elif len(metadata) > 0: + metadata.update({"format": "pt"}) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save @@ -742,7 +774,10 @@ def save_function(weights, filename): weight_name = LORA_WEIGHT_NAME save_path = Path(save_directory, weight_name).as_posix() - save_function(state_dict, save_path) + if save_function != torch.save: + save_function(state_dict, save_path, metadata) + else: + save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") @property diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f612cc0c6e53..613b16f0ee18 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os from typing import Callable, Dict, List, Optional, Union @@ -92,7 +93,9 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas, metadata = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True + ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: @@ -104,6 +107,7 @@ def load_lora_weights( unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, _pipeline=self, + config=metadata, ) self.load_lora_into_text_encoder( state_dict, @@ -113,6 +117,7 @@ def load_lora_weights( else self.text_encoder, lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -168,6 +173,9 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_metadata (`bool`): + If state dict metadata should be returned. Is only supported when the state dict has a safetensors + extension. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -181,6 +189,7 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_metadata = kwargs.pop("return_metadata", False) allow_pickle = False if use_safetensors is None: @@ -192,7 +201,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict, metadata = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -224,10 +233,13 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + if return_metadata: + return state_dict, network_alphas, metadata + else: + return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -245,6 +257,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -258,7 +271,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + config=config, + _pipeline=_pipeline, ) @classmethod @@ -270,6 +287,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -291,6 +309,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -303,6 +322,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -341,7 +363,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -385,6 +409,8 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + unet_lora_config: dict = None, + text_encoder_lora_config: dict = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -401,6 +427,10 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + unet_lora_config (`dict`, *optional*): + LoRA configuration used to train the `unet_lora_layers`. + text_encoder_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -413,19 +443,27 @@ def save_lora_weights( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} + metadata = {} if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") if unet_lora_layers: state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + if unet_lora_config: + unet_metadata = cls.pack_metadata(unet_lora_config, cls.unet_name) + metadata.update(unet_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if text_encoder_lora_config: + te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name) + metadata.update(te_metadata) # Save the model cls.write_lora_layers( state_dict=state_dict, + metadata=metadata, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, @@ -489,10 +527,6 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -550,9 +584,10 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, + return_metadata=True, **kwargs, ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) @@ -560,7 +595,12 @@ def load_lora_weights( raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( - state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self + state_dict, + network_alphas=network_alphas, + unet=self.unet, + adapter_name=adapter_name, + config=metadata, + _pipeline=self, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -571,6 +611,7 @@ def load_lora_weights( prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -583,6 +624,7 @@ def load_lora_weights( prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -639,6 +681,9 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_metadata (`bool`): + If state dict metadata should be returned. Is only supported when the state dict has a safetensors + extension. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -652,6 +697,7 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_metadata = kwargs.pop("return_metadata", False) allow_pickle = False if use_safetensors is None: @@ -663,7 +709,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict, metadata = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -695,11 +741,14 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + if return_metadata: + return state_dict, network_alphas, metadata + else: + return state_dict, network_alphas @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -717,6 +766,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -730,7 +780,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + config=config, + _pipeline=_pipeline, ) @classmethod @@ -743,6 +797,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -764,6 +819,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -776,6 +832,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -814,7 +873,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -859,6 +920,9 @@ def save_lora_weights( unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + unet_lora_config: dict = None, + text_encoder_lora_config: dict = None, + text_encoder_2_lora_config: dict = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -878,6 +942,12 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + unet_lora_config (`dict`, *optional*): + LoRA configuration used to train the `unet_lora_layers`. + text_encoder_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_lora_layers`. + text_encoder_2_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_2_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -890,6 +960,7 @@ def save_lora_weights( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} + metadata = {} if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -898,15 +969,25 @@ def save_lora_weights( if unet_lora_layers: state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) + if unet_lora_config is not None: + unet_metadata = cls.pack_metadata(unet_lora_config, "unet") + metadata.update(unet_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + if text_encoder_lora_config is not None: + te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder") + metadata.update(te_metadata) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if text_encoder_2_lora_config is not None: + te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2") + metadata.update(te2_metadata) cls.write_lora_layers( state_dict=state_dict, + metadata=metadata, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, @@ -970,10 +1051,6 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -1041,6 +1118,9 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_metadata (`bool`, *optional*): + If state dict metadata should be returned. Is only supported when the state dict has a safetensors + extension. """ # Load the main state dict first which has the LoRA layers for either of @@ -1054,6 +1134,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_metadata = kwargs.pop("return_metadata", False) allow_pickle = False if use_safetensors is None: @@ -1065,7 +1146,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict, metadata = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1080,13 +1161,17 @@ def lora_state_dict( allow_pickle=allow_pickle, ) - return state_dict + # Otherwise, this would be a breaking change. + if return_metadata: + return state_dict, metadata + else: + return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. @@ -1114,7 +1199,9 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, metadata = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True + ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: @@ -1124,6 +1211,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -1136,6 +1224,7 @@ def load_lora_weights( prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -1148,11 +1237,12 @@ def load_lora_weights( prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @classmethod - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1166,6 +1256,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`): Configuration that was used to train this LoRA. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict @@ -1192,7 +1283,11 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if config is not None and isinstance(config, dict) and len(config) > 0: + config = json.loads(config[cls.transformer_name]) + lora_config_kwargs = get_peft_kwargs( + rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( @@ -1239,6 +1334,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -1260,6 +1356,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1272,6 +1369,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -1310,7 +1410,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -1349,12 +1451,16 @@ def load_lora_into_text_encoder( # Unsafe code /> @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + transformer_lora_config: dict = None, + text_encoder_lora_config: dict = None, + text_encoder_2_lora_config: dict = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -1374,6 +1480,12 @@ def save_lora_weights( text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + transformer_lora_config (`dict`, *optional*): + LoRA configuration used to train the `transformer_lora_layers`. + text_encoder_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_lora_layers`. + text_encoder_2_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_2_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1386,24 +1498,34 @@ def save_lora_weights( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} + metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( - "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." + "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." ) if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, "transformer")) + if transformer_lora_config is not None: + transformer_metadata = cls.pack_metadata(transformer_lora_config, "transformer") + metadata.update(transformer_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + if text_encoder_lora_config is not None: + te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder") + metadata.update(te_metadata) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if text_encoder_2_lora_config is not None: + te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2") + metadata.update(te2_metadata) - # Save the model cls.write_lora_layers( state_dict=state_dict, + metadata=metadata, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, @@ -1411,6 +1533,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], @@ -1454,6 +1577,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): r""" Reverses the effect of @@ -1467,10 +1591,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -1538,6 +1658,9 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_metadata (`bool`, *optional*): + If state dict metadata should be returned. Is only supported when the state dict has a safetensors + extension. """ # Load the main state dict first which has the LoRA layers for either of @@ -1551,6 +1674,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_metadata = kwargs.pop("return_metadata", False) allow_pickle = False if use_safetensors is None: @@ -1562,7 +1686,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict, metadata = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1577,7 +1701,11 @@ def lora_state_dict( allow_pickle=allow_pickle, ) - return state_dict + # Otherwise, this would be a breaking change. + if return_metadata: + return state_dict, metadata + else: + return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -1611,7 +1739,9 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, metadata = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True + ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: @@ -1621,6 +1751,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @@ -1633,12 +1764,13 @@ def load_lora_weights( prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, + config=metadata, _pipeline=self, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1652,6 +1784,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`): Configuration that was used to train this LoRA. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict @@ -1678,7 +1811,11 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if config is not None and isinstance(config, dict) and len(config) > 0: + config = json.loads(config[cls.transformer_name]) + lora_config_kwargs = get_peft_kwargs( + rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( @@ -1725,6 +1862,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -1746,6 +1884,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1758,6 +1897,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -1796,7 +1938,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -1841,6 +1985,8 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_config: dict = None, + text_encoder_lora_config: dict = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -1857,6 +2003,10 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + transformer_lora_config (`dict`, *optional*): + LoRA configuration used to train the `transformer_lora_layers`. + text_encoder_lora_config (`dict`, *optional*): + LoRA configuration used to train the `text_encoder_lora_layers`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1869,19 +2019,27 @@ def save_lora_weights( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} + metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_config: + transformer_metadata = cls.pack_metadata(transformer_lora_config, cls.transformer_name) + metadata.update(transformer_metadata) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if text_encoder_lora_config: + te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name) + metadata.update(te_metadata) # Save the model cls.write_lora_layers( state_dict=state_dict, + metadata=metadata, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, @@ -1933,6 +2091,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): r""" Reverses the effect of @@ -2051,6 +2210,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + config=None, _pipeline=None, ): """ @@ -2072,6 +2232,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2084,6 +2245,9 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -2122,7 +2286,9 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 32ace77b6224..9caaf2440bde 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os from collections import defaultdict from contextlib import nullcontext @@ -115,6 +116,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict `default_{i}` where i is the total number of adapters being loaded. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + config: (`dict`, *optional*) Example: @@ -143,6 +145,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict _pipeline = kwargs.pop("_pipeline", None) network_alphas = kwargs.pop("network_alphas", None) allow_pickle = False + config = kwargs.pop("config", None) if use_safetensors is None: use_safetensors = True @@ -208,6 +211,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict unet_identifier_key=self.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, + config=config, _pipeline=_pipeline, ) else: @@ -268,7 +272,7 @@ def _process_custom_diffusion(self, state_dict): return attn_processors - def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): + def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, config=None): # This method does the following things: # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # format. For legacy format no filtering is applied. @@ -316,7 +320,10 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if config is not None and isinstance(config, dict) and len(config) > 0: + config = json.loads(config["unet"]) + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index ca55192ff7ae..d4cfd4cc4d83 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -21,9 +21,12 @@ from packaging import version +from . import logging from .import_utils import is_peft_available, is_torch_available +logger = logging.get_logger(__name__) + if is_torch_available(): import torch @@ -147,11 +150,29 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] + # Try to retrieve config. + alpha_retrieved = False + if config is not None: + lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha + alpha_retrieved = True + + # We simply ignore the `alpha_pattern` and `rank_pattern` if they are found + # in the `config`. This is because: + # 1. We determine `rank_pattern` from the `rank_dict`. + # 2. When `network_alpha_dict` is passed that means the underlying checkpoint + # is a non-diffusers checkpoint. + # More details: https://github.com/huggingface/diffusers/pull/9143#discussion_r1711491175 + if config.get("alpha_pattern", None) is not None: + logger.warning("`alpha_pattern` found in the LoRA config. This will be ignored.") + + if config.get("rank_pattern", None) is not None: + logger.warning("`rank_pattern` found in the LoRA config. This will be ignored.") + if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] @@ -160,7 +181,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} - if network_alpha_dict is not None and len(network_alpha_dict) > 0: + if not alpha_retrieved and network_alpha_dict is not None and len(network_alpha_dict) > 0: if len(set(network_alpha_dict.values())) > 1: # get the alpha occuring the most number of times lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index c0f0684ac4de..b08fceee3937 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -19,7 +19,7 @@ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device sys.path.append(".") @@ -28,6 +28,7 @@ @require_peft_backend +@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = FluxPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler() diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 283b9f534766..fd9c15505438 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -87,7 +87,7 @@ class PeftLoraLoaderMixinTests: transformer_kwargs = None vae_kwargs = None - def get_dummy_components(self, scheduler_cls=None, use_dora=False): + def get_dummy_components(self, scheduler_cls=None, use_dora=False, alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") if self.has_two_text_encoders and self.has_three_text_encoders: @@ -95,6 +95,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls rank = 4 + alpha = rank if alpha is None else alpha torch.manual_seed(0) if self.unet_kwargs is not None: @@ -120,7 +121,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): text_lora_config = LoraConfig( r=rank, - lora_alpha=rank, + lora_alpha=alpha, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False, use_dora=use_dora, @@ -128,7 +129,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): denoiser_lora_config = LoraConfig( r=rank, - lora_alpha=rank, + lora_alpha=alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False, use_dora=use_dora, @@ -1752,3 +1753,85 @@ def set_pad_mode(network, mode="circular"): _, _, inputs = self.get_dummy_inputs() _ = pipe(**inputs).images + + def test_if_lora_alpha_is_correctly_parsed(self): + lora_alpha = 8 + + scheduler_class = FlowMatchEulerDiscreteScheduler if self.uses_flow_matching else DDIMScheduler + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + alpha=lora_alpha, scheduler_cls=scheduler_class + ) + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + + # Inference works? + _ = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + denoiser = pipe.unet if self.unet_kwargs else pipe.transformer + denoiser_state_dict = get_peft_model_state_dict(denoiser) + denoiser_lora_config = denoiser.peft_config["default"] + + text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) + text_encoder_lora_config = pipe.text_encoder.peft_config["default"] + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"] + + saving_kwargs = { + "save_directory": tmpdirname, + "text_encoder_lora_layers": text_encoder_state_dict, + "text_encoder_lora_config": text_encoder_lora_config, + } + if self.unet_kwargs is not None: + saving_kwargs.update( + {"unet_lora_layers": denoiser_state_dict, "unet_lora_config": denoiser_lora_config} + ) + else: + saving_kwargs.update( + {"transformer_lora_layers": denoiser_state_dict, "transformer_lora_config": denoiser_lora_config} + ) + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + saving_kwargs.update( + { + "text_encoder_2_lora_layers": text_encoder_2_state_dict, + "text_encoder_2_lora_config": text_encoder_2_lora_config, + } + ) + + self.pipeline_class.save_lora_weights(**saving_kwargs) + loaded_pipe = self.pipeline_class(**components) + loaded_pipe.load_lora_weights(tmpdirname) + + # Inference works? + _ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images + + denoiser_loaded = pipe.unet if self.unet_kwargs is not None else pipe.transformer + assert ( + denoiser_loaded.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for UNet." + assert ( + loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder." + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + assert ( + loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder 2."