Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Learning rate in log will show the first params_group in optimizer, rather than the base learning rate of optimizer #482

Closed
HAOCHENYE opened this issue Aug 29, 2022 · 4 comments · Fixed by #1019

Comments

@HAOCHENYE
Copy link
Collaborator

Thanks for your error report and we appreciate it a lot.

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. I have read the FAQ documentation but cannot get the expected help.
  3. The bug has not been fixed in the latest version.

Describe the bug
For example, the config of optim_wrapper is:

optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001),
    clip_grad=dict(max_norm=0.1, norm_type=2),
    paramwise_cfg=dict(
        custom_keys={
            'backbone': dict(lr_mult=0.1),
            'sampling_offsets': dict(lr_mult=0.1),
            'reference_points': dict(lr_mult=0.1)
        }))

The log will record the learning rate of model.backbone(2e-5), rather than the base learning rate of optimizer(2e-4).

Reproduction

  1. What command or script did you run?
A placeholder for the command.
  1. Did you make any modifications on the code or config? Did you understand what you have modified?
  2. What dataset did you use?

Environment

  1. Please run python mmdet/utils/collect_env.py to collect necessary environment information and paste it here.
  2. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source]
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback
If applicable, paste the error trackback here.

A placeholder for trackback.

Bug fix
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

@HAOCHENYE HAOCHENYE changed the title [Bug] Learning rate in log message will show the first params_group in optimizer, rather than the base learning rate of optimizer [Bug] Learning rate in log will show the first params_group in optimizer, rather than the base learning rate of optimizer Aug 29, 2022
@ZwwWayne ZwwWayne added this to the 0.6.0 milestone Aug 29, 2022
@AkideLiu
Copy link
Contributor

Hi HAOCHENYE, this is Akide, I hope you are doing well. We are a course group of five members from The University of Adelaide, trying to complete this issue and contribute. We will try to kick off this issue this week. If you have any more ideas, please don't hesitate to contact us. Cheers

@AkideLiu
Copy link
Contributor

Hi, @HAOCHENYE could you please provide an example config file, not in the downstream task repository?

@HAOCHENYE
Copy link
Collaborator Author

Hi, @HAOCHENYE could you please provide an example config file, not in the downstream task repository?

Hi, you can modify the examples/distributed_training.py like this:

# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner


class MMResNet50(BaseModel):

    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels


class Accuracy(BaseMetric):

    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        return dict(accuracy=100 * total_correct / total_size)


def parse_args():
    parser = argparse.ArgumentParser(description='Distributed Training')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
    train_set = torchvision.datasets.CIFAR10(
        'data/cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg)
        ]))
    valid_set = torchvision.datasets.CIFAR10(
        'data/cifar10',
        train=False,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(**norm_cfg)]))
    train_dataloader = dict(
        batch_size=32,
        dataset=train_set,
        sampler=dict(type='DefaultSampler', shuffle=True),
        collate_fn=dict(type='default_collate'))
    val_dataloader = dict(
        batch_size=32,
        dataset=valid_set,
        sampler=dict(type='DefaultSampler', shuffle=False),
        collate_fn=dict(type='default_collate'))
    runner = Runner(
        model=MMResNet50(),
        work_dir='./work_dir',
        train_dataloader=train_dataloader,
        optim_wrapper=dict(
            optimizer=dict(type=SGD, lr=0.001, momentum=0.9),
            paramwise_cfg=dict(custom_keys=dict(
                conv1=dict(lr_mult=0.1)))),
        train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
        val_dataloader=val_dataloader,
        val_cfg=dict(),
        val_evaluator=dict(type=Accuracy),
        launcher=args.launcher,
    )
    runner.train()


if __name__ == '__main__':
    main()

Then you can see the log like this:

03/25 01:12:03 - mmengine - INFO - Checkpoints will be saved to /home/yehaochen/codebase/mmengine/examples/work_dir.
03/25 01:12:05 - mmengine - INFO - Epoch(train) [1][  10/1563]  lr: 1.0000e-04  eta: 0:09:08  time: 0.1761  data_time: 0.0080  memory: 476  loss: 5.0082
03/25 01:12:05 - mmengine - INFO - Epoch(train) [1][  20/1563]  lr: 1.0000e-04  eta: 0:05:34  time: 0.0392  data_time: 0.0078  memory: 476  loss: 2.7224
03/25 01:12:06 - mmengine - INFO - Epoch(train) [1][  30/1563]  lr: 1.0000e-04  eta: 0:04:23  time: 0.0402  data_time: 0.0078  memory: 476  loss: 2.6029
03/25 01:12:06 - mmengine - INFO - Epoch(train) [1][  40/1563]  lr: 1.0000e-04  eta: 0:03:48  time: 0.0407  data_time: 0.0077  memory: 476  loss: 2.5827
03/25 01:12:07 - mmengine - INFO - Epoch(train) [1][  50/1563]  lr: 1.0000e-04  eta: 0:03:25  time: 0.0376  data_time: 0.0080  memory: 476  loss: 2.6548
03/25 01:12:07 - mmengine - INFO - Epoch(train) [1][  60/1563]  lr: 1.0000e-04  eta: 0:03:09  time: 0.0366  data_time: 0.0074  memory: 476  loss: 2.8956
03/25 01:12:07 - mmengine - INFO - Epoch(train) [1][  70/1563]  lr: 1.0000e-04  eta: 0:02:58  time: 0.0374  data_time: 0.0074  memory: 476  loss: 2.8190

Apparently, the lr should be 1e-3 instead of 1e-4.

@AkideLiu
Copy link
Contributor

Hi, @HAOCHENYE

Thank you for providing the example configuration file. We have initiated an initial investigation and identified the following logics that may be related to the issue at hand:

  1. Printing logs are primarily managed by runner.message_hub.update_scalar.
  2. Logging scalars are exported to message_hub in RuntimeInfoHook::before_train_iter.
  3. The learning rate is retrieved from runner.optim_wrapper.get_lr().
  4. In OptimWrapper::get_lr, an unsorted list of learning rates is created from the optimizer parameter groups.

The core issue can potentially be resolved by sorting the learning rate list according to paramwise_cfg. At this stage, we are planning to create a preliminary pull request to sort the learning rate list without conditional checking. We believe it is essential to define rules to automatically identify the base learning rate in the optimizer parameter groups.

    def get_lr(self) -> Dict[str, List[float]]:
        """Get the learning rate of the optimizer.

        Provide unified interface to get learning rate of optimizer.

        Returns:
            Dict[str, List[float]]: Learning rate of the optimizer.
        """
        lr = [group['lr'] for group in self.param_groups]
        lr.sort(reverse=True)
        return dict(lr=lr)

The output shows that :

03/26 15:31:49 - mmengine - INFO - Checkpoints will be saved to /home/akide/pychram_remote/mmengine-dev/examples/work_dir.
03/26 15:31:50 - mmengine - INFO - Epoch(train) [1][  10/1563]  lr: 1.0000e-03  eta: 0:06:40  time: 0.1285  data_time: 0.0032  memory: 369  loss: 5.4107
03/26 15:31:50 - mmengine - INFO - Epoch(train) [1][  20/1563]  lr: 1.0000e-03  eta: 0:03:43  time: 0.0152  data_time: 0.0029  memory: 369  loss: 3.0671
03/26 15:31:51 - mmengine - INFO - Epoch(train) [1][  30/1563]  lr: 1.0000e-03  eta: 0:02:43  time: 0.0150  data_time: 0.0033  memory: 369  loss: 2.6957

Please let us know if you have any concerns or suggestions. We appreciate your input and look forward to collaborating on this matter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants