diff --git a/docs/source/en/main_classes/trainer.md b/docs/source/en/main_classes/trainer.md index 7f85d6d72ad02..7304de8174dcd 100644 --- a/docs/source/en/main_classes/trainer.md +++ b/docs/source/en/main_classes/trainer.md @@ -426,8 +426,7 @@ To read more about it and the benefits, check out the [Fully Sharded Data Parall We have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature. All you need to do is enable it through the config. -**Required PyTorch version for FSDP support**: PyTorch Nightly (or 1.12.0 if you read this after it has been released) -as the model saving with FSDP activated is only available with recent fixes. +**Required PyTorch version for FSDP support**: PyTorch >=2.1.0 **Usage**: @@ -440,6 +439,8 @@ as the model saving with FSDP activated is only available with recent fixes. - SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs. For this, add `--fsdp shard_grad_op` to the command line arguments. - NO_SHARD : No sharding. For this, add `--fsdp no_shard` to the command line arguments. + - HYBRID_SHARD : No sharding. For this, add `--fsdp hybrid_shard` to the command line arguments. + - HYBRID_SHARD_ZERO2 : No sharding. For this, add `--fsdp hybrid_shard_zero2` to the command line arguments. - To offload the parameters and gradients to the CPU, add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments. - To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`, @@ -449,18 +450,18 @@ as the model saving with FSDP activated is only available with recent fixes. - Remaining FSDP config is passed via `--fsdp_config `. It is either a location of FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`. - If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy. - - For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available. + - For transformer based auto wrap policy, it is recommended to specify `transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available. This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] .... This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units. Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit. Therefore, use this for transformer based models. - - For size based auto wrap policy, please add `fsdp_min_num_params` in the config file. + - For size based auto wrap policy, please add `min_num_params` in the config file. It specifies FSDP's minimum number of parameters for auto wrapping. - - `fsdp_backward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters. + - `backward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters. `backward_pre` and `backward_pos` are available options. For more information refer `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch` - - `fsdp_forward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters. + - `forward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters. If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. - `limit_all_gathers` can be specified in the config file. If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. @@ -468,6 +469,20 @@ as the model saving with FSDP activated is only available with recent fixes. If `"True"`, FSDP activation checkpointing is a technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage. + - `use_orig_params` can be specified in the config file. + If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP. + Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). + +**Saving and loading** +Saving entire intermediate checkpoints using `FULL_STATE_DICT` state_dict_type with CPU offloading on rank 0 takes a lot of time and often results in NCCL Timeout errors due to indefinite hanging during broadcasting. However, at the end of training, we want the whole model state dict instead of the sharded state dict which is only compatible with FSDP. Use `SHARDED_STATE_DICT` (default) state_dict_type to save the intermediate checkpoints and optimizer states in this format recommended by the PyTorch team. + +Saving the final checkpoint in transformers format using default `safetensors` format requires below changes. +```python +if trainer.is_fsdp_enabled: + trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + +trainer.save_model(script_args.output_dir) +``` **Few caveats to be aware of** - it is incompatible with `generate`, thus is incompatible with `--predict_with_generate` @@ -492,15 +507,15 @@ Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_co https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py). - `xla_fsdp_grad_ckpt`. When `True`, uses gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through - `fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`. + `min_num_params` or `transformer_layer_cls_to_wrap`. - You can either use transformer based auto wrap policy or size based auto wrap policy. - - For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available. + - For transformer based auto wrap policy, it is recommended to specify `transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available. This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] .... This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units. Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit. Therefore, use this for transformer based models. - - For size based auto wrap policy, please add `fsdp_min_num_params` in the config file. + - For size based auto wrap policy, please add `min_num_params` in the config file. It specifies FSDP's minimum number of parameters for auto wrapping. diff --git a/setup.py b/setup.py index a51f2a7a5266a..eb240c8172f0f 100644 --- a/setup.py +++ b/setup.py @@ -96,7 +96,7 @@ # 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py _deps = [ "Pillow>=10.0.1,<=15.0", - "accelerate>=0.20.3", + "accelerate>=0.21.0", "av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream. "beautifulsoup4", "codecarbon==1.2.0", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 08fddd2e1ecc6..5bef2ec9e22e7 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -3,7 +3,7 @@ # 2. run `make deps_table_update`` deps = { "Pillow": "Pillow>=10.0.1,<=15.0", - "accelerate": "accelerate>=0.20.3", + "accelerate": "accelerate>=0.21.0", "av": "av==9.2.0", "beautifulsoup4": "beautifulsoup4", "codecarbon": "codecarbon==1.2.0", diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e2b27de7d1e51..d60d795a0f93d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -132,8 +132,12 @@ def is_fsdp_enabled(): ) -def is_fsdp_enabled_and_dist_rank_0(): - return is_fsdp_enabled() and int(os.environ.get("LOCAL_RANK", -1)) == 0 +def is_local_dist_rank_0(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and int(os.environ.get("LOCAL_RANK", -1)) == 0 + ) if is_sagemaker_mp_enabled(): @@ -474,13 +478,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): return safe_load_file(checkpoint_file) try: if ( - (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) - and torch.distributed.is_initialized() - and torch.distributed.get_rank() > 0 - ): + is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0 + ) or (is_fsdp_enabled() and not is_local_dist_rank_0()): map_location = "meta" else: map_location = "cpu" + return torch.load(checkpoint_file, map_location=map_location) except Exception as e: try: @@ -3904,7 +3907,18 @@ def _find_mismatched_keys( ignore_mismatched_sizes, ) if low_cpu_mem_usage: - if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): + if is_fsdp_enabled() and not is_local_dist_rank_0(): + for key, param in model_to_load.state_dict().items(): + if param.device == torch.device("meta"): + if not (is_quantized): + set_module_tensor_to_device( + model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + set_module_quantized_tensor_to_device( + model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, state_dict, @@ -3922,17 +3936,6 @@ def _find_mismatched_keys( keep_in_fp32_modules=keep_in_fp32_modules, ) error_msgs += new_error_msgs - else: - for key, param in model_to_load.state_dict().items(): - if param.device == torch.device("meta"): - if not (is_quantized): - set_module_tensor_to_device( - model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) - ) - else: - set_module_quantized_tensor_to_device( - model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) - ) else: error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0bb123d0e7c2e..7a4fcd129cb3e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -99,7 +99,6 @@ BestRun, EvalLoopOutput, EvalPrediction, - FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy, @@ -193,15 +192,15 @@ if is_accelerate_available(): from accelerate import Accelerator, skip_first_batches from accelerate import __version__ as accelerate_version - from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin - - if version.parse(accelerate_version) > version.parse("0.20.3"): - from accelerate.utils import ( - load_fsdp_model, - load_fsdp_optimizer, - save_fsdp_model, - save_fsdp_optimizer, - ) + from accelerate.utils import ( + DistributedDataParallelKwargs, + GradientAccumulationPlugin, + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, + ) + DATA_SAMPLERS = [RandomSampler] if version.parse(accelerate_version) > version.parse("0.23.0"): from accelerate.data_loader import SeedableRandomSampler @@ -226,6 +225,7 @@ OPTIMIZER_NAME_BIN = "optimizer.bin" SCHEDULER_NAME = "scheduler.pt" SCALER_NAME = "scaler.pt" +FSDP_MODEL_NAME = "pytorch_model_fsdp" class Trainer: @@ -415,7 +415,7 @@ def __init__( " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " ) - self.fsdp = None + self.is_fsdp_xla_enabled = args.fsdp_config["xla"] if len(args.fsdp) > 0: if self.is_deepspeed_enabled: raise ValueError( @@ -424,32 +424,6 @@ def __init__( if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using fsdp only works in distributed training.") - # dep_version_check("torch>=1.12.0") - # Would have to update setup.py with torch>=1.12.0 - # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 - # below is the current alternative. - if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): - raise ValueError("FSDP requires PyTorch >= 1.12.0") - - from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy - - if FSDPOption.FULL_SHARD in args.fsdp: - self.fsdp = ShardingStrategy.FULL_SHARD - elif FSDPOption.SHARD_GRAD_OP in args.fsdp: - self.fsdp = ShardingStrategy.SHARD_GRAD_OP - elif FSDPOption.NO_SHARD in args.fsdp: - self.fsdp = ShardingStrategy.NO_SHARD - - self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE - if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get( - "backward_prefetch", [] - ): - self.backward_prefetch = BackwardPrefetch.BACKWARD_POST - - self.limit_all_gathers = False - if self.args.fsdp_config.get("limit_all_gathers", False): - self.limit_all_gathers = True - # one place to sort out whether to place the model on device or not # postpone switching model to cuda when: # 1. MP - since we are trying to fit a much bigger than 1 gpu model @@ -462,7 +436,7 @@ def __init__( self.is_model_parallel or self.is_deepspeed_enabled or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) - or (self.fsdp is not None) + or self.is_fsdp_xla_enabled or self.is_fsdp_enabled ): self.place_model_on_device = False @@ -513,7 +487,7 @@ def __init__( " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) - if (self.is_deepspeed_enabled or (self.fsdp is not None)) and ( + if (self.is_deepspeed_enabled or self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( @@ -1367,7 +1341,7 @@ def _wrap_model(self, model, training=True, dataloader=None): # Distributed training (should be after apex fp16 initialization) # Distributed training using PyTorch FSDP - if self.fsdp is not None and self.args.fsdp_config["xla"]: + if self.is_fsdp_xla_enabled: try: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import checkpoint_module @@ -1626,7 +1600,7 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: @@ -1676,8 +1650,6 @@ def _inner_training_loop( use_accelerator_prepare = True if model is self.model else False if delay_optimizer_creation: - if use_accelerator_prepare: - self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare @@ -1895,9 +1867,7 @@ def _inner_training_loop( ): # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered # in accelerate. So, explicitly enable sync gradients to True in that case. - if is_last_step_and_steps_less_than_grad_acc or ( - version.parse(accelerate_version) <= version.parse("0.20.3") - ): + if is_last_step_and_steps_less_than_grad_acc: self.accelerator.gradient_state._set_sync_gradients(True) # Gradient clipping @@ -2051,7 +2021,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any( - WEIGHTS_NAME.split(".")[0] in folder_name + FSDP_MODEL_NAME in folder_name for folder_name in os.listdir(resume_from_checkpoint) if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) ) @@ -2360,56 +2330,12 @@ def _save_checkpoint(self, model, trial, metrics=None): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) - if self.is_deepspeed_enabled: - # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed - # config `stage3_gather_16bit_weights_on_model_save` is True - self.model_wrapped.save_checkpoint(output_dir) - # Save optimizer and scheduler - if self.fsdp or self.is_fsdp_enabled: - if self.is_fsdp_enabled: - save_fsdp_optimizer( - self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir - ) - else: - # FSDP has a different interface for saving optimizer states. - # Needs to be called on all ranks to gather all states. - # full_optim_state_dict will be deprecated after Pytorch 2.2! - full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) - torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) - - if is_torch_tpu_available(): - xm.rendezvous("saving_optimizer_states") - xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) - with warnings.catch_warnings(record=True) as caught_warnings: - xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) - reissue_pt_warnings(caught_warnings) - elif is_sagemaker_mp_enabled(): - opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) - smp.barrier() - if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: - smp.save( - opt_state_dict, - os.path.join(output_dir, OPTIMIZER_NAME), - partial=True, - v3=smp.state.cfg.shard_optimizer_state, - ) - elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled): - # deepspeed.save_checkpoint above saves model/optim/sched - torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) - - # Save SCHEDULER & SCALER - is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance( - self.lr_scheduler, DeepSpeedSchedulerWrapper - ) - if ( - self.args.should_save - and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) - and not is_torch_tpu_available() - ): - with warnings.catch_warnings(record=True) as caught_warnings: - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) - reissue_pt_warnings(caught_warnings) + if not self.args.save_only_model: + # Save optimizer and scheduler + self._save_optimizer_and_scheduler(output_dir) + # Save RNG state + self._save_rng_state(output_dir) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -2431,6 +2357,14 @@ def _save_checkpoint(self, model, trial, metrics=None): if self.args.should_save: self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + def _save_rng_state(self, output_dir): # Save RNG state in non-distributed training rng_states = { "python": random.getstate(), @@ -2462,12 +2396,49 @@ def _save_checkpoint(self, model, trial, metrics=None): else: torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) - if self.args.push_to_hub: - self._push_from_checkpoint(output_dir) + def _save_optimizer_and_scheduler(self, output_dir): + if is_torch_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + smp.barrier() + if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + elif self.is_deepspeed_enabled: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_16bit_weights_on_model_save` is True + self.model_wrapped.save_checkpoint(output_dir) + elif self.is_fsdp_enabled: + # save fsdp specific ckpt for resuming from ckpt + save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) + save_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir + ) + elif self.args.should_save: + # deepspeed.save_checkpoint above saves model/optim/sched + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) - # Maybe delete some older checkpoints. - if self.args.should_save: - self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + # Save SCHEDULER & SCALER + is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance( + self.lr_scheduler, DeepSpeedSchedulerWrapper + ) + if ( + self.args.should_save + and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) + and not is_torch_tpu_available() + ): + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" @@ -2535,23 +2506,14 @@ def opt_load_hook(mod, opt): # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more # likely to get OOM on CPU (since we load num_gpu times the optimizer state map_location = self.args.device if self.args.world_size > 1 else "cpu" - if self.fsdp or self.is_fsdp_enabled: - if self.is_fsdp_enabled: - load_fsdp_optimizer( - self.accelerator.state.fsdp_plugin, - self.accelerator, - self.optimizer, - self.model, - checkpoint, - ) - else: - full_osd = None - # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it - if self.args.process_index == 0: - full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) - # call scatter_full_optim_state_dict on all ranks - sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) - self.optimizer.load_state_dict(sharded_osd) + if self.is_fsdp_enabled: + load_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, + self.accelerator, + self.optimizer, + self.model, + checkpoint, + ) else: self.optimizer.load_state_dict( torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) @@ -2826,19 +2788,14 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if IS_SAGEMAKER_MP_POST_1_10: # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 Path(os.path.join(output_dir, "user_content.pt")).touch() - elif self.fsdp is not None or self.is_fsdp_enabled: - state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {} - if self.args.should_save: - self._save(output_dir, state_dict=state_dict) - if self.is_fsdp_enabled: - # remove the dummy state_dict - remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) - save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) - + elif self.is_fsdp_enabled: + if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and ( + version.parse(accelerate_version) > version.parse("0.24.1") + ): + state_dict = self.accelerator.get_state_dict(self.model) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) elif self.is_deepspeed_enabled: - # this takes care of everything as long as we aren't under zero3 - if version.parse(accelerate_version) <= version.parse("0.20.3"): - raise ValueError("Install Accelerate from main branch") try: state_dict = self.accelerator.get_state_dict(self.deepspeed) if self.args.should_save: @@ -3247,11 +3204,7 @@ def evaluation_loop( self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if ( - args.eval_accumulation_steps is not None - and (step + 1) % args.eval_accumulation_steps == 0 - and (self.accelerator.sync_gradients or version.parse(accelerate_version) > version.parse("0.20.3")) - ): + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: if losses_host is not None: losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) @@ -3877,8 +3830,7 @@ def _add_sm_patterns_to_gitignore(self) -> None: def create_accelerator_and_postprocess(self): grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} - if version.parse(accelerate_version) > version.parse("0.20.3"): - grad_acc_kwargs["sync_with_dataloader"] = False + grad_acc_kwargs["sync_with_dataloader"] = False gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) # create accelerator object diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index e6f26d0df5196..dbd868d112024 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -727,6 +727,8 @@ class FSDPOption(ExplicitEnum): FULL_SHARD = "full_shard" SHARD_GRAD_OP = "shard_grad_op" NO_SHARD = "no_shard" + HYBRID_SHARD = "hybrid_shard" + HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2" OFFLOAD = "offload" AUTO_WRAP = "auto_wrap" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b368d86e0ed8e..6146395311087 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -304,6 +304,11 @@ class TrainingArguments: This should not be activated when the different nodes use the same storage as the files will be saved with the same names for each node. + save_only_model (`bool`, *optional*, defaults to `False`): + When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state. + Note that when this is true, you won't be able to resume training from checkpoint. + This enables you to save storage by not storing the optimizer, scheduler & rng state. + You can only load the model using `from_pretrained` with this option set to `True`. use_cpu (`bool`, *optional*, defaults to `False`): Whether or not to use cpu. If set to False, we will use cuda or mps device if available. seed (`int`, *optional*, defaults to 42): @@ -418,12 +423,14 @@ class TrainingArguments: - `"full_shard"`: Shard parameters, gradients and optimizer states. - `"shard_grad_op"`: Shard optimizer states and gradients. + - `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes. + - `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes. - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and `"shard_grad_op"`). - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`. fsdp_config (`str` or `dict`, *optional*): Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of - deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`. + fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`. A List of config and its options: - min_num_params (`int`, *optional*, defaults to `0`): @@ -452,7 +459,7 @@ class TrainingArguments: FSDP's limit_all_gathers (useful only when `fsdp` field is passed). If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. - - use_orig_params (`bool`, *optional*, defaults to `False`) + - use_orig_params (`bool`, *optional*, defaults to `True`) If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please refer this @@ -460,6 +467,10 @@ class TrainingArguments: - sync_module_states (`bool`, *optional*, defaults to `True`) If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization + - activation_checkpointing (`bool`, *optional*, defaults to `False`): + If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of + certain layers and recomputing them during a backward pass. Effectively, this trades extra + computation time for reduced memory usage. - xla (`bool`, *optional*, defaults to `False`): Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature and its API may evolve in the future. @@ -472,10 +483,6 @@ class TrainingArguments: Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap. - - activation_checkpointing (`bool`, *optional*, defaults to `False`): - If True, activation checkpointing is a technique to reduce memory usage by clearing activations of - certain layers and recomputing them during a backward pass. Effectively, this trades extra - computation time for reduced memory usage. deepspeed (`str` or `dict`, *optional*): Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may @@ -835,6 +842,17 @@ class TrainingArguments: ) }, ) + save_only_model: bool = field( + default=False, + metadata={ + "help": ( + "When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state." + "Note that when this is true, you won't be able to resume training from checkpoint." + "This enables you to save storage by not storing the optimizer, scheduler & rng state." + "You can only load the model using from_pretrained with this option set to True." + ) + }, + ) no_cuda: bool = field( default=False, metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."}, @@ -1670,7 +1688,7 @@ def __post_init__(self): os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false") os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true") - os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false") + os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true") if self.tpu_metrics_debug: warnings.warn( diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c4862b197c97e..beb6c4779573e 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -652,7 +652,7 @@ def is_protobuf_available(): return importlib.util.find_spec("google.protobuf") is not None -def is_accelerate_available(min_version: str = None): +def is_accelerate_available(min_version: str = "0.21.0"): if min_version is not None: return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version) return _accelerate_available