Skip to content

Commit

Permalink
Add the loop stage in message_hub (#1277)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Jul 31, 2023
1 parent 237aee3 commit 2df93eb
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 2 deletions.
26 changes: 26 additions & 0 deletions mmengine/hooks/runtime_info_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,15 @@ def before_run(self, runner) -> None:
mmengine_version=__version__ + get_git_hash())
runner.message_hub.update_info_dict(metainfo)

self.last_loop_stage = None

def before_train(self, runner) -> None:
"""Update resumed training state.
Args:
runner (Runner): The runner of the training process.
"""
runner.message_hub.update_info('loop_stage', 'train')
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
Expand All @@ -68,6 +71,9 @@ def before_train(self, runner) -> None:
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)

def after_train(self, runner) -> None:
runner.message_hub.pop_info('loop_stage')

def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch.
Expand Down Expand Up @@ -119,6 +125,10 @@ def after_train_iter(self,
for key, value in outputs.items():
runner.message_hub.update_scalar(f'train/{key}', value)

def before_val(self, runner) -> None:
self.last_loop_stage = runner.message_hub.get_info('loop_stage')
runner.message_hub.update_info('loop_stage', 'val')

def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
Expand All @@ -138,6 +148,22 @@ def after_val_epoch(self,
else:
runner.message_hub.update_info(f'val/{key}', value)

def after_val(self, runner) -> None:
# ValLoop may be called within the TrainLoop, so we need to reset
# the loop_stage
# workflow: before_train -> before_val -> after_val -> after_train
if self.last_loop_stage == 'train':
runner.message_hub.update_info('loop_stage', self.last_loop_stage)
self.last_loop_stage = None
else:
runner.message_hub.pop_info('loop_stage')

def before_test(self, runner) -> None:
runner.message_hub.update_info('loop_stage', 'test')

def after_test(self, runner) -> None:
runner.message_hub.pop_info('loop_stage')

def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
Expand Down
16 changes: 15 additions & 1 deletion mmengine/logging/message_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ def update_info(self, key: str, value: Any, resumed: bool = True) -> None:
self._set_resumed_keys(key, resumed)
self._runtime_info[key] = value

def pop_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Remove runtime information by key. If the key does not exist, this
method will return the default value.
Args:
key (str): Key of runtime information.
default (Any, optional): The default returned value for the
given key.
Returns:
Any: The runtime information if the key exists.
"""
return self._runtime_info.pop(key, default)

def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None:
"""Update runtime information with dictionary.
Expand Down Expand Up @@ -289,7 +303,7 @@ def get_scalar(self, key: str) -> HistoryBuffer:
return self.log_scalars[key]

def get_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Get runtime information by key. if the key does not exist, this
"""Get runtime information by key. If the key does not exist, this
method will return default information.
Args:
Expand Down
36 changes: 35 additions & 1 deletion tests/test_hooks/test_runtime_info_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,20 @@ def tearDown(self):
DATASETS.module_dict.pop('DatasetWithMetainfo')
return super().tearDown()

def test_before_train(self):
def test_before_and_after_train(self):

cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.train_dataloader.dataset.type = 'DatasetWithoutMetainfo'
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
self.assertEqual(runner.message_hub.get_info('epoch'), 0)
self.assertEqual(runner.message_hub.get_info('iter'), 0)
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)
hook.after_train(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))

cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
runner = self.build_runner(cfg)
Expand Down Expand Up @@ -110,6 +113,28 @@ def test_after_train_iter(self):
self.assertEqual(
runner.message_hub.get_scalar('train/loss_cls').current(), 1.111)

def test_before_and_after_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val')
self.assertIsNone(hook.last_loop_stage)
hook.after_val(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))

# Simulate the workflow of calling the ValLoop within the TrainLoop
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
hook.before_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val')
self.assertEqual(hook.last_loop_stage, 'train')
hook.after_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
self.assertIsNone(hook.last_loop_stage)

def test_after_val_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
Expand All @@ -118,6 +143,15 @@ def test_after_val_epoch(self):
self.assertEqual(
runner.message_hub.get_scalar('val/acc').current(), 0.8)

def test_before_and_after_test(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_test(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'test')
hook.after_test(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))

def test_after_test_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_logging/test_message_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def test_update_info(self):
message_hub.update_info('key', 1)
assert message_hub.runtime_info['key'] == 1

def test_pop_info(self):
message_hub = MessageHub.get_instance('mmengine')
message_hub.update_info('pop_key', 'pop_info')
assert message_hub.runtime_info['pop_key'] == 'pop_info'
assert message_hub.pop_info('pop_key') == 'pop_info'

assert message_hub.pop_info('not_existed_key', 'info') == 'info'

def test_update_infos(self):
message_hub = MessageHub.get_instance('mmengine')
# test runtime value can be overwritten.
Expand Down

0 comments on commit 2df93eb

Please sign in to comment.