Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support writing data to vis_backend with prefix #972

Merged
merged 18 commits into from Mar 13, 2023
8 changes: 1 addition & 7 deletions mmengine/hooks/logger_hook.py
Expand Up @@ -231,14 +231,8 @@ def after_val_epoch(self,
runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str)
if self.log_metric_by_epoch:
# when `log_metric_by_epoch` is set to True, it's expected
# that validation metric can be logged by epoch rather than
# by iter. At the same time, scalars related to time should
# still be logged by iter to avoid messy visualized result.
# see details in PR #278.
metric_tags = {k: v for k, v in tag.items() if 'time' not in k}
runner.visualizer.add_scalars(
metric_tags, step=runner.epoch, file_path=self.json_log_path)
tag, step=runner.epoch, file_path=self.json_log_path)
else:
runner.visualizer.add_scalars(
tag, step=runner.iter, file_path=self.json_log_path)
Expand Down
148 changes: 116 additions & 32 deletions mmengine/runner/log_processor.py
Expand Up @@ -52,6 +52,13 @@ class LogProcessor:
`epoch` to statistics log value by epoch.
num_digits (int): The number of significant digit shown in the
logging message.
log_with_hierarchy (bool): Whether to log with hierarchy. If it is
True, the information is written to visualizer backend such as
:obj:`LocalVisBackend` and :obj:`TensorboardBackend`
with hierarchy. For example, ``loss`` will be saved as
``train/loss``, and accuracy will be saved as ``val/accuracy``.
Defaults to False.
`New in version 0.7.0.`

Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional
Expand Down Expand Up @@ -98,11 +105,13 @@ def __init__(self,
window_size=10,
by_epoch=True,
custom_cfg: Optional[List[dict]] = None,
num_digits: int = 4):
num_digits: int = 4,
log_with_hierarchy: bool = False):
self.window_size = window_size
self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else []
self.num_digits = num_digits
self.log_with_hierarchy = log_with_hierarchy
self._check_custom_cfg()

