Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support torch ZeroRedundancyOptimizer #551

Merged
merged 25 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union

from mmengine.dist import master_only
from mmengine.dist import is_main_process
from mmengine.fileio import FileClient, get_file_backend
from mmengine.registry import HOOKS
from mmengine.utils import is_list_of, is_seq_of
Expand Down Expand Up @@ -309,7 +309,6 @@ def _get_metric_score(self, metrics, key_indicator):

return eval_res[key_indicator]

@master_only
def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.

Expand All @@ -331,6 +330,11 @@ def _save_checkpoint(self, runner) -> None:
backend_args=self.backend_args,
**self.args)

# Model parallel-like training should involve pulling sharded states
# from all ranks, but skip the following procedure.
if not is_main_process():
return

runner.message_hub.update_info(
'last_ckpt',
self.file_backend.join_path(self.out_dir, ckpt_filename))
Expand All @@ -357,7 +361,6 @@ def _save_checkpoint(self, runner) -> None:
with open(save_file, 'w') as f:
f.write(filepath)

@master_only
def _save_best_checkpoint(self, runner, metrics) -> None:
"""Save the current checkpoint and delete outdated checkpoint.

Expand Down
3 changes: 2 additions & 1 deletion mmengine/optim/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .default_constructor import DefaultOptimWrapperConstructor
from .optimizer_wrapper import OptimWrapper
from .optimizer_wrapper_dict import OptimWrapperDict
from .zero_optimizer import ZeroRedundancyOptimizer

__all__ = [
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
'AmpOptimWrapper', 'OptimWrapperDict'
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
]
59 changes: 59 additions & 0 deletions mmengine/optim/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
from torch.distributed.rpc import is_available

from mmengine.dist import is_main_process
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
from torch.distributed.optim import \
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
except ImportError:
_ZeroRedundancyOptimizer = object

from .builder import OPTIMIZERS


@OPTIMIZERS.register_module()
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer):
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a
optimizer type as string. This class wraps an arbitrary
:class:`optim.Optimizer.

<torch.optim.Optimizer>` and shards its states across ranks in the group as
described by ZeRO_. The local optimizer instance in each rank is only
responsible for updating approximately ``1 / world_size`` parameters and
hence only needs to keep ``1 / world_size`` optimizer states. After
parameters are updated locally, each rank will broadcast its parameters to
all other peers to keep all model replicas in the same state.
``ZeroRedundancyOptimizer`` can be used in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
memory consumption.
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
of parameters at each rank. Each parameter belongs to a single rank and is
not divided among ranks. The partition is arbitrary and might not match the
the parameter registration or usage order.
Warnings:
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.
Args:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
or :class:`dict` s giving all parameters, which will be sharded
across ranks.
optimizer_type (str): the string of the local optimizer class.
"""
nijkah marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, params, optimizer_type: str, **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), (
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8.')
assert is_available(), 'torch.distributed.rpc is not available.'
optimizer_class = getattr(torch.optim, optimizer_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it support custom Optimizer classes?

Copy link
Contributor Author

@nijkah nijkah Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still figuring it out now. Until now, it does not seem to have a specific dependency on torch's optimizers. It may be possible to custom Optimizer classes.

super().__init__(params, optimizer_class, **kwargs)

def state_dict(self):
"""Consolidate `state_dict`s from ranks to save the `state_dict`."""
self.consolidate_state_dict()
if is_main_process():
return super().state_dict()
81 changes: 80 additions & 1 deletion tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import sys
import unittest
from unittest import TestCase
from unittest.mock import MagicMock

import torch
import torch.nn as nn
from torch.distributed.rpc import is_available

from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
from mmengine.utils.version_utils import digit_version

MMCV_FULL_AVAILABLE = mmcv_full_available()
if not MMCV_FULL_AVAILABLE:
Expand Down Expand Up @@ -713,3 +718,77 @@ def test_default_optimizer_constructor_custom_key(self):
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'


@unittest.skipIf(
(digit_version(TORCH_VERSION) < digit_version('1.8.0'))
or not is_available(),
reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.')
class TestZeroOptimizer(MultiProcessTestCase):

def setUp(self) -> None:
super().setUp()
self._spawn_processes()

def _check_default_optimizer(self, optimizer, model):
self.assertIsInstance(optimizer.optim, torch.optim.SGD)
self.assertEqual(optimizer.defaults['lr'], self.base_lr)
self.assertEqual(optimizer.defaults['momentum'], self.momentum)
self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd)
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
self.assertEqual(len(param_groups['params']), len(param_names))
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[param_names[i]])

def test_build_zero_redundancy_optimizer(self):
from torch.distributed.optim import ZeroRedundancyOptimizer
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9

# test build function
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
optimizer_type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(optim_wrapper.optimizer, model)

# test build optimizer without ``optimizer_type``
with self.assertRaises(TypeError):
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29510'
os.environ['RANK'] = str(rank)
torch.distributed.init_process_group(
backend='gloo', rank=rank, world_size=world_size)