Skip to content

Commit ecded50

Browse files
realliujiaxuDN6
andauthored
add convert diffuser pipeline of XL to original stable diffusion (#4596)
convert diffuser pipeline of XL to original stable diffusion Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent e34d9aa commit ecded50

File tree

1 file changed

+340
-0
lines changed

1 file changed

+340
-0
lines changed
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2+
# *Only* converts the UNet, VAE, and Text Encoder.
3+
# Does not convert optimizer state or any other thing.
4+
5+
import argparse
6+
import os.path as osp
7+
import re
8+
9+
import torch
10+
from safetensors.torch import load_file, save_file
11+
12+
13+
# =================#
14+
# UNet Conversion #
15+
# =================#
16+
17+
unet_conversion_map = [
18+
# (stable-diffusion, HF Diffusers)
19+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
20+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
21+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
22+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
23+
("input_blocks.0.0.weight", "conv_in.weight"),
24+
("input_blocks.0.0.bias", "conv_in.bias"),
25+
("out.0.weight", "conv_norm_out.weight"),
26+
("out.0.bias", "conv_norm_out.bias"),
27+
("out.2.weight", "conv_out.weight"),
28+
("out.2.bias", "conv_out.bias"),
29+
# the following are for sdxl
30+
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
31+
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
32+
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
33+
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
34+
]
35+
36+
unet_conversion_map_resnet = [
37+
# (stable-diffusion, HF Diffusers)
38+
("in_layers.0", "norm1"),
39+
("in_layers.2", "conv1"),
40+
("out_layers.0", "norm2"),
41+
("out_layers.3", "conv2"),
42+
("emb_layers.1", "time_emb_proj"),
43+
("skip_connection", "conv_shortcut"),
44+
]
45+
46+
unet_conversion_map_layer = []
47+
# hardcoded number of downblocks and resnets/attentions...
48+
# would need smarter logic for other networks.
49+
for i in range(3):
50+
# loop over downblocks/upblocks
51+
52+
for j in range(2):
53+
# loop over resnets/attentions for downblocks
54+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
55+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
56+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
57+
58+
if i > 0:
59+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
60+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
61+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
62+
63+
for j in range(4):
64+
# loop over resnets/attentions for upblocks
65+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
66+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
67+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
68+
69+
if i < 2:
70+
# no attention layers in up_blocks.0
71+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
72+
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
73+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
74+
75+
if i < 3:
76+
# no downsample in down_blocks.3
77+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
78+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
79+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
80+
81+
# no upsample in up_blocks.3
82+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
83+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
84+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
85+
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
86+
87+
hf_mid_atn_prefix = "mid_block.attentions.0."
88+
sd_mid_atn_prefix = "middle_block.1."
89+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
90+
for j in range(2):
91+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
92+
sd_mid_res_prefix = f"middle_block.{2*j}."
93+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
94+
95+
96+
def convert_unet_state_dict(unet_state_dict):
97+
# buyer beware: this is a *brittle* function,
98+
# and correct output requires that all of these pieces interact in
99+
# the exact order in which I have arranged them.
100+
mapping = {k: k for k in unet_state_dict.keys()}
101+
for sd_name, hf_name in unet_conversion_map:
102+
mapping[hf_name] = sd_name
103+
for k, v in mapping.items():
104+
if "resnets" in k:
105+
for sd_part, hf_part in unet_conversion_map_resnet:
106+
v = v.replace(hf_part, sd_part)
107+
mapping[k] = v
108+
for k, v in mapping.items():
109+
for sd_part, hf_part in unet_conversion_map_layer:
110+
v = v.replace(hf_part, sd_part)
111+
mapping[k] = v
112+
new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
113+
return new_state_dict
114+
115+
116+
# ================#
117+
# VAE Conversion #
118+
# ================#
119+
120+
vae_conversion_map = [
121+
# (stable-diffusion, HF Diffusers)
122+
("nin_shortcut", "conv_shortcut"),
123+
("norm_out", "conv_norm_out"),
124+
("mid.attn_1.", "mid_block.attentions.0."),
125+
]
126+
127+
for i in range(4):
128+
# down_blocks have two resnets
129+
for j in range(2):
130+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
131+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
132+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
133+
134+
if i < 3:
135+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
136+
sd_downsample_prefix = f"down.{i}.downsample."
137+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
138+
139+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
140+
sd_upsample_prefix = f"up.{3-i}.upsample."
141+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
142+
143+
# up_blocks have three resnets
144+
# also, up blocks in hf are numbered in reverse from sd
145+
for j in range(3):
146+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
147+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
148+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
149+
150+
# this part accounts for mid blocks in both the encoder and the decoder
151+
for i in range(2):
152+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
153+
sd_mid_res_prefix = f"mid.block_{i+1}."
154+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
155+
156+
157+
vae_conversion_map_attn = [
158+
# (stable-diffusion, HF Diffusers)
159+
("norm.", "group_norm."),
160+
# the following are for SDXL
161+
("q.", "to_q."),
162+
("k.", "to_k."),
163+
("v.", "to_v."),
164+
("proj_out.", "to_out.0."),
165+
]
166+
167+
168+
def reshape_weight_for_sd(w):
169+
# convert HF linear weights to SD conv2d weights
170+
return w.reshape(*w.shape, 1, 1)
171+
172+
173+
def convert_vae_state_dict(vae_state_dict):
174+
mapping = {k: k for k in vae_state_dict.keys()}
175+
for k, v in mapping.items():
176+
for sd_part, hf_part in vae_conversion_map:
177+
v = v.replace(hf_part, sd_part)
178+
mapping[k] = v
179+
for k, v in mapping.items():
180+
if "attentions" in k:
181+
for sd_part, hf_part in vae_conversion_map_attn:
182+
v = v.replace(hf_part, sd_part)
183+
mapping[k] = v
184+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
185+
weights_to_convert = ["q", "k", "v", "proj_out"]
186+
for k, v in new_state_dict.items():
187+
for weight_name in weights_to_convert:
188+
if f"mid.attn_1.{weight_name}.weight" in k:
189+
print(f"Reshaping {k} for SD format")
190+
new_state_dict[k] = reshape_weight_for_sd(v)
191+
return new_state_dict
192+
193+
194+
# =========================#
195+
# Text Encoder Conversion #
196+
# =========================#
197+
198+
199+
textenc_conversion_lst = [
200+
# (stable-diffusion, HF Diffusers)
201+
("transformer.resblocks.", "text_model.encoder.layers."),
202+
("ln_1", "layer_norm1"),
203+
("ln_2", "layer_norm2"),
204+
(".c_fc.", ".fc1."),
205+
(".c_proj.", ".fc2."),
206+
(".attn", ".self_attn"),
207+
("ln_final.", "text_model.final_layer_norm."),
208+
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
209+
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
210+
]
211+
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
212+
textenc_pattern = re.compile("|".join(protected.keys()))
213+
214+
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
215+
code2idx = {"q": 0, "k": 1, "v": 2}
216+
217+
218+
def convert_openclip_text_enc_state_dict(text_enc_dict):
219+
new_state_dict = {}
220+
capture_qkv_weight = {}
221+
capture_qkv_bias = {}
222+
for k, v in text_enc_dict.items():
223+
if (
224+
k.endswith(".self_attn.q_proj.weight")
225+
or k.endswith(".self_attn.k_proj.weight")
226+
or k.endswith(".self_attn.v_proj.weight")
227+
):
228+
k_pre = k[: -len(".q_proj.weight")]
229+
k_code = k[-len("q_proj.weight")]
230+
if k_pre not in capture_qkv_weight:
231+
capture_qkv_weight[k_pre] = [None, None, None]
232+
capture_qkv_weight[k_pre][code2idx[k_code]] = v
233+
continue
234+
235+
if (
236+
k.endswith(".self_attn.q_proj.bias")
237+
or k.endswith(".self_attn.k_proj.bias")
238+
or k.endswith(".self_attn.v_proj.bias")
239+
):
240+
k_pre = k[: -len(".q_proj.bias")]
241+
k_code = k[-len("q_proj.bias")]
242+
if k_pre not in capture_qkv_bias:
243+
capture_qkv_bias[k_pre] = [None, None, None]
244+
capture_qkv_bias[k_pre][code2idx[k_code]] = v
245+
continue
246+
247+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
248+
new_state_dict[relabelled_key] = v
249+
250+
for k_pre, tensors in capture_qkv_weight.items():
251+
if None in tensors:
252+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
253+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
254+
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
255+
256+
for k_pre, tensors in capture_qkv_bias.items():
257+
if None in tensors:
258+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
259+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
260+
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
261+
262+
return new_state_dict
263+
264+
265+
def convert_openai_text_enc_state_dict(text_enc_dict):
266+
return text_enc_dict
267+
268+
269+
if __name__ == "__main__":
270+
parser = argparse.ArgumentParser()
271+
272+
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
273+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
274+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
275+
parser.add_argument(
276+
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
277+
)
278+
279+
args = parser.parse_args()
280+
281+
assert args.model_path is not None, "Must provide a model path!"
282+
283+
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
284+
285+
# Path for safetensors
286+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
287+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
288+
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
289+
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
290+
291+
# Load models from safetensors if it exists, if it doesn't pytorch
292+
if osp.exists(unet_path):
293+
unet_state_dict = load_file(unet_path, device="cpu")
294+
else:
295+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
296+
unet_state_dict = torch.load(unet_path, map_location="cpu")
297+
298+
if osp.exists(vae_path):
299+
vae_state_dict = load_file(vae_path, device="cpu")
300+
else:
301+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
302+
vae_state_dict = torch.load(vae_path, map_location="cpu")
303+
304+
if osp.exists(text_enc_path):
305+
text_enc_dict = load_file(text_enc_path, device="cpu")
306+
else:
307+
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
308+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
309+
310+
if osp.exists(text_enc_2_path):
311+
text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
312+
else:
313+
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
314+
text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
315+
316+
# Convert the UNet model
317+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
318+
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
319+
320+
# Convert the VAE model
321+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
322+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
323+
324+
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
325+
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
326+
327+
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
328+
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
329+
330+
# Put together new checkpoint
331+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
332+
333+
if args.half:
334+
state_dict = {k: v.half() for k, v in state_dict.items()}
335+
336+
if args.use_safetensors:
337+
save_file(state_dict, args.checkpoint_path)
338+
else:
339+
state_dict = {"state_dict": state_dict}
340+
torch.save(state_dict, args.checkpoint_path)

0 commit comments

Comments
 (0)