Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add save_state_dict #2084

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
from .file_download import hf_hub_download
from .hf_api import HfApi
from .serialization import save_torch_state_dict
from .utils import (
EntryNotFoundError,
HfHubHTTPError,
Expand All @@ -28,7 +29,6 @@

if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -463,7 +463,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
save_file(model_to_save.state_dict(), save_directory / SAFETENSORS_SINGLE_FILE)
save_torch_state_dict(model_to_save.state_dict(), save_directory)

@classmethod
def _from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from ._base import StateDictSplit, split_state_dict_into_shards_factory
from ._numpy import split_numpy_state_dict_into_shards
from ._tensorflow import split_tf_state_dict_into_shards
from ._torch import split_torch_state_dict_into_shards
from ._torch import save_torch_state_dict, split_torch_state_dict_into_shards
29 changes: 29 additions & 0 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,23 @@
"""Contains pytorch-specific helpers."""

import importlib
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Tuple

from ..utils import is_safetensors_available, is_torch_available
from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory


if is_torch_available():
import torch # type: ignore

if is_safetensors_available():
from safetensors.torch import save_file


if TYPE_CHECKING:
import torch

Expand Down Expand Up @@ -198,3 +209,21 @@ def _get_dtype_size(dtype: "torch.dtype") -> int:
_float8_e5m2: 1,
}
return _SIZE[dtype]


def save_torch_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: Path):
state_dict_split = split_torch_state_dict_into_shards(state_dict)
for filename, tensors in state_dict_split.filename_to_tensors.values():
shard = {tensor: state_dict[tensor] for tensor in tensors}
save_file(
shard,
os.path.join(save_directory, filename),
metadata={"format": "pt"},
)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
f.write(json.dumps(index, indent=2))
Loading