From dc8d68a279fdbc345db7cd27898071e2c23677c9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 11 Feb 2024 11:40:49 +0530 Subject: [PATCH 1/4] fix: bias loading bug --- scripts/convert_diffusers_to_original_stable_diffusion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index cc90a5131732..d1b7df070c43 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -170,7 +170,10 @@ def convert_unet_state_dict(unet_state_dict): def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) + if not w.ndim == 1: + return w.reshape(*w.shape, 1, 1) + else: + return w def convert_vae_state_dict(vae_state_dict): From 0c238d821d2b30a3e1461be9015b6728131d6858 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 11 Feb 2024 16:17:18 +0530 Subject: [PATCH 2/4] fixes for SDXL --- scripts/convert_diffusers_to_original_sdxl.py | 5 ++++- src/diffusers/loaders/single_file_utils.py | 3 +-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/scripts/convert_diffusers_to_original_sdxl.py b/scripts/convert_diffusers_to_original_sdxl.py index 1f11ef457068..62ca102554e3 100644 --- a/scripts/convert_diffusers_to_original_sdxl.py +++ b/scripts/convert_diffusers_to_original_sdxl.py @@ -167,7 +167,10 @@ def convert_unet_state_dict(unet_state_dict): def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) + if not w.ndim == 1: + return w.reshape(*w.shape, 1, 1) + else: + return w def convert_vae_state_dict(vae_state_dict): diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 1df964fe413e..5bf1506355e3 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -178,7 +178,7 @@ LDM_UNET_KEY = "model.diffusion_model." LDM_CONTROLNET_KEY = "control_model." LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."] -LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 +LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1280 SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", @@ -1112,7 +1112,6 @@ def create_text_encoder_from_open_clip_checkpoint( text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim] text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2] text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :] - else: text_model_dict[diffusers_key] = checkpoint[key] From 897257ec5aaaa456872387e1c1d70628a21a4c16 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 12 Feb 2024 11:11:35 +0530 Subject: [PATCH 3/4] apply changes to the conversion script to match single_file_utils.py --- scripts/convert_diffusers_to_original_sdxl.py | 5 +++++ src/diffusers/loaders/single_file_utils.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/convert_diffusers_to_original_sdxl.py b/scripts/convert_diffusers_to_original_sdxl.py index 62ca102554e3..6eb6b0368bd5 100644 --- a/scripts/convert_diffusers_to_original_sdxl.py +++ b/scripts/convert_diffusers_to_original_sdxl.py @@ -324,11 +324,16 @@ def convert_openai_text_enc_state_dict(text_enc_dict): vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + # Convert text encoder 1 text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict) text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()} + # Convert text encoder 2 text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict) text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()} + text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop( + "conditioner.embedders.1.model.text_projection.weight" + ) # Put together new checkpoint state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict} diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 5bf1506355e3..3a9f6e88238a 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -178,7 +178,7 @@ LDM_UNET_KEY = "model.diffusion_model." LDM_CONTROLNET_KEY = "control_model." LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."] -LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1280 +LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", From 0210370447d82e95a499e165b70568911e8201fb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 12 Feb 2024 11:41:05 +0530 Subject: [PATCH 4/4] do transpose to match the single file loading logic. --- scripts/convert_diffusers_to_original_sdxl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/convert_diffusers_to_original_sdxl.py b/scripts/convert_diffusers_to_original_sdxl.py index 6eb6b0368bd5..648d0376f72e 100644 --- a/scripts/convert_diffusers_to_original_sdxl.py +++ b/scripts/convert_diffusers_to_original_sdxl.py @@ -331,9 +331,11 @@ def convert_openai_text_enc_state_dict(text_enc_dict): # Convert text encoder 2 text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict) text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()} + # We call the `.T.contiguous()` to match what's done in + # https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085 text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop( "conditioner.embedders.1.model.text_projection.weight" - ) + ).T.contiguous() # Put together new checkpoint state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}