Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix build unnecessary loop during train/test/val #1107

Merged
merged 7 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 16 additions & 2 deletions mmengine/hooks/logger_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,25 @@ def after_val_epoch(self,
runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str)
if self.log_metric_by_epoch:
# Accessing the epoch attribute of the runner will trigger
# the construction of the train_loop. Therefore, to avoid
# triggering the construction of the train_loop during
# validation, check before accessing the epoch.
if (isinstance(runner._train_loop, dict)
or runner._train_loop is None):
epoch = 0
else:
epoch = runner.epoch
runner.visualizer.add_scalars(
tag, step=runner.epoch, file_path=self.json_log_path)
tag, step=epoch, file_path=self.json_log_path)
else:
if (isinstance(runner._train_loop, dict)
or runner._train_loop is None):
iter = 0
else:
iter = runner.iter
runner.visualizer.add_scalars(
tag, step=runner.iter, file_path=self.json_log_path)
tag, step=iter, file_path=self.json_log_path)

def after_test_epoch(self,
runner,
Expand Down
42 changes: 25 additions & 17 deletions mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def get_log_after_iter(self, runner, batch_idx: int,
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in ['train', 'test', 'val']
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
parsed_cfg = self._parse_windows_size(runner, batch_idx,
self.custom_cfg)
Expand Down Expand Up @@ -172,19 +171,23 @@ def get_log_after_iter(self, runner, batch_idx: int,
# ... ||| |||
# Epoch(train) [ 10][100/270]
dataloader_len = self._get_dataloader_size(runner, mode)
cur_iter = self._get_iter(runner, batch_idx)
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))

if mode in ['train', 'val']:
# Right Align the epoch log:
# Epoch(train) [9][100/270]
# ... ||
# Epoch(train) [100][100/270]
cur_epoch = self._get_epoch(runner, mode)
max_epochs = runner.max_epochs
# 3 means the three characters: "[", "]", and " " occupied in
# " [{max_epochs}]"
cur_epoch_str = f'[{cur_epoch}]'.rjust(
len(str(max_epochs)) + 3, ' ')
if not (isinstance(runner._train_loop, dict)
or runner._train_loop is None):
# Right Align the epoch log:
# Epoch(train) [9][100/270]
# ... ||
# Epoch(train) [100][100/270]
max_epochs = runner.max_epochs
# 3 means the three characters: "[", "]", and " " occupied
# in " [{max_epochs}]"
cur_epoch_str = f'[{cur_epoch}]'.rjust(
len(str(max_epochs)) + 3, ' ')
else:
cur_epoch_str = f'[{cur_epoch}]'
tag['epoch'] = cur_epoch
log_str = (f'Epoch({mode}){cur_epoch_str}'
f'[{cur_iter_str}/{dataloader_len}] ')
Expand All @@ -193,6 +196,7 @@ def get_log_after_iter(self, runner, batch_idx: int,
f'[{cur_iter_str}/{dataloader_len}] ')
else:
if mode == 'train':
cur_iter = self._get_iter(runner, batch_idx)
cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters)))
log_str = (f'Iter({mode}) '
f'[{cur_iter_str}/{runner.max_iters}] ')
Expand Down Expand Up @@ -492,19 +496,19 @@ def _get_max_memory(self, runner) -> int:
device = getattr(runner.model, 'output_device', None)
return get_max_cuda_memory(device)

def _get_iter(self, runner, batch_idx: int = None) -> int:
def _get_iter(self, runner, batch_idx: int) -> int:
"""Get current iteration index.

Args:
runner (Runner): The runner of the training/testing/validation
process.
batch_idx (int, optional): The iteration index of current
batch_idx (int): The iteration index of current
dataloader. Defaults to None.

Returns:
int: The current global iter or inner iter.
"""
if self.by_epoch and batch_idx is not None:
if self.by_epoch:
current_iter = batch_idx + 1
else:
current_iter = runner.iter + 1
Expand All @@ -524,9 +528,13 @@ def _get_epoch(self, runner, mode: str) -> int:
if mode == 'train':
epoch = runner.epoch + 1
elif mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before validation
epoch = runner.epoch
if (isinstance(runner._train_loop, dict)
or runner._train_loop is None):
epoch = 0
else:
# normal val mode
# runner.epoch += 1 has been done before validation
epoch = runner.epoch
else:
raise ValueError(
f"runner mode should be 'train' or 'val', but got {mode}")
Expand Down
5 changes: 1 addition & 4 deletions tests/test_runner/test_log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,7 @@ def test_get_max_memory(self):

def test_get_iter(self):
log_processor = LogProcessor()
# Get global iter when `inner_iter=False`
iter = log_processor._get_iter(self.runner)
assert iter == 11
# Get inner iter
# Get batch_idx
iter = log_processor._get_iter(self.runner, 1)
assert iter == 2
# Still get global iter when `logger_hook.by_epoch==False`
Expand Down
29 changes: 28 additions & 1 deletion tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,16 @@ def train_step(self, *args, **kwargs):
log = f.read()
self.assertIn('Epoch(train) [1][4/4]', log)

# 14. test_loop will not be built
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
cfg = copy.deepcopy(cfg)
cfg.experiment_name = 'test_train14'
runner = Runner.from_cfg(cfg)
runner.train()
self.assertIsInstance(runner._train_loop, BaseLoop)
self.assertIsInstance(runner._val_loop, BaseLoop)
self.assertIsInstance(runner._test_loop, dict)

@skipIf(
SKIP_TEST_COMPILE,
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
Expand Down Expand Up @@ -1842,6 +1852,15 @@ def get_outputs_callback(module, inputs, outputs):
self.assertIn(predictions[0].dtype,
(torch.float16, torch.bfloat16))

# train_loop and test_loop will not be built
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
cfg = copy.deepcopy(cfg)
cfg.experiment_name = 'test_val4'
runner = Runner.from_cfg(cfg)
runner.val()
self.assertIsInstance(runner._train_loop, dict)
self.assertIsInstance(runner._test_loop, dict)

@skipIf(
SKIP_TEST_COMPILE,
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
Expand Down Expand Up @@ -1901,7 +1920,7 @@ def get_outputs_callback(module, inputs, outputs):
predictions.clear()

# Test fp16 `autocast` context.
cfg.experiment_name = 'test_val3'
cfg.experiment_name = 'test_test3'
cfg.test_cfg = dict(fp16=True)
runner = Runner.from_cfg(cfg)
runner.model.register_forward_hook(get_outputs_callback)
Expand All @@ -1913,6 +1932,14 @@ def get_outputs_callback(module, inputs, outputs):
runner.test()
self.assertIn(predictions[0].dtype,
(torch.float16, torch.bfloat16))
# train_loop and val_loop will not be built
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
cfg = copy.deepcopy(cfg)
cfg.experiment_name = 'test_test4'
runner = Runner.from_cfg(cfg)
runner.test()
self.assertIsInstance(runner._train_loop, dict)
self.assertIsInstance(runner._val_loop, dict)

@skipIf(
SKIP_TEST_COMPILE,
Expand Down