Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient cumulative optimizer #784

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions mmcv/runner/hooks/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,38 @@ def after_train_iter(self, runner):
runner.optimizer.step()


@HOOKS.register_module()
class GradientCumulativeOptimizerHook(OptimizerHook):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs are missing.

def __init__(self, grad_clip=None, cumulative_iters=1):
super(GradientCumulativeOptimizerHook, self).__init__(grad_clip)
self.cumulative_iters = cumulative_iters
self.divisible_ietrs = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.divisible_iters

self.remainder_iters = 0
self.initialized = False

def _init(self, runner):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put a warning for the usage of BN.

self.divisible_ietrs = runner.max_iters // self.cumulative_iters * self.cumulative_iters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another corner case where users resume from iter=2 but the cumulative_iters=4. It seems this implementation will bring wrong gradients. If no good solutions, please put a warning here.

self.remainder_iters = runner.max_iters % self.cumulative_iters
self.initialized = True

def after_train_iter(self, runner):
if not self.initialized:
self._init(runner)
loss_factor = self.cumulative_iters if runner.iter < self.divisible_ietrs else self.remainder_iters
runner.outputs['loss'] = runner.outputs['loss'] / loss_factor
runner.outputs['loss'].backward()
if (runner.iter + 1) % self.cumulative_iters == 0 or runner.iter == runner.max_iters:
runner.optimizer.step()
if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update({'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
runner.optimizer.zero_grad()


@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
"""FP16 optimizer hook.
Expand Down