From 9b78ab25010031a81fb3bae85ad2fe29cd5bc9a2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 5 May 2025 17:15:19 +0530 Subject: [PATCH 1/2] use removeprefix to preserve sanity. --- src/diffusers/loaders/lora_base.py | 4 ++-- src/diffusers/loaders/lora_pipeline.py | 4 ++-- src/diffusers/loaders/peft.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 280a9fa6e73f..c1050994f00d 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -348,7 +348,7 @@ def _load_lora_into_text_encoder( # Load the layers corresponding to text encoder and make necessary adjustments. if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict = {k.removeprefix(prefix + "."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: logger.info(f"Loading {prefix}.") @@ -374,7 +374,7 @@ def _load_lora_into_text_encoder( if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + network_alphas = {k.removeprefix(prefix + "."): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1a6768e70de4..317b6059d834 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2103,7 +2103,7 @@ def _load_norm_into_transformer( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + state_dict[key.removeprefix(prefix + ".")] = state_dict.pop(key) # Find invalid keys transformer_state_dict = transformer.state_dict() @@ -2425,7 +2425,7 @@ def _maybe_expand_transformer_param_shape_or_error_( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + state_dict[key.removeprefix(prefix + ".")] = state_dict.pop(key) # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index bbef5b1628cb..2bb653166dbc 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -230,7 +230,7 @@ def load_lora_adapter( raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict = {k.removeprefix(prefix + "."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -261,7 +261,9 @@ def load_lora_adapter( if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + network_alphas = { + k.removeprefix(prefix + "."): v for k, v in network_alphas.items() if k in alpha_keys + } lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) _maybe_raise_error_for_ambiguity(lora_config_kwargs) From 90b2abf199fbe706dfe0232b922885e246c209dd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 May 2025 10:37:13 +0530 Subject: [PATCH 2/2] f-string. --- src/diffusers/loaders/lora_base.py | 4 ++-- src/diffusers/loaders/lora_pipeline.py | 4 ++-- src/diffusers/loaders/peft.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index c1050994f00d..1377807b3e85 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -348,7 +348,7 @@ def _load_lora_into_text_encoder( # Load the layers corresponding to text encoder and make necessary adjustments. if prefix is not None: - state_dict = {k.removeprefix(prefix + "."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: logger.info(f"Loading {prefix}.") @@ -374,7 +374,7 @@ def _load_lora_into_text_encoder( if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.removeprefix(prefix + "."): v for k, v in network_alphas.items() if k in alpha_keys} + network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 317b6059d834..810ec8adb1e0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2103,7 +2103,7 @@ def _load_norm_into_transformer( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key.removeprefix(prefix + ".")] = state_dict.pop(key) + state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) # Find invalid keys transformer_state_dict = transformer.state_dict() @@ -2425,7 +2425,7 @@ def _maybe_expand_transformer_param_shape_or_error_( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key.removeprefix(prefix + ".")] = state_dict.pop(key) + state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 2bb653166dbc..b7da4fb746f6 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -230,7 +230,7 @@ def load_lora_adapter( raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - state_dict = {k.removeprefix(prefix + "."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -262,7 +262,7 @@ def load_lora_adapter( if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = { - k.removeprefix(prefix + "."): v for k, v in network_alphas.items() if k in alpha_keys + k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys } lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)