Skip to content

Commit

Permalink
[Feature] Support torch ZeroRedundancyOptimizer (#551)
Browse files Browse the repository at this point in the history
* [Feature] Support torch ZeRORedundancyOptimizer

Co-authored-by: Junhwa Song <ethan9867@gmail.com>
Signed-off-by: Junhwa Song <ethan9867@gmail.com>
Signed-off-by: Hakjin Lee <nijkah@gmail.com>

* lint

* Fix saving optimizer state_dict

* Fix handling import error

* Add test case

* fix UT

* Revert "fix UT"

This reverts commit dd64538.

* fix handling import in UT

* Fix saving zero checkpoint and delete redundant master_only

* lint

* test unittest

* Fix handling impor error

* Fix UT condition

* Edit docstrings

* Fix typo

* Skip redundant procudure in checkpoint hook

* fix typo again

* Update mmengine/optim/optimizer/zero_optimizer.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Add api info

* lint

* Fix lint

* Handling AmpOptimWrapper case

* handling overlap_with_ddp

* Fix error

Signed-off-by: Junhwa Song <ethan9867@gmail.com>
Signed-off-by: Hakjin Lee <nijkah@gmail.com>
Co-authored-by: Junhwa Song <ethan9867@gmail.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 27, 2022
1 parent bf369da commit 0857f9f
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/en/api/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Optimizer
OptimWrapper
OptimWrapperDict
DefaultOptimWrapperConstructor
ZeroRedundancyOptimizer

.. autosummary::
:toctree: generated
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/api/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Optimizer
OptimWrapper
OptimWrapperDict
DefaultOptimWrapperConstructor
ZeroRedundancyOptimizer

.. autosummary::
:toctree: generated
Expand Down
9 changes: 6 additions & 3 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union

from mmengine.dist import master_only
from mmengine.dist import is_main_process
from mmengine.fileio import FileClient, get_file_backend
from mmengine.registry import HOOKS
from mmengine.utils import is_list_of, is_seq_of
Expand Down Expand Up @@ -309,7 +309,6 @@ def _get_metric_score(self, metrics, key_indicator):

return eval_res[key_indicator]

@master_only
def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Expand All @@ -331,6 +330,11 @@ def _save_checkpoint(self, runner) -> None:
backend_args=self.backend_args,
**self.args)

# Model parallel-like training should involve pulling sharded states
# from all ranks, but skip the following procedure.
if not is_main_process():
return

runner.message_hub.update_info(
'last_ckpt',
self.file_backend.join_path(self.out_dir, ckpt_filename))
Expand All @@ -357,7 +361,6 @@ def _save_checkpoint(self, runner) -> None:
with open(save_file, 'w') as f:
f.write(filepath)

@master_only
def _save_best_checkpoint(self, runner, metrics) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Expand Down
3 changes: 2 additions & 1 deletion mmengine/optim/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .default_constructor import DefaultOptimWrapperConstructor
from .optimizer_wrapper import OptimWrapper
from .optimizer_wrapper_dict import OptimWrapperDict
from .zero_optimizer import ZeroRedundancyOptimizer

__all__ = [
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
'AmpOptimWrapper', 'OptimWrapperDict'
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
]
66 changes: 66 additions & 0 deletions mmengine/optim/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
from torch.distributed.rpc import is_available

from mmengine.dist import is_main_process
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
from torch.distributed.optim import \
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
except ImportError:
_ZeroRedundancyOptimizer = object

from .builder import OPTIMIZERS


@OPTIMIZERS.register_module()
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer):
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a
optimizer type as string.
This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its
states across ranks in the group as described by ZeRO_. The local optimizer
instance in each rank is only responsible for updating approximately
``1 / world_size`` parameters and hence only needs to keep
``1 / world_size`` optimizer states. After parameters are updated locally,
each rank will broadcast its parameters to all other peers to keep all
model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used
in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to
reduce per-rank peak memory consumption.
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
of parameters at each rank. Each parameter belongs to a single rank and is
not divided among ranks. The partition is arbitrary and might not match the
the parameter registration or usage order.
Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.
Args:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
or :class:`dict` s giving all parameters, which will be sharded
across ranks.
optimizer_type (str): the string of the local optimizer class.
.. _ZeRO: https://arxiv.org/abs/1910.02054
"""

def __init__(self, params, optimizer_type: str, **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), (
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8.')
assert is_available(), 'torch.distributed.rpc is not available.'
optimizer_class = getattr(torch.optim, optimizer_type)
# TODO: Register a DDP communication hook for `overlap_with_ddp=True`.
# Currently only `overlap_with_ddp=False` is supported. For more
# details, please refer to the pytorch's official documentation.
super().__init__(params, optimizer_class, **kwargs)

def state_dict(self):
"""Consolidate `state_dict`s from ranks to save the `state_dict`."""
self.consolidate_state_dict()
state_dict = super().state_dict() if is_main_process() else dict()
return state_dict
81 changes: 80 additions & 1 deletion tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import sys
import unittest
from unittest import TestCase
from unittest.mock import MagicMock

import torch
import torch.nn as nn
from torch.distributed.rpc import is_available

from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
from mmengine.utils.version_utils import digit_version

MMCV_FULL_AVAILABLE = mmcv_full_available()
if not MMCV_FULL_AVAILABLE:
Expand Down Expand Up @@ -713,3 +718,77 @@ def test_default_optimizer_constructor_custom_key(self):
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'


@unittest.skipIf(
(digit_version(TORCH_VERSION) < digit_version('1.8.0'))
or not is_available(),
reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.')
class TestZeroOptimizer(MultiProcessTestCase):

def setUp(self) -> None:
super().setUp()
self._spawn_processes()

def _check_default_optimizer(self, optimizer, model):
self.assertIsInstance(optimizer.optim, torch.optim.SGD)
self.assertEqual(optimizer.defaults['lr'], self.base_lr)
self.assertEqual(optimizer.defaults['momentum'], self.momentum)
self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd)
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
self.assertEqual(len(param_groups['params']), len(param_names))
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[param_names[i]])

def test_build_zero_redundancy_optimizer(self):
from torch.distributed.optim import ZeroRedundancyOptimizer
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9

# test build function
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
optimizer_type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(optim_wrapper.optimizer, model)

# test build optimizer without ``optimizer_type``
with self.assertRaises(TypeError):
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29510'
os.environ['RANK'] = str(rank)
torch.distributed.init_process_group(
backend='gloo', rank=rank, world_size=world_size)

0 comments on commit 0857f9f

Please sign in to comment.