Skip to content

Commit 74d902e

Browse files
zuojianghuajianghuazuoDN6
authored
add config_file to from_single_file (#4614)
* Update loaders.py add config_file to from_single_file, when the download_from_original_stable_diffusion_ckpt use * Update loaders.py add config_file to from_single_file, when the download_from_original_stable_diffusion_ckpt use * change config_file to original_config_file * make style && make quality --------- Co-authored-by: jianghua.zuo <jianghua.zuo@weimob.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent d7c4ae6 commit 74d902e

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/diffusers/loaders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
17901790
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
17911791
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
17921792
of `CLIPTokenizer` by itself if needed.
1793+
original_config_file (`str`):
1794+
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
1795+
automatically inferred by looking for a key that only exists in SD2.0 models.
17931796
kwargs (remaining dictionary of keyword arguments, *optional*):
17941797
Can be used to overwrite load and saveable variables (for example the pipeline components of the
17951798
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
@@ -1820,6 +1823,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
18201823
# import here to avoid circular dependency
18211824
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
18221825

1826+
original_config_file = kwargs.pop("original_config_file", None)
18231827
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
18241828
resume_download = kwargs.pop("resume_download", False)
18251829
force_download = kwargs.pop("force_download", False)
@@ -1936,6 +1940,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
19361940
text_encoder=text_encoder,
19371941
vae=vae,
19381942
tokenizer=tokenizer,
1943+
original_config_file=original_config_file,
19391944
)
19401945

19411946
if torch_dtype is not None:

0 commit comments

Comments
 (0)