Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZeRO++ to DeepSpeed usage docs #2166

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions docs/source/usage_guides/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ rendered properly in your Markdown viewer.

# DeepSpeed

[DeepSpeed](https://github.com/microsoft/DeepSpeed) implements everything described in the [ZeRO paper](https://arxiv.org/abs/1910.02054). Currently, it provides full support for:
[DeepSpeed](https://github.com/microsoft/DeepSpeed) implements everything described in the [ZeRO paper](https://arxiv.org/abs/1910.02054). Some of the salient optimizations are:

1. Optimizer state partitioning (ZeRO stage 1)
2. Gradient partitioning (ZeRO stage 2)
3. Parameter partitioning (ZeRO stage 3)
4. Custom mixed precision training handling
5. A range of fast CUDA-extension-based optimizers
6. ZeRO-Offload to CPU and Disk/NVMe
7. Heirarchical partitioning of model parameters (ZeRO++)

ZeRO-Offload has its own dedicated paper: [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840). And NVMe-support is described in the paper [ZeRO-Infinity: Breaking the GPU
Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857).
Expand All @@ -44,7 +45,7 @@ won't be possible on a single GPU.

Training:

1. DeepSpeed ZeRO training supports the full ZeRO stages 1, 2 and 3 as well as CPU/Disk offload of optimizer states, gradients and parameters.
1. 🤗 Accelerate integrates all features of DeepSpeed ZeRO. This includes all the ZeRO stages 1, 2 and 3 as well as ZeRO-Offload, ZeRO-Infinity (which can offload to disk/NVMe) and ZeRO++.
Below is a short description of Data Parallelism using ZeRO - Zero Redundancy Optimizer along with diagram from this [blog post](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
![ZeRO Data Parallelism](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-zero.png)

Expand All @@ -60,6 +61,8 @@ Below is a short description of Data Parallelism using ZeRO - Zero Redundancy Op

e. **Param Offload**: Offloads the model parameters to CPU/Disk building on top of ZERO Stage 3

f. **Heirarchical Paritioning**: Enables efficient multi-node training with data-parallel training across nodes and ZeRO-3 sharding within a node, built on top of ZeRO Stage 3.

<u>Note</u>: With respect to Disk Offload, the disk should be an NVME for decent speed but it technically works on any Disk

Inference:
Expand Down Expand Up @@ -349,6 +352,27 @@ accelerate launch examples/by_feature/deepspeed_with_config_support.py \
--report_to "wandb"\
```

**ZeRO++ Config Example**
You can use the the features of ZeRO++ by using the appropriate config parameters. Note that ZeRO++ is an extension for ZeRO Stage 3. Here is how the config file can be modified, from [DeepSpeed's ZeRO++ tutorial](https://www.deepspeed.ai/tutorials/zeropp/):

```json
{
"zero_optimization": {
"stage": 3,
"reduce_bucket_size": "auto",

"zero_quantized_weights": true,
"zero_hpz_partition_size": 8,
"zero_quantized_gradients": true,

"contiguous_gradients": true,
"overlap_comm": true
}
}
```

For heirarchical partitioning, the partition size `zero_hpz_partition_size` should ideally be set to the number of GPUs per node. (For example, the above config file assumes 8 GPUs per node)

**Important code changes when using DeepSpeed Config File**

1. DeepSpeed Optimizers and Schedulers. For more information on these,
Expand Down Expand Up @@ -683,6 +707,8 @@ Papers:
- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054)
- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
- [ZeRO++: Extremely Efficient Collective Communication for Giant Model Training](https://arxiv.org/abs/2306.10209)


Finally, please, remember that 🤗 `Accelerate` only integrates DeepSpeed, therefore if you
have any problems or questions with regards to DeepSpeed usage, please, file an issue with [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/issues).
Expand Down
Loading