Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] MessageHub.get_info() supports returning a default value #991

Merged
merged 22 commits into from
Apr 23, 2023
23 changes: 15 additions & 8 deletions mmengine/logging/message_hub.py
Expand Up @@ -296,22 +296,29 @@ def get_scalar(self, key: str) -> HistoryBuffer:
f'instance name is: {MessageHub.instance_name}')
return self.log_scalars[key]

def get_info(self, key: str) -> Any:
"""Get runtime information by key.
def get_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Get runtime information by key. if the key does not exist, this
method will return default information.

Args:
key (str): Key of runtime information.
default (Any, optional): The default returned value for the
given key.

Returns:
Any: A copy of corresponding runtime information if the key exists.
"""
if key not in self.runtime_info:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')

# TODO: There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]
if default is not None:
return default
else:
raise KeyError(
f'Can not find {key} in runtime information of message_hub'
)
else:
# TODO: There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]

def _get_valid_value(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logging/test_message_hub.py
Expand Up @@ -83,7 +83,7 @@ def test_get_scalar(self):
def test_get_runtime(self):
message_hub = MessageHub.get_instance('mmengine')
with pytest.raises(KeyError):
message_hub.get_info('unknown')
message_hub.get_info('unknown', ['test_value'])
enkilee marked this conversation as resolved.
Show resolved Hide resolved
recorded_dict = dict(a=1, b=2)
message_hub.update_info('test_value', recorded_dict)
assert message_hub.get_info('test_value') == recorded_dict
Expand Down