def get_log_after_iter(self, runner, batch_idx: int,
Expand All @@ -120,18 +129,26 @@ def get_log_after_iter(self, runner, batch_idx: int,
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in ['train', 'test', 'val']
current_loop = self._get_cur_loop(runner, mode)
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
# `log_tag` will pop 'lr' and loop other keys to `log_str`.
log_tag = copy.deepcopy(tag)
parsed_cfg = self._parse_windows_size(runner, batch_idx,
self.custom_cfg)
# log_tag is used to write log information to terminal
# If `self.log_with_hierarchy` is False, the tag is the same as
# log_tag. Otherwise, each key in tag starts with prefix `train`,
# `test` or `val`
log_tag = self._collect_scalars(parsed_cfg, runner, mode)

if not self.log_with_hierarchy:
tag = copy.deepcopy(log_tag)
else:
tag = self._collect_scalars(parsed_cfg, runner, mode, True)

# Record learning rate.
lr_str_list = []
for key, value in tag.items():
if key.endswith('lr'):
key = self._remove_prefix(key, f'{mode}/')
log_tag.pop(key)
lr_str_list.append(f'{key}: '
f'{value:.{self.num_digits}e}')
Expand All @@ -148,7 +165,7 @@ def get_log_after_iter(self, runner, batch_idx: int,
# Epoch(train) [ 9][010/270]
# ... ||| |||
# Epoch(train) [ 10][100/270]
dataloader_len = len(current_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))

if mode in ['train', 'val']:
Expand All @@ -174,23 +191,22 @@ def get_log_after_iter(self, runner, batch_idx: int,
log_str = (f'Iter({mode}) '
f'[{cur_iter_str}/{runner.max_iters}] ')
else:
dataloader_len = len(current_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)
cur_iter_str = str(batch_idx + 1).rjust(
len(str(dataloader_len)))
log_str = (f'Iter({mode}) [{cur_iter_str}'
f'/{len(current_loop.dataloader)}] ')
log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ')
# Concatenate lr, momentum string with log header.
log_str += f'{lr_str} '
# If IterTimerHook used in runner, eta, time, and data_time should be
# recorded.
if (all(item in tag for item in ['time', 'data_time'])
if (all(item in log_tag for item in ['time', 'data_time'])
and 'eta' in runner.message_hub.runtime_info):
eta = runner.message_hub.get_info('eta')
eta_str = str(datetime.timedelta(seconds=int(eta)))
log_str += f'eta: {eta_str} '
log_str += (f'time: {tag["time"]:.{self.num_digits}f} '
log_str += (f'time: {log_tag["time"]:.{self.num_digits}f} '
f'data_time: '
f'{tag["data_time"]:.{self.num_digits}f} ')
f'{log_tag["data_time"]:.{self.num_digits}f} ')
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
Expand Down Expand Up @@ -233,15 +249,8 @@ def get_log_after_epoch(self,
'test', 'val'
], ('`_get_metric_log_str` only accept val or test mode, but got '
f'{mode}')
cur_loop = self._get_cur_loop(runner, mode)
dataloader_len = len(cur_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)

custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
non_scalar_tag = self._collect_non_scalars(runner, mode)
tag.pop('time', None)
tag.pop('data_time', None)
# By epoch:
# Epoch(val) [10][1000/1000] ...
# Epoch(test) [1000/1000] ...
Expand All @@ -259,8 +268,42 @@ def get_log_after_epoch(self,

else:
log_str = (f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
# `time` and `data_time` will not be recorded in after epoch log
# message.

custom_cfg_copy = copy.deepcopy(self.custom_cfg)
# remove prefix
custom_keys = [
self._remove_prefix(cfg['data_src'], f'{mode}/')
for cfg in custom_cfg_copy
]
# Count the averaged time and data_time by epoch
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved
if 'time' not in custom_keys:
custom_cfg_copy.append(
dict(
data_src=f'{mode}/time',
window_size='epoch',
method_name='mean'))
if 'data_time' not in custom_keys:
custom_cfg_copy.append(
dict(
data_src=f'{mode}/data_time',
window_size='epoch',
method_name='mean'))
parsed_cfg = self._parse_windows_size(runner, batch_idx,
custom_cfg_copy)
# tag is used to write log information to different backends.
ori_tag = self._collect_scalars(parsed_cfg, runner, mode,
self.log_with_hierarchy)
non_scalar_tag = self._collect_non_scalars(runner, mode)
# move `time` or `data_time` to the end of the log
tag = OrderedDict()
time_tag = OrderedDict()
for key, value in ori_tag.items():
if key in (f'{mode}/time', f'{mode}/data_time', 'time',
'data_time'):
time_tag[key] = value
else:
tag[key] = value
# Log other messages.
log_items = []
for name, val in chain(tag.items(), non_scalar_tag.items()):
if isinstance(val, float):
Expand All @@ -271,12 +314,19 @@ def get_log_after_epoch(self,
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)

for name, val in time_tag.items():
log_str += f'{name}: {val:.{self.num_digits}f} '

if with_non_scalar:
tag.update(non_scalar_tag)
tag.update(time_tag)
return tag, log_str

def _collect_scalars(self, custom_cfg: List[dict], runner,
mode: str) -> dict:
def _collect_scalars(self,
custom_cfg: List[dict],
runner,
mode: str,
reserve_prefix: bool = False) -> dict:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
"""Collect log information to compose a dict according to mode.

Args:
Expand All @@ -285,6 +335,7 @@ def _collect_scalars(self, custom_cfg: List[dict], runner,
runner (Runner): The runner of the training/testing/validation
process.
mode (str): Current mode of runner.
reserve_prefix (bool): Whether to reserve the prefix of the key.

Returns:
dict: Statistical values of logs.
Expand All @@ -298,7 +349,10 @@ def _collect_scalars(self, custom_cfg: List[dict], runner,
# according to mode.
for prefix_key, log_buffer in history_scalars.items():
if prefix_key.startswith(mode):
key = prefix_key.partition('/')[-1]
if not reserve_prefix:
key = self._remove_prefix(prefix_key, f'{mode}/')
else:
key = prefix_key
mode_history_scalars[key] = log_buffer
for key in mode_history_scalars:
# Update the latest learning rate and smoothed time logs.
Expand Down Expand Up @@ -339,10 +393,20 @@ def _collect_non_scalars(self, runner, mode: str) -> dict:
# extract log info and remove prefix to `mode_infos` according to mode.
for prefix_key, value in infos.items():
if prefix_key.startswith(mode):
key = prefix_key.partition('/')[-1]
if self.log_with_hierarchy:
key = prefix_key
else:
key = self._remove_prefix(prefix_key, f'{mode}/')
mode_infos[key] = value
return mode_infos

def _remove_prefix(self, string: str, prefix: str):
"""Remove the prefix ``train``, ``val`` and ``test`` of the key."""
if string.startswith(prefix):
return string[len(prefix):]
else:
return string

def _check_custom_cfg(self) -> None:
"""Check the legality of ``self.custom_cfg``."""

Expand Down Expand Up @@ -375,16 +439,24 @@ def _check_repeated_log_name():
_check_repeated_log_name()
_check_window_size()

def _parse_windows_size(self, runner, batch_idx: int) -> list:
def _parse_windows_size(self,
runner,
batch_idx: int,
custom_cfg: Optional[list] = None) -> list:
"""Parse window_size defined in custom_cfg to int value.

Args:
runner (Runner): The runner of the training/testing/validation
process.
batch_idx (int): The iteration index of current dataloader.
custom_cfg (list): A copy of ``self.custom_cfg``. Defaults to None
to keep backward compatibility.
"""
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
for log_cfg in custom_cfg_copy:
if custom_cfg is None:
custom_cfg = copy.deepcopy(self.custom_cfg)
else:
custom_cfg = copy.deepcopy(custom_cfg)
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved
for log_cfg in custom_cfg:
window_size = log_cfg.get('window_size', None)
if window_size is None or isinstance(window_size, int):
continue
Expand All @@ -396,7 +468,7 @@ def _parse_windows_size(self, runner, batch_idx: int) -> list:
raise TypeError(
'window_size should be int, epoch or global, but got '
f'invalid {window_size}')
return custom_cfg_copy
return custom_cfg

def _get_max_memory(self, runner) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
Expand Down Expand Up @@ -472,3 +544,15 @@ def _get_cur_loop(self, runner, mode: str):
return runner.val_loop
else:
return runner.test_loop

def _get_dataloader_size(self, runner, mode) -> int:
"""Get dataloader size of current loop.

Args:
runner (Runner): The runner of the training/validation/testing
mode (str): Current mode of runner.

Returns:
int: The dataloader size of current loop.
"""
return len(self._get_cur_loop(runner=runner, mode=mode).dataloader)
7 changes: 6 additions & 1 deletion tests/test_hooks/test_logger_hook.py
Expand Up @@ -146,7 +146,12 @@ def test_after_val_epoch(self):
logger_hook.after_val_epoch(runner)
args = {'step': ANY, 'file_path': ANY}
# expect visualizer log `time` and `metric` respectively
runner.visualizer.add_scalars.assert_called_with({'acc': 0.8}, **args)
runner.visualizer.add_scalars.assert_called_with(
{
'time': 1,
'datatime': 1,
'acc': 0.8
}, **args)

# Test when `log_metric_by_epoch` is False
logger_hook = LoggerHook(log_metric_by_epoch=False)
Expand Down