Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Delete frozen parameters when using paramwise_cfg #1441

Merged
merged 6 commits into from Apr 22, 2024

Conversation

LZHgrla
Copy link
Contributor

@LZHgrla LZHgrla commented Nov 27, 2023

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. By the way, if you're not familiar with how to use pre-commit to fix lint issues or add unit tests, please refer to Contributing to OpenMMLab.

Motivation

It will cause errors when initializing DeepSpeed optimizer, with

  1. Freezing some parameters of the model
  2. Setting paramwise_cfg for optimizer to set different lr or weight_decay for different parameters

This is because that if setting paramwise_cfg, mmengine will treat each parameter (including frozen parameters) as a separate group, and that will lead to an empty list of trainable_parameters on the below code.

https://github.com/microsoft/DeepSpeed/blob/2afa1c7f2f961ef18042a88467ff5d3373c22c07/deepspeed/runtime/zero/stage_1_and_2.py#L308-L313

Modification

mmengine/_strategy/deepspeed.py

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDetection or MMPretrain.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@LZHgrla LZHgrla changed the title [Fix] Delete the freezing parameters from the DeepSpeed optimizer [Fix] Delete the frozen parameters from the DeepSpeed optimizer Nov 27, 2023
@LZHgrla LZHgrla changed the title [Fix] Delete the frozen parameters from the DeepSpeed optimizer [Fix] Delete frozen parameters from the DeepSpeed optimizer Nov 27, 2023
@LZHgrla LZHgrla mentioned this pull request Nov 27, 2023
16 tasks
@LZHgrla LZHgrla marked this pull request as draft November 27, 2023 16:20
@LZHgrla LZHgrla marked this pull request as ready for review February 4, 2024 05:29
@zhouzaida
Copy link
Member

How about moving this logic deleting frozen parameters to DefaultOptimWrapperConstructor.

for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if bypass_duplicate and self._is_in(param_group, params):
print_log(
f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}',
logger='current',
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)
continue
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in sorted_keys:
if key in f'{prefix}.{name}':
is_custom = True
lr_mult = custom_keys[key].get('lr_mult', 1.)
param_group['lr'] = self.base_lr * lr_mult
if self.base_wd is not None:
decay_mult = custom_keys[key].get('decay_mult', 1.)
param_group['weight_decay'] = self.base_wd * decay_mult
# add custom settings to param_group
for k, v in custom_keys[key].items():
param_group[k] = v
break
if not is_custom:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if name == 'bias' and not (
is_norm or is_dcn_module) and bias_lr_mult is not None:
param_group['lr'] = self.base_lr * bias_lr_mult
if (prefix.find('conv_offset') != -1 and is_dcn_module
and dcn_offset_lr_mult is not None
and isinstance(module, torch.nn.Conv2d)):
# deal with both dcn_offset's bias & weight
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm and norm_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# bias lr and decay
elif (name == 'bias' and not is_dcn_module
and bias_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
# depth-wise conv
elif is_dwconv and dwconv_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# flatten parameters except dcn offset
elif (param.ndim == 1 and not is_dcn_module
and flat_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * flat_decay_mult
params.append(param_group)

@LZHgrla
Copy link
Contributor Author

LZHgrla commented Feb 18, 2024

How about moving this logic deleting frozen parameters to DefaultOptimWrapperConstructor.

for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if bypass_duplicate and self._is_in(param_group, params):
print_log(
f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}',
logger='current',
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)
continue
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in sorted_keys:
if key in f'{prefix}.{name}':
is_custom = True
lr_mult = custom_keys[key].get('lr_mult', 1.)
param_group['lr'] = self.base_lr * lr_mult
if self.base_wd is not None:
decay_mult = custom_keys[key].get('decay_mult', 1.)
param_group['weight_decay'] = self.base_wd * decay_mult
# add custom settings to param_group
for k, v in custom_keys[key].items():
param_group[k] = v
break
if not is_custom:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if name == 'bias' and not (
is_norm or is_dcn_module) and bias_lr_mult is not None:
param_group['lr'] = self.base_lr * bias_lr_mult
if (prefix.find('conv_offset') != -1 and is_dcn_module
and dcn_offset_lr_mult is not None
and isinstance(module, torch.nn.Conv2d)):
# deal with both dcn_offset's bias & weight
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm and norm_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# bias lr and decay
elif (name == 'bias' and not is_dcn_module
and bias_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
# depth-wise conv
elif is_dwconv and dwconv_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# flatten parameters except dcn offset
elif (param.ndim == 1 and not is_dcn_module
and flat_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * flat_decay_mult
params.append(param_group)

Good idea!

Shall we delete the L216?

if not param.requires_grad:
params.append(param_group)
continue

@zhouzaida
Copy link
Member

Yes, we can delete it.

How about moving this logic deleting frozen parameters to DefaultOptimWrapperConstructor.

for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if bypass_duplicate and self._is_in(param_group, params):
print_log(
f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}',
logger='current',
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)
continue
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in sorted_keys:
if key in f'{prefix}.{name}':
is_custom = True
lr_mult = custom_keys[key].get('lr_mult', 1.)
param_group['lr'] = self.base_lr * lr_mult
if self.base_wd is not None:
decay_mult = custom_keys[key].get('decay_mult', 1.)
param_group['weight_decay'] = self.base_wd * decay_mult
# add custom settings to param_group
for k, v in custom_keys[key].items():
param_group[k] = v
break
if not is_custom:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if name == 'bias' and not (
is_norm or is_dcn_module) and bias_lr_mult is not None:
param_group['lr'] = self.base_lr * bias_lr_mult
if (prefix.find('conv_offset') != -1 and is_dcn_module
and dcn_offset_lr_mult is not None
and isinstance(module, torch.nn.Conv2d)):
# deal with both dcn_offset's bias & weight
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm and norm_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# bias lr and decay
elif (name == 'bias' and not is_dcn_module
and bias_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
# depth-wise conv
elif is_dwconv and dwconv_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# flatten parameters except dcn offset
elif (param.ndim == 1 and not is_dcn_module
and flat_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * flat_decay_mult
params.append(param_group)

Good idea!

Shall we delete the L216?

if not param.requires_grad:
params.append(param_group)
continue

@LZHgrla LZHgrla changed the title [Fix] Delete frozen parameters from the DeepSpeed optimizer [Fix] Delete frozen parameters when using paramwise_cfg Feb 19, 2024
@LZHgrla
Copy link
Contributor Author

LZHgrla commented Feb 19, 2024

Hi, @zhouzaida
I have fixed it!
Ready for review and merge.

@zhouzaida
Copy link
Member

Hi, @zhouzaida I have fixed it! Ready for review and merge.

Please fix the ut.

@zhouzaida zhouzaida merged commit acbc5e4 into open-mmlab:main Apr 22, 2024
16 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants