Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Trainer, adds save_only_model arg and simplifying FSDP integration #27652

Merged
merged 13 commits into from
Nov 24, 2023
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow>=10.0.1,<=15.0",
"accelerate>=0.20.3",
"accelerate>=0.21.0",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4",
"codecarbon==1.2.0",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow>=10.0.1,<=15.0",
"accelerate": "accelerate>=0.20.3",
"accelerate": "accelerate>=0.21.0",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"codecarbon": "codecarbon==1.2.0",
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def is_fsdp_enabled():
)


def is_fsdp_enabled_and_dist_rank_0():
def is_fsdp_enabled_and_local_dist_rank_0():
return is_fsdp_enabled() and int(os.environ.get("LOCAL_RANK", -1)) == 0


Expand Down Expand Up @@ -473,14 +473,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
)
return safe_load_file(checkpoint_file)
try:
if (
(is_deepspeed_zero3_enabled() or is_fsdp_enabled())
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
):
if (is_deepspeed_zero3_enabled()) and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0:
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
map_location = "meta"
else:
map_location = "cpu"

map_location = "cpu" if is_fsdp_enabled_and_local_dist_rank_0 else "meta"
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
return torch.load(checkpoint_file, map_location=map_location)
except Exception as e:
try:
Expand Down Expand Up @@ -3904,7 +3902,7 @@ def _find_mismatched_keys(
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
if not is_fsdp_enabled() or is_fsdp_enabled_and_local_dist_rank_0():
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
Expand Down
226 changes: 89 additions & 137 deletions src/transformers/trainer.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,8 @@ class FSDPOption(ExplicitEnum):
FULL_SHARD = "full_shard"
SHARD_GRAD_OP = "shard_grad_op"
NO_SHARD = "no_shard"
HYBRID_SHARD = "hybrid_shard"
HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"

Expand Down
32 changes: 25 additions & 7 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ class TrainingArguments:

This should not be activated when the different nodes use the same storage as the files will be saved with
the same names for each node.
save_only_model (`bool`, *optional*, defaults to `False`):
When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
Note that when this is true, you won't be able to resume training from checkpoint.
This enables you to save storage by not storing the optimizer, scheduler & rng state.
You can only load the model using `from_pretrained` with this option set to `True`.
use_cpu (`bool`, *optional*, defaults to `False`):
Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
seed (`int`, *optional*, defaults to 42):
Expand Down Expand Up @@ -418,12 +423,14 @@ class TrainingArguments:

- `"full_shard"`: Shard parameters, gradients and optimizer states.
- `"shard_grad_op"`: Shard optimizer states and gradients.
- `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes.
- `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes.
- `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
`"shard_grad_op"`).
- `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
fsdp_config (`str` or `dict`, *optional*):
Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of
deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.
fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.

A List of config and its options:
- min_num_params (`int`, *optional*, defaults to `0`):
Expand Down Expand Up @@ -452,14 +459,18 @@ class TrainingArguments:
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- use_orig_params (`bool`, *optional*, defaults to `False`)
- use_orig_params (`bool`, *optional*, defaults to `True`)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this set to False by default before - is there an advantage to not having this enabled? Asking in case this introduces a degraded experience for some users

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, this is required now for simplifying the FSDP integration. Please find the explanation in the corresponding Accelerate PR: huggingface/accelerate#2177 (comment)

If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please
refer this
[blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
- sync_module_states (`bool`, *optional*, defaults to `True`)
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
ensure they are the same across all ranks after initialization
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
computation time for reduced memory usage.
- xla (`bool`, *optional*, defaults to `False`):
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
and its API may evolve in the future.
Expand All @@ -472,10 +483,6 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If True, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
computation time for reduced memory usage.

deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
Expand Down Expand Up @@ -835,6 +842,17 @@ class TrainingArguments:
)
},
)
save_only_model: bool = field(
default=False,
metadata={
"help": (
"When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state."
"Note that when this is true, you won't be able to resume training from checkpoint."
"This enables you to save storage by not storing the optimizer, scheduler & rng state."
"You can only load the model using from_pretrained with this option set to True."
)
},
)
no_cuda: bool = field(
default=False,
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
Expand Down Expand Up @@ -1670,7 +1688,7 @@ def __post_init__(self):
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")

if self.tpu_metrics_debug:
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def is_protobuf_available():
return importlib.util.find_spec("google.protobuf") is not None


def is_accelerate_available(min_version: str = None):
def is_accelerate_available(min_version: str = "0.21.0"):
if min_version is not None:
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
return _accelerate_available
Expand Down
Loading