-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Learning rate in log will show the first params_group
in optimizer, rather than the base learning rate of optimizer
#482
Comments
params_group
in optimizer, rather than the base learning rate of optimizerparams_group
in optimizer, rather than the base learning rate of optimizer
Hi HAOCHENYE, this is Akide, I hope you are doing well. We are a course group of five members from The University of Adelaide, trying to complete this issue and contribute. We will try to kick off this issue this week. If you have any more ideas, please don't hesitate to contact us. Cheers |
Hi, @HAOCHENYE could you please provide an example config file, not in the downstream task repository? |
Hi, you can modify the # Copyright (c) OpenMMLab. All rights reserved.
import argparse
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)
def parse_args():
parser = argparse.ArgumentParser(description='Distributed Training')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
return args
def main():
args = parse_args()
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_set = torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
]))
valid_set = torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)]))
train_dataloader = dict(
batch_size=32,
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'))
val_dataloader = dict(
batch_size=32,
dataset=valid_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'))
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(
optimizer=dict(type=SGD, lr=0.001, momentum=0.9),
paramwise_cfg=dict(custom_keys=dict(
conv1=dict(lr_mult=0.1)))),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher,
)
runner.train()
if __name__ == '__main__':
main() Then you can see the log like this: 03/25 01:12:03 - mmengine - INFO - Checkpoints will be saved to /home/yehaochen/codebase/mmengine/examples/work_dir.
03/25 01:12:05 - mmengine - INFO - Epoch(train) [1][ 10/1563] lr: 1.0000e-04 eta: 0:09:08 time: 0.1761 data_time: 0.0080 memory: 476 loss: 5.0082
03/25 01:12:05 - mmengine - INFO - Epoch(train) [1][ 20/1563] lr: 1.0000e-04 eta: 0:05:34 time: 0.0392 data_time: 0.0078 memory: 476 loss: 2.7224
03/25 01:12:06 - mmengine - INFO - Epoch(train) [1][ 30/1563] lr: 1.0000e-04 eta: 0:04:23 time: 0.0402 data_time: 0.0078 memory: 476 loss: 2.6029
03/25 01:12:06 - mmengine - INFO - Epoch(train) [1][ 40/1563] lr: 1.0000e-04 eta: 0:03:48 time: 0.0407 data_time: 0.0077 memory: 476 loss: 2.5827
03/25 01:12:07 - mmengine - INFO - Epoch(train) [1][ 50/1563] lr: 1.0000e-04 eta: 0:03:25 time: 0.0376 data_time: 0.0080 memory: 476 loss: 2.6548
03/25 01:12:07 - mmengine - INFO - Epoch(train) [1][ 60/1563] lr: 1.0000e-04 eta: 0:03:09 time: 0.0366 data_time: 0.0074 memory: 476 loss: 2.8956
03/25 01:12:07 - mmengine - INFO - Epoch(train) [1][ 70/1563] lr: 1.0000e-04 eta: 0:02:58 time: 0.0374 data_time: 0.0074 memory: 476 loss: 2.8190 Apparently, the |
Hi, @HAOCHENYE Thank you for providing the example configuration file. We have initiated an initial investigation and identified the following logics that may be related to the issue at hand:
The core issue can potentially be resolved by sorting the learning rate list according to paramwise_cfg. At this stage, we are planning to create a preliminary pull request to sort the learning rate list without conditional checking. We believe it is essential to define rules to automatically identify the base learning rate in the optimizer parameter groups. def get_lr(self) -> Dict[str, List[float]]:
"""Get the learning rate of the optimizer.
Provide unified interface to get learning rate of optimizer.
Returns:
Dict[str, List[float]]: Learning rate of the optimizer.
"""
lr = [group['lr'] for group in self.param_groups]
lr.sort(reverse=True)
return dict(lr=lr) The output shows that :
Please let us know if you have any concerns or suggestions. We appreciate your input and look forward to collaborating on this matter. |
Thanks for your error report and we appreciate it a lot.
Checklist
Describe the bug
For example, the config of optim_wrapper is:
The log will record the learning rate of
model.backbone
(2e-5), rather than the base learning rate of optimizer(2e-4).Reproduction
Environment
python mmdet/utils/collect_env.py
to collect necessary environment information and paste it here.$PATH
,$LD_LIBRARY_PATH
,$PYTHONPATH
, etc.)Error traceback
If applicable, paste the error trackback here.
Bug fix
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
The text was updated successfully, but these errors were encountered: