Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7c3c3e1
[From pretrained] Speed-up loading from cache
patrickvonplaten Feb 28, 2023
d1ad4d3
up
patrickvonplaten Feb 28, 2023
19a0fdf
Fix more
patrickvonplaten Feb 28, 2023
300544a
fix one more bug
patrickvonplaten Feb 28, 2023
152f902
make style
patrickvonplaten Feb 28, 2023
d13141a
bigger refactor
patrickvonplaten Feb 28, 2023
c4a49e6
factor out function
patrickvonplaten Mar 2, 2023
513b213
Improve more
patrickvonplaten Mar 2, 2023
c4aadde
better
patrickvonplaten Mar 2, 2023
b43be19
deprecate return cache folder
patrickvonplaten Mar 2, 2023
22eeb11
Merge branch 'main' into avoid_calling_hub_if_already_downlaoded
patrickvonplaten Mar 8, 2023
a37cb95
clean up
patrickvonplaten Mar 8, 2023
0f39ab7
improve tests
patrickvonplaten Mar 8, 2023
d6a1815
up
patrickvonplaten Mar 8, 2023
79afaf2
upload
patrickvonplaten Mar 8, 2023
e4bff0b
add nice tests
patrickvonplaten Mar 8, 2023
30717b0
simplify
patrickvonplaten Mar 8, 2023
71fa6b8
finish
patrickvonplaten Mar 8, 2023
a07ed0f
correct
patrickvonplaten Mar 8, 2023
5f2472e
fix version
patrickvonplaten Mar 8, 2023
63bf2f8
rename
patrickvonplaten Mar 8, 2023
6aeb8e3
Merge branch 'main' into avoid_calling_hub_if_already_downlaoded
patrickvonplaten Mar 9, 2023
aabdde8
Apply suggestions from code review
patrickvonplaten Mar 9, 2023
d085f06
Merge branch 'main' into avoid_calling_hub_if_already_downlaoded
patrickvonplaten Mar 9, 2023
26aacc0
rename
patrickvonplaten Mar 9, 2023
4590c99
finish
patrickvonplaten Mar 9, 2023
b569eb8
correct doc string
patrickvonplaten Mar 9, 2023
3470424
correct more
patrickvonplaten Mar 9, 2023
d28e8d4
Apply suggestions from code review
patrickvonplaten Mar 10, 2023
0359a96
apply code suggestions
patrickvonplaten Mar 10, 2023
343a330
finish
patrickvonplaten Mar 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import argparse

from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt


if __name__ == "__main__":
Expand Down Expand Up @@ -125,7 +125,7 @@
)
args = parser.parse_args()

pipe = load_pipeline_from_original_stable_diffusion_ckpt(
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=args.checkpoint_path,
original_config_file=args.original_config_file,
image_size=args.image_size,
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.10.0",
"huggingface-hub>=0.13.0",
"requests-mock==1.10.0",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This testing library is super useful to check how many HEAD, GET requests were made, is popular and very lightweight, so think we can add it here to help with testing.

"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
Expand Down Expand Up @@ -192,6 +193,7 @@ def run(self):
"pytest",
"pytest-timeout",
"pytest-xdist",
"requests-mock",
"safetensors",
"sentencepiece",
"scipy",
Expand Down
39 changes: 33 additions & 6 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@
from requests import HTTPError

from . import __version__
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DummyObject,
deprecate,
extract_commit_hash,
http_user_agent,
logging,
)


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -231,7 +239,11 @@ def get_config_dict(cls, *args, **kwargs):

@classmethod
def load_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
return_unused_kwargs=False,
return_commit_hash=False,
**kwargs,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r"""
Instantiate a Python class from a config dictionary
Expand Down Expand Up @@ -271,6 +283,10 @@ def load_config(
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False):
Whether unused keyword arguments of the config shall be returned.
return_commit_hash (`bool`, *optional*, defaults to `False):
Whether the commit_hash of the loaded configuration shall be returned.

<Tip>

Expand All @@ -295,8 +311,10 @@ def load_config(
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})

user_agent = {"file_type": "config"}
user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)

pretrained_model_name_or_path = str(pretrained_model_name_or_path)

Expand Down Expand Up @@ -336,7 +354,6 @@ def load_config(
subfolder=subfolder,
revision=revision,
)

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
Expand Down Expand Up @@ -378,13 +395,23 @@ def load_config(
try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)

commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")

if not (return_unused_kwargs or return_commit_hash):
return config_dict

outputs = (config_dict,)

if return_unused_kwargs:
return config_dict, kwargs
outputs += (kwargs,)

if return_commit_hash:
outputs += (commit_hash,)

return config_dict
return outputs

@staticmethod
def _get_init_keys(cls):
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.10.0",
"huggingface-hub": "huggingface-hub>=0.13.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
Expand Down
78 changes: 28 additions & 50 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,18 +458,34 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)

# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path

user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}

# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model

# load config
config, unused_kwargs, commit_hash = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
user_agent=user_agent,
**kwargs,
)

# load model
model_file = None
if from_flax:
model_file = _get_model_file(
Expand All @@ -484,20 +500,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
commit_hash=commit_hash,
)
model = cls.from_config(config, **unused_kwargs)

Expand All @@ -520,6 +523,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except: # noqa: E722
pass
Expand All @@ -536,25 +540,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)

if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)

# if device_map is None, load the state dict and move the params from meta device to the cpu
Expand Down Expand Up @@ -593,20 +584,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"error_msgs": [],
}
else:
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file, variant=variant)
Expand Down Expand Up @@ -803,6 +780,7 @@ def _get_model_file(
use_auth_token,
user_agent,
revision,
commit_hash=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the revision arg is there only to test if revision in DEPRECATED_REVISION_ARGS. I would at least put a comment here to explain why we are passing revision and commit_hash separately since it can be confusing.

Another possibility is to pass only commit_hash and have an argument is_deprecated_revision: bool. Since you control the logic calling _get_model_file you can adapt it. (but if you think it's unnecessary, just leave it as it is).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need revision for deprecation even if commit_hash is defined

):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
Expand Down Expand Up @@ -840,7 +818,7 @@ def _get_model_file(
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
revision=revision or commit_hash,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
Expand All @@ -865,7 +843,7 @@ def _get_model_file(
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
revision=revision or commit_hash,
)
return model_file

Expand Down
Loading