Skip to content

Commit

Permalink
Merge f7dbae0 into eb79d64
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Mar 29, 2023
2 parents eb79d64 + f7dbae0 commit 264aaf5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
13 changes: 13 additions & 0 deletions mmengine/logging/logger.py
Expand Up @@ -291,6 +291,19 @@ def setLevel(self, level):
logger._cache.clear()
_release_lock()

def reset_filehandler(self) -> None:
"""FileHandler will be closed by ``torch.compile`` in PyTorch version
2.0.0. This method will be called to resume the FileHandler.
"""
for i, handler in enumerate(self.handlers):
if isinstance(handler, logging.FileHandler):
filename = handler.baseFilename
handler.close()
new_handler = logging.FileHandler(filename, 'a')
new_handler.setFormatter(
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
self.handlers[i] = new_handler


def print_log(msg,
logger: Optional[Union[Logger, str]] = None,
Expand Down
2 changes: 2 additions & 0 deletions mmengine/runner/runner.py
Expand Up @@ -2312,6 +2312,8 @@ def _maybe_compile(self, target: str) -> None:
f'`compile` should be a dict or bool, got {type(compile_cfg)}')

func = getattr(self.model, target)
if digit_version(torch.__version__) == digit_version('2.0.0'):
self.logger.reset_filehandler()
compiled_func = torch.compile(func, **compile_cfg)
setattr(self.model, target, compiled_func)
self.logger.info('Model has been "compiled". The first few iterations'
Expand Down
8 changes: 8 additions & 0 deletions tests/test_runner/test_runner.py
Expand Up @@ -1745,6 +1745,14 @@ def test_train_with_compile(self):
runner = Runner.from_cfg(cfg)
runner.train()

runner._maybe_compile('train_step')
# PyTorch 2.0.0 could close the FileHandler after calling of
# ``torch.compile``. So we need to test our file handler still works.
with open(osp.join(f'{runner.log_dir}',
f'{runner.timestamp}.log')) as f:
last_line = f.readlines()[-1]
self.assertTrue(last_line.endswith('please be patient.\n'))

def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'
Expand Down

0 comments on commit 264aaf5

Please sign in to comment.