Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support torch ZeroRedundancyOptimizer #551

Merged
merged 25 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmengine/optim/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .default_constructor import DefaultOptimWrapperConstructor
from .optimizer_wrapper import OptimWrapper
from .optimizer_wrapper_dict import OptimWrapperDict
from .zero_optimizer import ZeroRedundancyOptimizer

__all__ = [
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
'AmpOptimWrapper', 'OptimWrapperDict'
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
]
55 changes: 55 additions & 0 deletions mmengine/optim/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
from torch.distributed.optim import \
ZeroRedundancyOptimizer as _ZeroReundancyOptimizer
nijkah marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
_ZeroReundancyOptimizer = object

from .builder import OPTIMIZERS


@OPTIMIZERS.register_module()
class ZeroRedundancyOptimizer(_ZeroReundancyOptimizer):
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a
optimizer type as string. This class wraps an arbitrary
:class:`optim.Optimizer.

<torch.optim.Optimizer>` and shards its states across ranks in the group as
described by ZeRO_. The local optimizer instance in each rank is only
responsible for updating approximately ``1 / world_size`` parameters and
hence only needs to keep ``1 / world_size`` optimizer states. After
parameters are updated locally, each rank will broadcast its parameters to
all other peers to keep all model replicas in the same state.
``ZeroRedundancyOptimizer`` can be used in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
memory consumption.
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
of parameters at each rank. Each parameter belongs to a single rank and is
not divided among ranks. The partition is arbitrary and might not match the
the parameter registration or usage order.
Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.
Args:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
or :class:`dict` s giving all parameters, which will be sharded
across ranks.
optimizer_type (str): the string of the local optimizer class.
"""
nijkah marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, params, optimizer_type: str, **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), (
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8')
nijkah marked this conversation as resolved.
Show resolved Hide resolved
optimizer_class = getattr(torch.optim, optimizer_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it support custom Optimizer classes?

Copy link
Contributor Author

@nijkah nijkah Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still figuring it out now. Until now, it does not seem to have a specific dependency on torch's optimizers. It may be possible to custom Optimizer classes.

super().__init__(params, optimizer_class, **kwargs)

def state_dict(self):
"""Consolidate `state_dict`s from ranks to save the `state_dict`"""
nijkah marked this conversation as resolved.
Show resolved Hide resolved
self.consolidate_state_dict()
return super().state_dict()
86 changes: 85 additions & 1 deletion tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import sys
import unittest
from unittest import TestCase
from unittest.mock import MagicMock

Expand All @@ -11,13 +13,20 @@
build_optim_wrapper)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
from mmengine.utils.version_utils import digit_version

MMCV_FULL_AVAILABLE = mmcv_full_available()
if not MMCV_FULL_AVAILABLE:
sys.modules['mmcv.ops'] = MagicMock(
DeformConv2d=dict, ModulatedDeformConv2d=dict)

try:
from torch.distributed.optim import ZeroRedundancyOptimizer
except ImportError:
ZeroRedundancyOptimizer = None


class ExampleModel(nn.Module):

Expand Down Expand Up @@ -713,3 +722,78 @@ def test_default_optimizer_constructor_custom_key(self):
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'


@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.8.0'),
reason='ZeRO needs Pytorch 1.8 or higher')
class TestZeroOptimizer(MultiProcessTestCase):

def setUp(self) -> None:
if ZeroRedundancyOptimizer is None:
self.skipTest('ZeroRedundancyOptimizer is not available.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ZeroRedundancyOptimizer is None:
self.skipTest('ZeroRedundancyOptimizer is not available.')
if ZeroRedundancyOptimizer is None:
self.skipTest('ZeroRedundancyOptimizer is not available.')

Is this line duplicated with

@unittest.skipIf(
    digit_version(TORCH_VERSION) < digit_version('1.8.0'),
    reason='ZeRO needs Pytorch 1.8 or higher')

Copy link
Contributor Author

@nijkah nijkah Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/open-mmlab/mmengine/actions/runs/3134972777/jobs/5090129146#step:8:132
I found that importing ZeroRedundancyOptimizer failed in the Windows CPU CI with & torch1.8.1.
(The importing failure made _ZeroRedundancyOptimizer as object.)
So I added duplicated skip code.

I'll check it again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that it has another condition.
torch.distributed.rpc should be available. I removed duplicated lines, and clarified this condition.


super().setUp()
self._spawn_processes()

def _check_default_optimizer(self, optimizer, model, prefix=''):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(optimizer.optim, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
assert len(param_groups['params']) == len(param_names)
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[prefix + param_names[i]])

def test_build_zero_redundancy_optimizer(self):
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9

# test build function
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
optimizer_type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(optim_wrapper.optimizer, model)

# test build optimizer without ``optimizer_type``
with self.assertRaises(TypeError):
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29510'
os.environ['RANK'] = str(rank)
torch.distributed.init_process_group(
backend='gloo', rank=rank, world_size=world_size)