From ff3d3808242cc546464c0ce68e4251b633829973 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 17:15:05 +0530 Subject: [PATCH 01/29] fix: parse lora_alpha correctly --- .../dreambooth/train_dreambooth_lora_sdxl.py | 18 ++++++++++ src/diffusers/loaders/lora.py | 34 +++++++++++++++---- src/diffusers/utils/peft_utils.py | 14 ++++++-- 3 files changed, 57 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c8a9a6ad4812..d188eb198073 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1033,13 +1033,20 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None + unet_lora_config = None + text_encoder_one_lora_config = None + text_encoder_two_lora_config = None + for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = get_peft_model_state_dict(model) + unet_lora_config = model.peft_config elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_one_lora_config = model.peft_config elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_two_lora_config = model.peft_config else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1051,6 +1058,9 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_one_lora_config, + text_encoder_2_lora_config=text_encoder_two_lora_config, ) def load_model_hook(models, input_dir): @@ -1616,21 +1626,29 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) unet_lora_layers = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) + text_encoder_one_lora_config = text_encoder_one.peft_config + text_encoder_two_lora_config = text_encoder_two.peft_config else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None + text_encoder_one_lora_config = None + text_encoder_two_lora_config = None StableDiffusionXLPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_one_lora_config, + text_encoder_2_lora_config=text_encoder_two_lora_config, ) # Final inference diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 3955fc2a1395..1bd24bed9506 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -778,6 +778,8 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + unet_lora_config=None, + text_encoder_lora_config=None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -805,21 +807,29 @@ def save_lora_weights( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ + if not USE_PEFT_BACKEND: + if unet_lora_config or text_encoder_lora_config: + raise ValueError( + "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`." + ) + state_dict = {} - def pack_weights(layers, prefix): + def pack_weights(layers, prefix, config=None): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + if config is not None: + layers_state_dict[f"{prefix}_lora_config"] = config return layers_state_dict if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + state_dict.update(pack_weights(unet_lora_layers, "unet", config=unet_lora_config)) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder", config=text_encoder_lora_config)) # Save the model cls.write_lora_layers( @@ -1336,6 +1346,9 @@ def save_lora_weights( unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + unet_lora_config=None, + text_encoder_lora_config=None, + text_encoder_2_lora_config=None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -1363,11 +1376,18 @@ def save_lora_weights( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ + if not USE_PEFT_BACKEND: + if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config: + raise ValueError( + "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`." + ) state_dict = {} - def pack_weights(layers, prefix): + def pack_weights(layers, prefix, config=None): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + if config is not None: + layers_state_dict[f"{prefix}_lora_config"] = config return layers_state_dict if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): @@ -1376,11 +1396,11 @@ def pack_weights(layers, prefix): ) if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + state_dict.update(pack_weights(unet_lora_layers, "unet", unet_lora_config)) if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder", text_encoder_lora_config)) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2", text_encoder_2_lora_config)) cls.write_lora_layers( state_dict=state_dict, diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index c77efc28f62a..c29bc3d199c1 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -143,6 +143,14 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] + # Try to retrive config. + alpha_retrieved_from_state_dict = False + config_key = list(filter(lambda x: "config" in x, list(peft_state_dict.keys()))) + if len(config_key) > 0: + config_key = config_key[0] + lora_alpha = peft_state_dict[config_key].lora_alpha + alpha_retrieved_from_state_dict = True + if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] @@ -154,7 +162,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True if network_alpha_dict is not None and len(network_alpha_dict) > 0: if len(set(network_alpha_dict.values())) > 1: # get the alpha occuring the most number of times - lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] + if not alpha_retrieved_from_state_dict: + lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) @@ -166,7 +175,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True else: alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} else: - lora_alpha = set(network_alpha_dict.values()).pop() + if not alpha_retrieved_from_state_dict: + lora_alpha = set(network_alpha_dict.values()).pop() # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) From 79b16373b706c2a7f9131723dbd499bbae122878 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 17:50:00 +0530 Subject: [PATCH 02/29] fix --- src/diffusers/loaders/lora.py | 99 ++++++++++++++++++++++++------- src/diffusers/utils/peft_utils.py | 16 +++-- 2 files changed, 84 insertions(+), 31 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 1bd24bed9506..e70da7da9ee1 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -102,7 +103,7 @@ def load_lora_weights( `default_{i}` where i is the total number of adapters being loaded. """ # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -113,6 +114,7 @@ def load_lora_weights( self.load_lora_into_unet( state_dict, network_alphas=network_alphas, + config=metadata, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, @@ -124,6 +126,7 @@ def load_lora_weights( text_encoder=getattr(self, self.text_encoder_name) if not hasattr(self, "text_encoder") else self.text_encoder, + config=metadata, lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, @@ -218,6 +221,7 @@ def lora_state_dict( } model_file = None + metadata = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( @@ -245,6 +249,8 @@ def lora_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") + with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: + metadata = f.metadata() except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -291,7 +297,7 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) - return state_dict, network_alphas + return state_dict, network_alphas, metadata @classmethod def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"): @@ -362,7 +368,7 @@ def _optionally_disable_offloading(cls, _pipeline): @classmethod def load_lora_into_unet( - cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None + cls, state_dict, network_alphas, unet, config=None, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -435,7 +441,9 @@ def load_lora_into_unet( if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if config is not None and len(config) > 0: + config = config["unet"] + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -476,6 +484,7 @@ def load_lora_into_text_encoder( network_alphas, text_encoder, prefix=None, + config=None, lora_scale=1.0, low_cpu_mem_usage=None, adapter_name=None, @@ -567,10 +576,11 @@ def load_lora_into_text_encoder( if USE_PEFT_BACKEND: from peft import LoraConfig + if config is not None and len(config) > 0: + config = config[prefix] lora_config_kwargs = get_peft_kwargs( - rank, network_alphas, text_encoder_lora_state_dict, is_unet=False + rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False ) - lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -807,10 +817,10 @@ def save_lora_weights( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ - if not USE_PEFT_BACKEND: + if not USE_PEFT_BACKEND and not safe_serialization: if unet_lora_config or text_encoder_lora_config: raise ValueError( - "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`." + "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`. It also requires `safe_serialization` to be set to True." ) state_dict = {} @@ -849,7 +859,11 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, + metadata=None, ): + if not safe_serialization and len(metadata) > 0: + raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.") + if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -857,8 +871,10 @@ def write_lora_layers( if save_function is None: if safe_serialization: - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + def save_function(weights, filename, metadata): + if metadata is None: + metadata = {"format": "pt"} + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save @@ -871,7 +887,10 @@ def save_function(weights, filename): else: weight_name = LORA_WEIGHT_NAME - save_function(state_dict, os.path.join(save_directory, weight_name)) + if save_function != torch.save: + save_function(state_dict, os.path.join(save_directory, weight_name), metadata) + else: + save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") def unload_lora_weights(self): @@ -1303,7 +1322,7 @@ def load_lora_weights( # pipeline. # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, @@ -1313,7 +1332,12 @@ def load_lora_weights( raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( - state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self + state_dict, + network_alphas=network_alphas, + unet=self.unet, + config=metadata, + adapter_name=adapter_name, + _pipeline=self, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -1321,6 +1345,7 @@ def load_lora_weights( text_encoder_state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder, + config=metadata, prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, @@ -1333,6 +1358,7 @@ def load_lora_weights( text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder_2, + config=metadata, prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, @@ -1381,26 +1407,54 @@ def save_lora_weights( raise ValueError( "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`." ) + + if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." + ) + state_dict = {} + metadata = {} def pack_weights(layers, prefix, config=None): + local_metadata = None layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + if config is not None: - layers_state_dict[f"{prefix}_lora_config"] = config - return layers_state_dict + if not isinstance(config, dict): + config = config.to_dict() + local_metadata = {"library": "peft", "has_config": "true"} + for key, value in config.items(): + if isinstance(value, set): + config[key] = list(value) + config_as_string = json.dumps(config, indent=2, sort_keys=True) + local_metadata[prefix] = config_as_string - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." - ) + return layers_state_dict, local_metadata if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet", unet_lora_config)) + unet_state_dict, unet_metadata = pack_weights(unet_lora_layers, "unet", unet_lora_config) + state_dict.update(unet_state_dict) + if unet_metadata is not None: + metadata.update(unet_metadata) if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder", text_encoder_lora_config)) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2", text_encoder_2_lora_config)) + text_encoder_state_dict, text_encoder_metadata = pack_weights( + text_encoder_lora_layers, "text_encoder", text_encoder_lora_config + ) + state_dict.update(text_encoder_state_dict) + + if text_encoder_metadata is not None: + metadata.update(text_encoder_metadata) + + text_encoder_2_state_dict, text_encoder_2_metadata = pack_weights( + text_encoder_2_lora_layers, "text_encoder_2", text_encoder_2_lora_config + ) + state_dict.update(text_encoder_2_state_dict) + + if text_encoder_2_metadata is not None: + metadata.update(text_encoder_2_metadata) cls.write_lora_layers( state_dict=state_dict, @@ -1409,6 +1463,7 @@ def pack_weights(layers, prefix, config=None): weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + metadata=metadata, ) def _remove_text_encoder_monkey_patch(self): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index c29bc3d199c1..cb0539cfc83f 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -138,18 +138,16 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] # Try to retrive config. - alpha_retrieved_from_state_dict = False - config_key = list(filter(lambda x: "config" in x, list(peft_state_dict.keys()))) - if len(config_key) > 0: - config_key = config_key[0] - lora_alpha = peft_state_dict[config_key].lora_alpha - alpha_retrieved_from_state_dict = True + alpha_retrieved = False + if config is not None: + lora_alpha = config["lora_alpha"] + alpha_retrieved = True if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times @@ -162,7 +160,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True if network_alpha_dict is not None and len(network_alpha_dict) > 0: if len(set(network_alpha_dict.values())) > 1: # get the alpha occuring the most number of times - if not alpha_retrieved_from_state_dict: + if not alpha_retrieved: lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` @@ -175,7 +173,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True else: alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} else: - if not alpha_retrieved_from_state_dict: + if not alpha_retrieved: lora_alpha = set(network_alpha_dict.values()).pop() # layer names without the Diffusers specific From 20fac7bc9d1ae1bf2f4b0d563b2a0c1d3aecf86c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 17:54:07 +0530 Subject: [PATCH 03/29] better conditioning --- src/diffusers/loaders/lora.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index e70da7da9ee1..cf122d32aa54 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1407,6 +1407,8 @@ def save_lora_weights( raise ValueError( "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`." ) + + from peft import LoraConfig if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -1422,7 +1424,7 @@ def pack_weights(layers, prefix, config=None): layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} if config is not None: - if not isinstance(config, dict): + if isinstance(config, LoraConfig): config = config.to_dict() local_metadata = {"library": "peft", "has_config": "true"} for key, value in config.items(): From 981ea8259127aeb2c963c90f74656d7a5b1532b5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 17:56:31 +0530 Subject: [PATCH 04/29] assertion --- src/diffusers/loaders/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index cf122d32aa54..ab1043c605d1 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1426,10 +1426,10 @@ def pack_weights(layers, prefix, config=None): if config is not None: if isinstance(config, LoraConfig): config = config.to_dict() - local_metadata = {"library": "peft", "has_config": "true"} for key, value in config.items(): if isinstance(value, set): config[key] = list(value) + assert isinstance(config, dict) config_as_string = json.dumps(config, indent=2, sort_keys=True) local_metadata[prefix] = config_as_string From cf132fb6b03a5e66bbd90499f0fd7b97feeaf311 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 17:57:51 +0530 Subject: [PATCH 05/29] debug --- src/diffusers/loaders/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index ab1043c605d1..14d6442f48c0 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1429,7 +1429,8 @@ def pack_weights(layers, prefix, config=None): for key, value in config.items(): if isinstance(value, set): config[key] = list(value) - assert isinstance(config, dict) + print(isinstance(config, dict)) + config_as_string = json.dumps(config, indent=2, sort_keys=True) local_metadata[prefix] = config_as_string From e4c00bc5c2a7e15f57a0294dd40fc44664e79185 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:03:21 +0530 Subject: [PATCH 06/29] debug --- src/diffusers/loaders/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 14d6442f48c0..3cb3f4f4e82b 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1424,11 +1424,11 @@ def pack_weights(layers, prefix, config=None): layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} if config is not None: - if isinstance(config, LoraConfig): - config = config.to_dict() + config = config.to_dict() for key, value in config.items(): if isinstance(value, set): config[key] = list(value) + print(isinstance(config, dict)) config_as_string = json.dumps(config, indent=2, sort_keys=True) From 3b27b230826761d2c179c70bc8551448fe89fc33 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:05:35 +0530 Subject: [PATCH 07/29] dehug --- src/diffusers/loaders/lora.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 3cb3f4f4e82b..11a26353ba31 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1410,6 +1410,8 @@ def save_lora_weights( from peft import LoraConfig + print(isinstance(unet_lora_config, LoraConfig)) + if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." From 0d08249c9f5dfce5f6aca50e4b91500d9fdda47f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:16:35 +0530 Subject: [PATCH 08/29] ifx? --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 12 ++++++------ src/diffusers/loaders/lora.py | 5 ++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index d188eb198073..ed9537a114ae 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1040,13 +1040,13 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = get_peft_model_state_dict(model) - unet_lora_config = model.peft_config + unet_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) - text_encoder_one_lora_config = model.peft_config + text_encoder_one_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) - text_encoder_two_lora_config = model.peft_config + text_encoder_two_lora_config = model.peft_config["default"] else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1626,15 +1626,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) unet_lora_layers = get_peft_model_state_dict(unet) - unet_lora_config = unet.peft_config + unet_lora_config = unet.peft_config["default"] if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) - text_encoder_one_lora_config = text_encoder_one.peft_config - text_encoder_two_lora_config = text_encoder_two.peft_config + text_encoder_one_lora_config = text_encoder_one.peft_config["default"] + text_encoder_two_lora_config = text_encoder_two.peft_config["default"] else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 11a26353ba31..2ce73ff7f36b 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1410,8 +1410,6 @@ def save_lora_weights( from peft import LoraConfig - print(isinstance(unet_lora_config, LoraConfig)) - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." @@ -1426,7 +1424,8 @@ def pack_weights(layers, prefix, config=None): layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} if config is not None: - config = config.to_dict() + if isinstance(config, LoraConfig): + config = config.to_dict() for key, value in config.items(): if isinstance(value, set): config[key] = list(value) From 41b9cd8787bdf384af7e7ac4f78b409c7514d722 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:17:22 +0530 Subject: [PATCH 09/29] fix? --- src/diffusers/loaders/lora.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 2ce73ff7f36b..2713ae7555a5 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1430,8 +1430,6 @@ def pack_weights(layers, prefix, config=None): if isinstance(value, set): config[key] = list(value) - print(isinstance(config, dict)) - config_as_string = json.dumps(config, indent=2, sort_keys=True) local_metadata[prefix] = config_as_string From b868e8a2fcd0abde486485938d98e71a32c5f5e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:20:20 +0530 Subject: [PATCH 10/29] ifx --- src/diffusers/loaders/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 2713ae7555a5..a0925e731396 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1419,7 +1419,7 @@ def save_lora_weights( metadata = {} def pack_weights(layers, prefix, config=None): - local_metadata = None + local_metadata = {} layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} From c341111d694c8d45dfc19057f7ab2c22aca2d601 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:22:20 +0530 Subject: [PATCH 11/29] ifx --- src/diffusers/loaders/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index a0925e731396..4f0db508f3b6 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1327,6 +1327,7 @@ def load_lora_weights( unet_config=self.unet.config, **kwargs, ) + print(f"Metadata: {metadata}") is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") From a2792cd942d690cf0213b6079d6d0d3a120ce408 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:24:09 +0530 Subject: [PATCH 12/29] unwrap --- src/diffusers/loaders/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 4f0db508f3b6..ea26e8ec016b 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -250,7 +250,7 @@ def lora_state_dict( ) state_dict = safetensors.torch.load_file(model_file, device="cpu") with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: - metadata = f.metadata() + metadata = json.loads(f.metadata()) except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e From 9ecb271ac8c1cc310ed319c0ac2edd2a5b929211 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:24:15 +0530 Subject: [PATCH 13/29] unwrap --- src/diffusers/loaders/lora.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index ea26e8ec016b..92e11da5cb56 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -250,7 +250,9 @@ def lora_state_dict( ) state_dict = safetensors.torch.load_file(model_file, device="cpu") with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: - metadata = json.loads(f.metadata()) + metadata = f.metadata() + if metadata is not None: + metadata = json.loads(metadata) except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e From 32212b6df6d29ee1249de2f8644cb9abb3660c1d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:31:13 +0530 Subject: [PATCH 14/29] json unwrap --- src/diffusers/loaders/lora.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 92e11da5cb56..9299e9bd7d0f 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -251,8 +251,6 @@ def lora_state_dict( state_dict = safetensors.torch.load_file(model_file, device="cpu") with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: metadata = f.metadata() - if metadata is not None: - metadata = json.loads(metadata) except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -444,7 +442,7 @@ def load_lora_into_unet( rank[key] = val.shape[1] if config is not None and len(config) > 0: - config = config["unet"] + config = json.loads(config["unet"]) lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) lora_config = LoraConfig(**lora_config_kwargs) @@ -579,7 +577,7 @@ def load_lora_into_text_encoder( from peft import LoraConfig if config is not None and len(config) > 0: - config = config[prefix] + config = json.loads(config[prefix]) lora_config_kwargs = get_peft_kwargs( rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False ) @@ -1410,7 +1408,7 @@ def save_lora_weights( raise ValueError( "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`." ) - + from peft import LoraConfig if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): @@ -1432,7 +1430,7 @@ def pack_weights(layers, prefix, config=None): for key, value in config.items(): if isinstance(value, set): config[key] = list(value) - + config_as_string = json.dumps(config, indent=2, sort_keys=True) local_metadata[prefix] = config_as_string From ed333f06ae93f01261118f5c4be42ff8d2965b02 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:31:52 +0530 Subject: [PATCH 15/29] remove print --- src/diffusers/loaders/lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 9299e9bd7d0f..bce37eb00743 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1327,7 +1327,6 @@ def load_lora_weights( unet_config=self.unet.config, **kwargs, ) - print(f"Metadata: {metadata}") is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") From fdb114618d2faf8861e401c39cde38ce566de727 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:33:00 +0530 Subject: [PATCH 16/29] Empty-Commit Co-authored-by: pacman100 <13534540+pacman100@users.noreply.github.com> From bcf0f4a78980f440297a8a862796b0215c5cf071 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 18:44:05 +0530 Subject: [PATCH 17/29] fix --- src/diffusers/loaders/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index bce37eb00743..3da90dcb410a 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -861,7 +861,7 @@ def write_lora_layers( safe_serialization: bool, metadata=None, ): - if not safe_serialization and len(metadata) > 0: + if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0: raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.") if os.path.isfile(save_directory): From 24cb282d367438593f4f8413c903cbc900d46c03 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 21:38:02 +0530 Subject: [PATCH 18/29] fix --- src/diffusers/loaders/lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 3da90dcb410a..9b17fc2f5f5d 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -441,7 +441,7 @@ def load_lora_into_unet( if "lora_B" in key: rank[key] = val.shape[1] - if config is not None and len(config) > 0: + if config is not None and isinstance(config, dict) and len(config) > 0: config = json.loads(config["unet"]) lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) lora_config = LoraConfig(**lora_config_kwargs) @@ -1407,8 +1407,8 @@ def save_lora_weights( raise ValueError( "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`." ) - - from peft import LoraConfig + else: + from peft import LoraConfig if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( From f4adaae5cb9a21477a4b36e627b160daebd70c62 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 12:49:01 +0530 Subject: [PATCH 19/29] move config related stuff in a separate utility. --- src/diffusers/loaders/lora.py | 78 ++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 25 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 9b17fc2f5f5d..e4663f46c048 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -820,26 +820,51 @@ def save_lora_weights( if not USE_PEFT_BACKEND and not safe_serialization: if unet_lora_config or text_encoder_lora_config: raise ValueError( - "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`. It also requires `safe_serialization` to be set to True." + "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`." ) + else: + from peft import LoraConfig + + if not (unet_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.") state_dict = {} + metadata = {} - def pack_weights(layers, prefix, config=None): + def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - if config is not None: - layers_state_dict[f"{prefix}_lora_config"] = config + return layers_state_dict - if not (unet_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") + def pack_metadata(config, prefix): + local_metadata = {} + if config is not None: + if isinstance(config, LoraConfig): + config = config.to_dict() + for key, value in config.items(): + if isinstance(value, set): + config[key] = list(value) + + config_as_string = json.dumps(config, indent=2, sort_keys=True) + local_metadata[prefix] = config_as_string + return local_metadata if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet", config=unet_lora_config)) + prefix = "unet" + unet_state_dict = pack_weights(unet_lora_layers, prefix) + state_dict.update(unet_state_dict) + if unet_lora_config is not None: + unet_metadata = pack_metadata(unet_lora_config, prefix) + metadata.update(unet_metadata) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder", config=text_encoder_lora_config)) + prefix = "text_encoder" + text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder") + state_dict.update(text_encoder_state_dict) + if text_encoder_lora_config is not None: + text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix) + metadata.update(text_encoder_metadata) # Save the model cls.write_lora_layers( @@ -849,6 +874,7 @@ def pack_weights(layers, prefix, config=None): weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + metadata=metadata, ) @staticmethod @@ -1402,7 +1428,7 @@ def save_lora_weights( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ - if not USE_PEFT_BACKEND: + if not USE_PEFT_BACKEND and not safe_serialization: if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config: raise ValueError( "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`." @@ -1418,11 +1444,14 @@ def save_lora_weights( state_dict = {} metadata = {} - def pack_weights(layers, prefix, config=None): - local_metadata = {} + def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + def pack_metadata(config, prefix): + local_metadata = {} if config is not None: if isinstance(config, LoraConfig): config = config.to_dict() @@ -1432,30 +1461,29 @@ def pack_weights(layers, prefix, config=None): config_as_string = json.dumps(config, indent=2, sort_keys=True) local_metadata[prefix] = config_as_string - - return layers_state_dict, local_metadata + return local_metadata if unet_lora_layers: - unet_state_dict, unet_metadata = pack_weights(unet_lora_layers, "unet", unet_lora_config) + prefix = "unet" + unet_state_dict = pack_weights(unet_lora_layers, prefix) state_dict.update(unet_state_dict) - if unet_metadata is not None: + if unet_lora_config is not None: + unet_metadata = pack_metadata(unet_lora_config, prefix) metadata.update(unet_metadata) if text_encoder_lora_layers and text_encoder_2_lora_layers: - text_encoder_state_dict, text_encoder_metadata = pack_weights( - text_encoder_lora_layers, "text_encoder", text_encoder_lora_config - ) + prefix = "text_encoder" + text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder") state_dict.update(text_encoder_state_dict) - - if text_encoder_metadata is not None: + if text_encoder_lora_config is not None: + text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix) metadata.update(text_encoder_metadata) - text_encoder_2_state_dict, text_encoder_2_metadata = pack_weights( - text_encoder_2_lora_layers, "text_encoder_2", text_encoder_2_lora_config - ) + prefix = "text_encoder_2" + text_encoder_2_state_dict = pack_weights(text_encoder_2_lora_layers, prefix) state_dict.update(text_encoder_2_state_dict) - - if text_encoder_2_metadata is not None: + if text_encoder_2_lora_config is not None: + text_encoder_2_metadata = pack_metadata(text_encoder_2_lora_config, prefix) metadata.update(text_encoder_2_metadata) cls.write_lora_layers( From 57a16f35eef00cee586e8001bab6dc2e49cb6b5f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 12:59:32 +0530 Subject: [PATCH 20/29] fix: import error --- src/diffusers/loaders/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index e4663f46c048..5cdd75b213a8 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -822,7 +822,7 @@ def save_lora_weights( raise ValueError( "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`." ) - else: + elif USE_PEFT_BACKEND and safe_serialization: from peft import LoraConfig if not (unet_lora_layers or text_encoder_lora_layers): @@ -1433,7 +1433,7 @@ def save_lora_weights( raise ValueError( "Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`." ) - else: + elif USE_PEFT_BACKEND and safe_serialization: from peft import LoraConfig if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): From d24e7d3ea9bf63682e152ef53ef532f8449b9428 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 14:14:33 +0530 Subject: [PATCH 21/29] debug --- src/diffusers/loaders/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 5cdd75b213a8..40e6dc83c20d 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -443,6 +443,7 @@ def load_lora_into_unet( if config is not None and isinstance(config, dict) and len(config) > 0: config = json.loads(config["unet"]) + print(f"From LoRA loading: {config} ") lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) lora_config = LoraConfig(**lora_config_kwargs) From 09618d09a6aaf99f1e2e09ed8e90a5668035ac14 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 14:24:28 +0530 Subject: [PATCH 22/29] remove print --- src/diffusers/loaders/lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 40e6dc83c20d..5cdd75b213a8 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -443,7 +443,6 @@ def load_lora_into_unet( if config is not None and isinstance(config, dict) and len(config) > 0: config = json.loads(config["unet"]) - print(f"From LoRA loading: {config} ") lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True) lora_config = LoraConfig(**lora_config_kwargs) From ec9df6fc480d5d1f9e2bc860610566774fd4b5cb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 17:14:39 +0530 Subject: [PATCH 23/29] simplify condition. --- src/diffusers/utils/peft_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index cb0539cfc83f..2c0a6c7670f7 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -172,9 +172,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, } else: alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} - else: - if not alpha_retrieved: - lora_alpha = set(network_alpha_dict.values()).pop() + elif not alpha_retrieved: + lora_alpha = set(network_alpha_dict.values()).pop() # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) From 16ac1b2f4fa75a608448323f87df1cbc3d950409 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 17:46:48 +0530 Subject: [PATCH 24/29] propagate changes to sd dreambooth lora. --- examples/dreambooth/train_dreambooth_lora.py | 20 ++++++++++++++++--- .../dreambooth/train_dreambooth_lora_sdxl.py | 10 ++++++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 60213dd75685..7c7a4ace4309 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -880,11 +880,16 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers_to_save = None text_encoder_lora_layers_to_save = None + unet_lora_config = None + text_encoder_lora_config = None + for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = get_peft_model_state_dict(model) + unet_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): text_encoder_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_lora_config = model.peft_config["default"] else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -895,6 +900,8 @@ def save_model_hook(models, weights, output_dir): output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_lora_config, ) def load_model_hook(models, input_dir): @@ -911,10 +918,12 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet( + lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata + ) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -1315,17 +1324,22 @@ def compute_text_embeddings(prompt): unet = unet.to(torch.float32) unet_lora_state_dict = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] if args.train_text_encoder: text_encoder = accelerator.unwrap_model(text_encoder) text_encoder_state_dict = get_peft_model_state_dict(text_encoder) + text_encoder_lora_config = text_encoder.peft_config["default"] else: text_encoder_state_dict = None + text_encoder_lora_config = None LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=text_encoder_state_dict, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_lora_config, ) # Final inference diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index ed9537a114ae..d8fd11e34c43 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1080,17 +1080,19 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet( + lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata + ) text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata ) text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata ) accelerator.register_save_state_pre_hook(save_model_hook) From ece6d89cf22b9b7c615c29b2cc7a46bcd977ddb8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 17:49:42 +0530 Subject: [PATCH 25/29] propagate to sd t2i lora fine-tuning --- examples/text_to_image/train_text_to_image_lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b63500f906a8..fc55815228e5 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -833,10 +833,12 @@ def collate_fn(examples): accelerator.save_state(save_path) unet_lora_state_dict = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] StableDiffusionPipeline.save_lora_weights( save_directory=save_path, unet_lora_layers=unet_lora_state_dict, + unet_lora_config=unet_lora_config, safe_serialization=True, ) @@ -898,10 +900,12 @@ def collate_fn(examples): unet = unet.to(torch.float32) unet_lora_state_dict = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] StableDiffusionPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, safe_serialization=True, + unet_lora_config=unet_lora_config, ) if args.push_to_hub: From 8c98a187c78904ce132f31704badc0f4cce0c51a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 17:53:54 +0530 Subject: [PATCH 26/29] propagate to sdxl t2i lora fine-tuning --- .../train_text_to_image_lora_sdxl.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 2e70c77e860e..20a6731e716d 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -682,13 +682,20 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None + unet_lora_config = None + text_encoder_one_lora_config = None + text_encoder_two_lora_config = None + for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = get_peft_model_state_dict(model) + unet_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_one_lora_config = model.peft_config["default"] elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_two_lora_config = model.peft_config["default"] else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -700,6 +707,9 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_one_lora_config, + text_encoder_2_lora_config=text_encoder_two_lora_config, ) def load_model_hook(models, input_dir): @@ -719,17 +729,19 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet( + lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata + ) text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata ) text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -1194,6 +1206,7 @@ def compute_time_ids(original_size, crops_coords_top_left): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet_lora_state_dict = get_peft_model_state_dict(unet) + unet_lora_config = unet.peft_config["default"] if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) @@ -1201,15 +1214,23 @@ def compute_time_ids(original_size, crops_coords_top_left): text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one) text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two) + + text_encoder_one_lora_config = text_encoder_one.peft_config["default"] + text_encoder_two_lora_config = text_encoder_two.peft_config["default"] else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None + text_encoder_one_lora_config = None + text_encoder_two_lora_config = None StableDiffusionXLPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_one_lora_config, + text_encoder_2_lora_config=text_encoder_two_lora_config, ) del unet From 765fef71347f49fbb783a98ac1bad3a65cea9367 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 17:58:26 +0530 Subject: [PATCH 27/29] add: doc strings. --- src/diffusers/loaders/lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 5cdd75b213a8..231ed3ade88f 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -382,6 +382,8 @@ def load_lora_into_unet( See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. + config: (`Dict`): + LoRA configuration parsed from the state dict. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. @@ -503,6 +505,8 @@ def load_lora_into_text_encoder( The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. + config (`Dict`): + LoRA configuration parsed from state dict. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. From f145d48ed7ad0d1eb12d54fba607646372b12739 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 18:23:07 +0530 Subject: [PATCH 28/29] add test --- tests/lora/test_lora_layers_peft.py | 72 +++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 6d3ac8b4592a..ed3784d66df2 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -107,8 +107,9 @@ class PeftLoraLoaderMixinTests: unet_kwargs = None vae_kwargs = None - def get_dummy_components(self, scheduler_cls=None): + def get_dummy_components(self, scheduler_cls=None, lora_alpha=None): scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler + lora_alpha = 4 if lora_alpha is None else lora_alpha torch.manual_seed(0) unet = UNet2DConditionModel(**self.unet_kwargs) @@ -123,11 +124,14 @@ def get_dummy_components(self, scheduler_cls=None): tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") text_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False + r=4, + lora_alpha=lora_alpha, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + init_lora_weights=False, ) unet_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False ) unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) @@ -714,6 +718,68 @@ def test_simple_inference_with_text_unet_lora_unloaded(self): "Fused lora should change the output", ) + def test_if_lora_alpha_is_correctly_parsed(self): + lora_alpha = 8 + + components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.unet.add_adapter(unet_lora_config) + pipe.text_encoder.add_adapter(text_lora_config) + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + + # Inference works? + _ = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + unet_state_dict = get_peft_model_state_dict(pipe.unet) + unet_lora_config = pipe.unet.peft_config["default"] + + text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) + text_encoder_lora_config = pipe.text_encoder.peft_config["default"] + + if self.has_two_text_encoders: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"] + + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=unet_state_dict, + text_encoder_lora_layers=text_encoder_state_dict, + text_encoder_2_lora_layers=text_encoder_2_state_dict, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_lora_config, + text_encoder_2_lora_config=text_encoder_2_lora_config, + ) + else: + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=unet_state_dict, + text_encoder_lora_layers=text_encoder_state_dict, + unet_lora_config=unet_lora_config, + text_encoder_lora_config=text_encoder_lora_config, + ) + loaded_pipe = self.pipeline_class(**components) + loaded_pipe.load_lora_weights(tmpdirname) + + # Inference works? + _ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images + + assert ( + loaded_pipe.unet.peft_config["default"]["lora_alpha"] == lora_alpha + ), "LoRA alpha not correctly loaded for UNet." + assert ( + loaded_pipe.text_encoder.peft_config["default"]["lora_alpha"] == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder." + if self.has_two_text_encoders: + assert ( + loaded_pipe.text_encoder_2.peft_config["default"]["lora_alpha"] == lora_alpha + ), "LoRA alpha not correctly loaded for text encoder 2." + def test_simple_inference_with_text_unet_lora_unfused(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights From 5d04eebd1f08ff9ed884c1b94b2f692bdeaf0ffc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Dec 2023 18:32:57 +0530 Subject: [PATCH 29/29] fix attribute access. --- tests/lora/test_lora_layers_peft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index ed3784d66df2..5a47d2e2c3ac 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -770,14 +770,14 @@ def test_if_lora_alpha_is_correctly_parsed(self): _ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images assert ( - loaded_pipe.unet.peft_config["default"]["lora_alpha"] == lora_alpha + loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha ), "LoRA alpha not correctly loaded for UNet." assert ( - loaded_pipe.text_encoder.peft_config["default"]["lora_alpha"] == lora_alpha + loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha ), "LoRA alpha not correctly loaded for text encoder." if self.has_two_text_encoders: assert ( - loaded_pipe.text_encoder_2.peft_config["default"]["lora_alpha"] == lora_alpha + loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha ), "LoRA alpha not correctly loaded for text encoder 2." def test_simple_inference_with_text_unet_lora_unfused(self):