From 0256560f32064f01bc990dc269c6eba7553c7f94 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 22 Aug 2024 09:12:57 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file_utils.py | 24 ++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 2ca37630e7c4..0c3bd575a0ca 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -79,7 +79,10 @@ "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", "animatediff_rgb": "controlnet_cond_embedding.weight", - "flux": "double_blocks.0.img_attn.norm.key_norm.scale", + "flux": [ + "double_blocks.0.img_attn.norm.key_norm.scale", + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", + ], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -258,7 +261,7 @@ "timestep_spacing": "leading", } -LDM_VAE_KEY = "first_stage_model." +LDM_VAE_KEYS = ["first_stage_model.", "vae."] LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 PLAYGROUND_VAE_SCALING_FACTOR = 0.5 LDM_UNET_KEY = "model.diffusion_model." @@ -267,7 +270,6 @@ "cond_stage_model.transformer.", "conditioner.embedders.0.transformer.", ] -OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] @@ -518,8 +520,10 @@ def infer_diffusers_model_type(checkpoint): else: model_type = "animatediff_v3" - elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint: - if "guidance_in.in_layer.bias" in checkpoint: + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): + if any( + g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] + ): model_type = "flux-dev" else: model_type = "flux-schnell" @@ -1178,7 +1182,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config): # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys vae_state_dict = {} keys = list(checkpoint.keys()) - vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else "" + vae_key = "" + for ldm_vae_key in LDM_VAE_KEYS: + if any(k.startswith(ldm_vae_key) for k in keys): + vae_key = ldm_vae_key + for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) @@ -1883,6 +1891,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401