-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove compile wrapper to simplify access to model attributes #5581
Remove compile wrapper to simplify access to model attributes #5581
Conversation
deepspeed/runtime/engine.py
Outdated
@@ -90,7 +90,7 @@ | |||
|
|||
from .pipe.module import PipelineModule | |||
from .utils import get_ma_status | |||
from .compiler import CompiledModuleWrapper | |||
from .compiler import get_backend_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from .compiler import get_backend_fn | |
from .compiler import get_backend_fn, is_compile_supported |
deepspeed/runtime/engine.py
Outdated
if self._config.compile_config.enabled and not self._is_compiled: | ||
if self._compiler_backend is None: | ||
self._compiler_backend = get_backend_fn(self._config.compile_config.backend) | ||
|
||
if self._compiler_fn is None: | ||
compiled_model = torch.compile(self.module, | ||
backend=self._compiler_backend, | ||
**self._config.compile_config.kwargs) | ||
else: | ||
compiled_model = self._compiler_fn(self.module) | ||
|
||
self._set_client_model(compiled_model) | ||
self._is_compiled = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self._config.compile_config.enabled and not self._is_compiled: | |
if self._compiler_backend is None: | |
self._compiler_backend = get_backend_fn(self._config.compile_config.backend) | |
if self._compiler_fn is None: | |
compiled_model = torch.compile(self.module, | |
backend=self._compiler_backend, | |
**self._config.compile_config.kwargs) | |
else: | |
compiled_model = self._compiler_fn(self.module) | |
self._set_client_model(compiled_model) | |
self._is_compiled = True | |
if self._config.compile_config.enabled not self._is_compiled and is_compile_supported(): | |
backend = self._compiler_fn if self._compiler_fn is not None else self._compiler_backend | |
self.module.compile(backend = backend, **self._compile_kwargs) | |
self._is_compiled = True |
Reasons for the above:
- Allows passing user defined kwargs even when the user provides custom compiler_fn
- Type of self.module does not changes since torch.nn.Module.compile is called instead of torch.compile. The module is compiled in-place.
I used to fail on this line. It should pass now as we only compile on forward and not in init, nevertheless we should consider keeping the type of self.module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.nn.Module.compile
is a great idea. Thank you for this suggestion.
compiler_fn
is not a backend. The intention of compiler_fn
is to enable something that can't be done just by setting backend and kwargs for torch.compile
(e.g. compiling part of the model). We need to fix that part carefully to make it consistent with in-place compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is possible to achieve this with compiler_fn
. torch.nn.Module.compile
just calls this function for compilation.
example in this custom function one can iterate over module and call compile()
inplace for every part to compile
def custom_compiler_fn(module: torch.nn.Module, example_inputs):
global custom_compler_fn_called
custom_compler_fn_called = True
module.l1.compile(backend=get_accelerator().get_compile_backend())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reasons for the above:
- Allows passing user defined kwargs even when the user provides custom compiler_fn
- Type of self.module does not changes since torch.nn.Module.compile is called instead of torch.compile. The module is compiled in-place.
I used to fail on this line. It should pass now as we only compile on forward and not in init, nevertheless we should consider keeping the type of self.module.
On the other hand:
- The model returned by
torch.compile(user_model)
is actually a smart wrapper of theuser_model
. Torch keeps theuser_model
inside, such that user can add new attributes or change the existing ones in their original model. - User may not want their module be compiled by DS.
compiled_model = torch.compile(self.module,..
solves that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BacharL I think I get your points in your last comment. Can you clarify the intention of this? This seems that _compiler_fn
has no difference with _compiler_backend
.
backend = self._compiler_fn if self._compiler_fn is not None else self._compiler_backend
self.module.compile(backend = backend, **self._compile_kwargs)
@deepcharm You can pass a compiled module to deepspeed's init but it will make a slight difference. DeepSpeed sets some hooks but I thought it can't once the model is compiled. I don't think it is harmful to keep compiler_fn
while such use cases might not be very popular.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BacharL @deepcharm, I had a discussion with @tjruwase about the design.
He suggested having compile
API of DeepSpeed instead of running compilation at the first forward pass. The behavior would be easier for users to understand. In this design, we don't need ds config for compile any more. At least for now, it will be a simple wrapper of engine.module.compile().
Do you have any thought on this? I think I can briefly test the feasibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This suggestion seems good.
Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BacharL I implemented the approach. Would it work for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!, Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tohtana This approach is simple and explicit, thanks.
deepspeed/runtime/engine.py
Outdated
if self.is_compiled: | ||
return | ||
|
||
self.module = self.module.compile(backend=backend, **compile_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.module = self.module.compile(backend=backend, **compile_kwargs) | |
self.module.compile(backend=backend, **compile_kwargs) |
nn.Module.compile returns None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, thank you!
Having a wrapper of a compiled module brings various restrictions about accessing attributes of the compiled model.
This PR removes the wrapper of compiled module to simplify the access to the compiled model.