From 7ea95c36ae752f5ff662cd893cf1e80bb387fad8 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Tue, 12 Jan 2021 16:24:51 +0800 Subject: [PATCH 1/4] Add gradient cumulative optimizer fixes #190 --- mmcv/runner/hooks/optimizer.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index 4f27844a4e..d92b3a79e4 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -34,6 +34,30 @@ def after_train_iter(self, runner): runner.optimizer.step() +@HOOKS.register_module() +class GradientCumulativeOptimizerHook(OptimizerHook): + + def __init__(self, grad_clip=None, cumulative_iters=1): + super(GradientCumulativeOptimizerHook, self).__init__(grad_clip) + self.cumulative_iters = cumulative_iters + self.steps = 0 + + def after_train_iter(self, runner): + self.steps += 1 + if (runner.iter + 1) % cumulative_iters == 0 or runner.iter == max_iters: + runner.optimizer.zero_grad() + runner.outputs['loss'] = runner.outputs['loss'] / self.steps + runner.outputs['loss'].backward() + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + self.steps = 0 + + @HOOKS.register_module() class Fp16OptimizerHook(OptimizerHook): """FP16 optimizer hook. From bfa748fda7a255a9cc68ac193aed835c8f79fac8 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Tue, 12 Jan 2021 16:32:29 +0800 Subject: [PATCH 2/4] Update optimizer.py --- mmcv/runner/hooks/optimizer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index d92b3a79e4..97b3bc53ea 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -40,21 +40,19 @@ class GradientCumulativeOptimizerHook(OptimizerHook): def __init__(self, grad_clip=None, cumulative_iters=1): super(GradientCumulativeOptimizerHook, self).__init__(grad_clip) self.cumulative_iters = cumulative_iters - self.steps = 0 def after_train_iter(self, runner): - self.steps += 1 + runner.outputs['loss'] = runner.outputs['loss'] / self.cumulative_iters + runner.outputs['loss'].backward() if (runner.iter + 1) % cumulative_iters == 0 or runner.iter == max_iters: - runner.optimizer.zero_grad() - runner.outputs['loss'] = runner.outputs['loss'] / self.steps - runner.outputs['loss'].backward() + runner.optimizer.step() if self.grad_clip is not None: grad_norm = self.clip_grads(runner.model.parameters()) if grad_norm is not None: # Add grad norm to the logger runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.outputs['num_samples']) - runner.optimizer.step() + runner.optimizer.zero_grad() self.steps = 0 From 04a0fef13bddd6984edf8bd6d48774e3ad10736c Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Tue, 12 Jan 2021 16:37:08 +0800 Subject: [PATCH 3/4] Update optimizer.py --- mmcv/runner/hooks/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index 97b3bc53ea..89d2a81096 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -53,7 +53,6 @@ def after_train_iter(self, runner): runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.outputs['num_samples']) runner.optimizer.zero_grad() - self.steps = 0 @HOOKS.register_module() From 5087bec5d51b6e93f20f74268be4645cc6576e03 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 18 Feb 2021 23:18:47 +0800 Subject: [PATCH 4/4] fix loss scale improperly in last equivalent_iter --- mmcv/runner/hooks/optimizer.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index 89d2a81096..8c03724c9e 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -40,11 +40,22 @@ class GradientCumulativeOptimizerHook(OptimizerHook): def __init__(self, grad_clip=None, cumulative_iters=1): super(GradientCumulativeOptimizerHook, self).__init__(grad_clip) self.cumulative_iters = cumulative_iters + self.divisible_ietrs = 0 + self.remainder_iters = 0 + self.initialized = False + + def _init(self, runner): + self.divisible_ietrs = runner.max_iters // self.cumulative_iters * self.cumulative_iters + self.remainder_iters = runner.max_iters % self.cumulative_iters + self.initialized = True def after_train_iter(self, runner): - runner.outputs['loss'] = runner.outputs['loss'] / self.cumulative_iters + if not self.initialized: + self._init(runner) + loss_factor = self.cumulative_iters if runner.iter < self.divisible_ietrs else self.remainder_iters + runner.outputs['loss'] = runner.outputs['loss'] / loss_factor runner.outputs['loss'].backward() - if (runner.iter + 1) % cumulative_iters == 0 or runner.iter == max_iters: + if (runner.iter + 1) % self.cumulative_iters == 0 or runner.iter == runner.max_iters: runner.optimizer.step() if self.grad_clip is not None: grad_norm = self.clip_grads(runner.model.parameters())