Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] continue PR #784 #1221

Merged
merged 15 commits into from Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 6 additions & 5 deletions mmcv/runner/__init__.py
Expand Up @@ -11,10 +11,11 @@
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook,
LrUpdaterHook, MlflowLoggerHook, NeptuneLoggerHook,
OptimizerHook, PaviLoggerHook, SyncBuffersHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
Fp16OptimizerHook, GradientCumulativeOptimizerHook, Hook,
IterTimerHook, LoggerHook, LrUpdaterHook, MlflowLoggerHook,
NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
Expand All @@ -39,5 +40,5 @@
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
'ModuleList'
'ModuleList', 'GradientCumulativeOptimizerHook'
]
5 changes: 3 additions & 2 deletions mmcv/runner/hooks/__init__.py
Expand Up @@ -11,7 +11,8 @@
from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook
from .momentum_updater import MomentumUpdaterHook
from .optimizer import Fp16OptimizerHook, OptimizerHook
from .optimizer import (Fp16OptimizerHook, GradientCumulativeOptimizerHook,
OptimizerHook)
from .profiler import ProfilerHook
from .sampler_seed import DistSamplerSeedHook
from .sync_buffer import SyncBuffersHook
Expand All @@ -23,5 +24,5 @@
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
'DistEvalHook', 'ProfilerHook'
'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook'
]
90 changes: 89 additions & 1 deletion mmcv/runner/hooks/optimizer.py
Expand Up @@ -5,7 +5,7 @@

from torch.nn.utils import clip_grad

from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
from ..dist_utils import allreduce_grads
from ..fp16_utils import LossScaler, wrap_fp16_model
from .hook import HOOKS, Hook
Expand Down Expand Up @@ -42,6 +42,94 @@ def after_train_iter(self, runner):
runner.optimizer.step()


@HOOKS.register_module()
class GradientCumulativeOptimizerHook(OptimizerHook):
"""Optimizer Hook implements multi-iters gradient cumulating.

Args:
grad_clip (dict, optional): Parameters passed to
`torch.nn.utils.clip_grad_norm_`, and if None, disable grad clip.
Defaults to None.
cumulative_iters (int, optional): Num of gradient cumulative iters.
The optimizer will step every `cumulative_iters` iters.
Defaults to 1.

Examples:
>>> # Use cumulative_iters to simulate a large batch size
>>> # It is helpful when the hardware cannot handle a large batch size.
>>> loader = DataLoader(data, batch_size=64)
>>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
>>> # almost equals to
>>> loader = DataLoader(data, batch_size=256)
>>> optim_hook = OptimizerHook()
"""

def __init__(self, grad_clip=None, cumulative_iters=1):
super(GradientCumulativeOptimizerHook, self).__init__(grad_clip)

assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
f'cumulative_iters only accepts positive int, but got ' \
f'{type(cumulative_iters)} instead.'

self.cumulative_iters = cumulative_iters
self.divisible_iters = 0
self.remainder_iters = 0
self.initialized = False

def has_batch_norm(self, module):
if isinstance(module, _BatchNorm):
return True
for m in module.children():
if self.has_batch_norm(m):
return True
return False

def _init(self, runner):
if runner.iter % self.cumulative_iters > 0:
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
runner.logger.warning(
'Resume iter number is not divisible by cumulative_iters in '
'GradientCumulativeOptimizerHook, which means the gradient of '
'some iters is lost and the result may be influenced slightly.'
)

if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
runner.logger.warning(
'GradientCumulativeOptimizerHook may slightly decrease '
'performance if the model has BatchNorm layers.')

residual_iters = runner.max_iters - runner.iter

self.divisible_iters = (
residual_iters // self.cumulative_iters * self.cumulative_iters)
self.remainder_iters = residual_iters - self.divisible_iters

self.initialized = True

def after_train_iter(self, runner):
if not self.initialized:
self._init(runner)

if runner.iter < self.divisible_iters:
loss_factor = self.cumulative_iters
else:
loss_factor = self.remainder_iters
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved
loss = runner.outputs['loss']
loss = loss / loss_factor
loss.backward()

if (self.every_n_iters(runner, self.cumulative_iters)
or self.is_last_iter(runner)):

if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update({'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
runner.optimizer.step()
runner.optimizer.zero_grad()


if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):

Expand Down
84 changes: 83 additions & 1 deletion tests/test_runner/test_hooks.py
Expand Up @@ -20,7 +20,8 @@
from torch.utils.data import DataLoader

from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
IterTimerHook, MlflowLoggerHook, NeptuneLoggerHook,
GradientCumulativeOptimizerHook, IterTimerHook,
MlflowLoggerHook, NeptuneLoggerHook, OptimizerHook,
PaviLoggerHook, WandbLoggerHook, build_runner)
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
Expand Down Expand Up @@ -1229,3 +1230,84 @@ def after_epoch():
# stages output have order, so here is list instead of set.
expected_stages = ['before_run', 'after_train_epoch', 'after_val_epoch']
assert hook.get_triggered_stages() == expected_stages


def test_gradient_cumulative_optimizer_hook():

class ToyModel(nn.Module):

def __init__(self, with_norm=False):
super().__init__()
self.fp16_enabled = False
self.fc = nn.Linear(3, 2)
nn.init.constant_(self.fc.weight, 1.)
nn.init.constant_(self.fc.bias, 1.)
self.with_norm = with_norm
if with_norm:
self.norm = nn.BatchNorm1d(2)

def forward(self, x):
x = self.fc(x)
if self.with_norm:
x = self.norm(x)
return x

def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x).mean(), num_samples=x.shape[0])

def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x).mean(), num_samples=x.shape[0])

def build_toy_runner(config=dict(type='EpochBasedRunner', max_epochs=3)):
model = ToyModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02)
tmp_dir = tempfile.mkdtemp()

runner = build_runner(
config,
default_args=dict(
model=model,
work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger(),
meta=dict()))
return runner

with pytest.raises(AssertionError):
# cumulative_iters only accepts int
GradientCumulativeOptimizerHook(cumulative_iters='str')

with pytest.raises(AssertionError):
# cumulative_iters only accepts positive number
GradientCumulativeOptimizerHook(cumulative_iters=-1)

data = torch.rand((6, 3))
# optimize with cumulative_iters
loader_1 = DataLoader(data, batch_size=1)
runner_1 = build_toy_runner()
optimizer_hook = GradientCumulativeOptimizerHook(
grad_clip=dict(max_norm=0.2), cumulative_iters=3)
runner_1.register_hook(optimizer_hook)
runner_1.run([loader_1], [('train', 1)])

# optimize without cumulative_iters
loader_2 = DataLoader(data, batch_size=3)
runner_2 = build_toy_runner()
optimizer_hook = OptimizerHook(grad_clip=dict(max_norm=0.2))
runner_2.register_hook(optimizer_hook)
runner_2.run([loader_2], [('train', 1)])

# test optimizer works well
assert (runner_1.model.fc.weight < 1).all()
assert (runner_1.model.fc.bias < 1).all()
# test optimizer with cumulative_iters gets the same results
assert torch.allclose(runner_1.model.fc.weight, runner_2.model.fc.weight)
assert torch.allclose(runner_1.model.fc.bias, runner_2.model.fc.bias)
shutil.rmtree(runner_1.work_dir)
shutil.rmtree(runner_2.work_dir)

# test has_batch_norm
model = ToyModel(with_norm=True)
optimizer_hook = GradientCumulativeOptimizerHook(
grad_clip=dict(max_norm=0.2), cumulative_iters=3)
assert optimizer_hook.has_batch_norm(model)