Skip to content

Commit

Permalink
Merge c0e6c49 into 65eb39b
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu committed Dec 9, 2022
2 parents 65eb39b + c0e6c49 commit c0c9a96
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
32 changes: 20 additions & 12 deletions mmengine/optim/optimizer/default_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,13 @@ def add_params(self,
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)

bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None)
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None)

# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
Expand Down Expand Up @@ -225,29 +226,36 @@ def add_params(self,
if not is_custom:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if name == 'bias' and not (is_norm or is_dcn_module):
if name == 'bias' and not (
is_norm or is_dcn_module) and bias_lr_mult is not None:
param_group['lr'] = self.base_lr * bias_lr_mult

if (prefix.find('conv_offset') != -1 and is_dcn_module
and dcn_offset_lr_mult is not None
and isinstance(module, torch.nn.Conv2d)):
# deal with both dcn_offset's bias & weight
param_group['lr'] = self.base_lr * dcn_offset_lr_mult

# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm:
if is_norm and norm_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# bias lr and decay
elif (name == 'bias' and not is_dcn_module
and bias_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
# depth-wise conv
elif is_dwconv:
elif is_dwconv and dwconv_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# bias lr and decay
elif name == 'bias' and not is_dcn_module:
# TODO: current bias_decay_mult will have affect on DCN
# flatten parameters except dcn offset
elif (param.ndim == 1 and not is_dcn_module
and flat_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
'weight_decay'] = self.base_wd * flat_decay_mult
params.append(param_group)
for key, value in param_group.items():
if key == 'params':
Expand Down
21 changes: 13 additions & 8 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _check_sgd_optimizer(self,
norm_decay_mult=1,
dwconv_decay_mult=1,
dcn_offset_lr_mult=1,
flat_decay_mult=1,
bypass_duplicate=False):
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
Expand All @@ -139,7 +140,7 @@ def _check_sgd_optimizer(self,
# param1
param1 = param_groups[0]
assert param1['lr'] == self.base_lr
assert param1['weight_decay'] == self.base_wd
assert param1['weight_decay'] == self.base_wd * flat_decay_mult
# conv1.weight
conv1_weight = param_groups[1]
assert conv1_weight['lr'] == self.base_lr
Expand All @@ -163,7 +164,7 @@ def _check_sgd_optimizer(self,
# sub.param1
sub_param1 = param_groups[6]
assert sub_param1['lr'] == self.base_lr
assert sub_param1['weight_decay'] == self.base_wd
assert sub_param1['weight_decay'] == self.base_wd * flat_decay_mult
# sub.conv1.weight
sub_conv1_weight = param_groups[7]
assert sub_conv1_weight['lr'] == self.base_lr
Expand All @@ -172,8 +173,7 @@ def _check_sgd_optimizer(self,
# sub.conv1.bias
sub_conv1_bias = param_groups[8]
assert sub_conv1_bias['lr'] == self.base_lr * bias_lr_mult
assert sub_conv1_bias[
'weight_decay'] == self.base_wd * dwconv_decay_mult
assert sub_conv1_bias['weight_decay'] == self.base_wd * bias_decay_mult
# sub.gn.weight
sub_gn_weight = param_groups[9]
assert sub_gn_weight['lr'] == self.base_lr
Expand Down Expand Up @@ -258,7 +258,8 @@ def test_build_default_optimizer_constructor(self):
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3)
optim_constructor_cfg = dict(
type='DefaultOptimWrapperConstructor',
optim_wrapper_cfg=optim_wrapper,
Expand Down Expand Up @@ -390,7 +391,8 @@ def test_default_optimizer_constructor_with_model_wrapper(self):
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
Expand Down Expand Up @@ -429,7 +431,8 @@ def test_default_optimizer_constructor_with_model_wrapper(self):
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
Expand Down Expand Up @@ -484,7 +487,8 @@ def test_default_optimizer_constructor_with_paramwise_cfg(self):
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(self.model)
Expand Down Expand Up @@ -554,6 +558,7 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3,
bypass_duplicate=True)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
Expand Down

0 comments on commit c0c9a96

Please sign in to comment.