Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make safetensors the default #2120

Merged
merged 15 commits into from
Nov 8, 2023
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
]
},
python_requires=">=3.8.0",
install_requires=["numpy>=1.17", "packaging>=20.0", "psutil", "pyyaml", "torch>=1.10.0", "huggingface_hub"],
install_requires=["numpy>=1.17", "packaging>=20.0", "psutil", "pyyaml", "torch>=1.10.0", "huggingface_hub", "safetensors>=0.3.1"],
extras_require=extras,
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down
13 changes: 6 additions & 7 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
is_ipex_available,
is_megatron_lm_available,
is_npu_available,
is_safetensors_available,
is_torch_version,
is_tpu_available,
is_xpu_available,
Expand Down Expand Up @@ -2536,7 +2535,7 @@ def save_model(
model: torch.nn.Module,
save_directory: Union[str, os.PathLike],
max_shard_size: Union[int, str] = "10GB",
safe_serialization: bool = False,
safe_serialization: bool = True,
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Save a model so that it can be re-loaded using load_checkpoint_in_model
Expand All @@ -2557,7 +2556,7 @@ def save_model(

</Tip>

safe_serialization (`bool`, *optional*, defaults to `False`):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).

Example:
Expand All @@ -2571,9 +2570,6 @@ def save_model(
```
"""

if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")

if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand Down Expand Up @@ -2690,7 +2686,7 @@ def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.Remov
self._save_model_state_pre_hook[handle.id] = hook
return handle

def save_state(self, output_dir: str = None, **save_model_func_kwargs):
def save_state(self, output_dir: str = None, safe_serialization: bool = True, **save_model_func_kwargs):
"""
Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder.

Expand All @@ -2711,6 +2707,8 @@ def save_state(self, output_dir: str = None, **save_model_func_kwargs):
Args:
output_dir (`str` or `os.PathLike`):
The name of the folder to save all relevant weights and states.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
save_model_func_kwargs (`dict`, *optional*):
Additional keyword arguments for saving model which can be passed to the underlying save function, such
as optional arguments for DeepSpeed's `save_checkpoint` function.
Expand Down Expand Up @@ -2815,6 +2813,7 @@ def _inner(folder):
self.state.process_index,
self.scaler,
save_on_each_node=self.project_configuration.save_on_each_node,
safe_serialization=safe_serialization,
)
for i, obj in enumerate(self._custom_objects):
save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)
Expand Down
53 changes: 33 additions & 20 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
from pathlib import Path
from typing import List

import numpy as np
import torch
from safetensors.torch import load_file
from torch.cuda.amp import GradScaler

from .utils import (
Expand Down Expand Up @@ -54,6 +54,7 @@ def save_accelerator_state(
process_index: int,
scaler: GradScaler = None,
save_on_each_node: bool = False,
safe_serialization: bool = True,
):
"""
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
Expand All @@ -75,45 +76,49 @@ def save_accelerator_state(
An optional gradient scaler instance to save
save_on_each_node (`bool`, *optional*):
Whether to save on every node, or only the main node.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
output_dir = Path(output_dir)
# Model states
for i, state in enumerate(model_states):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
output_model_file = os.path.join(output_dir, weights_name)
save(state, output_model_file, save_on_each_node=save_on_each_node)
weights_name = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
weights_name += ".bin" if not safe_serialization else ".safetensors"
output_model_file = output_dir.joinpath(weights_name)
save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
logger.info(f"Model weights saved in {output_model_file}")
# Optimizer states
for i, opt in enumerate(optimizers):
state = opt.state_dict()
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
output_optimizer_file = os.path.join(output_dir, optimizer_name)
save(state, output_optimizer_file, save_on_each_node=save_on_each_node)
output_optimizer_file = output_dir.joinpath(optimizer_name)
save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Optimizer state saved in {output_optimizer_file}")
# Scheduler states
for i, scheduler in enumerate(schedulers):
state = scheduler.state_dict()
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
output_scheduler_file = os.path.join(output_dir, scheduler_name)
save(state, output_scheduler_file, save_on_each_node=save_on_each_node)
output_scheduler_file = output_dir.joinpath(scheduler_name)
save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Scheduler state saved in {output_scheduler_file}")
# DataLoader states
for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
output_sampler_file = os.path.join(output_dir, sampler_name)
output_sampler_file = output_dir.joinpath(sampler_name)
# Only save if we have our custom sampler
from .data_loader import IterableDatasetShard, SeedableRandomSampler

if isinstance(dataloader.dataset, IterableDatasetShard):
sampler = dataloader.sampler.sampler

if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
if scaler is not None:
state = scaler.state_dict()
output_scaler_file = os.path.join(output_dir, SCALER_NAME)
output_scaler_file = output_dir.joinpath(SCALER_NAME)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
torch.save(state, output_scaler_file)
logger.info(f"Gradient scaler state saved in {output_scaler_file}")
# Random number generator states
Expand All @@ -128,7 +133,7 @@ def save_accelerator_state(
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
if is_tpu_available():
states["xm_seed"] = xm.get_rng_state()
output_states_file = os.path.join(output_dir, states_name)
output_states_file = output_dir.joinpath(states_name)
torch.save(states, output_states_file)
logger.info(f"Random states saved in {output_states_file}")
return output_dir
Expand Down Expand Up @@ -174,31 +179,39 @@ def load_accelerator_state(
map_location = "cpu"
elif map_location == "on_device":
map_location = PartialState().device

input_dir = Path(input_dir)
# Model states
for i, model in enumerate(models):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
input_model_file = os.path.join(input_dir, weights_name)
models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs)
weights_name = f"{MODEL_NAME}.safetensors" if i == 0 else f"{MODEL_NAME}_{i}.safetensors"
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
input_model_file = input_dir.joinpath(weights_name)
if input_model_file.exists():
state_dict = load_file(input_model_file, device=str(map_location))
else:
# Load with torch
input_model_file = input_model_file.with_suffix(".bin")
state_dict = torch.load(input_model_file, map_location=map_location)
models[i].load_state_dict(state_dict, **load_model_func_kwargs)
logger.info("All model weights loaded successfully")

# Optimizer states
for i, opt in enumerate(optimizers):
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
input_optimizer_file = os.path.join(input_dir, optimizer_name)
input_optimizer_file = input_dir.joinpath(optimizer_name)
optimizer_state = torch.load(input_optimizer_file, map_location=map_location)
optimizers[i].load_state_dict(optimizer_state)
logger.info("All optimizer states loaded successfully")

# Scheduler states
for i, scheduler in enumerate(schedulers):
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
input_scheduler_file = os.path.join(input_dir, scheduler_name)
input_scheduler_file = input_dir.joinpath(scheduler_name)
scheduler.load_state_dict(torch.load(input_scheduler_file))
logger.info("All scheduler states loaded successfully")

for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
input_sampler_file = os.path.join(input_dir, sampler_name)
input_sampler_file = input_dir.joinpath(sampler_name)
# Only load if we have our custom sampler
from .data_loader import IterableDatasetShard, SeedableRandomSampler

Expand All @@ -211,13 +224,13 @@ def load_accelerator_state(

# GradScaler state
if scaler is not None:
input_scaler_file = os.path.join(input_dir, SCALER_NAME)
input_scaler_file = input_dir.joinpath(SCALER_NAME)
scaler.load_state_dict(torch.load(input_scaler_file))
logger.info("GradScaler state loaded successfully")

# Random states
try:
states = torch.load(os.path.join(input_dir, f"{RNG_STATE_NAME}_{process_index}.pkl"))
states = torch.load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
random.setstate(states["random_state"])
np.random.set_state(states["numpy_random_seed"])
torch.set_rng_state(states["torch_manual_seed"])
Expand Down
3 changes: 2 additions & 1 deletion src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from .testing import (
SAVE_TYPES,
are_the_same_tensors,
assert_exception,
execute_subprocess_async,
parameterized_custom_name_func,
require_bnb,
require_cpu,
require_cuda,
require_huggingface_suite,
require_mps,
require_multi_gpu,
require_multi_xpu,
require_safetensors,
require_single_gpu,
require_single_xpu,
require_torch_min_version,
Expand Down
22 changes: 13 additions & 9 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from unittest import mock

import torch
from parameterized import parameterized

from ..state import AcceleratorState, PartialState
from ..utils import (
Expand All @@ -37,7 +38,6 @@
is_deepspeed_available,
is_mps_available,
is_pandas_available,
is_safetensors_available,
is_tensorboard_available,
is_timm_available,
is_torch_version,
Expand Down Expand Up @@ -65,6 +65,18 @@ def parse_flag_from_env(key, default=False):
return _value


def parameterized_custom_name_func(func, param_num, param):
# customize the test name generator function as we want both params to appear in the sub-test
# name, as by default it shows only the first param
param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
return f"{func.__name__}_{param_based_name}"


SAFETENSORS = True
PYTORCH = False
SAVE_TYPES = (SAFETENSORS, PYTORCH)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)


Expand Down Expand Up @@ -179,14 +191,6 @@ def require_multi_xpu(test_case):
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)


def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors installed. These tests are skipped when safetensors isn't
installed
"""
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)


def require_deepspeed(test_case):
"""
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
Expand Down
1 change: 0 additions & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
is_npu_available,
is_pandas_available,
is_rich_available,
is_safetensors_available,
is_sagemaker_available,
is_tensorboard_available,
is_timm_available,
Expand Down
4 changes: 0 additions & 4 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ def is_megatron_lm_available():
return False


def is_safetensors_available():
return _is_package_available("safetensors")


def is_transformers_available():
return _is_package_available("transformers")

Expand Down
12 changes: 4 additions & 8 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..state import AcceleratorState
from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from .dataclasses import AutocastKwargs, CustomDtype, DistributedType
from .imports import is_mps_available, is_npu_available, is_safetensors_available, is_xpu_available
from .imports import is_mps_available, is_npu_available, is_xpu_available
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm

Expand All @@ -39,9 +39,9 @@
import torch_npu # noqa: F401


if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file


WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"

Expand Down Expand Up @@ -1156,10 +1156,6 @@ def load_state_dict(checkpoint_file, device_map=None):
name, once a given module name is inside, every submodule of it will be sent to the same device.
"""
if checkpoint_file.endswith(".safetensors"):
if not is_safetensors_available():
raise ImportError(
f"To load {checkpoint_file}, the `safetensors` library is necessary `pip install safetensors`."
)
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
weight_names = f.keys()
Expand Down
8 changes: 1 addition & 7 deletions src/accelerate/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

import numpy as np
import torch

from .imports import is_safetensors_available
from safetensors import safe_open


def offload_weight(weight, weight_name, offload_folder, index=None):
Expand Down Expand Up @@ -165,11 +164,6 @@ def __getitem__(self, key: str):
return self.state_dict[key]
weight_info = self.index[key]
if weight_info.get("safetensors_file") is not None:
if not is_safetensors_available():
raise ImportError("These offloaded weights require the use of safetensors: `pip install safetensors`.")

from safetensors import safe_open

device = "cpu" if self.device is None else self.device
with safe_open(weight_info["safetensors_file"], framework="pt", device=device) as f:
tensor = f.get_tensor(weight_info.get("weight_name", key))
Expand Down
Loading
Loading