Skip to content

Commit

Permalink
Merge dd64538 into 36af1f0
Browse files Browse the repository at this point in the history
  • Loading branch information
nijkah committed Sep 27, 2022
2 parents 36af1f0 + dd64538 commit 7e91b88
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mmengine/optim/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .default_constructor import DefaultOptimWrapperConstructor
from .optimizer_wrapper import OptimWrapper
from .optimizer_wrapper_dict import OptimWrapperDict
from .zero_optimizer import ZeroRedundancyOptimizer

__all__ = [
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
'AmpOptimWrapper', 'OptimWrapperDict'
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
]
55 changes: 55 additions & 0 deletions mmengine/optim/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
from torch.distributed.optim import \
ZeroRedundancyOptimizer as _ZeroReundancyOptimizer
except ImportError:
_ZeroReundancyOptimizer = object

from .builder import OPTIMIZERS


@OPTIMIZERS.register_module()
class ZeroRedundancyOptimizer(_ZeroReundancyOptimizer):
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a
optimizer type as string. This class wraps an arbitrary
:class:`optim.Optimizer.
<torch.optim.Optimizer>` and shards its states across ranks in the group as
described by ZeRO_. The local optimizer instance in each rank is only
responsible for updating approximately ``1 / world_size`` parameters and
hence only needs to keep ``1 / world_size`` optimizer states. After
parameters are updated locally, each rank will broadcast its parameters to
all other peers to keep all model replicas in the same state.
``ZeroRedundancyOptimizer`` can be used in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
memory consumption.
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
of parameters at each rank. Each parameter belongs to a single rank and is
not divided among ranks. The partition is arbitrary and might not match the
the parameter registration or usage order.
Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.
Args:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
or :class:`dict` s giving all parameters, which will be sharded
across ranks.
optimizer_type (str): the string of the local optimizer class.
"""

def __init__(self, params, optimizer_type: str, **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), (
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8')
optimizer_class = getattr(torch.optim, optimizer_type)
super().__init__(params, optimizer_class, **kwargs)

def state_dict(self):
"""Consolidate `state_dict`s from ranks to save the `state_dict`"""
self.consolidate_state_dict()
return super().state_dict()
74 changes: 73 additions & 1 deletion tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import sys
import unittest
from unittest import TestCase
from unittest.mock import MagicMock

import torch
import torch.nn as nn

from mmengine.model import MMDistributedDataParallel
from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
from mmengine.utils.version_utils import digit_version

MMCV_FULL_AVAILABLE = mmcv_full_available()
if not MMCV_FULL_AVAILABLE:
sys.modules['mmcv.ops'] = MagicMock(
DeformConv2d=dict, ModulatedDeformConv2d=dict)

if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
from mmengine.optim.optimizer import ZeroRedundancyOptimizer # noqa: F401


class ExampleModel(nn.Module):

Expand Down Expand Up @@ -713,3 +721,67 @@ def test_default_optimizer_constructor_custom_key(self):
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'


@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.8.0'),
reason='ZeRO needs Pytorch 1.8 or higher')
class TestZeroOptimizer(MultiProcessTestCase):

def setUp(self) -> None:
super().setUp()
self._spawn_processes()

def _check_default_optimizer(self, optimizer, model, prefix=''):
assert isinstance(optimizer.optim, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
assert len(param_groups['params']) == len(param_names)
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[prefix + param_names[i]])

def test_build_zero_redundancy_optimizer(self):
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
model = MMDistributedDataParallel(module=model)
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9

# test build function
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
optimizer_type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(
optim_wrapper.optimizer, model, prefix='module.')

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29510'
os.environ['RANK'] = str(rank)
torch.distributed.init_process_group(
backend='gloo', rank=rank, world_size=world_size)

0 comments on commit 7e91b88

Please sign in to comment.