Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove compile wrapper to simplify access to model attributes #5581

Merged
merged 10 commits into from
Jun 17, 2024

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented May 29, 2024

Having a wrapper of a compiled module brings various restrictions about accessing attributes of the compiled model.
This PR removes the wrapper of compiled module to simplify the access to the compiled model.

deepspeed/runtime/engine.py Outdated Show resolved Hide resolved
deepspeed/runtime/engine.py Outdated Show resolved Hide resolved
@@ -90,7 +90,7 @@

from .pipe.module import PipelineModule
from .utils import get_ma_status
from .compiler import CompiledModuleWrapper
from .compiler import get_backend_fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .compiler import get_backend_fn
from .compiler import get_backend_fn, is_compile_supported

Comment on lines 1795 to 1808
if self._config.compile_config.enabled and not self._is_compiled:
if self._compiler_backend is None:
self._compiler_backend = get_backend_fn(self._config.compile_config.backend)

if self._compiler_fn is None:
compiled_model = torch.compile(self.module,
backend=self._compiler_backend,
**self._config.compile_config.kwargs)
else:
compiled_model = self._compiler_fn(self.module)

self._set_client_model(compiled_model)
self._is_compiled = True

Copy link
Collaborator

@BacharL BacharL May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self._config.compile_config.enabled and not self._is_compiled:
if self._compiler_backend is None:
self._compiler_backend = get_backend_fn(self._config.compile_config.backend)
if self._compiler_fn is None:
compiled_model = torch.compile(self.module,
backend=self._compiler_backend,
**self._config.compile_config.kwargs)
else:
compiled_model = self._compiler_fn(self.module)
self._set_client_model(compiled_model)
self._is_compiled = True
if self._config.compile_config.enabled not self._is_compiled and is_compile_supported():
backend = self._compiler_fn if self._compiler_fn is not None else self._compiler_backend
self.module.compile(backend = backend, **self._compile_kwargs)
self._is_compiled = True

Reasons for the above:

  1. Allows passing user defined kwargs even when the user provides custom compiler_fn
  2. Type of self.module does not changes since torch.nn.Module.compile is called instead of torch.compile. The module is compiled in-place.
    I used to fail on this line. It should pass now as we only compile on forward and not in init, nevertheless we should consider keeping the type of self.module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.nn.Module.compile is a great idea. Thank you for this suggestion.

compiler_fn is not a backend. The intention of compiler_fn is to enable something that can't be done just by setting backend and kwargs for torch.compile (e.g. compiling part of the model). We need to fix that part carefully to make it consistent with in-place compile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is possible to achieve this with compiler_fn. torch.nn.Module.compile just calls this function for compilation.
example in this custom function one can iterate over module and call compile() inplace for every part to compile

    def custom_compiler_fn(module: torch.nn.Module, example_inputs):
        global custom_compler_fn_called
        custom_compler_fn_called = True
        module.l1.compile(backend=get_accelerator().get_compile_backend())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reasons for the above:

  1. Allows passing user defined kwargs even when the user provides custom compiler_fn
  2. Type of self.module does not changes since torch.nn.Module.compile is called instead of torch.compile. The module is compiled in-place.
    I used to fail on this line. It should pass now as we only compile on forward and not in init, nevertheless we should consider keeping the type of self.module.

On the other hand:

  1. The model returned by torch.compile(user_model) is actually a smart wrapper of the user_model. Torch keeps the user_model inside, such that user can add new attributes or change the existing ones in their original model.
  2. User may not want their module be compiled by DS.
    compiled_model = torch.compile(self.module,.. solves that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BacharL I think I get your points in your last comment. Can you clarify the intention of this? This seems that _compiler_fn has no difference with _compiler_backend.

            backend = self._compiler_fn if self._compiler_fn is not None else self._compiler_backend
	    self.module.compile(backend = backend, **self._compile_kwargs)

@deepcharm You can pass a compiled module to deepspeed's init but it will make a slight difference. DeepSpeed sets some hooks but I thought it can't once the model is compiled. I don't think it is harmful to keep compiler_fn while such use cases might not be very popular.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BacharL @deepcharm, I had a discussion with @tjruwase about the design.

He suggested having compile API of DeepSpeed instead of running compilation at the first forward pass. The behavior would be easier for users to understand. In this design, we don't need ds config for compile any more. At least for now, it will be a simple wrapper of engine.module.compile().
Do you have any thought on this? I think I can briefly test the feasibility.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggestion seems good.
Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BacharL I implemented the approach. Would it work for you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!, Thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tohtana This approach is simple and explicit, thanks.

@tohtana tohtana marked this pull request as ready for review May 31, 2024 06:12
if self.is_compiled:
return

self.module = self.module.compile(backend=backend, **compile_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.module = self.module.compile(backend=backend, **compile_kwargs)
self.module.compile(backend=backend, **compile_kwargs)

nn.Module.compile returns None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thank you!

@tohtana tohtana enabled auto-merge June 14, 2024 23:31
@tohtana tohtana added this pull request to the merge queue Jun 17, 2024
@loadams loadams removed this pull request from the merge queue due to a manual request Jun 17, 2024
@loadams loadams merged commit 2a0c0e3 into microsoft:master Jun 17, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants