-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support torch ZeroRedundancyOptimizer #551
Merged
Merged
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
d466fb8
[Feature] Support torch ZeRORedundancyOptimizer
nijkah d957bdc
lint
nijkah 67bcec8
Fix saving optimizer state_dict
nijkah da39cd2
Fix handling import error
nijkah a408f69
Add test case
nijkah dd64538
fix UT
nijkah 48806af
Revert "fix UT"
nijkah 00ed3b3
fix handling import in UT
nijkah dd5986c
Fix saving zero checkpoint and delete redundant master_only
nijkah 764133d
lint
nijkah f9a6dad
test unittest
nijkah 1c1da2e
Fix handling impor error
nijkah 0275d08
Fix UT condition
nijkah 831e5c7
Edit docstrings
nijkah 543f34e
Fix typo
nijkah 715f1a5
Skip redundant procudure in checkpoint hook
nijkah 4933933
fix typo again
nijkah 2b7417c
Merge remote-tracking branch 'origin/main' into zero_1_optimizer
nijkah 9d63ea6
Update mmengine/optim/optimizer/zero_optimizer.py
nijkah e244e30
Add api info
nijkah 41dff5a
lint
nijkah 929cf06
Fix lint
nijkah d6d0f45
Handling AmpOptimWrapper case
nijkah f0ac4cb
handling overlap_with_ddp
nijkah adee57c
Fix error
nijkah File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,63 @@ | ||||
# Copyright (c) OpenMMLab. All rights reserved. | ||||
|
||||
import torch | ||||
from torch.distributed.rpc import is_available | ||||
|
||||
from mmengine.dist import is_main_process | ||||
from mmengine.utils import digit_version | ||||
from mmengine.utils.dl_utils import TORCH_VERSION | ||||
|
||||
try: | ||||
from torch.distributed.optim import \ | ||||
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer | ||||
except ImportError: | ||||
_ZeroRedundancyOptimizer = object | ||||
|
||||
from .builder import OPTIMIZERS | ||||
|
||||
|
||||
@OPTIMIZERS.register_module() | ||||
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer): | ||||
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a | ||||
optimizer type as string. | ||||
|
||||
This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its | ||||
states across ranks in the group as described by ZeRO_. The local optimizer | ||||
instance in each rank is only responsible for updating approximately | ||||
``1 / world_size`` parameters and hence only needs to keep | ||||
``1 / world_size`` optimizer states. After parameters are updated locally, | ||||
each rank will broadcast its parameters to all other peers to keep all | ||||
model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used | ||||
in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to | ||||
reduce per-rank peak memory consumption. | ||||
|
||||
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number | ||||
of parameters at each rank. Each parameter belongs to a single rank and is | ||||
not divided among ranks. The partition is arbitrary and might not match the | ||||
the parameter registration or usage order. | ||||
|
||||
Warnings: | ||||
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8. | ||||
|
||||
Args: | ||||
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s | ||||
or :class:`dict` s giving all parameters, which will be sharded | ||||
across ranks. | ||||
optimizer_type (str): the string of the local optimizer class. | ||||
|
||||
.. _ZeRO: https://arxiv.org/abs/1910.02054 | ||||
""" | ||||
|
||||
def __init__(self, params, optimizer_type: str, **kwargs): | ||||
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), ( | ||||
'`torch.distributed.optim.ZeroReundancyOptimizer` is only ' | ||||
'available when pytorch version >= 1.8.') | ||||
assert is_available(), 'torch.distributed.rpc is not available.' | ||||
optimizer_class = getattr(torch.optim, optimizer_type) | ||||
super().__init__(params, optimizer_class, **kwargs) | ||||
|
||||
def state_dict(self): | ||||
"""Consolidate `state_dict`s from ranks to save the `state_dict`.""" | ||||
self.consolidate_state_dict() | ||||
state_dict = super().state_dict() if is_main_process() else dict() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Due to this line, using ZeroRedundancyOptimizer with AmpOptimWrapper gave the error like
So I modified it to return |
||||
return state_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it support custom Optimizer classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still figuring it out now. Until now, it does not seem to have a specific dependency on torch's optimizers. It may be possible to custom Optimizer classes.