diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 60213dd75685..7c7a4ace4309 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -880,11 +880,16 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers_to_save = None text_encoder_lora_layers_to_save = None + unet_lora_config = None + text_encoder_lora_config = None + for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = get_peft_model_state_dict(model) + unet_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): text_encoder_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_lora_config = model.peft_config["default"] else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -895,6 +900,8 @@ def save_model_hook(models, weights, output_dir): output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_lora_config, ) def load_model_hook(models, input_dir): @@ -911,10 +918,12 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet( + lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata + ) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -1315,17 +1324,22 @@ def compute_text_embeddings(prompt): unet = unet.to(torch.float32) unet_lora_state_dict = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] if args.train_text_encoder: text_encoder = accelerator.unwrap_model(text_encoder) text_encoder_state_dict = get_peft_model_state_dict(text_encoder) + text_encoder_lora_config = text_encoder.peft_config["default"] else: text_encoder_state_dict = None + text_encoder_lora_config = None LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=text_encoder_state_dict, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_lora_config, ) # Final inference diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c8a9a6ad4812..d8fd11e34c43 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1033,13 +1033,20 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None + unet_lora_config = None + text_encoder_one_lora_config = None + text_encoder_two_lora_config = None + for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = get_peft_model_state_dict(model) + unet_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_one_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_two_lora_config = model.peft_config["default"] else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1051,6 +1058,9 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_one_lora_config, + text_encoder_2_lora_config=text_encoder_two_lora_config, ) def load_model_hook(models, input_dir): @@ -1070,17 +1080,19 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet( + lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata + ) text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata ) text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -1616,21 +1628,29 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) unet_lora_layers = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) + text_encoder_one_lora_config = text_encoder_one.peft_config["default"] + text_encoder_two_lora_config = text_encoder_two.peft_config["default"] else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None + text_encoder_one_lora_config = None + text_encoder_two_lora_config = None StableDiffusionXLPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_one_lora_config, + text_encoder_2_lora_config=text_encoder_two_lora_config, ) # Final inference diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b63500f906a8..fc55815228e5 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -833,10 +833,12 @@ def collate_fn(examples): accelerator.save_state(save_path) unet_lora_state_dict = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] StableDiffusionPipeline.save_lora_weights( save_directory=save_path, unet_lora_layers=unet_lora_state_dict, + unet_lora_config=unet_lora_config, safe_serialization=True, ) @@ -898,10 +900,12 @@ def collate_fn(examples): unet = unet.to(torch.float32) unet_lora_state_dict = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] StableDiffusionPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, safe_serialization=True, + unet_lora_config=unet_lora_config, ) if args.push_to_hub: diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 2e70c77e860e..20a6731e716d 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -682,13 +682,20 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None + unet_lora_config = None + text_encoder_one_lora_config = None + text_encoder_two_lora_config = None + for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = get_peft_model_state_dict(model) + unet_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_one_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_two_lora_config = model.peft_config["default"] else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -700,6 +707,9 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_one_lora_config, + text_encoder_2_lora_config=text_encoder_two_lora_config, ) def load_model_hook(models, input_dir): @@ -719,17 +729,19 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet( + lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata + ) text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata ) text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -1194,6 +1206,7 @@ def compute_time_ids(original_size, crops_coords_top_left): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet_lora_state_dict = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) @@ -1201,15 +1214,23 @@ def compute_time_ids(original_size, crops_coords_top_left): text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one) text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two) + + text_encoder_one_lora_config = text_encoder_one.peft_config["default"] + text_encoder_two_lora_config = text_encoder_two.peft_config["default"] else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None + text_encoder_one_lora_config = None + text_encoder_two_lora_config = None StableDiffusionXLPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_one_lora_config, + text_encoder_2_lora_config=text_encoder_two_lora_config, ) del unet diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index c1c3a260ec11..c306a2ea16a0 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -103,7 +104,7 @@ def load_lora_weights( `default_{i}` where i is the total number of adapters being loaded. """ # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -114,6 +115,7 @@ def load_lora_weights( self.load_lora_into_unet( state_dict, network_alphas=network_alphas, + config=metadata, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, @@ -125,6 +127,7 @@ def load_lora_weights( text_encoder=getattr(self, self.text_encoder_name) if not hasattr(self, "text_encoder") else self.text_encoder, + config=metadata, lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, @@ -219,6 +222,7 @@ def lora_state_dict( } model_file = None + metadata = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( @@ -248,6 +252,8 @@ def lora_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") + with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: + metadata = f.metadata() except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -294,7 +300,7 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) - return state_dict, network_alphas + return state_dict, network_alphas, metadata @classmethod def _best_guess_weight_name( @@ -370,7 +376,7 @@ def _optionally_disable_offloading(cls, _pipeline): @classmethod def load_lora_into_unet( - cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None + cls, state_dict, network_alphas, unet, config=None, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -384,6 +390,8 @@ def load_lora_into_unet( See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. + config: (`Dict`): + LoRA configuration parsed from the state dict. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. @@ -443,7 +451,9 @@ def load_lora_into_unet( if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if config is not None and isinstance(config, dict) and len(config) > 0: + config = json.loads(config["unet"]) + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -484,6 +494,7 @@ def load_lora_into_text_encoder( network_alphas, text_encoder, prefix=None, + config=None, lora_scale=1.0, low_cpu_mem_usage=None, adapter_name=None, @@ -502,6 +513,8 @@ def load_lora_into_text_encoder( The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. + config (`Dict`): + LoRA configuration parsed from state dict. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. @@ -575,10 +588,11 @@ def load_lora_into_text_encoder( if USE_PEFT_BACKEND: from peft import LoraConfig + if config is not None and len(config) > 0: + config = json.loads(config[prefix]) lora_config_kwargs = get_peft_kwargs( - rank, network_alphas, text_encoder_lora_state_dict, is_unet=False + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False ) - lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -786,6 +800,8 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + unet_lora_config=None, + text_encoder_lora_config=None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -813,21 +829,54 @@ def save_lora_weights( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ + if not USE_PEFT_BACKEND and not safe_serialization: + if unet_lora_config or text_encoder_lora_config: + raise ValueError( + "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`." + ) + elif USE_PEFT_BACKEND and safe_serialization: + from peft import LoraConfig + + if not (unet_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.") + state_dict = {} + metadata = {} def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict - if not (unet_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") + def pack_metadata(config, prefix): + local_metadata = {} + if config is not None: + if isinstance(config, LoraConfig): + config = config.to_dict() + for key, value in config.items(): + if isinstance(value, set): + config[key] = list(value) + + config_as_string = json.dumps(config, indent=2, sort_keys=True) + local_metadata[prefix] = config_as_string + return local_metadata if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + prefix = "unet" + unet_state_dict = pack_weights(unet_lora_layers, prefix) + state_dict.update(unet_state_dict) + if unet_lora_config is not None: + unet_metadata = pack_metadata(unet_lora_config, prefix) + metadata.update(unet_metadata) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + prefix = "text_encoder" + text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder") + state_dict.update(text_encoder_state_dict) + if text_encoder_lora_config is not None: + text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix) + metadata.update(text_encoder_metadata) # Save the model cls.write_lora_layers( @@ -837,6 +886,7 @@ def pack_weights(layers, prefix): weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + metadata=metadata, ) @staticmethod @@ -847,7 +897,11 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, + metadata=None, ): + if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0: + raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.") + if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -855,8 +909,10 @@ def write_lora_layers( if save_function is None: if safe_serialization: - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + def save_function(weights, filename, metadata): + if metadata is None: + metadata = {"format": "pt"} + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save @@ -869,7 +925,10 @@ def save_function(weights, filename): else: weight_name = LORA_WEIGHT_NAME - save_function(state_dict, os.path.join(save_directory, weight_name)) + if save_function != torch.save: + save_function(state_dict, os.path.join(save_directory, weight_name), metadata) + else: + save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") def unload_lora_weights(self): @@ -1301,7 +1360,7 @@ def load_lora_weights( # pipeline. # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, @@ -1311,7 +1370,12 @@ def load_lora_weights( raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( - state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self + state_dict, + network_alphas=network_alphas, + unet=self.unet, + config=metadata, + adapter_name=adapter_name, + _pipeline=self, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -1319,6 +1383,7 @@ def load_lora_weights( text_encoder_state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder, + config=metadata, prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, @@ -1331,6 +1396,7 @@ def load_lora_weights( text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder_2, + config=metadata, prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, @@ -1344,6 +1410,9 @@ def save_lora_weights( unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + unet_lora_config=None, + text_encoder_lora_config=None, + text_encoder_2_lora_config=None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -1371,24 +1440,63 @@ def save_lora_weights( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ + if not USE_PEFT_BACKEND and not safe_serialization: + if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config: + raise ValueError( + "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`." + ) + elif USE_PEFT_BACKEND and safe_serialization: + from peft import LoraConfig + + if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." + ) + state_dict = {} + metadata = {} def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." - ) + def pack_metadata(config, prefix): + local_metadata = {} + if config is not None: + if isinstance(config, LoraConfig): + config = config.to_dict() + for key, value in config.items(): + if isinstance(value, set): + config[key] = list(value) + + config_as_string = json.dumps(config, indent=2, sort_keys=True) + local_metadata[prefix] = config_as_string + return local_metadata if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + prefix = "unet" + unet_state_dict = pack_weights(unet_lora_layers, prefix) + state_dict.update(unet_state_dict) + if unet_lora_config is not None: + unet_metadata = pack_metadata(unet_lora_config, prefix) + metadata.update(unet_metadata) if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + prefix = "text_encoder" + text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder") + state_dict.update(text_encoder_state_dict) + if text_encoder_lora_config is not None: + text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix) + metadata.update(text_encoder_metadata) + + prefix = "text_encoder_2" + text_encoder_2_state_dict = pack_weights(text_encoder_2_lora_layers, prefix) + state_dict.update(text_encoder_2_state_dict) + if text_encoder_2_lora_config is not None: + text_encoder_2_metadata = pack_metadata(text_encoder_2_lora_config, prefix) + metadata.update(text_encoder_2_metadata) cls.write_lora_layers( state_dict=state_dict, @@ -1397,6 +1505,7 @@ def pack_weights(layers, prefix): weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + metadata=metadata, ) def _remove_text_encoder_monkey_patch(self): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index c77efc28f62a..2c0a6c7670f7 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -138,11 +138,17 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] + # Try to retrive config. + alpha_retrieved = False + if config is not None: + lora_alpha = config["lora_alpha"] + alpha_retrieved = True + if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] @@ -154,7 +160,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True if network_alpha_dict is not None and len(network_alpha_dict) > 0: if len(set(network_alpha_dict.values())) > 1: # get the alpha occuring the most number of times - lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] + if not alpha_retrieved: + lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) @@ -165,7 +172,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True } else: alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} - else: + elif not alpha_retrieved: lora_alpha = set(network_alpha_dict.values()).pop() # layer names without the Diffusers specific diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 6d3ac8b4592a..5a47d2e2c3ac 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -107,8 +107,9 @@ class PeftLoraLoaderMixinTests: unet_kwargs = None vae_kwargs = None - def get_dummy_components(self, scheduler_cls=None): + def get_dummy_components(self, scheduler_cls=None, lora_alpha=None): scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler + lora_alpha = 4 if lora_alpha is None else lora_alpha torch.manual_seed(0) unet = UNet2DConditionModel(**self.unet_kwargs) @@ -123,11 +124,14 @@ def get_dummy_components(self, scheduler_cls=None): tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") text_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False + r=4, + lora_alpha=lora_alpha, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + init_lora_weights=False, ) unet_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False ) unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) @@ -714,6 +718,68 @@ def test_simple_inference_with_text_unet_lora_unloaded(self): "Fused lora should change the output", ) + def test_if_lora_alpha_is_correctly_parsed(self): + lora_alpha = 8 + + components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.unet.add_adapter(unet_lora_config) + pipe.text_encoder.add_adapter(text_lora_config) + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + + # Inference works? + _ = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + unet_state_dict = get_peft_model_state_dict(pipe.unet) + unet_lora_config = pipe.unet.peft_config["default"] + + text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) + text_encoder_lora_config = pipe.text_encoder.peft_config["default"] + + if self.has_two_text_encoders: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"] + + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=unet_state_dict, + text_encoder_lora_layers=text_encoder_state_dict, + text_encoder_2_lora_layers=text_encoder_2_state_dict, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_lora_config, + text_encoder_2_lora_config=text_encoder_2_lora_config, + ) + else: + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=unet_state_dict, + text_encoder_lora_layers=text_encoder_state_dict, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_lora_config, + ) + loaded_pipe = self.pipeline_class(**components) + loaded_pipe.load_lora_weights(tmpdirname) + + # Inference works? + _ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images + + assert ( + loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for UNet." + assert ( + loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder." + if self.has_two_text_encoders: + assert ( + loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder 2." + def test_simple_inference_with_text_unet_lora_unfused(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights