Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the loop stage in message_hub #1277

Merged
merged 3 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions mmengine/hooks/runtime_info_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,15 @@ def before_run(self, runner) -> None:
mmengine_version=__version__ + get_git_hash())
runner.message_hub.update_info_dict(metainfo)

self.last_loop_stage = None

def before_train(self, runner) -> None:
"""Update resumed training state.

Args:
runner (Runner): The runner of the training process.
"""
runner.message_hub.update_info('loop_stage', 'train')
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
Expand All @@ -68,6 +71,9 @@ def before_train(self, runner) -> None:
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)

def after_train(self, runner) -> None:
runner.message_hub.pop_info('loop_stage')

def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch.

Expand Down Expand Up @@ -119,6 +125,10 @@ def after_train_iter(self,
for key, value in outputs.items():
runner.message_hub.update_scalar(f'train/{key}', value)

def before_val(self, runner) -> None:
self.last_loop_stage = runner.message_hub.get_info('loop_stage')
runner.message_hub.update_info('loop_stage', 'val')

def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
Expand All @@ -138,6 +148,22 @@ def after_val_epoch(self,
else:
runner.message_hub.update_info(f'val/{key}', value)

def after_val(self, runner) -> None:
# ValLoop may be called within the TrainLoop, so we need to reset
# the loop_stage
# workflow: before_train -> before_val -> after_val -> after_train
if self.last_loop_stage == 'train':
runner.message_hub.update_info('loop_stage', self.last_loop_stage)
self.last_loop_stage = None
else:
runner.message_hub.pop_info('loop_stage')

def before_test(self, runner) -> None:
runner.message_hub.update_info('loop_stage', 'test')

def after_test(self, runner) -> None:
runner.message_hub.pop_info('loop_stage')

def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
Expand Down
16 changes: 15 additions & 1 deletion mmengine/logging/message_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ def update_info(self, key: str, value: Any, resumed: bool = True) -> None:
self._set_resumed_keys(key, resumed)
self._runtime_info[key] = value

def pop_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Remove runtime information by key. If the key does not exist, this
method will return the default value.

Args:
key (str): Key of runtime information.
default (Any, optional): The default returned value for the
given key.

Returns:
Any: The runtime information if the key exists.
"""
return self._runtime_info.pop(key, default)

def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None:
"""Update runtime information with dictionary.

Expand Down Expand Up @@ -289,7 +303,7 @@ def get_scalar(self, key: str) -> HistoryBuffer:
return self.log_scalars[key]

def get_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Get runtime information by key. if the key does not exist, this
"""Get runtime information by key. If the key does not exist, this
method will return default information.

Args:
Expand Down
36 changes: 35 additions & 1 deletion tests/test_hooks/test_runtime_info_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,20 @@ def tearDown(self):
DATASETS.module_dict.pop('DatasetWithMetainfo')
return super().tearDown()

def test_before_train(self):
def test_before_and_after_train(self):

cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.train_dataloader.dataset.type = 'DatasetWithoutMetainfo'
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
self.assertEqual(runner.message_hub.get_info('epoch'), 0)
self.assertEqual(runner.message_hub.get_info('iter'), 0)
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)
hook.after_train(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))

cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
runner = self.build_runner(cfg)
Expand Down Expand Up @@ -110,6 +113,28 @@ def test_after_train_iter(self):
self.assertEqual(
runner.message_hub.get_scalar('train/loss_cls').current(), 1.111)

def test_before_and_after_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val')
self.assertIsNone(hook.last_loop_stage)
hook.after_val(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))

# Simulate the workflow of calling the ValLoop within the TrainLoop
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
hook.before_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val')
self.assertEqual(hook.last_loop_stage, 'train')
hook.after_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
self.assertIsNone(hook.last_loop_stage)

def test_after_val_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
Expand All @@ -118,6 +143,15 @@ def test_after_val_epoch(self):
self.assertEqual(
runner.message_hub.get_scalar('val/acc').current(), 0.8)

def test_before_and_after_test(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_test(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'test')
hook.after_test(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))

def test_after_test_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_logging/test_message_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def test_update_info(self):
message_hub.update_info('key', 1)
assert message_hub.runtime_info['key'] == 1

def test_pop_info(self):
message_hub = MessageHub.get_instance('mmengine')
message_hub.update_info('pop_key', 'pop_info')
assert message_hub.runtime_info['pop_key'] == 'pop_info'
assert message_hub.pop_info('pop_key') == 'pop_info'

assert message_hub.pop_info('not_existed_key', 'info') == 'info'

def test_update_infos(self):
message_hub = MessageHub.get_instance('mmengine')
# test runtime value can be overwritten.
Expand Down