Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add support for full wandb's define_metric arguments #1099

Merged
merged 7 commits into from Jun 1, 2023

Conversation

i-aki-y
Copy link
Contributor

@i-aki-y i-aki-y commented Apr 23, 2023

Motivation

In this PR, add support for other original arguments of wandb.define_metric described in: https://docs.wandb.ai/ref/python/run#define_metric
(The current WandbVisBackend implementation supports only the case of the summary metric.)

Especially the step_metric is helpful for defining the x-axis more flexibly.
Specifying the "iter" or "epoch" as an x-axis, we can solve the issue of resumption discussed in issue #1072. I think this PR is a better solution because this does not need additional variables like auto_step norlast_step nor extra logic and uses only the wandb's feature.

Furthermore, users can specify different x-axes depending on the metrics in the same experiment. See below:

iter_epoch

Modification

To allow the visualizers to use the iter as a step, I added an iter key in the tag of the log_processor.
And I added a new argument (define_metric_cfgs) in the WandbVisBackend for users to configure the metric.

Since all visualizers use the same tag, this affects other non-wandb visualizers' behavior.
I think this modification would not cause any negative effect on other visualizers.
But if it causes something bad, I will add a flag like 'add_iter' in the log_processor for user control to include iter or not.

BC-breaking (Optional)

As I described above, this PR added a new key, "iter," in the output of the log_processor.
A new plot for the "iter" might appear in some visualization, but I think it does not cause serious problems.

Use cases (Optional)

Expected usage is defining all the metrics used by the experiment:

vis_backends = [
    dict(type='LocalVisBackend'),
    dict(
        type='WandbVisBackend',        
        init_kwargs={...},
        define_metric_cfgs=[
            dict(name="lr", step_metric='iter’),
            ...
            dict(name="coco/bbox_mAP", step_metric='epoch'),
            dict(name="coco/bbox_mAP_50", step_metric='epoch'),            
        ],
    )
]

Providing a standard preset of the define_metric_cfgs as default (ex., configs/_base_/wandb_metrics.py) might be great, but I don't think it is indispensable. It is not difficult, though cumbersome, for users to define metrics individually for a particular project. And defining all possible metrics in advance is a bit difficult because the metrics would change depending on the models and tasks.

If the user missed defining some metrics, wandb would use the default step as the x-axis and loose nothing.
We can fix the X-axis later in the Web UI.

switch_x

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@i-aki-y
Copy link
Contributor Author

i-aki-y commented Apr 26, 2023

Updated: allow define_metric_cfg to accept dict and list of dict, and remove added define_metric_cfgs argument.

I noticed that the define_metric support glob so that we can use the following setting:

 define_metric_cfgs=[
            dict(name="loss*", step_metric='iter’),
            dict(name="coco/*", step_metric='epoch'),
            ...
        ],

https://docs.wandb.ai/guides/track/log/customize-logging-axes

Copy link
Collaborator

@HAOCHENYE HAOCHENYE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! But should we make self.wandb.log accept the step argument in this PR? Will the current modification resolve the resume problem?

Comment on lines 204 to 205
# Add global iter.
tag['iter'] = runner.iter + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of runner.iter in mmengine's Runner triggers the instantiation of runner._train_loop as seen in these lines of code:

return self.train_loop.iter

and

return self._train_loop

runner.iter actually represents the training iteration, and the construction of the train_loop should be avoided during validation or testing phases. The current design mechanism is somewhat implicit, which could lead to confusion. Recently, I also created a pull request (#1107) to fix this kind of bug 🤣 .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest assigning iter like this:

if isinstance(runner._train_loop, dict) or runner._train_loop is None:
     tag['iter'] = 0
else:
    tag['iter'] = runner.iter + 1

mmengine/visualization/vis_backend.py Outdated Show resolved Hide resolved
mmengine/visualization/vis_backend.py Outdated Show resolved Hide resolved
mmengine/visualization/vis_backend.py Outdated Show resolved Hide resolved
mmengine/visualization/vis_backend.py Outdated Show resolved Hide resolved
@i-aki-y
Copy link
Contributor Author

i-aki-y commented Apr 27, 2023

@HAOCHENYE Thank you for your review.

Thanks for your contribution! But should we make self.wandb.log accept the step argument in this PR? Will the current modification resolve the resume problem?

We can fix the continuity of the resumed plot by changing the X-axis from the 'step' to 'iter' (or epoch).

fix_resume

In this PR, we include the 'iter' in the log target so that the wandb knows how to map from 'step' to 'iter' even if the user did not set define_metric in advance.

This means that the user needs to change X-axis by themselves by using define_metric or Web UI.
For example, the following settings are working for me so far:

    define_metric_cfg=[
        dict(name="coco/*", step_metric="epoch"),
        dict(name="*", step_metric="iter"),
    ],

@HAOCHENYE
Copy link
Collaborator

@HAOCHENYE Thank you for your review.

Thanks for your contribution! But should we make self.wandb.log accept the step argument in this PR? Will the current modification resolve the resume problem?

We can fix the continuity of the resumed plot by changing the X-axis from the 'step' to 'iter' (or epoch).

fix_resume

In this PR, we include the 'iter' in the log target so that the wandb knows how to map from 'step' to 'iter' even if the user did not set define_metric in advance.

This means that the user needs to change X-axis by themselves by using define_metric or Web UI. For example, the following settings are working for me so far:

    define_metric_cfg=[
        dict(name="coco/*", step_metric="epoch"),
        dict(name="*", step_metric="iter"),
    ],

Thanks for your detailed explanation 😄 ! I've got it!

@i-aki-y i-aki-y force-pushed the add-full-define_metrc-support branch from 5ddd4ca to 9f77ab9 Compare May 4, 2023 23:30
@i-aki-y
Copy link
Contributor Author

i-aki-y commented May 4, 2023

@HAOCHENYE Thank you for your comment.
I added a code to avoid log_processor to build train_loop.

Suggest assigning iter like this:
if isinstance(runner._train_loop, dict) or runner._train_loop is None:
tag['iter'] = 0
else:
tag['iter'] = runner.iter + 1

@zhouzaida zhouzaida modified the milestones: 0.7.3, 0.7.4 May 10, 2023
@HAOCHENYE HAOCHENYE modified the milestones: 0.7.4, 0.7.5 May 31, 2023
@zhouzaida zhouzaida merged commit 6df9621 into open-mmlab:main Jun 1, 2023
16 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants