From 3ee349accb60234eb769b2ad4a7a3ee5af8db4af Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 9 Nov 2023 13:02:03 -0500 Subject: [PATCH 1/2] Custom objects are not saved using saftensors --- src/accelerate/checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 11d30d9fef1..0d84e0d8882 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -263,7 +263,7 @@ def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False # Should this be the right way to get a qual_name type value from `obj`? save_location = Path(path) / f"custom_checkpoint_{index}.pkl" logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}") - save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node) + save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node, safe_serialization=False) def load_custom_state(obj, path, index: int = 0): From 25a27679e818bb23eb7f7aa30c96f970a32f635d Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 9 Nov 2023 13:04:55 -0500 Subject: [PATCH 2/2] Leave save as false --- src/accelerate/checkpointing.py | 2 +- src/accelerate/utils/other.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 0d84e0d8882..11d30d9fef1 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -263,7 +263,7 @@ def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False # Should this be the right way to get a qual_name type value from `obj`? save_location = Path(path) / f"custom_checkpoint_{index}.pkl" logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}") - save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node, safe_serialization=False) + save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node) def load_custom_state(obj, path, index: int = 0): diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 285dd0a5ad9..43a6f799e59 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -115,7 +115,7 @@ def wait_for_everyone(): PartialState().wait_for_everyone() -def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = True): +def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): """ Save the data to disk. Use in place of `torch.save()`. @@ -126,8 +126,8 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Tru The file (or file-like object) to use to save the data save_on_each_node (`bool`, *optional*, defaults to `False`): Whether to only save on the global main process - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"}) if PartialState().distributed_type == DistributedType.TPU: