Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fsdp refactoring #2177

Merged
merged 8 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 18 additions & 64 deletions docs/source/usage_guides/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,30 @@ For instance, here is how you would run the NLP example (from the root of the re

```bash
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: 'no'
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```

Expand All @@ -66,7 +73,7 @@ accelerate launch examples/nlp_example.py

Currently, `Accelerate` supports the following config through the CLI:

```bash

`Sharding Strategy`: [1] FULL_SHARD (shards optimizer states, gradients and parameters), [2] SHARD_GRAD_OP (shards optimizer states and gradients), [3] NO_SHARD (DDP), [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy), [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy)

`Offload Params`: Decides Whether to offload parameters and gradients to CPU
Expand Down Expand Up @@ -94,12 +101,12 @@ all-gather while executing in the forward pass. only use with Static graphs.

`Use Orig Params`: If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres.
Useful in cases such as parameter-efficient fine-tuning.
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019)
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP.

`CPU RAM Efficient Model loading`: If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for 🤗 Transformers models. This should be set to False if you experience errors when loading the pretrained 🤗 Transformers model via `from_pretrained` method. When using this, `Sync Module States` needs to be True else all the processes expect the main process would have random empty weights leading to unexpected behaviour during training.

`Sync Module States`: If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0
```


For additional and more nuanced control, you can specify other FSDP parameters via `FullyShardedDataParallelPlugin`.
When creating `FullyShardedDataParallelPlugin` object, pass it the parameters that weren't part of the accelerate config or if you want to override them.
Expand Down Expand Up @@ -156,72 +163,19 @@ When using transformers `save_pretrained`, pass `state_dict=accelerator.get_stat
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
+ state_dict=accelerator.get_state_dict(model, unwrap=False),
+ state_dict=accelerator.get_state_dict(model),
)
```

### State Dict

`accelerator.get_state_dict` will call the underlying `model.state_dict` implementation. With a model wrapped by FSDP, the default behavior of `state_dict` is to gather all of the state in the rank 0 device. This can cause CUDA out of memory errors if the parameters don't fit on a single GPU.

To avoid this, PyTorch provides a context manager that adjusts the behavior of `state_dict`. To offload some of the state dict onto CPU, you can use the following code:

```
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig

full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(unwrapped_model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
state = accelerator.get_state_dict(unwrapped_model)
```
`accelerator.get_state_dict` will call the underlying `model.state_dict` implementation using `FullStateDictConfig(offload_to_cpu=True, rank0_only=True)` context manager to get the state dict only for rank 0 and it will be offloaded to CPU.

You can then pass `state` into the `save_pretrained` method. There are several modes for `StateDictType` and `FullStateDictConfig` that you can use to control the behavior of `state_dict`. For more information, see the [PyTorch documentation](https://pytorch.org/docs/stable/fsdp.html).

## A few caveats to be aware of

- PyTorch FSDP auto wraps sub-modules, flattens the parameters and shards the parameters in place.
Due to this, any optimizer created before model wrapping gets broken and occupies more memory.
Hence, it is highly recommended and efficient to prepare the model before creating the optimizer.
`Accelerate` will automatically wrap the model and create an optimizer for you in case of single model with a warning message.
> FSDP Warning: When using FSDP, it is efficient and recommended to call prepare for the model before creating the optimizer

However, below is the recommended way to prepare model and optimizer while using FSDP:

```diff
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
+ model = accelerator.prepare(model)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)

- model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
- model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
- )

+ optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
+ optimizer, train_dataloader, eval_dataloader, lr_scheduler
+ )
```

- In case of a single model, if you have created the optimizer with multiple parameter groups and called prepare with them together,
then the parameter groups will be lost and the following warning is displayed:
> FSDP Warning: When using FSDP, several parameter groups will be conflated into
> a single one due to nested module wrapping and parameter flattening.

This is because parameter groups created before wrapping will have no meaning post wrapping due to parameter flattening of nested FSDP modules into 1D arrays (which can consume many layers).
For instance, below are the named parameters of an FSDP model on GPU 0 (When using 2 GPUs. Around 55M (110M/2) params in 1D arrays as this will have the 1st shard of the parameters).
Here, if one has applied no weight decay for [bias, LayerNorm.weight] the named parameters of an unwrapped BERT model,
it can't be applied to the below FSDP wrapped model as there are no named parameters with either of those strings and
the parameters of those layers are concatenated with parameters of various other layers.
```
{
'_fsdp_wrapped_module.flat_param': torch.Size([494209]),
'_fsdp_wrapped_module._fpw_module.bert.embeddings.word_embeddings._fsdp_wrapped_module.flat_param': torch.Size([11720448]),
'_fsdp_wrapped_module._fpw_module.bert.encoder._fsdp_wrapped_module.flat_param': torch.Size([42527232])
}
```


- In case of multiple models, it is necessary to prepare the models before creating optimizers or else it will throw an error.
Then pass the optimizers to the prepare call in the same order as corresponding models else `accelerator.save_state()` and `accelerator.load_state()` will result in wrong/unexpected behaviour.
- In case of multiple models, pass the optimizers to the prepare call in the same order as corresponding models else `accelerator.save_state()` and `accelerator.load_state()` will result in wrong/unexpected behaviour.
- This feature is incompatible with `--predict_with_generate` in the `run_translation.py` script of 🤗 `Transformers` library.

For more control, users can leverage the `FullyShardedDataParallelPlugin`. After creating an instance of this class, users can pass it to the Accelerator class instantiation.
Expand Down
32 changes: 15 additions & 17 deletions examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,19 @@ def collate_fn(examples):
args.model_name_or_path, return_dict=True, low_cpu_mem_usage=True
)

# New Code #
# For FSDP feature, it is highly recommended and efficient to prepare the model before creating optimizer
model = accelerator.prepare(model)
accelerator.print(model)

# Instantiate optimizer
# New Code #
# For FSDP feature, at present it doesn't support multiple parameter groups,
# so we need to create a single parameter group for the whole model
optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr, weight_decay=2e-4)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.003,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]

