Skip to content

Commit

Permalink
[Enhancement] Support writing data to vis_backend with prefix (#972)
Browse files Browse the repository at this point in the history
* Log with prefix

* Fix test of loggerhook

* minor refine

* minor refine

* Fix unit test

* clean the code

* deepcopy in method

* replace regex

* Fix as comment

* Enhance readable

* rename reserve_prefix to remove_prefix

* Fix as comment

* Refine unit test

* Adjust sequence

* clean the code

* clean the code

* revert renaming reserve prefix

* Count the dataloader length in _get_dataloader_size
  • Loading branch information
HAOCHENYE committed Mar 13, 2023
1 parent 0d25625 commit 8063d2c
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 64 deletions.
8 changes: 1 addition & 7 deletions mmengine/hooks/logger_hook.py
Expand Up @@ -234,14 +234,8 @@ def after_val_epoch(self,
runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str)
if self.log_metric_by_epoch:
# when `log_metric_by_epoch` is set to True, it's expected
# that validation metric can be logged by epoch rather than
# by iter. At the same time, scalars related to time should
# still be logged by iter to avoid messy visualized result.
# see details in PR #278.
metric_tags = {k: v for k, v in tag.items() if 'time' not in k}
runner.visualizer.add_scalars(
metric_tags, step=runner.epoch, file_path=self.json_log_path)
tag, step=runner.epoch, file_path=self.json_log_path)
else:
runner.visualizer.add_scalars(
tag, step=runner.iter, file_path=self.json_log_path)
Expand Down
148 changes: 116 additions & 32 deletions mmengine/runner/log_processor.py
Expand Up @@ -52,6 +52,13 @@ class LogProcessor:
`epoch` to statistics log value by epoch.
num_digits (int): The number of significant digit shown in the
logging message.
log_with_hierarchy (bool): Whether to log with hierarchy. If it is
True, the information is written to visualizer backend such as
:obj:`LocalVisBackend` and :obj:`TensorboardBackend`
with hierarchy. For example, ``loss`` will be saved as
``train/loss``, and accuracy will be saved as ``val/accuracy``.
Defaults to False.
`New in version 0.7.0.`
Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional
Expand Down Expand Up @@ -98,11 +105,13 @@ def __init__(self,
window_size=10,
by_epoch=True,
custom_cfg: Optional[List[dict]] = None,
num_digits: int = 4):
num_digits: int = 4,
log_with_hierarchy: bool = False):
self.window_size = window_size
self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else []
self.num_digits = num_digits
self.log_with_hierarchy = log_with_hierarchy
self._check_custom_cfg()

