Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add configurations to support torch.compile in Runner #976

Merged
merged 12 commits into from
Mar 12, 2023
37 changes: 37 additions & 0 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,6 +1683,10 @@ def train(self) -> nn.Module:
self._train_loop.iter, # type: ignore
self._train_loop.max_iters) # type: ignore

# Maybe compile the model according to options in self.cfg.compile
# This must be called **AFTER** model has been wrapped.
self._maybe_compile()
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved

model = self.train_loop.run() # type: ignore
self.call_hook('after_run')
return model
Expand All @@ -1706,6 +1710,10 @@ def val(self) -> dict:
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()

# Maybe compile the model according to options in self.cfg.compile
# This must be called **AFTER** model has been wrapped.
self._maybe_compile()
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved

metrics = self.val_loop.run() # type: ignore
self.call_hook('after_run')
return metrics
Expand All @@ -1729,6 +1737,10 @@ def test(self) -> dict:
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()

# Maybe compile the model according to options in self.cfg.compile
# This must be called **AFTER** model has been wrapped.
self._maybe_compile()
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved

metrics = self.test_loop.run() # type: ignore
self.call_hook('after_run')
return metrics
Expand Down Expand Up @@ -2280,3 +2292,28 @@ def _log_env(self, env_cfg: dict) -> None:
'\nRuntime environment:' + runtime_env_info + '\n' +
dash_line + '\n')
self.logger.info(f'Config:\n{self.cfg.pretty_text}')

def _maybe_compile(self):
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved
"""Use `torch.compile` to optimize model/wrapped_model."""
compile_cfg = self.cfg.get('compile', None)
if compile_cfg is None:
# no compile options given, won't compile
return

assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), (
'PyTorch >= 2.0.0 is required to enable torch.compile')
assert isinstance(compile_cfg, dict), (
f'`compile` option should be a dict, got {type(compile_cfg)}')
assert 'target' in compile_cfg
target = compile_cfg.pop('target')
assert target == '' or hasattr(self.model, target), (
f'compile.target `{target}` should be an attribute of Model')
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved

if target == '':
# Compile the model itself
self.model = torch.compile(self.model, **compile_cfg)
else:
# Compile its function/module: forward, train_step, etc.
func = getattr(self.model, target)
compiled_func = torch.compile(func, **compile_cfg)
setattr(self.model, target, compiled_func)