Skip to content

Commit

Permalink
[Fix] Fix build unnecessary loop during train/test/val (#1107)
Browse files Browse the repository at this point in the history
* [Fix] Fix build unnecessary loop during train/test/val

* move unit test to runner

* Update unit test

* Fix unit test

* check train_loop is None

* update comment

* replace(type(None)) with is not None
  • Loading branch information
HAOCHENYE committed Apr 27, 2023
1 parent 49b27dd commit 298a4b1
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 24 deletions.
18 changes: 16 additions & 2 deletions mmengine/hooks/logger_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,25 @@ def after_val_epoch(self,
runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str)
if self.log_metric_by_epoch:
# Accessing the epoch attribute of the runner will trigger
# the construction of the train_loop. Therefore, to avoid
# triggering the construction of the train_loop during
# validation, check before accessing the epoch.
if (isinstance(runner._train_loop, dict)
or runner._train_loop is None):
epoch = 0
else:
epoch = runner.epoch
runner.visualizer.add_scalars(
tag, step=runner.epoch, file_path=self.json_log_path)
tag, step=epoch, file_path=self.json_log_path)
else:
if (isinstance(runner._train_loop, dict)
or runner._train_loop is None):
iter = 0
else:
iter = runner.iter
runner.visualizer.add_scalars(
tag, step=runner.iter, file_path=self.json_log_path)
tag, step=iter, file_path=self.json_log_path)

def after_test_epoch(self,
runner,
Expand Down
42 changes: 25 additions & 17 deletions mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def get_log_after_iter(self, runner, batch_idx: int,
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in ['train', 'test', 'val']
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
parsed_cfg = self._parse_windows_size(runner, batch_idx,
self.custom_cfg)
Expand Down Expand Up @@ -172,19 +171,23 @@ def get_log_after_iter(self, runner, batch_idx: int,
# ... ||| |||
# Epoch(train) [ 10][100/270]
dataloader_len = self._get_dataloader_size(runner, mode)
cur_iter = self._get_iter(runner, batch_idx)
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))

if mode in ['train', 'val']:
# Right Align the epoch log:
# Epoch(train) [9][100/270]
# ... ||
# Epoch(train) [100][100/270]
cur_epoch = self._get_epoch(runner, mode)
max_epochs = runner.max_epochs
# 3 means the three characters: "[", "]", and " " occupied in
# " [{max_epochs}]"
cur_epoch_str = f'[{cur_epoch}]'.rjust(
len(str(max_epochs)) + 3, ' ')
if not (isinstance(runner._train_loop, dict)
or runner._train_loop is None):
# Right Align the epoch log:
# Epoch(train) [9][100/270]
# ... ||
# Epoch(train) [100][100/270]
max_epochs = runner.max_epochs
# 3 means the three characters: "[", "]", and " " occupied
# in " [{max_epochs}]"
cur_epoch_str = f'[{cur_epoch}]'.rjust(
len(str(max_epochs)) + 3, ' ')
else:
cur_epoch_str = f'[{cur_epoch}]'
tag['epoch'] = cur_epoch
log_str = (f'Epoch({mode}){cur_epoch_str}'
f'[{cur_iter_str}/{dataloader_len}] ')
Expand All @@ -193,6 +196,7 @@ def get_log_after_iter(self, runner, batch_idx: int,
f'[{cur_iter_str}/{dataloader_len}] ')
else:
if mode == 'train':
cur_iter = self._get_iter(runner, batch_idx)
cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters)))
log_str = (f'Iter({mode}) '
f'[{cur_iter_str}/{runner.max_iters}] ')
Expand Down Expand Up @@ -492,19 +496,19 @@ def _get_max_memory(self, runner) -> int:
device = getattr(runner.model, 'output_device', None)
return get_max_cuda_memory(device)

def _get_iter(self, runner, batch_idx: int = None) -> int:
def _get_iter(self, runner, batch_idx: int) -> int:
"""Get current iteration index.
Args:
runner (Runner): The runner of the training/testing/validation
process.
batch_idx (int, optional): The iteration index of current
batch_idx (int): The iteration index of current
dataloader. Defaults to None.
Returns:
int: The current global iter or inner iter.
"""
if self.by_epoch and batch_idx is not None:
if self.by_epoch:
current_iter = batch_idx + 1
else:
current_iter = runner.iter + 1
Expand All @@ -524,9 +528,13 @@ def _get_epoch(self, runner, mode: str) -> int:
if mode == 'train':
epoch = runner.epoch + 1
elif mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before validation
epoch = runner.epoch
if (isinstance(runner._train_loop, dict)
or runner._train_loop is None):
epoch = 0
else:
# normal val mode
# runner.epoch += 1 has been done before validation
epoch = runner.epoch
else:
raise ValueError(
f"runner mode should be 'train' or 'val', but got {mode}")
Expand Down
5 changes: 1 addition & 4 deletions tests/test_runner/test_log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,7 @@ def test_get_max_memory(self):

def test_get_iter(self):
log_processor = LogProcessor()
# Get global iter when `inner_iter=False`
iter = log_processor._get_iter(self.runner)
assert iter == 11
# Get inner iter
# Get batch_idx
iter = log_processor._get_iter(self.runner, 1)
assert iter == 2
# Still get global iter when `logger_hook.by_epoch==False`
Expand Down
29 changes: 28 additions & 1 deletion tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,16 @@ def train_step(self, *args, **kwargs):
log = f.read()
self.assertIn('Epoch(train) [1][4/4]', log)

# 14. test_loop will not be built
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
cfg = copy.deepcopy(cfg)
cfg.experiment_name = 'test_train14'
runner = Runner.from_cfg(cfg)
runner.train()
self.assertIsInstance(runner._train_loop, BaseLoop)
self.assertIsInstance(runner._val_loop, BaseLoop)
self.assertIsInstance(runner._test_loop, dict)

@skipIf(
SKIP_TEST_COMPILE,
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
Expand Down Expand Up @@ -1880,6 +1890,15 @@ def get_outputs_callback(module, inputs, outputs):
self.assertIn(predictions[0].dtype,
(torch.float16, torch.bfloat16))

# train_loop and test_loop will not be built
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
cfg = copy.deepcopy(cfg)
cfg.experiment_name = 'test_val4'
runner = Runner.from_cfg(cfg)
runner.val()
self.assertIsInstance(runner._train_loop, dict)
self.assertIsInstance(runner._test_loop, dict)

@skipIf(
SKIP_TEST_COMPILE,
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
Expand Down Expand Up @@ -1939,7 +1958,7 @@ def get_outputs_callback(module, inputs, outputs):
predictions.clear()

# Test fp16 `autocast` context.
cfg.experiment_name = 'test_val3'
cfg.experiment_name = 'test_test3'
cfg.test_cfg = dict(fp16=True)
runner = Runner.from_cfg(cfg)
runner.model.register_forward_hook(get_outputs_callback)
Expand All @@ -1951,6 +1970,14 @@ def get_outputs_callback(module, inputs, outputs):
runner.test()
self.assertIn(predictions[0].dtype,
(torch.float16, torch.bfloat16))
# train_loop and val_loop will not be built
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
cfg = copy.deepcopy(cfg)
cfg.experiment_name = 'test_test4'
runner = Runner.from_cfg(cfg)
runner.test()
self.assertIsInstance(runner._train_loop, dict)
self.assertIsInstance(runner._val_loop, dict)

@skipIf(
SKIP_TEST_COMPILE,
Expand Down

0 comments on commit 298a4b1

Please sign in to comment.