diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 013b7a4711..3c6c0f01e5 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -8,6 +8,7 @@ from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE from .file_download import hf_hub_download from .hf_api import HfApi +from .serialization import save_torch_state_dict from .utils import ( EntryNotFoundError, HfHubHTTPError, @@ -28,7 +29,6 @@ if is_safetensors_available(): from safetensors import safe_open - from safetensors.torch import save_file logger = logging.get_logger(__name__) @@ -463,7 +463,7 @@ class PyTorchModelHubMixin(ModelHubMixin): def _save_pretrained(self, save_directory: Path) -> None: """Save weights from a Pytorch model to a local directory.""" model_to_save = self.module if hasattr(self, "module") else self # type: ignore - save_file(model_to_save.state_dict(), save_directory / SAFETENSORS_SINGLE_FILE) + save_torch_state_dict(model_to_save.state_dict(), save_directory) @classmethod def _from_pretrained( diff --git a/src/huggingface_hub/serialization/__init__.py b/src/huggingface_hub/serialization/__init__.py index 0bb6c2d0a1..c3241e1857 100644 --- a/src/huggingface_hub/serialization/__init__.py +++ b/src/huggingface_hub/serialization/__init__.py @@ -17,4 +17,4 @@ from ._base import StateDictSplit, split_state_dict_into_shards_factory from ._numpy import split_numpy_state_dict_into_shards from ._tensorflow import split_tf_state_dict_into_shards -from ._torch import split_torch_state_dict_into_shards +from ._torch import save_torch_state_dict, split_torch_state_dict_into_shards diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index 00ab7e2c80..288f644804 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -14,12 +14,23 @@ """Contains pytorch-specific helpers.""" import importlib +import json +import os from functools import lru_cache +from pathlib import Path from typing import TYPE_CHECKING, Dict, Tuple +from ..utils import is_safetensors_available, is_torch_available from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory +if is_torch_available(): + import torch # type: ignore + +if is_safetensors_available(): + from safetensors.torch import save_file + + if TYPE_CHECKING: import torch @@ -198,3 +209,21 @@ def _get_dtype_size(dtype: "torch.dtype") -> int: _float8_e5m2: 1, } return _SIZE[dtype] + + +def save_torch_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: Path): + state_dict_split = split_torch_state_dict_into_shards(state_dict) + for filename, tensors in state_dict_split.filename_to_tensors.values(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + save_file( + shard, + os.path.join(save_directory, filename), + metadata={"format": "pt"}, + ) + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + f.write(json.dumps(index, indent=2))