Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make safetensors the default #2120

Merged
merged 15 commits into from
Nov 8, 2023
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
]
},
python_requires=">=3.8.0",
install_requires=["numpy>=1.17", "packaging>=20.0", "psutil", "pyyaml", "torch>=1.10.0", "huggingface_hub"],
install_requires=["numpy>=1.17", "packaging>=20.0", "psutil", "pyyaml", "torch>=1.10.0", "huggingface_hub", "safetensors>=0.3.1"],
extras_require=extras,
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down
13 changes: 6 additions & 7 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
is_ipex_available,
is_megatron_lm_available,
is_npu_available,
is_safetensors_available,
is_torch_version,
is_tpu_available,
is_xpu_available,
Expand Down Expand Up @@ -2536,7 +2535,7 @@ def save_model(
model: torch.nn.Module,
save_directory: Union[str, os.PathLike],
max_shard_size: Union[int, str] = "10GB",
safe_serialization: bool = False,
safe_serialization: bool = True,
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Save a model so that it can be re-loaded using load_checkpoint_in_model
Expand All @@ -2557,7 +2556,7 @@ def save_model(

</Tip>

safe_serialization (`bool`, *optional*, defaults to `False`):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).

Example:
Expand All @@ -2571,9 +2570,6 @@ def save_model(
```
"""

if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")

if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand Down Expand Up @@ -2690,7 +2686,7 @@ def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.Remov
self._save_model_state_pre_hook[handle.id] = hook
return handle

def save_state(self, output_dir: str = None, **save_model_func_kwargs):
def save_state(self, output_dir: str = None, safe_serialization: bool = True, **save_model_func_kwargs):
"""
Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder.

Expand All @@ -2711,6 +2707,8 @@ def save_state(self, output_dir: str = None, **save_model_func_kwargs):
Args:
output_dir (`str` or `os.PathLike`):
The name of the folder to save all relevant weights and states.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
save_model_func_kwargs (`dict`, *optional*):
Additional keyword arguments for saving model which can be passed to the underlying save function, such
as optional arguments for DeepSpeed's `save_checkpoint` function.
Expand Down Expand Up @@ -2815,6 +2813,7 @@ def _inner(folder):
self.state.process_index,
self.scaler,
save_on_each_node=self.project_configuration.save_on_each_node,
safe_serialization=safe_serialization,
)
for i, obj in enumerate(self._custom_objects):
save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)
Expand Down
53 changes: 33 additions & 20 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
from pathlib import Path
from typing import List

import numpy as np
import torch
from safetensors.torch import load_file
from torch.cuda.amp import GradScaler

from .utils import (
Expand Down Expand Up @@ -54,6 +54,7 @@ def save_accelerator_state(
process_index: int,
scaler: GradScaler = None,
save_on_each_node: bool = False,
safe_serialization: bool = True,
):
"""
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
Expand All @@ -75,45 +76,49 @@ def save_accelerator_state(
An optional gradient scaler instance to save
save_on_each_node (`bool`, *optional*):
Whether to save on every node, or only the main node.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
output_dir = Path(output_dir)
# Model states
for i, state in enumerate(model_states):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
output_model_file = os.path.join(output_dir, weights_name)
save(state, output_model_file, save_on_each_node=save_on_each_node)
weights_name = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
weights_name += ".bin" if not safe_serialization else ".safetensors"
output_model_file = output_dir.joinpath(weights_name)
save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
logger.info(f"Model weights saved in {output_model_file}")
# Optimizer states
for i, opt in enumerate(optimizers):
state = opt.state_dict()
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
output_optimizer_file = os.path.join(output_dir, optimizer_name)
save(state, output_optimizer_file, save_on_each_node=save_on_each_node)
output_optimizer_file = output_dir.joinpath(optimizer_name)
save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Optimizer state saved in {output_optimizer_file}")
# Scheduler states
for i, scheduler in enumerate(schedulers):
state = scheduler.state_dict()
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
output_scheduler_file = os.path.join(output_dir, scheduler_name)
save(state, output_scheduler_file, save_on_each_node=save_on_each_node)
output_scheduler_file = output_dir.joinpath(scheduler_name)
save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Scheduler state saved in {output_scheduler_file}")
# DataLoader states
for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
output_sampler_file = os.path.join(output_dir, sampler_name)
output_sampler_file = output_dir.joinpath(sampler_name)
# Only save if we have our custom sampler
from .data_loader import IterableDatasetShard, SeedableRandomSampler

if isinstance(dataloader.dataset, IterableDatasetShard):
sampler = dataloader.sampler.sampler

if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
if scaler is not None:
state = scaler.state_dict()
output_scaler_file = os.path.join(output_dir, SCALER_NAME)
output_scaler_file = output_dir.joinpath(SCALER_NAME)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
torch.save(state, output_scaler_file)
logger.info(f"Gradient scaler state saved in {output_scaler_file}")
# Random number generator states
Expand All @@ -128,7 +133,7 @@ def save_accelerator_state(
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
if is_tpu_available():
states["xm_seed"] = xm.get_rng_state()
output_states_file = os.path.join(output_dir, states_name)
output_states_file = output_dir.joinpath(states_name)
torch.save(states, output_states_file)
logger.info(f"Random states saved in {output_states_file}")
return output_dir
Expand Down Expand Up @@ -174,31 +179,39 @@ def load_accelerator_state(
map_location = "cpu"
elif map_location == "on_device":
map_location = PartialState().device

input_dir = Path(input_dir)
# Model states
for i, model in enumerate(models):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
input_model_file = os.path.join(input_dir, weights_name)
models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs)
weights_name = f"{MODEL_NAME}.safetensors" if i == 0 else f"{MODEL_NAME}_{i}.safetensors"
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
input_model_file = input_dir.joinpath(weights_name)
if input_model_file.exists():
state_dict = load_file(input_model_file, device=str(map_location))
else:
# Load with torch
input_model_file = input_model_file.with_suffix(".bin")
state_dict = torch.load(input_model_file, map_location=map_location)
models[i].load_state_dict(state_dict, **load_model_func_kwargs)
logger.info("All model weights loaded successfully")

# Optimizer states
for i, opt in enumerate(optimizers):
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
input_optimizer_file = os.path.join(input_dir, optimizer_name)
input_optimizer_file = input_dir.joinpath(optimizer_name)
optimizer_state = torch.load(input_optimizer_file, map_location=map_location)
optimizers[i].load_state_dict(optimizer_state)
logger.info("All optimizer states loaded successfully")

# Scheduler states
for i, scheduler in enumerate(schedulers):
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
input_scheduler_file = os.path.join(input_dir, scheduler_name)
input_scheduler_file = input_dir.joinpath(scheduler_name)
scheduler.load_state_dict(torch.load(input_scheduler_file))
logger.info("All scheduler states loaded successfully")

for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
input_sampler_file = os.path.join(input_dir, sampler_name)
input_sampler_file = input_dir.joinpath(sampler_name)
# Only load if we have our custom sampler
from .data_loader import IterableDatasetShard, SeedableRandomSampler

Expand All @@ -211,13 +224,13 @@ def load_accelerator_state(

# GradScaler state
if scaler is not None:
input_scaler_file = os.path.join(input_dir, SCALER_NAME)
input_scaler_file = input_dir.joinpath(SCALER_NAME)
scaler.load_state_dict(torch.load(input_scaler_file))
logger.info("GradScaler state loaded successfully")

# Random states
try:
states = torch.load(os.path.join(input_dir, f"{RNG_STATE_NAME}_{process_index}.pkl"))
states = torch.load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
random.setstate(states["random_state"])
np.random.set_state(states["numpy_random_seed"])
torch.set_rng_state(states["torch_manual_seed"])
Expand Down
1 change: 0 additions & 1 deletion src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
require_mps,
require_multi_gpu,
require_multi_xpu,
require_safetensors,
require_single_gpu,
require_single_xpu,
require_torch_min_version,
Expand Down
9 changes: 0 additions & 9 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
is_deepspeed_available,
is_mps_available,
is_pandas_available,
is_safetensors_available,
is_tensorboard_available,
is_timm_available,
is_torch_version,
Expand Down Expand Up @@ -179,14 +178,6 @@ def require_multi_xpu(test_case):
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)


def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors installed. These tests are skipped when safetensors isn't
installed
"""
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)


def require_deepspeed(test_case):
"""
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
Expand Down
1 change: 0 additions & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
is_npu_available,
is_pandas_available,
is_rich_available,
is_safetensors_available,
is_sagemaker_available,
is_tensorboard_available,
is_timm_available,
Expand Down
4 changes: 0 additions & 4 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ def is_megatron_lm_available():
return False


def is_safetensors_available():
return _is_package_available("safetensors")


def is_transformers_available():
return _is_package_available("transformers")

Expand Down
12 changes: 4 additions & 8 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..state import AcceleratorState
from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from .dataclasses import AutocastKwargs, CustomDtype, DistributedType
from .imports import is_mps_available, is_npu_available, is_safetensors_available, is_xpu_available
from .imports import is_mps_available, is_npu_available, is_xpu_available
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm

Expand All @@ -39,9 +39,9 @@
import torch_npu # noqa: F401


if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file


WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"

Expand Down Expand Up @@ -1156,10 +1156,6 @@ def load_state_dict(checkpoint_file, device_map=None):
name, once a given module name is inside, every submodule of it will be sent to the same device.
"""
if checkpoint_file.endswith(".safetensors"):
if not is_safetensors_available():
raise ImportError(
f"To load {checkpoint_file}, the `safetensors` library is necessary `pip install safetensors`."
)
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
weight_names = f.keys()
Expand Down
8 changes: 1 addition & 7 deletions src/accelerate/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

import numpy as np
import torch

from .imports import is_safetensors_available
from safetensors import safe_open


def offload_weight(weight, weight_name, offload_folder, index=None):
Expand Down Expand Up @@ -165,11 +164,6 @@ def __getitem__(self, key: str):
return self.state_dict[key]
weight_info = self.index[key]
if weight_info.get("safetensors_file") is not None:
if not is_safetensors_available():
raise ImportError("These offloaded weights require the use of safetensors: `pip install safetensors`.")

from safetensors import safe_open

device = "cpu" if self.device is None else self.device
with safe_open(weight_info["safetensors_file"], framework="pt", device=device) as f:
tensor = f.get_tensor(weight_info.get("weight_name", key))
Expand Down
12 changes: 5 additions & 7 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@

import torch
from packaging.version import Version
from safetensors.torch import save_file as safe_save_file

from ..commands.config.default import write_basic_config # noqa: F401
from ..logging import get_logger
from ..state import PartialState
from .constants import FSDP_PYTORCH_VERSION
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_safetensors_available, is_tpu_available
from .imports import is_deepspeed_available, is_tpu_available
from .transformer_engine import convert_model
from .versions import is_torch_version

Expand All @@ -39,9 +40,6 @@
if is_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

if is_safetensors_available():
from safetensors.torch import save_file as safe_save_file


def is_compiled_module(module):
"""
Expand Down Expand Up @@ -117,7 +115,7 @@ def wait_for_everyone():
PartialState().wait_for_everyone()


def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = True):
"""
Save the data to disk. Use in place of `torch.save()`.

Expand All @@ -128,8 +126,8 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Fal
The file (or file-like object) to use to save the data
save_on_each_node (`bool`, *optional*, defaults to `False`):
Whether to only save on the global main process
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save `obj` using `safetensors`
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"})
if PartialState().distributed_type == DistributedType.TPU:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights
from accelerate.accelerator import Accelerator
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import require_bnb, require_multi_gpu, require_safetensors, slow
from accelerate.test_utils import require_bnb, require_multi_gpu, slow
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda
from accelerate.utils import patch_environment
from accelerate.utils.modeling import load_checkpoint_in_model
Expand Down Expand Up @@ -126,7 +126,6 @@ def test_save_model_pytorch(self):
load_checkpoint_in_model(model, tmpdirname)
self.assertTrue(abs(model_signature - get_signature(model)) < 1e-3)

@require_safetensors
def test_save_model_safetensors(self):
accelerator = Accelerator()
model = torch.nn.Linear(10, 10)
Expand Down
Loading
Loading