Skip to content

Commit

Permalink
[From ckpt] Fix from_ckpt (#3466)
Browse files Browse the repository at this point in the history
* Correct from_ckpt

* make style
  • Loading branch information
patrickvonplaten committed May 17, 2023
1 parent 88295f9 commit 2858d7e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"

if from_safetensors and use_safetensors is True:
if from_safetensors and use_safetensors is False:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")

# TODO: For now we only support stable diffusion
Expand Down
22 changes: 13 additions & 9 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")

new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")

new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")

new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")

new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")

new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)

Expand Down Expand Up @@ -204,8 +204,12 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"])

# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
shape = old_checkpoint[path["old"]].shape
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]

Expand Down

0 comments on commit 2858d7e

Please sign in to comment.