Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] ZeroRedundancyOptimizer ambiguous error with param groups when pytorch < 1.12.0 #818

Merged
merged 7 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions mmengine/optim/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer):
Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.

Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param
groups.

Args:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
or :class:`dict` s giving all parameters, which will be sharded
Expand All @@ -53,6 +57,15 @@ def __init__(self, params, optimizer_type: str, **kwargs):
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8.')
assert is_available(), 'torch.distributed.rpc is not available.'
# Avoid the generator becoming empty after the following check
params = list(params)
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
assert (
all(isinstance(p, torch.Tensor) for p in params)
or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), (
'PyTorch ZeroRedundancyOptimizer started to support param '
'groups since 1.12.0. Please update your pytorch version to '
'enable this feature, or disable param groups by deleting '
'`paramwise_cfg` filed in config file.')
optimizer_class = getattr(torch.optim, optimizer_type)
# TODO: Register a DDP communication hook for `overlap_with_ddp=True`.
# Currently only `overlap_with_ddp=False` is supported. For more
Expand Down
63 changes: 42 additions & 21 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,28 +735,23 @@ def _check_default_optimizer(self, optimizer, model):
self.assertEqual(optimizer.defaults['lr'], self.base_lr)
self.assertEqual(optimizer.defaults['momentum'], self.momentum)
self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd)
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
param_groups = optimizer.param_groups
params_set = set(model.parameters())
self.assertEqual(
sum(len(param_group['params']) for param_group in param_groups),
len(params_set))
self.assertTrue(
all(param in params_set for param_group in param_groups
for param in param_group['params']))
state_dict = optimizer.state_dict()
if torch.distributed.get_rank() == 0:
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(
sum(len(pg['params']) for pg in state_dict['param_groups']),
len(params_set))
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
self.assertEqual(len(param_groups['params']), len(param_names))
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[param_names[i]])
self.assertEqual(state_dict, {})

def test_build_zero_redundancy_optimizer(self):
from torch.distributed.optim import ZeroRedundancyOptimizer
def test_zero_redundancy_optimizer(self):
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
self.base_lr = 0.01
Expand All @@ -772,7 +767,6 @@ def test_build_zero_redundancy_optimizer(self):
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(optim_wrapper.optimizer, model)

# test build optimizer without ``optimizer_type``
Expand All @@ -785,6 +779,33 @@ def test_build_zero_redundancy_optimizer(self):
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)

@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.12.0'),
reason='ZeRO started to support param groups since pytorch 1.12.0')
def test_zero_redundancy_optimizer_with_paramwise_cfg(self):
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9

# test build function
paramwise_cfg = dict(
custom_keys={
'conv1': dict(lr_mult=0.0, decay_mult=0.0),
'conv2': dict(lr_mult=1.0, decay_mult=2.0)
})
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
optimizer_type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum),
paramwise_cfg=paramwise_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self._check_default_optimizer(optim_wrapper.optimizer, model)

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
Expand Down