Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Delete frozen parameters when using paramwise_cfg #1441

Merged
merged 6 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion mmengine/optim/optimizer/default_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def add_params(self,
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)
print_log((f'{prefix}.{name} is skipped since its '
f'requires_grad={param.requires_grad}'),
logger='current',
level=logging.WARNING)
continue

# if the parameter match one of the custom keys, ignore other rules
Expand Down
30 changes: 11 additions & 19 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,8 @@ def test_default_optimizer_constructor_with_empty_paramwise_cfg(self):
weight_decay=self.base_wd,
momentum=self.momentum))
paramwise_cfg = dict()
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
self._check_default_optimizer(optim_wrapper.optimizer, model)

Expand Down Expand Up @@ -591,23 +592,16 @@ def test_default_optimizer_constructor_no_grad(self):
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)

for param in self.model.parameters():
param.requires_grad = False
self.model.conv1.requires_grad_(False)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(self.model)
optimizer = optim_wrapper.optimizer
param_groups = optimizer.param_groups
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
for i, (name, param) in enumerate(self.model.named_parameters()):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == self.momentum
assert param_group['lr'] == self.base_lr
assert param_group['weight_decay'] == self.base_wd

all_params = []
for pg in optim_wrapper.param_groups:
all_params.extend(map(id, pg['params']))
self.assertNotIn(id(self.model.conv1.weight), all_params)
self.assertIn(id(self.model.conv2.weight), all_params)

def test_default_optimizer_constructor_bypass_duplicate(self):
# paramwise_cfg with bypass_duplicate option
Expand Down Expand Up @@ -663,10 +657,8 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(optim_wrapper.optimizer.param_groups) == len(
model_parameters) == num_params
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)
assert len(optim_wrapper.optimizer.param_groups
) == len(model_parameters) - 1 == num_params - 1

def test_default_optimizer_constructor_custom_key(self):
# test DefaultOptimWrapperConstructor with custom_keys and
Expand Down