-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support torch ZeroRedundancyOptimizer (#551)
* [Feature] Support torch ZeRORedundancyOptimizer Co-authored-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Hakjin Lee <nijkah@gmail.com> * lint * Fix saving optimizer state_dict * Fix handling import error * Add test case * fix UT * Revert "fix UT" This reverts commit dd64538. * fix handling import in UT * Fix saving zero checkpoint and delete redundant master_only * lint * test unittest * Fix handling impor error * Fix UT condition * Edit docstrings * Fix typo * Skip redundant procudure in checkpoint hook * fix typo again * Update mmengine/optim/optimizer/zero_optimizer.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Add api info * lint * Fix lint * Handling AmpOptimWrapper case * handling overlap_with_ddp * Fix error Signed-off-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Hakjin Lee <nijkah@gmail.com> Co-authored-by: Junhwa Song <ethan9867@gmail.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
- Loading branch information
1 parent
bf369da
commit 0857f9f
Showing
6 changed files
with
156 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import torch | ||
from torch.distributed.rpc import is_available | ||
|
||
from mmengine.dist import is_main_process | ||
from mmengine.utils import digit_version | ||
from mmengine.utils.dl_utils import TORCH_VERSION | ||
|
||
try: | ||
from torch.distributed.optim import \ | ||
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer | ||
except ImportError: | ||
_ZeroRedundancyOptimizer = object | ||
|
||
from .builder import OPTIMIZERS | ||
|
||
|
||
@OPTIMIZERS.register_module() | ||
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer): | ||
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a | ||
optimizer type as string. | ||
This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its | ||
states across ranks in the group as described by ZeRO_. The local optimizer | ||
instance in each rank is only responsible for updating approximately | ||
``1 / world_size`` parameters and hence only needs to keep | ||
``1 / world_size`` optimizer states. After parameters are updated locally, | ||
each rank will broadcast its parameters to all other peers to keep all | ||
model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used | ||
in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to | ||
reduce per-rank peak memory consumption. | ||
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number | ||
of parameters at each rank. Each parameter belongs to a single rank and is | ||
not divided among ranks. The partition is arbitrary and might not match the | ||
the parameter registration or usage order. | ||
Warnings: | ||
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8. | ||
Args: | ||
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s | ||
or :class:`dict` s giving all parameters, which will be sharded | ||
across ranks. | ||
optimizer_type (str): the string of the local optimizer class. | ||
.. _ZeRO: https://arxiv.org/abs/1910.02054 | ||
""" | ||
|
||
def __init__(self, params, optimizer_type: str, **kwargs): | ||
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), ( | ||
'`torch.distributed.optim.ZeroReundancyOptimizer` is only ' | ||
'available when pytorch version >= 1.8.') | ||
assert is_available(), 'torch.distributed.rpc is not available.' | ||
optimizer_class = getattr(torch.optim, optimizer_type) | ||
# TODO: Register a DDP communication hook for `overlap_with_ddp=True`. | ||
# Currently only `overlap_with_ddp=False` is supported. For more | ||
# details, please refer to the pytorch's official documentation. | ||
super().__init__(params, optimizer_class, **kwargs) | ||
|
||
def state_dict(self): | ||
"""Consolidate `state_dict`s from ranks to save the `state_dict`.""" | ||
self.consolidate_state_dict() | ||
state_dict = super().state_dict() if is_main_process() else dict() | ||
return state_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters