Skip to content

Commit

Permalink
compile train_step, val_step, test_step instead
Browse files Browse the repository at this point in the history
  • Loading branch information
C1rN09 committed Mar 8, 2023
1 parent 7b4060b commit b845d98
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2327,8 +2327,11 @@ def _maybe_compile(self) -> None:
assert isinstance(compile_cfg, dict), (
f'`compile` should be a dict or bool, got {type(compile_cfg)}')

# Compile the model.forward
self.model.forward = torch.compile(self.model.forward, **compile_cfg)
# Compile the model's train_step, val_step and test_step
for target in ('train_step', 'val_step', 'test_step'):
func = getattr(self.model, target)
compiled_func = torch.compile(func, **compile_cfg)
setattr(self.model, target, compiled_func)
self._is_compiled = True
self.logger.info('Model has been "compiled". The first few iterations'
' will be slow, please be patient.')

0 comments on commit b845d98

Please sign in to comment.