Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] MessageHub.get_info() supports returning a default value #991

Merged
merged 22 commits into from Apr 23, 2023
18 changes: 10 additions & 8 deletions mmengine/logging/message_hub.py
Expand Up @@ -296,22 +296,24 @@ def get_scalar(self, key: str) -> HistoryBuffer:
f'instance name is: {MessageHub.instance_name}')
return self.log_scalars[key]

def get_info(self, key: str) -> Any:
"""Get runtime information by key.
def get_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Get runtime information by key. if the key does not exist, this
method will return default information.

Args:
key (str): Key of runtime information.
default (Any, optional): The default returned value for the
given key.

Returns:
Any: A copy of corresponding runtime information if the key exists.
"""
if key not in self.runtime_info:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')

# TODO: There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]
return default
else:
# TODO: There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]

def _get_valid_value(
self,
Expand Down
3 changes: 0 additions & 3 deletions tests/test_hooks/test_runtime_info_hook.py
Expand Up @@ -46,9 +46,6 @@ def test_before_train(self):
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)

with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
runner.message_hub.get_info('dataset_meta')

cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
runner = self.build_runner(cfg)
hook.before_train(runner)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_logging/test_message_hub.py
Expand Up @@ -82,8 +82,8 @@ def test_get_scalar(self):

def test_get_runtime(self):
message_hub = MessageHub.get_instance('mmengine')
with pytest.raises(KeyError):
message_hub.get_info('unknown')
enkilee marked this conversation as resolved.
Show resolved Hide resolved
tgrresult = message_hub.get_info('unknown')
assert tgrresult is None
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
recorded_dict = dict(a=1, b=2)
message_hub.update_info('test_value', recorded_dict)
assert message_hub.get_info('test_value') == recorded_dict
Expand Down Expand Up @@ -186,10 +186,10 @@ def test_getstate(self):
obj = pickle.dumps(message_hub)
instance = pickle.loads(obj)

with pytest.raises(KeyError):
instance.get_info('feat')
with pytest.raises(KeyError):
instance.get_info('lr')
enkilee marked this conversation as resolved.
Show resolved Hide resolved
featresult = instance.get_info('feat')
assert featresult is None
lrresult = instance.get_info('lr')
assert lrresult is None
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

instance.get_info('iter')
instance.get_scalar('loss')
Expand Down