Skip to content

Commit

Permalink
[Proposal] Support loading from safetensors if file is present. (#1357)
Browse files Browse the repository at this point in the history
* [Proposal] Support loading from safetensors if file is present.

* Style.

* Fix.

* Adding some test to check loading logic.

+ modify download logic to not download pytorch file if not necessary.

* Fixing the logic.

* Adressing comments.

* factor out into a function.

* Remove dead function.

* Typo.

* Extra fetch only if safetensors is there.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
Narsil and patrickvonplaten committed Nov 28, 2022
1 parent 6b02323 commit 5755d16
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 67 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
"pytest",
"pytest-timeout",
"pytest-xdist",
"safetensors",
"sentencepiece>=0.1.91,!=0.1.92",
"scipy",
"regex!=2019.12.17",
Expand Down Expand Up @@ -184,10 +185,11 @@ def run(self):
"pytest",
"pytest-timeout",
"pytest-xdist",
"safetensors",
"sentencepiece",
"scipy",
"torchvision",
"transformers"
"transformers",
)
extras["torch"] = deps_list("torch", "accelerate")

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"safetensors": "safetensors",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"scipy": "scipy",
"regex": "regex!=2019.12.17",
Expand Down
183 changes: 120 additions & 63 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
logging,
)
Expand All @@ -51,6 +53,9 @@
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version

if is_safetensors_available():
import safetensors


def get_parameter_device(parameter: torch.nn.Module):
try:
Expand Down Expand Up @@ -84,10 +89,13 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:

def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
try:
return torch.load(checkpoint_file, map_location="cpu")
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
return torch.load(checkpoint_file, map_location="cpu")
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
Expand All @@ -104,7 +112,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
Expand Down Expand Up @@ -375,75 +383,39 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
else:

model_file = None
if is_safetensors_available():
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
model_file = _get_model_file(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
weights_name=SAFETENSORS_WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}"
)

# restore default dtype
except:
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)

if low_cpu_mem_usage:
# Instantiate model with empty weights
Expand Down Expand Up @@ -691,3 +663,88 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
return unwrap_model(model.module)
else:
return model


def _get_model_file(
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
return model_file

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)
31 changes: 29 additions & 2 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import diffusers
import PIL
from huggingface_hub import snapshot_download
from huggingface_hub import model_info, snapshot_download
from packaging import version
from PIL import Image
from tqdm.auto import tqdm
Expand All @@ -44,6 +44,7 @@
BaseOutput,
deprecate,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
is_transformers_available,
logging,
Expand Down Expand Up @@ -117,6 +118,23 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray


def is_safetensors_compatible(info) -> bool:
filenames = set(sibling.rfilename for sibling in info.siblings)
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
for pt_filename in pt_filenames:
prefix, raw = os.path.split(pt_filename)
if raw == "pytorch_model.bin":
# transformers specific
sf_filename = os.path.join(prefix, "model.safetensors")
else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if sf_filename not in filenames:
logger.warning("{sf_filename} not found")
is_safetensors_compatible = False
return is_safetensors_compatible


class DiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
Expand Down Expand Up @@ -459,7 +477,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]

# make sure we don't download flax weights
ignore_patterns = "*.msgpack"
ignore_patterns = ["*.msgpack"]

if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
Expand All @@ -473,6 +491,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)

if is_safetensors_available():
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")

# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_inflect_available,
is_modelcards_available,
is_onnx_available,
is_safetensors_available,
is_scipy_available,
is_tf_available,
is_torch_available,
Expand Down Expand Up @@ -69,6 +70,7 @@
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
Expand Down
18 changes: 17 additions & 1 deletion src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()

STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

Expand All @@ -55,7 +56,7 @@
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info("Disabling PyTorch because USE_TF is set")
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False


Expand Down Expand Up @@ -109,6 +110,17 @@
else:
_flax_available = False

if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
_safetensors_available = importlib.util.find_spec("safetensors") is not None
if _safetensors_available:
try:
_safetensors_version = importlib_metadata.version("safetensors")
logger.info(f"Safetensors version {_safetensors_version} available.")
except importlib_metadata.PackageNotFoundError:
_safetensors_available = False
else:
logger.info("Disabling Safetensors because USE_TF is set")
_safetensors_available = False

_transformers_available = importlib.util.find_spec("transformers") is not None
try:
Expand Down Expand Up @@ -190,6 +202,10 @@ def is_torch_available():
return _torch_available


def is_safetensors_available():
return _safetensors_available


def is_tf_available():
return _tf_available

Expand Down

0 comments on commit 5755d16

Please sign in to comment.