Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient cumulative optimizer #784

Closed
wants to merge 4 commits into from

Conversation

ZhiyuanChen
Copy link
Contributor

fixes #190

@ZhiyuanChen
Copy link
Contributor Author

Unit tests to be added

@codecov
Copy link

codecov bot commented Jan 12, 2021

Codecov Report

Merging #784 (72ce7e3) into master (f169fb5) will decrease coverage by 0.07%.
The diff coverage is 81.21%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #784      +/-   ##
==========================================
- Coverage   62.89%   62.81%   -0.08%     
==========================================
  Files         144      145       +1     
  Lines        8467     8702     +235     
  Branches     1520     1574      +54     
==========================================
+ Hits         5325     5466     +141     
- Misses       2874     2971      +97     
+ Partials      268      265       -3     
Flag Coverage Δ
unittests 62.81% <81.21%> (-0.08%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcv/cnn/alexnet.py 26.08% <0.00%> (-4.35%) ⬇️
mmcv/cnn/resnet.py 12.19% <0.00%> (-0.61%) ⬇️
mmcv/cnn/vgg.py 11.11% <0.00%> (-1.02%) ⬇️
mmcv/onnx/onnx_utils/symbolic_helper.py 0.00% <0.00%> (-19.88%) ⬇️
mmcv/ops/roi_align.py 58.75% <0.00%> (-1.25%) ⬇️
mmcv/visualization/image.py 10.76% <0.00%> (-0.17%) ⬇️
mmcv/ops/nms.py 34.43% <8.33%> (ø)
mmcv/runner/hooks/optimizer.py 22.85% <20.00%> (-0.90%) ⬇️
mmcv/runner/iter_based_runner.py 53.95% <50.00%> (ø)
mmcv/utils/parrots_jit.py 78.94% <66.66%> (+2.47%) ⬆️
... and 20 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f169fb5...5087bec. Read the comment docs.

@hellock hellock requested a review from nbei February 4, 2021 09:44
@nbei
Copy link
Contributor

nbei commented Feb 4, 2021

Please fix the linting error

Copy link
Contributor

@nbei nbei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Unit tests are missing.
  2. The detailed implementation should be clarified. Please see my comments.

@@ -34,6 +34,27 @@ def after_train_iter(self, runner):
runner.optimizer.step()


@HOOKS.register_module()
class GradientCumulativeOptimizerHook(OptimizerHook):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs are missing.

self.cumulative_iters = cumulative_iters

def after_train_iter(self, runner):
runner.outputs['loss'] = runner.outputs['loss'] / self.cumulative_iters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confusing about the detailed implementation of this function. Could you please offer a reference for this implementation to show that it is a general case?

In my opinion, the accumulative gradients are adopted to avoid large batch sizes. Is it right? If it's right, why should we divide the self.cumulative_iters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gradient accumulate is to achieve an equivalent larger batch size with small batch size. Therefore, the loss should be normalised. See more at https://discuss.pytorch.org/t/pytorch-gradient-accumulation/55955/2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, got it. You may specify this in the docs.

In addition, please also fix the corner case where total_iters % cumulative_iters != 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, isn't BatchNormalization an issue as mentioned in here https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3#gistcomment-3381285

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggalan87 Yes the behavior of batch normal is different, however, not all networks contain batchnorm

@ZhiyuanChen
Copy link
Contributor Author

As mentioned by @ggalan87
We probably need to raise a warning when runner.model contains nn.BatchNorm.

@hellock hellock requested a review from nbei February 21, 2021 03:29
@nbei
Copy link
Contributor

nbei commented Feb 21, 2021

Hi @ZhiyuanChen , many thanks for your contribution. Please fix the linting error first. It seems that you have not adopted the pre-commit hook as requested in CONTRIBUTING.

Copy link
Contributor

@nbei nbei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments.

@@ -40,11 +40,22 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, cumulative_iters=1):
super(GradientCumulativeOptimizerHook, self).__init__(grad_clip)
self.cumulative_iters = cumulative_iters
self.divisible_ietrs = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.divisible_iters

self.initialized = False

def _init(self, runner):
self.divisible_ietrs = runner.max_iters // self.cumulative_iters * self.cumulative_iters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another corner case where users resume from iter=2 but the cumulative_iters=4. It seems this implementation will bring wrong gradients. If no good solutions, please put a warning here.

self.remainder_iters = 0
self.initialized = False

def _init(self, runner):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put a warning for the usage of BN.

@mzr1996 mzr1996 mentioned this pull request Jul 28, 2021
4 tasks
ZwwWayne pushed a commit that referenced this pull request Aug 23, 2021
* Add gradient cumulative optimizer

fixes #190

* Update optimizer.py

* Update optimizer.py

* fix loss scale improperly in last equivalent_iter

* Add `GradientCumulativeOptimizerHook` in `__init__.py`.

* Add docstring of `GradientCumulativeOptimizerHook`.

* Add type check, BN warning and resume warning. And fix typo, lint the
code.

* Add unit test

* Update docstring example.

* Change GradientCumulativeOptimizerHook `__init__` arguments.

* Add GradientCumulativeOptimzierHook unit tests with IterBasedRunner.

* Add GradientCumulativeFp16OptimizerHook.

* Add unit tests of GradientCumulativeFp16OptimizerHook

* Use '!=' instead of '>' to determine resume

Co-authored-by: Zhiyuan Chen <this@zyc.ai>
@zhouzaida
Copy link
Member

closed by #1221

@zhouzaida zhouzaida closed this Aug 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cumulative gradient?Using small BatchSize to simulate big BatchSize
4 participants