Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions scripts/convert_diffusers_to_original_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def reshape_weight_for_sd(w):


def convert_vae_state_dict(vae_state_dict):
# print(vae_state_dict["encoder.mid_block.attentions.0.to_k.bias"].shape)

mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
Expand All @@ -188,6 +190,7 @@ def convert_vae_state_dict(vae_state_dict):
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}

weights_to_convert = ["q", "k", "v", "proj_out"]
keys_to_rename = {}
for k, v in new_state_dict.items():
Expand All @@ -198,9 +201,13 @@ def convert_vae_state_dict(vae_state_dict):
for weight_name, real_weight_name in vae_extra_conversion_map:
if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
keys_to_rename[k] = k.replace(weight_name, real_weight_name)

# print(keys_to_rename)
for k, v in keys_to_rename.items():
if k in new_state_dict:
print(f"Renaming {k} to {v}")
if "encoder.mid.attn_1.k.bias" in v:
print(new_state_dict[k].shape, reshape_weight_for_sd(new_state_dict[k]).shape)
new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
del new_state_dict[k]
return new_state_dict
Expand Down
27 changes: 24 additions & 3 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"xl_refiner": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml",
"upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml",
"controlnet": "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml",
"test_single_file_sd": "https://huggingface.co/datasets/sayakpaul/sample-datasets/raw/main/tiny-sd-single-file-config.yaml",
}

CHECKPOINT_KEY_NAMES = {
Expand Down Expand Up @@ -278,6 +279,12 @@ def infer_original_config_file(class_name, checkpoint):
elif class_name == "ControlNetModel":
config_url = CONFIG_URLS["controlnet"]

elif len(checkpoint) == 512:
config_url = CONFIG_URLS["test_single_file_sd"]

elif len(checkpoint) == 701:
config_url = CONFIG_URLS["test_single_file_sdxl"]
Comment on lines +282 to +286
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's any other sane way for us to verify this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we add this into the test module? Why have it in src if it's only relevant to tests? We can pass it in using Pipeline.from_single_file(original_config_file="<file url>"). original_config_file can be a URL now.


else:
config_url = CONFIG_URLS["v1"]

Expand Down Expand Up @@ -1010,9 +1017,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint


def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False):
def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False, subfolder=None):
try:
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
if subfolder is None:
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
else:
config = CLIPTextConfig.from_pretrained(config_name, subfolder=subfolder)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
Expand Down Expand Up @@ -1189,6 +1199,7 @@ def create_diffusers_vae_model_from_ldm(

if is_accelerate_available():
for param_name, param in diffusers_format_vae_checkpoint.items():
# print(param_name, param.shape)
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else:
vae.load_state_dict(diffusers_format_vae_checkpoint)
Expand Down Expand Up @@ -1231,7 +1242,17 @@ def create_text_encoders_and_tokenizers_from_ldm(
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)

except Exception:
raise ValueError(
try:
if not local_files_only:
config_name = "hf-internal-testing/tiny-sd-pipe"
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
config_name, checkpoint, local_files_only=False, subfolder="text_encoder"
)
tokenizer = CLIPTokenizer.from_pretrained(config_name, subfolder="tokenizer", local_files_only=False)
else:
raise ValueError("This option needs `local_files_only` set to False.")
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
)
else:
Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
f"Unable to load weights from checkpoint file for '{checkpoint_file} at '{checkpoint_file}'."
)


Expand Down