Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support writing data to vis_backend with prefix #972

Merged
merged 18 commits into from Mar 13, 2023
8 changes: 1 addition & 7 deletions mmengine/hooks/logger_hook.py
Expand Up @@ -231,14 +231,8 @@ def after_val_epoch(self,
runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str)
if self.log_metric_by_epoch:
# when `log_metric_by_epoch` is set to True, it's expected
# that validation metric can be logged by epoch rather than
# by iter. At the same time, scalars related to time should
# still be logged by iter to avoid messy visualized result.
# see details in PR #278.
metric_tags = {k: v for k, v in tag.items() if 'time' not in k}
runner.visualizer.add_scalars(
metric_tags, step=runner.epoch, file_path=self.json_log_path)
tag, step=runner.epoch, file_path=self.json_log_path)
else:
runner.visualizer.add_scalars(
tag, step=runner.iter, file_path=self.json_log_path)
Expand Down
121 changes: 100 additions & 21 deletions mmengine/runner/log_processor.py
Expand Up @@ -52,6 +52,13 @@ class LogProcessor:
`epoch` to statistics log value by epoch.
num_digits (int): The number of significant digit shown in the
logging message.
log_with_hierarchy (bool): Whether to log with hierarchy. If it is
True, the information is written to visualizer backend such as
:obj:`LocalVisBackend` and :obj:`TensorboardBackend`
with hierarchy. For example, ``loss`` will be saved as
``train/loss``, and accuracy will be saved as ``val/accuracy``.
Defaults to False.
`New in version 0.7.0.`

Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional
Expand Down Expand Up @@ -98,11 +105,13 @@ def __init__(self,
window_size=10,
by_epoch=True,
custom_cfg: Optional[List[dict]] = None,
num_digits: int = 4):
num_digits: int = 4,
log_with_hierarchy: bool = False):
self.window_size = window_size
self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else []
self.num_digits = num_digits
self.log_with_hierarchy = log_with_hierarchy
self._check_custom_cfg()

