Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support torch ZeroRedundancyOptimizer #551

Merged
merged 25 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Optimizer
OptimWrapper
OptimWrapperDict
DefaultOptimWrapperConstructor
ZeroRedundancyOptimizer

.. autosummary::
:toctree: generated
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/api/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Optimizer
OptimWrapper
OptimWrapperDict
DefaultOptimWrapperConstructor
ZeroRedundancyOptimizer

.. autosummary::
:toctree: generated
Expand Down
9 changes: 6 additions & 3 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union

from mmengine.dist import master_only
from mmengine.dist import is_main_process
from mmengine.fileio import FileClient, get_file_backend
from mmengine.registry import HOOKS
from mmengine.utils import is_list_of, is_seq_of
Expand Down Expand Up @@ -309,7 +309,6 @@ def _get_metric_score(self, metrics, key_indicator):

return eval_res[key_indicator]

@master_only
def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.

Expand All @@ -331,6 +330,11 @@ def _save_checkpoint(self, runner) -> None:
backend_args=self.backend_args,
**self.args)

# Model parallel-like training should involve pulling sharded states
# from all ranks, but skip the following procedure.
if not is_main_process():
return

runner.message_hub.update_info(
'last_ckpt',
self.file_backend.join_path(self.out_dir, ckpt_filename))
Expand All @@ -357,7 +361,6 @@ def _save_checkpoint(self, runner) -> None:
with open(save_file, 'w') as f:
f.write(filepath)

@master_only
def _save_best_checkpoint(self, runner, metrics) -> None:
"""Save the current checkpoint and delete outdated checkpoint.

Expand Down
3 changes: 2 additions & 1 deletion mmengine/optim/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .default_constructor import DefaultOptimWrapperConstructor
from .optimizer_wrapper import OptimWrapper
from .optimizer_wrapper_dict import OptimWrapperDict
from .zero_optimizer import ZeroRedundancyOptimizer

__all__ = [
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
'AmpOptimWrapper', 'OptimWrapperDict'
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
]
66 changes: 66 additions & 0 deletions mmengine/optim/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
from torch.distributed.rpc import is_available

from mmengine.dist import is_main_process
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
from torch.distributed.optim import \
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
except ImportError:
_ZeroRedundancyOptimizer = object

from .builder import OPTIMIZERS


@OPTIMIZERS.register_module()
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer):
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a
optimizer type as string.

This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its
states across ranks in the group as described by ZeRO_. The local optimizer
instance in each rank is only responsible for updating approximately
``1 / world_size`` parameters and hence only needs to keep
``1 / world_size`` optimizer states. After parameters are updated locally,
each rank will broadcast its parameters to all other peers to keep all
model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used
in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to
reduce per-rank peak memory consumption.

``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
of parameters at each rank. Each parameter belongs to a single rank and is
not divided among ranks. The partition is arbitrary and might not match the
the parameter registration or usage order.

Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.

Args:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
or :class:`dict` s giving all parameters, which will be sharded
across ranks.
optimizer_type (str): the string of the local optimizer class.

.. _ZeRO: https://arxiv.org/abs/1910.02054
"""

def __init__(self, params, optimizer_type: str, **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), (
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8.')
assert is_available(), 'torch.distributed.rpc is not available.'
optimizer_class = getattr(torch.optim, optimizer_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it support custom Optimizer classes?

Copy link
Contributor Author

@nijkah nijkah Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still figuring it out now. Until now, it does not seem to have a specific dependency on torch's optimizers. It may be possible to custom Optimizer classes.

# TODO: Register a DDP communication hook for `overlap_with_ddp=True`.
# Currently only `overlap_with_ddp=False` is supported. For more
# details, please refer to the pytorch's official documentation.
super().__init__(params, optimizer_class, **kwargs)

def state_dict(self):
"""Consolidate `state_dict`s from ranks to save the `state_dict`."""
self.consolidate_state_dict()
state_dict = super().state_dict() if is_main_process() else dict()
Copy link
Contributor Author

@nijkah nijkah Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state_dict['loss_scaler'] = self.loss_scaler.state_dict()

Due to this line, using ZeroRedundancyOptimizer with AmpOptimWrapper gave the error like

TypeError: 'NoneType' object does not support item assignment in <mmengine.hooks.checkpoint_hook.CheckpointHook object at XXXXXX>

So I modified it to return dict() instead of None when it is not the main process.

return state_dict
81 changes: 80 additions & 1 deletion tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import sys
import unittest
from unittest import TestCase
from unittest.mock import MagicMock

import torch
import torch.nn as nn
from torch.distributed.rpc import is_available

from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
from mmengine.utils.version_utils import digit_version

MMCV_FULL_AVAILABLE = mmcv_full_available()
if not MMCV_FULL_AVAILABLE:
Expand Down Expand Up @@ -713,3 +718,77 @@ def test_default_optimizer_constructor_custom_key(self):
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'


@unittest.skipIf(
(digit_version(TORCH_VERSION) < digit_version('1.8.0'))
or not is_available(),
reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.')
class TestZeroOptimizer(MultiProcessTestCase):

def setUp(self) -> None:
super().setUp()
self._spawn_processes()

def _check_default_optimizer(self, optimizer, model):
self.assertIsInstance(optimizer.optim, torch.optim.SGD)
self.assertEqual(optimizer.defaults['lr'], self.base_lr)
self.assertEqual(optimizer.defaults['momentum'], self.momentum)
self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd)
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
self.assertEqual(len(param_groups['params']), len(param_names))
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[param_names[i]])

def test_build_zero_redundancy_optimizer(self):
from torch.distributed.optim import ZeroRedundancyOptimizer
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9

# test build function
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
optimizer_type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(optim_wrapper.optimizer, model)

# test build optimizer without ``optimizer_type``
with self.assertRaises(TypeError):
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29510'
os.environ['RANK'] = str(rank)
torch.distributed.init_process_group(
backend='gloo', rank=rank, world_size=world_size)