Skip to content

Commit

Permalink
[Fix] No training log when the num of iterations is smaller than the …
Browse files Browse the repository at this point in the history
…interval (#1046)
  • Loading branch information
shufanwu committed Apr 24, 2023
1 parent 580c9d4 commit 2aef53d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmengine/hooks/logger_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def after_train_iter(self,
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'train')
elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
and not self.ignore_last):
and (not self.ignore_last
or len(runner.train_dataloader) <= self.interval)):
# `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
Expand Down
10 changes: 10 additions & 0 deletions tests/test_hooks/test_logger_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def test_after_train_iter(self):
logger_hook.after_train_iter(runner, batch_idx=999)
runner.logger.info.assert_called()

# Test print training log when the num of
# iterations is smaller than the default interval
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
runner.train_dataloader = [0] * 9
logger_hook = LoggerHook()
logger_hook.after_train_iter(runner, batch_idx=8)
runner.log_processor.get_log_after_iter.assert_called()

def test_after_val_epoch(self):
logger_hook = LoggerHook()
runner = MagicMock()
Expand Down
11 changes: 11 additions & 0 deletions tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,17 @@ def train_step(self, *args, **kwargs):
with self.assertRaisesRegex(AssertionError, 'If you want to validate'):
runner.train()

# 13 Test the logs will be printed when the length of
# train_dataloader is smaller than the interval set in LoggerHook
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train13'
cfg.default_hooks = dict(logger=dict(type='LoggerHook', interval=5))
runner = Runner.from_cfg(cfg)
runner.train()
with open(runner.logger._log_file) as f:
log = f.read()
self.assertIn('Epoch(train) [1][4/4]', log)

@skipIf(
SKIP_TEST_COMPILE,
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
Expand Down

0 comments on commit 2aef53d

Please sign in to comment.