def get_log_after_iter(self, runner, batch_idx: int,
Expand All @@ -123,15 +132,24 @@ def get_log_after_iter(self, runner, batch_idx: int,
current_loop = self._get_cur_loop(runner, mode)
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
# `log_tag` will pop 'lr' and loop other keys to `log_str`.
log_tag = copy.deepcopy(tag)
parsed_cfg = self._parse_windows_size(runner, batch_idx,
self.custom_cfg)
# log_tag is used to write log information to terminal
# If `self.log_with_hierarchy` is False, the tag is the same as
# log_tag. Otherwise, each key in tag starts with prefix `train`,
# `test` or `val`
log_tag = self._collect_scalars(parsed_cfg, runner, mode)

if not self.log_with_hierarchy:
tag = copy.deepcopy(log_tag)
else:
tag = self._collect_scalars(parsed_cfg, runner, mode, False)

# Record learning rate.
lr_str_list = []
for key, value in tag.items():
if key.endswith('lr'):
key = self._remove_prefix(key, f'{mode}/')
log_tag.pop(key)
lr_str_list.append(f'{key}: '
f'{value:.{self.num_digits}e}')
Expand Down Expand Up @@ -183,14 +201,14 @@ def get_log_after_iter(self, runner, batch_idx: int,
log_str += f'{lr_str} '
# If IterTimerHook used in runner, eta, time, and data_time should be
# recorded.
if (all(item in tag for item in ['time', 'data_time'])
if (all(item in log_tag for item in ['time', 'data_time'])
and 'eta' in runner.message_hub.runtime_info):
eta = runner.message_hub.get_info('eta')
eta_str = str(datetime.timedelta(seconds=int(eta)))
log_str += f'eta: {eta_str} '
log_str += (f'time: {tag["time"]:.{self.num_digits}f} '
log_str += (f'time: {log_tag["time"]:.{self.num_digits}f} '
f'data_time: '
f'{tag["data_time"]:.{self.num_digits}f} ')
f'{log_tag["data_time"]:.{self.num_digits}f} ')
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
Expand Down Expand Up @@ -236,12 +254,42 @@ def get_log_after_epoch(self,
cur_loop = self._get_cur_loop(runner, mode)
dataloader_len = len(cur_loop.dataloader)

custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
# remove prefix
custom_keys = [
self._remove_prefix(cfg['data_src'], f'{mode}/')
for cfg in custom_cfg_copy
]

# Count the averaged time and data_time by epoch
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved
if 'time' not in custom_keys:
custom_cfg_copy.append(
dict(
data_src=f'{mode}/time',
window_size='epoch',
method_name='mean'))
if 'data_time' not in custom_keys:
custom_cfg_copy.append(
dict(
data_src=f'{mode}/data_time',
window_size='epoch',
method_name='mean'))
parsed_cfg = self._parse_windows_size(runner, batch_idx,
custom_cfg_copy)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
ori_tag = self._collect_scalars(parsed_cfg, runner, mode,
not self.log_with_hierarchy)
non_scalar_tag = self._collect_non_scalars(runner, mode)
tag.pop('time', None)
tag.pop('data_time', None)
# move `time` or `data_time` to the end of the log
tag = OrderedDict()
time_tag = OrderedDict()
# pop `time` and `data_time` to log them at last.
for key, value in ori_tag.items():
if 'time' not in key:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
tag[key] = value
else:
time_tag[key] = value

# By epoch:
# Epoch(val) [10][1000/1000] ...
# Epoch(test) [1000/1000] ...
Expand All @@ -263,6 +311,8 @@ def get_log_after_epoch(self,
# message.
log_items = []
for name, val in chain(tag.items(), non_scalar_tag.items()):
if name in ('time', 'data_time'):
name = self._remove_prefix(name, f'{mode}/')
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(val, float):
val = f'{val:.{self.num_digits}f}'
if isinstance(val, (torch.Tensor, np.ndarray)):
Expand All @@ -271,12 +321,19 @@ def get_log_after_epoch(self,
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)

for name, val in time_tag.items():
log_str += f'{name}: {val:.{self.num_digits}f} '

if with_non_scalar:
tag.update(non_scalar_tag)
tag.update(time_tag)
return tag, log_str

def _collect_scalars(self, custom_cfg: List[dict], runner,
mode: str) -> dict:
def _collect_scalars(self,
custom_cfg: List[dict],
runner,
mode: str,
remove_prefix: bool = True) -> dict:
"""Collect log information to compose a dict according to mode.

Args:
Expand All @@ -285,6 +342,7 @@ def _collect_scalars(self, custom_cfg: List[dict], runner,
runner (Runner): The runner of the training/testing/validation
process.
mode (str): Current mode of runner.
remove_prefix (bool): Whether to remove the prefix of the key.

Returns:
dict: Statistical values of logs.
Expand All @@ -298,7 +356,10 @@ def _collect_scalars(self, custom_cfg: List[dict], runner,
# according to mode.
for prefix_key, log_buffer in history_scalars.items():
if prefix_key.startswith(mode):
key = prefix_key.partition('/')[-1]
if remove_prefix:
key = self._remove_prefix(prefix_key, f'{mode}/')
else:
key = prefix_key
mode_history_scalars[key] = log_buffer
for key in mode_history_scalars:
# Update the latest learning rate and smoothed time logs.
Expand Down Expand Up @@ -339,10 +400,20 @@ def _collect_non_scalars(self, runner, mode: str) -> dict:
# extract log info and remove prefix to `mode_infos` according to mode.
for prefix_key, value in infos.items():
if prefix_key.startswith(mode):
key = prefix_key.partition('/')[-1]
if self.log_with_hierarchy:
key = prefix_key
else:
key = self._remove_prefix(prefix_key, f'{mode}/')
mode_infos[key] = value
return mode_infos

def _remove_prefix(self, string: str, prefix: str):
"""Remove the prefix ``train``, ``val`` and ``test`` of the key."""
if string.startswith(prefix):
return string[len(prefix):]
else:
return string

def _check_custom_cfg(self) -> None:
"""Check the legality of ``self.custom_cfg``."""

Expand Down Expand Up @@ -375,16 +446,24 @@ def _check_repeated_log_name():
_check_repeated_log_name()
_check_window_size()

def _parse_windows_size(self, runner, batch_idx: int) -> list:
def _parse_windows_size(self,
runner,
batch_idx: int,
custom_cfg: Optional[list] = None) -> list:
"""Parse window_size defined in custom_cfg to int value.

Args:
runner (Runner): The runner of the training/testing/validation
process.
batch_idx (int): The iteration index of current dataloader.
custom_cfg (list): A copy of ``self.custom_cfg``. Defaults to None
to keep backward compatibility.
"""
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
for log_cfg in custom_cfg_copy:
if custom_cfg is None:
custom_cfg = copy.deepcopy(self.custom_cfg)
else:
custom_cfg = copy.deepcopy(custom_cfg)
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved
for log_cfg in custom_cfg:
window_size = log_cfg.get('window_size', None)
if window_size is None or isinstance(window_size, int):
continue
Expand All @@ -396,7 +475,7 @@ def _parse_windows_size(self, runner, batch_idx: int) -> list:
raise TypeError(
'window_size should be int, epoch or global, but got '
f'invalid {window_size}')
return custom_cfg_copy
return custom_cfg

def _get_max_memory(self, runner) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_hooks/test_logger_hook.py
Expand Up @@ -146,7 +146,12 @@ def test_after_val_epoch(self):
logger_hook.after_val_epoch(runner)
args = {'step': ANY, 'file_path': ANY}
# expect visualizer log `time` and `metric` respectively
runner.visualizer.add_scalars.assert_called_with({'acc': 0.8}, **args)
runner.visualizer.add_scalars.assert_called_with(
{
'time': 1,
'datatime': 1,
'acc': 0.8
}, **args)

# Test when `log_metric_by_epoch` is False
logger_hook = LoggerHook(log_metric_by_epoch=False)
Expand Down
30 changes: 15 additions & 15 deletions tests/test_runner/test_log_processor.py
Expand Up @@ -47,28 +47,27 @@ def test_check_custom_cfg(self):
def test_parse_windows_size(self):
log_processor = LogProcessor()
# Test parse 'epoch' window_size.
log_processor.custom_cfg = [
dict(data_src='loss_cls', window_size='epoch')
]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
custom_cfg = [dict(data_src='loss_cls', window_size='epoch')]
custom_cfg = log_processor._parse_windows_size(self.runner, 1,
custom_cfg)
assert custom_cfg[0]['window_size'] == 2

# Test parse 'global' window_size.
log_processor.custom_cfg = [
dict(data_src='loss_cls', window_size='global')
]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
custom_cfg = [dict(data_src='loss_cls', window_size='global')]
custom_cfg = log_processor._parse_windows_size(self.runner, 1,
custom_cfg)
assert custom_cfg[0]['window_size'] == 11

# Test parse int window_size
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
custom_cfg = [dict(data_src='loss_cls', window_size=100)]
custom_cfg = log_processor._parse_windows_size(self.runner, 1,
custom_cfg)
assert custom_cfg[0]['window_size'] == 100

# Invalid type window_size will raise TypeError.
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])]
custom_cfg = [dict(data_src='loss_cls', window_size=[])]
with pytest.raises(TypeError):
log_processor._parse_windows_size(custom_cfg, self.runner)
log_processor._parse_windows_size(self.runner, 1, custom_cfg)

@pytest.mark.parametrize('by_epoch,mode',
([True, 'train'], [False, 'train'], [True, 'val'],
Expand All @@ -84,8 +83,9 @@ def test_get_log_after_iter(self, by_epoch, mode):
train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0)
else:
train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
log_processor._collect_scalars = MagicMock(return_value=train_logs)
tag, out = log_processor.get_log_after_iter(self.runner, 1, mode)
log_processor._collect_scalars = \
lambda *args, **kwargs: copy.deepcopy(train_logs)
_, out = log_processor.get_log_after_iter(self.runner, 1, mode)
# Verify that the correct context have been logged.
cur_loop = log_processor._get_cur_loop(self.runner, mode)
if by_epoch:
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_log_val(self, by_epoch, mode):
return_value=non_scalar_logs)
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
expect_metric_str = ("accuracy: 0.9000 recall: {'cat': 1, 'dog': 0} "
'cm: \ntensor([1, 2, 3])\n')
'cm: \ntensor([1, 2, 3])\ndata_time: 1.0000 ')
if by_epoch:
if mode == 'test':
assert out == 'Epoch(test) [5/5] ' + expect_metric_str
Expand Down