optimizer = torch.optim.AdamW(params=optimizer_grouped_parameters, lr=lr, weight_decay=2e-4)

# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
Expand All @@ -265,13 +268,8 @@ def collate_fn(examples):
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)

# New Code #
# For FSDP feature, prepare everything except the model as we have already prepared the model
# before creating the optimizer
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
# prepare method.
optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
optimizer, train_dataloader, eval_dataloader, lr_scheduler
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

overall_step = 0
Expand Down
97 changes: 13 additions & 84 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,52 +1100,6 @@ def _prepare_one(self, obj, first_pass=False, device_placement=None):
# Return the unprocessed object if previous criteria was not met
return obj

def _prepare_fsdp(self, *args):
result = []
for obj in args:
if isinstance(obj, torch.nn.Module):
model = obj
break
optimizers = []

self._schedulers = []
self._models = []
intermediate_result = []
for obj in args:
if isinstance(obj, torch.optim.Optimizer):
if len(obj.param_groups) > 1:
logger.warning(
"FSDP Warning: When using FSDP, several parameter groups will be conflated into "
"a single one due to nested module wrapping and parameter flattening."
)
try:
optimizer = obj.optimizer.__class__(model.parameters(), **obj.optimizer.defaults)
except TypeError:
if "differentiable" in obj.optimizer.defaults:
# https://github.com/huggingface/accelerate/issues/801
defaults = {k: v for k, v in obj.optimizer.defaults.items() if k != "differentiable"}
optimizer = obj.optimizer.__class__(model.parameters(), **defaults)
else:
raise
obj = self.prepare_optimizer(optimizer)
optimizers.append(obj)
elif isinstance(obj, torch.nn.Module):
self._models.append(obj)
intermediate_result.append(obj)

for obj in intermediate_result:
if isinstance(obj, AcceleratedScheduler):
obj.optimizer = optimizers
for i, opt in enumerate(self._optimizers):
if getattr(obj.scheduler, "optimizer", None) == opt.optimizer:
obj.scheduler.optimizer = optimizers[i]
obj.optimizers = [optimizers[i]]
break
self._schedulers.append(obj)
result.append(obj)
self._optimizers = optimizers
return tuple(result)