def get_log_after_iter(self, runner, batch_idx: int,
Expand All @@ -120,18 +129,26 @@ def get_log_after_iter(self, runner, batch_idx: int,
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in ['train', 'test', 'val']
current_loop = self._get_cur_loop(runner, mode)
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
# `log_tag` will pop 'lr' and loop other keys to `log_str`.
log_tag = copy.deepcopy(tag)
parsed_cfg = self._parse_windows_size(runner, batch_idx,
self.custom_cfg)
# log_tag is used to write log information to terminal
# If `self.log_with_hierarchy` is False, the tag is the same as
# log_tag. Otherwise, each key in tag starts with prefix `train`,
# `test` or `val`
log_tag = self._collect_scalars(parsed_cfg, runner, mode)

if not self.log_with_hierarchy:
tag = copy.deepcopy(log_tag)
else:
tag = self._collect_scalars(parsed_cfg, runner, mode, True)

# Record learning rate.
lr_str_list = []
for key, value in tag.items():
if key.endswith('lr'):
key = self._remove_prefix(key, f'{mode}/')
log_tag.pop(key)
lr_str_list.append(f'{key}: '
f'{value:.{self.num_digits}e}')
Expand All @@ -148,7 +165,7 @@ def get_log_after_iter(self, runner, batch_idx: int,
# Epoch(train) [ 9][010/270]
# ... ||| |||
# Epoch(train) [ 10][100/270]
dataloader_len = len(current_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))

if mode in ['train', 'val']:
Expand All @@ -174,23 +191,22 @@ def get_log_after_iter(self, runner, batch_idx: int,
log_str = (f'Iter({mode}) '
f'[{cur_iter_str}/{runner.max_iters}] ')
else:
dataloader_len = len(current_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)
cur_iter_str = str(batch_idx + 1).rjust(
len(str(dataloader_len)))
log_str = (f'Iter({mode}) [{cur_iter_str}'
f'/{len(current_loop.dataloader)}] ')
log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ')
# Concatenate lr, momentum string with log header.
log_str += f'{lr_str} '
# If IterTimerHook used in runner, eta, time, and data_time should be
# recorded.
if (all(item in tag for item in ['time', 'data_time'])
if (all(item in log_tag for item in ['time', 'data_time'])
and 'eta' in runner.message_hub.runtime_info):
eta = runner.message_hub.get_info('eta')
eta_str = str(datetime.timedelta(seconds=int(eta)))
log_str += f'eta: {eta_str} '
log_str += (f'time: {tag["time"]:.{self.num_digits}f} '
log_str += (f'time: {log_tag["time"]:.{self.num_digits}f} '
f'data_time: '
f'{tag["data_time"]:.{self.num_digits}f} ')
f'{log_tag["data_time"]:.{self.num_digits}f} ')
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
Expand Down Expand Up @@ -235,15 +251,8 @@ def get_log_after_epoch(self,
'test', 'val'
], ('`_get_metric_log_str` only accept val or test mode, but got '
f'{mode}')
cur_loop = self._get_cur_loop(runner, mode)
dataloader_len = len(cur_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)

custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
non_scalar_tag = self._collect_non_scalars(runner, mode)
tag.pop('time', None)
tag.pop('data_time', None)
# By epoch:
# Epoch(val) [10][1000/1000] ...
# Epoch(test) [1000/1000] ...
Expand All @@ -261,8 +270,42 @@ def get_log_after_epoch(self,

else:
log_str = (f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
# `time` and `data_time` will not be recorded in after epoch log
# message.

custom_cfg_copy = copy.deepcopy(self.custom_cfg)
# remove prefix
custom_keys = [
self._remove_prefix(cfg['data_src'], f'{mode}/')
for cfg in custom_cfg_copy
]
# Count the averaged time and data_time by epoch
if 'time' not in custom_keys:
custom_cfg_copy.append(
dict(
data_src=f'{mode}/time',
window_size='epoch',
method_name='mean'))
if 'data_time' not in custom_keys:
custom_cfg_copy.append(
dict(
data_src=f'{mode}/data_time',
window_size='epoch',
method_name='mean'))
parsed_cfg = self._parse_windows_size(runner, batch_idx,
custom_cfg_copy)
# tag is used to write log information to different backends.
ori_tag = self._collect_scalars(parsed_cfg, runner, mode,
self.log_with_hierarchy)
non_scalar_tag = self._collect_non_scalars(runner, mode)
# move `time` or `data_time` to the end of the log
tag = OrderedDict()
time_tag = OrderedDict()
for key, value in ori_tag.items():
if key in (f'{mode}/time', f'{mode}/data_time', 'time',
'data_time'):
time_tag[key] = value
else:
tag[key] = value
# Log other messages.
log_items = []
for name, val in chain(tag.items(), non_scalar_tag.items()):
if isinstance(val, float):
Expand All @@ -273,12 +316,19 @@ def get_log_after_epoch(self,
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)

for name, val in time_tag.items():
log_str += f'{name}: {val:.{self.num_digits}f} '

if with_non_scalar:
tag.update(non_scalar_tag)
tag.update(time_tag)
return tag, log_str

def _collect_scalars(self, custom_cfg: List[dict], runner,
mode: str) -> dict:
def _collect_scalars(self,
custom_cfg: List[dict],
runner,
mode: str,
reserve_prefix: bool = False) -> dict:
"""Collect log information to compose a dict according to mode.
Args:
Expand All @@ -287,6 +337,7 @@ def _collect_scalars(self, custom_cfg: List[dict], runner,
runner (Runner): The runner of the training/testing/validation
process.
mode (str): Current mode of runner.
reserve_prefix (bool): Whether to reserve the prefix of the key.
Returns:
dict: Statistical values of logs.
Expand All @@ -300,7 +351,10 @@ def _collect_scalars(self, custom_cfg: List[dict], runner,
# according to mode.
for prefix_key, log_buffer in history_scalars.items():
if prefix_key.startswith(mode):
key = prefix_key.partition('/')[-1]
if not reserve_prefix:
key = self._remove_prefix(prefix_key, f'{mode}/')
else:
key = prefix_key
mode_history_scalars[key] = log_buffer
for key in mode_history_scalars:
# Update the latest learning rate and smoothed time logs.
Expand Down Expand Up @@ -341,10 +395,20 @@ def _collect_non_scalars(self, runner, mode: str) -> dict:
# extract log info and remove prefix to `mode_infos` according to mode.
for prefix_key, value in infos.items():
if prefix_key.startswith(mode):
key = prefix_key.partition('/')[-1]
if self.log_with_hierarchy:
key = prefix_key
else:
key = self._remove_prefix(prefix_key, f'{mode}/')
mode_infos[key] = value
return mode_infos

def _remove_prefix(self, string: str, prefix: str):
"""Remove the prefix ``train``, ``val`` and ``test`` of the key."""
if string.startswith(prefix):
return string[len(prefix):]
else:
return string

def _check_custom_cfg(self) -> None:
"""Check the legality of ``self.custom_cfg``."""

Expand Down Expand Up @@ -377,16 +441,24 @@ def _check_repeated_log_name():
_check_repeated_log_name()
_check_window_size()

def _parse_windows_size(self, runner, batch_idx: int) -> list:
def _parse_windows_size(self,
runner,
batch_idx: int,
custom_cfg: Optional[list] = None) -> list:
"""Parse window_size defined in custom_cfg to int value.
Args:
runner (Runner): The runner of the training/testing/validation
process.
batch_idx (int): The iteration index of current dataloader.
custom_cfg (list): A copy of ``self.custom_cfg``. Defaults to None
to keep backward compatibility.
"""
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
for log_cfg in custom_cfg_copy:
if custom_cfg is None:
custom_cfg = copy.deepcopy(self.custom_cfg)
else:
custom_cfg = copy.deepcopy(custom_cfg)
for log_cfg in custom_cfg:
window_size = log_cfg.get('window_size', None)
if window_size is None or isinstance(window_size, int):
continue
Expand All @@ -398,7 +470,7 @@ def _parse_windows_size(self, runner, batch_idx: int) -> list:
raise TypeError(
'window_size should be int, epoch or global, but got '
f'invalid {window_size}')
return custom_cfg_copy
return custom_cfg

def _get_max_memory(self, runner) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
Expand Down Expand Up @@ -474,3 +546,15 @@ def _get_cur_loop(self, runner, mode: str):
return runner.val_loop
else:
return runner.test_loop

def _get_dataloader_size(self, runner, mode) -> int:
"""Get dataloader size of current loop.
Args:
runner (Runner): The runner of the training/validation/testing
mode (str): Current mode of runner.
Returns:
int: The dataloader size of current loop.
"""
return len(self._get_cur_loop(runner=runner, mode=mode).dataloader)
7 changes: 6 additions & 1 deletion tests/test_hooks/test_logger_hook.py
Expand Up @@ -147,7 +147,12 @@ def test_after_val_epoch(self):
logger_hook.after_val_epoch(runner)
args = {'step': ANY, 'file_path': ANY}
# expect visualizer log `time` and `metric` respectively
runner.visualizer.add_scalars.assert_called_with({'acc': 0.8}, **args)
runner.visualizer.add_scalars.assert_called_with(
{
'time': 1,
'datatime': 1,
'acc': 0.8
}, **args)

# Test when `log_metric_by_epoch` is False
logger_hook = LoggerHook(log_metric_by_epoch=False)
Expand Down

0 comments on commit 8063d2c

Please sign in to comment.