From b9da674b6cfad572b75e55567cc00795965186e1 Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 4 Mar 2024 13:39:23 +0100 Subject: [PATCH 1/2] Add save_state_dict --- src/huggingface_hub/hub_mixin.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 013b7a4711..ae48b44cd0 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -8,6 +8,7 @@ from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE from .file_download import hf_hub_download from .hf_api import HfApi +from .serialization import split_torch_state_dict_into_shards from .utils import ( EntryNotFoundError, HfHubHTTPError, @@ -463,7 +464,25 @@ class PyTorchModelHubMixin(ModelHubMixin): def _save_pretrained(self, save_directory: Path) -> None: """Save weights from a Pytorch model to a local directory.""" model_to_save = self.module if hasattr(self, "module") else self # type: ignore - save_file(model_to_save.state_dict(), save_directory / SAFETENSORS_SINGLE_FILE) + + def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: Path): + state_dict_split = split_torch_state_dict_into_shards(state_dict) + for filename, tensors in state_dict_split.filename_to_tensors.values(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + save_file( + shard, + os.path.join(save_directory, filename), + metadata={"format": "pt"}, + ) + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + f.write(json.dumps(index, indent=2)) + + save_state_dict(model_to_save.state_dict(), save_directory) @classmethod def _from_pretrained( From 2a776e950587526ba004056e07e2e942efaf2996 Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 4 Mar 2024 13:50:13 +0100 Subject: [PATCH 2/2] More improvements --- src/huggingface_hub/hub_mixin.py | 23 ++------------- src/huggingface_hub/serialization/__init__.py | 2 +- src/huggingface_hub/serialization/_torch.py | 29 +++++++++++++++++++ 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index ae48b44cd0..3c6c0f01e5 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -8,7 +8,7 @@ from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE from .file_download import hf_hub_download from .hf_api import HfApi -from .serialization import split_torch_state_dict_into_shards +from .serialization import save_torch_state_dict from .utils import ( EntryNotFoundError, HfHubHTTPError, @@ -29,7 +29,6 @@ if is_safetensors_available(): from safetensors import safe_open - from safetensors.torch import save_file logger = logging.get_logger(__name__) @@ -464,25 +463,7 @@ class PyTorchModelHubMixin(ModelHubMixin): def _save_pretrained(self, save_directory: Path) -> None: """Save weights from a Pytorch model to a local directory.""" model_to_save = self.module if hasattr(self, "module") else self # type: ignore - - def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: Path): - state_dict_split = split_torch_state_dict_into_shards(state_dict) - for filename, tensors in state_dict_split.filename_to_tensors.values(): - shard = {tensor: state_dict[tensor] for tensor in tensors} - save_file( - shard, - os.path.join(save_directory, filename), - metadata={"format": "pt"}, - ) - if state_dict_split.is_sharded: - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: - f.write(json.dumps(index, indent=2)) - - save_state_dict(model_to_save.state_dict(), save_directory) + save_torch_state_dict(model_to_save.state_dict(), save_directory) @classmethod def _from_pretrained( diff --git a/src/huggingface_hub/serialization/__init__.py b/src/huggingface_hub/serialization/__init__.py index 0bb6c2d0a1..c3241e1857 100644 --- a/src/huggingface_hub/serialization/__init__.py +++ b/src/huggingface_hub/serialization/__init__.py @@ -17,4 +17,4 @@ from ._base import StateDictSplit, split_state_dict_into_shards_factory from ._numpy import split_numpy_state_dict_into_shards from ._tensorflow import split_tf_state_dict_into_shards -from ._torch import split_torch_state_dict_into_shards +from ._torch import save_torch_state_dict, split_torch_state_dict_into_shards diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index 00ab7e2c80..288f644804 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -14,12 +14,23 @@ """Contains pytorch-specific helpers.""" import importlib +import json +import os from functools import lru_cache +from pathlib import Path from typing import TYPE_CHECKING, Dict, Tuple +from ..utils import is_safetensors_available, is_torch_available from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory +if is_torch_available(): + import torch # type: ignore + +if is_safetensors_available(): + from safetensors.torch import save_file + + if TYPE_CHECKING: import torch @@ -198,3 +209,21 @@ def _get_dtype_size(dtype: "torch.dtype") -> int: _float8_e5m2: 1, } return _SIZE[dtype] + + +def save_torch_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: Path): + state_dict_split = split_torch_state_dict_into_shards(state_dict) + for filename, tensors in state_dict_split.filename_to_tensors.values(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + save_file( + shard, + os.path.join(save_directory, filename), + metadata={"format": "pt"}, + ) + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + f.write(json.dumps(index, indent=2))