def prepare(self, *args, device_placement=None):
"""
Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same
Expand Down Expand Up @@ -1214,35 +1168,6 @@ def prepare(self, *args, device_placement=None):
" Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
)

if self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

model_count = 0
optimizer_present = False
is_type_fsdp = False
for obj in args:
if isinstance(obj, torch.nn.Module):
model_count += 1
# if the model is compiled using PyTorch 2.0,
# check that the wrapped model is FSDP or not;
# else check if it is FSDP or not;
is_type_fsdp = isinstance(obj, FSDP) or (
is_compiled_module(obj) and isinstance(obj._orig_mod, FSDP)
)
if isinstance(obj, torch.optim.Optimizer):
optimizer_present = True
if model_count > 1 and optimizer_present:
raise ValueError(
"For FSDP to work with multiple models (>1), "
"prepare must be called for all the models before optimizers are created. "
"Then pass the optimizers to the prepare call in the same order as corresponding models."
)
elif model_count == 1 and not is_type_fsdp and optimizer_present:
logger.warning(
"FSDP Warning: When using FSDP, "
"it is efficient and recommended to call prepare for the model before creating the optimizer"
)

if self.distributed_type == DistributedType.DEEPSPEED:
model_count = 0
for obj in args:
Expand Down Expand Up @@ -1298,14 +1223,6 @@ def prepare(self, *args, device_placement=None):
if isinstance(obj, torch.optim.Optimizer):
obj._switch_parameters(mapping)

if (
self.distributed_type == DistributedType.FSDP
and model_count == 1
and not is_type_fsdp
and optimizer_present
):
result = self._prepare_fsdp(*result)

for item in result:
if any(
item in container
Expand Down Expand Up @@ -2753,7 +2670,7 @@ def _inner(folder):
# Save the optimizers taking care of FSDP and DeepSpeed nuances
optimizers = []
if self.distributed_type == DistributedType.FSDP:
for opt in self._optimizers:
for i, opt in enumerate(self._optimizers):
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Saving FSDP Optimizer")
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
Expand Down Expand Up @@ -3068,6 +2985,18 @@ def get_state_dict(self, model, unwrap=True):
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save

state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict())
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import (
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)

full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
state_dict = model.state_dict()
else:
if unwrap:
model = self.unwrap_model(model)
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ def get_cluster_input():
error_message="Please enter yes or no.",
)
fsdp_config["fsdp_use_orig_params"] = _ask_field(
"Do you want to enable FSDP's `use_orig_params` feature? [yes/NO]: ",
"Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ",
_convert_yes_no_to_bool,
default=False,
default=True,
error_message="Please enter yes or no.",
)
fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field(
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def launch_command_parser(subparsers=None):
)
fsdp_args.add_argument(
"--fsdp_use_orig_params",
default="false",
default="true",
type=str,
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
" (useful only when `use_fsdp` flag is passed).",
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"]
FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"]
FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
FSDP_PYTORCH_VERSION = "2.0.1"
FSDP_PYTORCH_VERSION = "2.1.0"
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"]
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]

Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ class FullyShardedDataParallelPlugin:
},
)
use_orig_params: bool = field(
default=False,
default=True,
metadata={
"help": "If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. "
"Useful in cases such as parameter-efficient fine-tuning. "
Expand Down
18 changes: 8 additions & 10 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,14 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
):
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
optim_state = None
# below check should work but currently it isn't working (mostly opytorch issue),
# in the meantime disabling it at the cost of excess memory usage
# if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
optimizer_name = (
f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
)
input_optimizer_file = os.path.join(input_dir, optimizer_name)
logger.info(f"Loading Optimizer state from {input_optimizer_file}")
optim_state = torch.load(input_optimizer_file)
logger.info(f"Optimizer state loaded from {input_optimizer_file}")
if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
optimizer_name = (
f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
)
input_optimizer_file = os.path.join(input_dir, optimizer_name)
logger.info(f"Loading Optimizer state from {input_optimizer_file}")
optim_state = torch.load(input_optimizer_file)
logger.info(f"Optimizer state loaded from {input_optimizer_file}")
else:
ckpt_dir = (
os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
Expand Down
Loading
Loading