diff --git a/setup.py b/setup.py index 90ffd3495391..dd3decd8624b 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.23.2", + "huggingface-hub>=0.27.0", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 9e7bf242eca7..7e2ec1c51459 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -9,7 +9,7 @@ "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.23.2", + "huggingface-hub": "huggingface-hub>=0.27.0", "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark>=0.2.0", diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 934edef5b86e..ee7a78eb0471 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -12,14 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import importlib import os import re import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from huggingface_hub import ModelCard, model_info @@ -41,11 +39,12 @@ logging, ) from ..utils.torch_utils import is_compiled_module +from .transformers_loading_utils import load_tokenizer_from_dduf, load_transformers_model_from_dduf if is_transformers_available(): import transformers - from transformers import PreTrainedModel + from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME @@ -664,7 +663,7 @@ def load_sub_model( f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." ) - load_method = getattr(class_obj, load_method_name) + load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None) # add kwargs to loading method diffusers_module = importlib.import_module(__name__.split(".")[0]) @@ -750,6 +749,22 @@ def load_sub_model( return loaded_sub_model +def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable: + """ + Return the method to load the sub model. + + In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object + except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading + method that we need to use (won't use `from_pretrained`). + """ + if is_dduf: + if issubclass(class_obj, PreTrainedTokenizerBase): + return lambda *args, **kwargs: load_tokenizer_from_dduf(class_obj, *args, **kwargs) + if issubclass(class_obj, PreTrainedModel): + return lambda *args, **kwargs: load_transformers_model_from_dduf(class_obj, *args, **kwargs) + return getattr(class_obj, load_method_name) + + def _fetch_class_library_tuple(module): # import it here to avoid circular import diffusers_module = importlib.import_module(__name__.split(".")[0]) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f0770778f38a..4defbd5ac92d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -32,6 +32,7 @@ create_repo, hf_hub_download, model_info, + read_dduf_file, snapshot_download, ) from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args @@ -53,7 +54,6 @@ PushToHubMixin, is_accelerate_available, is_accelerate_version, - is_huggingface_hub_version, is_torch_npu_available, is_torch_version, is_transformers_version, @@ -677,7 +677,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. dduf_file(`str`, *optional*): - Load weights from the specified dduf file + Load weights from the specified dduf file. @@ -822,15 +822,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries = None if dduf_file: - if not is_huggingface_hub_version(">", "0.26.3"): - (">=", "0.17.0.dev0") - raise RuntimeError( - "To load a dduf file, you need to install huggingface_hub>0.26.3. " - "You can install it with the following: `pip install --upgrade huggingface_hub`." - ) - - from huggingface_hub import read_dduf_file - dduf_file_path = os.path.join(cached_folder, dduf_file) dduf_entries = read_dduf_file(dduf_file_path) # The reader contains already all the files needed, no need to check it again @@ -845,6 +836,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # We retrieve the information by matching whether variant model checkpoints exist in the subfolders. # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` # with variant being `"fp16"`. + # TODO: adapt logic for DDUF files (at the moment, scans the local directory which doesn't make sense in DDUF context) model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict) if len(model_variants) == 0 and variant is not None: error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." diff --git a/src/diffusers/pipelines/transformers_loading_utils.py b/src/diffusers/pipelines/transformers_loading_utils.py new file mode 100644 index 000000000000..e4aa331eeeeb --- /dev/null +++ b/src/diffusers/pipelines/transformers_loading_utils.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import contextlib +import tempfile +from typing import TYPE_CHECKING, Dict + +from huggingface_hub import DDUFEntry + +from ..utils import is_safetensors_available, is_transformers_available + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + +if is_transformers_available(): + from transformers import PreTrainedModel, PreTrainedTokenizer + +if is_safetensors_available(): + import safetensors.torch + + +def load_tokenizer_from_dduf( + cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry] +) -> "PreTrainedTokenizer": + """ + Load a tokenizer from a DDUF archive. + + In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a workaround + by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted files. There is an + extra cost of extracting the files, but of limited impact as the tokenizer files are usually small-ish. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + for entry_name, entry in dduf_entries.items(): + if entry_name.startswith(name + "/"): + tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/")) + with open(tmp_entry_path, "wb") as f: + with entry.as_mmap() as mm: + f.write(mm) + return cls.from_pretrained(tmp_dir, **kwargs) + + +def load_transformers_model_from_dduf( + cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs +) -> "PreTrainedModel": + """ + Load a transformers model from a DDUF archive. + + In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround + by instantiating a model from the config file and loading the weights from the DDUF archive directly. + """ + config_file = dduf_entries.get(f"{name}/config.json") + if config_file is None: + raise EnvironmentError( + f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})." + ) + + weight_files = [ + entry + for entry_name, entry in dduf_entries.items() + if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors") + ] + if not weight_files: + raise EnvironmentError( + f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})." + ) + if not is_safetensors_available(): + raise EnvironmentError( + "Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`." + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_config_file = os.path.join(tmp_dir, "config.json") + with open(tmp_config_file, "w") as f: + f.write(config_file.read_text()) + + with contextlib.ExitStack() as stack: + state_dict = { + key: tensor + for entry in weight_files # loop over safetensors files + for key, tensor in safetensors.torch.load( # load tensors from mmap-ed bytes + stack.enter_context(entry.as_mmap()) # use enter_context to close the mmap when done + ).items() + } + return cls.from_pretrained(tmp_dir, state_dict=state_dict, **kwargs)