Skip to content

Commit 6148157

Browse files
committed
update
1 parent 7d7bd7f commit 6148157

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ def download_from_original_stable_diffusion_ckpt(
11491149
adapter: Optional[bool] = None,
11501150
load_safety_checker: bool = True,
11511151
pipeline_class: DiffusionPipeline = None,
1152-
local_files_only: bool = False,
1152+
local_files_only=False,
11531153
vae_path=None,
11541154
vae=None,
11551155
text_encoder=None,
@@ -1696,7 +1696,7 @@ def download_from_original_stable_diffusion_ckpt(
16961696
elif model_type in ["SDXL", "SDXL-Refiner"]:
16971697
is_refiner = model_type == "SDXL-Refiner"
16981698

1699-
if tokenizer is None:
1699+
if (is_refiner is False) and (tokenizer is None):
17001700
try:
17011701
tokenizer = CLIPTokenizer.from_pretrained(
17021702
"openai/clip-vit-large-patch14", local_files_only=local_files_only
@@ -1705,7 +1705,8 @@ def download_from_original_stable_diffusion_ckpt(
17051705
raise ValueError(
17061706
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
17071707
)
1708-
if text_encoder is None:
1708+
1709+
if (is_refiner is False) and (text_encoder is None):
17091710
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
17101711

17111712
if tokenizer_2 is None:
@@ -1762,29 +1763,22 @@ def download_from_original_stable_diffusion_ckpt(
17621763
)
17631764

17641765
else:
1765-
if pipeline_class == StableDiffusionXLImg2ImgPipeline:
1766-
pipe = pipeline_class(
1767-
vae=vae,
1768-
text_encoder=text_encoder,
1769-
tokenizer=tokenizer,
1770-
text_encoder_2=text_encoder_2,
1771-
tokenizer_2=tokenizer_2,
1772-
unet=unet,
1773-
scheduler=scheduler,
1774-
force_zeros_for_empty_prompt=False,
1775-
requires_aesthetics_score=is_refiner,
1776-
)
1777-
else:
1778-
pipe = pipeline_class(
1779-
vae=vae,
1780-
text_encoder=text_encoder,
1781-
tokenizer=tokenizer,
1782-
text_encoder_2=text_encoder_2,
1783-
tokenizer_2=tokenizer_2,
1784-
unet=unet,
1785-
scheduler=scheduler,
1786-
force_zeros_for_empty_prompt=True,
1787-
)
1766+
pipeline_kwargs = {
1767+
"vae": vae,
1768+
"text_encoder": text_encoder,
1769+
"tokenizer": tokenizer,
1770+
"text_encoder_2": text_encoder_2,
1771+
"tokenizer_2": tokenizer_2,
1772+
"unet": unet,
1773+
"scheduler": scheduler,
1774+
}
1775+
1776+
if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (
1777+
pipeline_class == StableDiffusionXLInpaintPipeline
1778+
):
1779+
pipeline_kwargs.update({"requires_aesthetics_score": is_refiner})
1780+
1781+
pipe = pipeline_class(**pipeline_kwargs)
17881782
else:
17891783
text_config = create_ldm_bert_config(original_config)
17901784
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)

0 commit comments

Comments
 (0)