- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
[From pretrained] Speed-up loading from cache #2515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7c3c3e1
              d1ad4d3
              19a0fdf
              300544a
              152f902
              d13141a
              c4a49e6
              513b213
              c4aadde
              b43be19
              22eeb11
              a37cb95
              0f39ab7
              d6a1815
              79afaf2
              e4bff0b
              30717b0
              71fa6b8
              a07ed0f
              5f2472e
              63bf2f8
              6aeb8e3
              aabdde8
              d085f06
              26aacc0
              4590c99
              b569eb8
              3470424
              d28e8d4
              0359a96
              343a330
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -458,18 +458,34 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| " dispatching. Please make sure to set `low_cpu_mem_usage=True`." | ||
| ) | ||
|  | ||
| # Load config if we don't provide a configuration | ||
| config_path = pretrained_model_name_or_path | ||
|  | ||
| user_agent = { | ||
| "diffusers": __version__, | ||
| "file_type": "model", | ||
| "framework": "pytorch", | ||
| } | ||
|  | ||
| # Load config if we don't provide a configuration | ||
| config_path = pretrained_model_name_or_path | ||
|  | ||
| # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the | ||
| # Load model | ||
|  | ||
| # load config | ||
| config, unused_kwargs, commit_hash = cls.load_config( | ||
| config_path, | ||
| cache_dir=cache_dir, | ||
| return_unused_kwargs=True, | ||
| return_commit_hash=True, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| device_map=device_map, | ||
| user_agent=user_agent, | ||
| **kwargs, | ||
| ) | ||
|  | ||
| # load model | ||
| model_file = None | ||
| if from_flax: | ||
| model_file = _get_model_file( | ||
|  | @@ -484,20 +500,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| ) | ||
| config, unused_kwargs = cls.load_config( | ||
| config_path, | ||
| cache_dir=cache_dir, | ||
| return_unused_kwargs=True, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| device_map=device_map, | ||
| **kwargs, | ||
| commit_hash=commit_hash, | ||
| ) | ||
| model = cls.from_config(config, **unused_kwargs) | ||
|  | ||
|  | @@ -520,6 +523,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| commit_hash=commit_hash, | ||
| ) | ||
| except: # noqa: E722 | ||
| pass | ||
|  | @@ -536,25 +540,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| commit_hash=commit_hash, | ||
| ) | ||
|  | ||
| if low_cpu_mem_usage: | ||
| # Instantiate model with empty weights | ||
| with accelerate.init_empty_weights(): | ||
| config, unused_kwargs = cls.load_config( | ||
| config_path, | ||
| cache_dir=cache_dir, | ||
| return_unused_kwargs=True, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| device_map=device_map, | ||
| **kwargs, | ||
| ) | ||
| model = cls.from_config(config, **unused_kwargs) | ||
|  | ||
| # if device_map is None, load the state dict and move the params from meta device to the cpu | ||
|  | @@ -593,20 +584,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| "error_msgs": [], | ||
| } | ||
| else: | ||
| config, unused_kwargs = cls.load_config( | ||
| config_path, | ||
| cache_dir=cache_dir, | ||
| return_unused_kwargs=True, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| device_map=device_map, | ||
| **kwargs, | ||
| ) | ||
| model = cls.from_config(config, **unused_kwargs) | ||
|  | ||
| state_dict = load_state_dict(model_file, variant=variant) | ||
|  | @@ -803,6 +780,7 @@ def _get_model_file( | |
| use_auth_token, | ||
| user_agent, | ||
| revision, | ||
| commit_hash=None, | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that the  Another possibility is to pass only  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still need  | ||
| ): | ||
| pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||
| if os.path.isfile(pretrained_model_name_or_path): | ||
|  | @@ -840,7 +818,7 @@ def _get_model_file( | |
| use_auth_token=use_auth_token, | ||
| user_agent=user_agent, | ||
| subfolder=subfolder, | ||
| revision=revision, | ||
| revision=revision or commit_hash, | ||
| ) | ||
| warnings.warn( | ||
| f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", | ||
|  | @@ -865,7 +843,7 @@ def _get_model_file( | |
| use_auth_token=use_auth_token, | ||
| user_agent=user_agent, | ||
| subfolder=subfolder, | ||
| revision=revision, | ||
| revision=revision or commit_hash, | ||
| ) | ||
| return model_file | ||
|  | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This testing library is super useful to check how many HEAD, GET requests were made, is popular and very lightweight, so think we can add it here to help with testing.