Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fcfe263
add 0/1 Adam implementation
EugeneLYC Feb 22, 2022
8885646
add pytest for 0/1 Adam and fix loading issue
EugeneLYC Feb 24, 2022
4e7e9bb
Merge branch 'master' into 01Adam
conglongli Feb 25, 2022
c17741b
fix formatting issue and a type
EugeneLYC Feb 25, 2022
cd69b55
Merge branch 'master' into 01Adam
jeffra Feb 28, 2022
adb7151
Merge branch 'master' into 01Adam
jeffra Mar 3, 2022
6c9bdc1
Merge branch 'master' into 01Adam
conglongli Mar 5, 2022
6ddaf2e
Merge branch 'master' into 01Adam
conglongli Mar 7, 2022
6c1946f
fix mask/warning issues and add 0/1 Adam docs/tutorial
EugeneLYC Mar 7, 2022
99d35a2
fix formatting and add more details to 0/1 Adam tutorial
EugeneLYC Mar 8, 2022
db6f5b8
Merge branch 'master' into 01Adam
jeffra Mar 8, 2022
75d09d9
Merge branch 'master' into 01Adam
conglongli Mar 8, 2022
dfd6cf6
mark new tests as forced sequential
jeffra Mar 8, 2022
3750c2c
disable new tests (testing hang issue in CI)
jeffra Mar 8, 2022
f90a5d7
Revert "disable new tests (testing hang issue in CI)"
conglongli Mar 9, 2022
727a043
fix naive all reduce hanging issue (still need testing)
conglongli Mar 9, 2022
419486e
remove the scaling after naive allreduce
conglongli Mar 9, 2022
7190a99
mention 0/1 Adam in 1-bit Adam tutorial
conglongli Mar 9, 2022
900cdf2
Merge branch 'master' into 01Adam
conglongli Mar 9, 2022
a4a86f8
remove pytest.mark.sequential
conglongli Mar 9, 2022
ac08717
Merge branch '01Adam' of github.com:EugeneLYC/DeepSpeed-1 into 01Adam
conglongli Mar 9, 2022
a1e1615
fix the comm volume saving number with FP16
EugeneLYC Mar 10, 2022
93680e7
Using GPT results (FP16) to update comm volume saving number
EugeneLYC Mar 10, 2022
a2f273e
fail fast during tests
jeffra Mar 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,8 @@ jobs:
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
#TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -n 4 -m 'not sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -m 'sequential' unit/

nv-lightning-v100:
runs-on: [self-hosted, nvidia, torch18, v100]
Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ Remove until pypi issue is resolved: https://status.python.org/incidents/2jj696s
[![Downloads](https://pepy.tech/badge/deepspeed/month)](https://pepy.tech/project/deepspeed)
-->
## Latest News
* [2022/1/19] [DeepSpeed: Advancing MoE inference and training to power next-generation AI scale](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/)
* [2022/03/07] [Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam](https://www.deepspeed.ai/tutorials/zero-one-adam/)
* [2022/01/19] [DeepSpeed: Advancing MoE inference and training to power next-generation AI scale](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/)
* [Mixture of Experts (MoE) for NLG tutorial](https://www.deepspeed.ai/tutorials/mixture-of-experts-nlg/).
* [Mixture of Experts (MoE) Inference tutorial](https://www.deepspeed.ai/tutorials/moe-inference-tutorial).
* [2021/11/15] [Autotuning: Automatically discover the optimal DeepSpeed configuration that delivers good training speed](https://www.deepspeed.ai/news/2021/11/15/autotuning.html)
Expand All @@ -36,7 +37,7 @@ DeepSpeed delivers extreme-scale model training for everyone, from data scientis
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam/1-bit LAMB reduce communication volume by up to 5x while achieving similar convergence efficiency to Adam/LAMB, allowing for scaling to different types of GPU clusters and networks.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam, 0/1 Adam and 1-bit LAMB reduce communication volume by up to 26x while achieving similar convergence efficiency to Adam/LAMB, allowing for scaling to different types of GPU clusters and networks.

Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
Expand Down Expand Up @@ -130,9 +131,9 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* Memory- and compute-efficient sparse kernels
* Support 10x longer sequences than dense
* Flexible support to different sparse structures
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html) and [1-bit LAMB](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html), [0/1 Adam](https://www.deepspeed.ai/tutorials/zero-one-adam/) and [1-bit LAMB](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* Custom communication collective
* Up to 5x communication volume saving
* Up to 26x communication volume saving
* [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations)
* Smart Gradient Accumulation
* Communication/Computation Overlap
Expand Down Expand Up @@ -209,6 +210,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857).
7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069).
8. Conglong Li, Minjia Zhang, Yuxiong He. (2021) Curriculum Learning: A Regularization Method for Efficient and Stable Billion-Scale GPT Model Pre-Training. [arXiv:2108.06084](https://arxiv.org/abs/2108.06084).
9. Yucheng Lu, Conglong Li, Minjia Zhang, Christopher De Sa, Yuxiong He. (2022) Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam. [arXiv:2202.06009](https://arxiv.org/abs/2202.06009).

# Videos
1. DeepSpeed KDD 2020 Tutorial
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
ADAMW_OPTIMIZER = 'adamw'
LAMB_OPTIMIZER = 'lamb'
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ZERO_ONE_ADAM_OPTIMIZER = 'zerooneadam'
ONEBIT_LAMB_OPTIMIZER = 'onebitlamb'
DEEPSPEED_OPTIMIZERS = [
ADAGRAD_OPTIMIZER,
Expand All @@ -60,6 +61,7 @@
LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER,
ONEBIT_LAMB_OPTIMIZER,
ZERO_ONE_ADAM_OPTIMIZER
]

# extra optimizer parameters for adam/adamw
Expand Down
12 changes: 10 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER

from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
Expand Down Expand Up @@ -1169,6 +1169,14 @@ def _configure_basic_optimizer(self, model_parameters):
logger.warning(
f"Currently the convergence of 1-bit Adam is only verified under FP16"
)
elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER:
assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam

optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(
f'Currently the convergence of 0/1 Adam is only verified under FP16')
elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER:
assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb
Expand Down Expand Up @@ -1228,7 +1236,7 @@ def _configure_fp16_optimizer(self, optimizer):
else:
fused_opts = FusedAdam
if isinstance(optimizer, fused_opts) \
or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]:
if self.dynamic_loss_scale():
log_dist("Creating fp16 optimizer with dynamic loss scale", ranks=[0])
timers = self.timers if self.wall_clock_breakdown() else None
Expand Down
5 changes: 2 additions & 3 deletions deepspeed/runtime/fp16/onebit/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def step(self, closure=None, grads=None):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

if not self.initialize or (self.adam_freeze_key
and 'worker_error' not in state.keys()):
state['tensor_size'] = torch.numel(p.data)
state['corrected_tensor_size'] = state['tensor_size']

Expand All @@ -176,9 +178,6 @@ def step(self, closure=None, grads=None):
(self.size * self.divider)))
state['server_chunk_size'] = state[
'corrected_tensor_size'] // self.size

if not self.initialize or (self.adam_freeze_key
and 'worker_error' not in state.keys()):
torch.cuda.empty_cache()
state['worker_error'] = torch.zeros(state['corrected_tensor_size'],
device=p.device)
Expand Down
Loading