Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] MessageHub.get_info() supports returning a default value #991

Merged
merged 22 commits into from Apr 23, 2023
7 changes: 1 addition & 6 deletions mmengine/logging/message_hub.py
Expand Up @@ -309,12 +309,7 @@ def get_info(self, key: str, default: Optional[Any] = None) -> Any:
Any: A copy of corresponding runtime information if the key exists.
"""
if key not in self.runtime_info:
if default is not None:
return default
else:
raise KeyError(
f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return default
else:
# TODO: There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
Expand Down
3 changes: 0 additions & 3 deletions tests/test_hooks/test_runtime_info_hook.py
Expand Up @@ -46,9 +46,6 @@ def test_before_train(self):
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)

with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
runner.message_hub.get_info('dataset_meta')

cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
runner = self.build_runner(cfg)
hook.before_train(runner)
Expand Down
7 changes: 0 additions & 7 deletions tests/test_logging/test_message_hub.py
Expand Up @@ -82,8 +82,6 @@ def test_get_scalar(self):

def test_get_runtime(self):
message_hub = MessageHub.get_instance('mmengine')
with pytest.raises(KeyError):
message_hub.get_info('unknown')
enkilee marked this conversation as resolved.
Show resolved Hide resolved
recorded_dict = dict(a=1, b=2)
message_hub.update_info('test_value', recorded_dict)
assert message_hub.get_info('test_value') == recorded_dict
Expand Down Expand Up @@ -186,11 +184,6 @@ def test_getstate(self):
obj = pickle.dumps(message_hub)
instance = pickle.loads(obj)

with pytest.raises(KeyError):
instance.get_info('feat')
with pytest.raises(KeyError):
instance.get_info('lr')
enkilee marked this conversation as resolved.
Show resolved Hide resolved

instance.get_info('iter')
instance.get_scalar('loss')

Expand Down