Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add support for full wandb's define_metric arguments #1099

Merged
merged 7 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def get_log_after_iter(self, runner, batch_idx: int,
cur_iter_str = str(batch_idx + 1).rjust(
len(str(dataloader_len)))
log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ')
# Add global iter.
tag['iter'] = runner.iter + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of runner.iter in mmengine's Runner triggers the instantiation of runner._train_loop as seen in these lines of code:

return self.train_loop.iter

and

return self._train_loop

runner.iter actually represents the training iteration, and the construction of the train_loop should be avoided during validation or testing phases. The current design mechanism is somewhat implicit, which could lead to confusion. Recently, I also created a pull request (#1107) to fix this kind of bug 🤣 .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest assigning iter like this:

if isinstance(runner._train_loop, dict) or runner._train_loop is None:
     tag['iter'] = 0
else:
    tag['iter'] = runner.iter + 1

# Concatenate lr, momentum string with log header.
log_str += f'{lr_str} '
# If IterTimerHook used in runner, eta, time, and data_time should be
Expand Down
22 changes: 17 additions & 5 deletions mmengine/visualization/vis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,17 @@ class WandbVisBackend(BaseVisBackend):
input parameters.
See `wandb.init <https://docs.wandb.ai/ref/python/init>`_ for
details. Defaults to None.
define_metric_cfg (dict, optional):
A dict of metrics and summary for wandb.define_metric.
define_metric_cfg (dict | list[dict], optional):
When a dict is set, it is a dict of metrics and summary for
wandb.define_metric.
i-aki-y marked this conversation as resolved.
Show resolved Hide resolved
The key is metric and the value is summary.
When a list is set, each dict should be a valid argument of
the ``define_metric``.
When ``define_metric_cfg={'coco/bbox_mAP': 'max'}``,
The maximum value of ``coco/bbox_mAP`` is logged on wandb UI.
i-aki-y marked this conversation as resolved.
Show resolved Hide resolved
When ``define_metric_cfg=[dict(name="loss",
step_metric='epoch')]``,
i-aki-y marked this conversation as resolved.
Show resolved Hide resolved
the "loss” will be plotted against the epoch.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
See `wandb define_metric <https://docs.wandb.ai/ref/python/
run#define_metric>`_ for details.
Default: None
Expand All @@ -373,7 +379,7 @@ class WandbVisBackend(BaseVisBackend):
def __init__(self,
save_dir: str,
init_kwargs: Optional[dict] = None,
define_metric_cfg: Optional[dict] = None,
define_metric_cfg: Optional[Union[dict, list]] = None,
i-aki-y marked this conversation as resolved.
Show resolved Hide resolved
commit: Optional[bool] = True,
log_code_name: Optional[str] = None,
watch_kwargs: Optional[dict] = None):
Expand All @@ -400,8 +406,14 @@ def _init_env(self):

wandb.init(**self._init_kwargs)
if self._define_metric_cfg is not None:
for metric, summary in self._define_metric_cfg.items():
wandb.define_metric(metric, summary=summary)
if isinstance(self._define_metric_cfg, dict):
for metric, summary in self._define_metric_cfg.items():
wandb.define_metric(metric, summary=summary)
elif isinstance(self._define_metric_cfg, list):
for metric_cfg in self._define_metric_cfg:
wandb.define_metric(**metric_cfg)
else:
raise ValueError('define_metric_cfg should be dict or list')
self._wandb = wandb

@property # type: ignore
Expand Down
9 changes: 8 additions & 1 deletion tests/test_runner/test_log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
log_processor._collect_scalars = \
lambda *args, **kwargs: copy.deepcopy(train_logs)
_, out = log_processor.get_log_after_iter(self.runner, 1, mode)
tag, out = log_processor.get_log_after_iter(self.runner, 1, mode)

# Verify that the correct context have been logged.
cur_loop = log_processor._get_cur_loop(self.runner, mode)
if by_epoch:
Expand All @@ -117,6 +118,9 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
if mode == 'train':
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
assert out == log_str

if mode in ['train', 'val']:
assert 'epoch' in tag
else:
if mode == 'train':
max_iters = self.runner.max_iters
Expand Down Expand Up @@ -144,6 +148,9 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
assert out == log_str

# tag always has "iter" key
assert 'iter' in tag

@parameterized.expand(
([True, 'val', True], [True, 'val', False], [False, 'val', True],
[False, 'val', False], [True, 'test', True], [False, 'test', False]))
Expand Down
24 changes: 24 additions & 0 deletions tests/test_visualizer/test_vis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,30 @@ def test_close(self):
wandb_vis_backend.close()
shutil.rmtree('temp_dir')

def test_define_metric_cfg(self):
# list of dict
define_metric_cfg = [
dict(name='test1', step_metric='iter'),
dict(name='test2', step_metric='epoch'),
]
wandb_vis_backend = WandbVisBackend(
'temp_dir', define_metric_cfg=define_metric_cfg)
wandb_vis_backend._init_env()
wandb_vis_backend._wandb.define_metric.assert_any_call(
name='test1', step_metric='iter')
wandb_vis_backend._wandb.define_metric.assert_any_call(
name='test2', step_metric='epoch')

# dict
define_metric_cfg = dict(test3='max')
wandb_vis_backend = WandbVisBackend(
'temp_dir', define_metric_cfg=define_metric_cfg)
wandb_vis_backend._init_env()
wandb_vis_backend._wandb.define_metric.assert_any_call(
'test3', summary='max')

shutil.rmtree('temp_dir')


class TestMLflowVisBackend:

Expand Down