Skip to content

Commit

Permalink
ZeRO 3 Offload (#834)
Browse files Browse the repository at this point in the history
* Squash stage3 v1 (#146)

Co-authored-by: Samyam <samyamr@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Shaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: eltonzheng <eltonz@microsoft.com>

* Fix correctness bug (#147)

* formatting fix (#150)

* stage3 bugfix (API) update and simplified FP16 Z3 tests (#151)

* fp16 Z3 API update and bugfix

* revert debug change

* ZeRO-3 detach and race condition bugfixes (#149)

* trying out ZeRO-3 race condition fix

* CUDA sync instead of stream

* reduction stream sync

* remove commented code

* Fix optimizer state_dict KeyError (#148)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* fix for smaller SGS sizes, ensures each grad is backed by unique tensors (#152)

* Simplifying the logic for getting averaged gradients (#153)

* skip for now

* Z3 Docs redux (#154)

* removing some TODOs and commented code (#155)

* New Z3 defaults (#156)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* formatting

* megatron external params

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Shaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: eltonzheng <eltonz@microsoft.com>
  • Loading branch information
6 people committed Mar 8, 2021
1 parent ba33e86 commit 599258f
Show file tree
Hide file tree
Showing 41 changed files with 5,747 additions and 321 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ jobs:
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose tests/unit/
2 changes: 2 additions & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .utils import log_dist
from .utils.distributed import init_distributed

from .runtime import zero

from .pipe import PipelineModule

from .git_version_info import version, git_hash, git_branch
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def main(args=None):
# encode world info as base64 to make it easier to pass via command line
world_info_base64 = encode_world_info(active_resources)

multi_node_exec = len(active_resources) > 1
multi_node_exec = True # len(active_resources) > 1

if multi_node_exec and not shutil.which('pdsh'):
raise RuntimeError("pdsh is not installed, unable to proceed")
Expand Down
96 changes: 60 additions & 36 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,6 @@


class DeepSpeedCPUAdam(torch.optim.Optimizer):
"""Fast vectorized implementation of two variations of Adam optimizer on CPU:
- Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
- AdamW: FIXING WEIGHT DECAY REGULARIZATION IN ADAM (https://arxiv.org/abs/1711.05101v1)
DeepSpeed CPU Adam(W) provides between 5x to 7x speedu over torch.optim.adam(W).
In order to apply this optimizer, the model requires to have its master parameter (in FP32)
reside on the CPU memory.
To train on a hetrogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
with minimal impact on training througput. DeepSpeedCPUAdam plays an important role to minimize
the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
(https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
For calling step function, there are two options available: (1) update optimizer's states and (2) update
optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
option can bring 30% higher throughput than the doing the copy separately using option one.
Arguments:
model_params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
adamw_mode: select between Adam and AdamW implementations (default: AdamW)
"""

optimizer_id = 0

def __init__(self,
Expand All @@ -57,6 +22,47 @@ def __init__(self,
weight_decay=0,
amsgrad=False,
adamw_mode=True):
"""Fast vectorized implementation of two variations of Adam optimizer on CPU:
* Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
* AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101)
DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W).
In order to apply this optimizer, the model requires to have its master parameter (in FP32)
reside on the CPU memory.
To train on a hetrogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
with minimal impact on training througput. DeepSpeedCPUAdam plays an important role to minimize
the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
(https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
For calling step function, there are two options available: (1) update optimizer's states and (2) update
optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
option can bring 30% higher throughput than the doing the copy separately using option one.
.. note::
We recommend using our `config
<https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
to allow :meth:`deepspeed.initialize` to build this optimizer
for you.
Arguments:
model_params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
adamw_mode: select between Adam and AdamW implementations (default: AdamW)
"""

default_args = dict(lr=lr,
betas=betas,
Expand Down Expand Up @@ -86,6 +92,24 @@ def __setstate__(self, state):

@torch.no_grad()
def step(self, closure=None, fp16_param_groups=None):
"""Update the model parameters.
.. note::
This method will be called internally by ZeRO-Offload. DeepSpeed
users should still use ``engine.step()`` as shown in the
`Getting Started
<https://www.deepspeed.ai/getting-started/#training>`_ guide.
Args:
closure (callable, optional): closure to compute the loss.
Defaults to ``None``.
fp16_param_groups: FP16 GPU parameters to update. Performing the
copy here reduces communication time. Defaults to ``None``.
Returns:
loss: if ``closure`` is provided. Otherwise ``None``.
"""

loss = None
if closure is not None:
with torch.enable_grad():
Expand All @@ -100,7 +124,7 @@ def step(self, closure=None, fp16_param_groups=None):
state = self.state[p]
# State initialization
if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}')
#print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data,
Expand Down
Loading

0 comments on commit 599258f

Please sign in to comment.