-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto forward non method attribute lookups to the user's model and bind custom methods to ORTModule #8798
Conversation
Another thing to consider is that what if there is private variable used in user defined functions? def custom_functions(): self.state = update IIUC, current implementation only cover stateless user defined functions? |
Do you mean for auto detecting attribute change and marking the model for re-export? Then yes, we cannot always detect that the user has made changes to the model, especially if the changes are being made to the model from a path that cannot be controlled by |
1a20398
to
93dce83
Compare
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Outdated
Show resolved
Hide resolved
We are looking into recursive ORTModule soon. Will this PR handle such cases, where one ORTModule can wrap another? in one of its layers? |
It should work with recursive ORTModules. Don't see anything that prevents that. But I think it would be hard to assess and predict the difficulties this PR brings to the recursive ORTModules without more details/implementation. |
5ffd42d
to
ba15cc0
Compare
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing to consider is that what if there is private variable used in user defined functions?
def custom_functions(): self.state = update
IIUC, current implementation only cover stateless user defined functions?Do you mean for auto detecting attribute change and marking the model for re-export? Then yes, we cannot always detect that the user has made changes to the model, especially if the changes are being made to the model from a path that cannot be controlled by
ORTModule
. One way to ensure that the model get's re-exported is by exposing a method (mark_execution_graph_as_stale
) to the user that should be explicitly called whenever they made a change to the model.
Just confirming we are not publicly exposing an API for that purpose. ORTModule should assume attribute changes (method or non-method) mark the model as stale automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
Setting/Retrieving attributes based on module initialization status as opposed to internal names simplified understanding (although code-wise they were almost the same)!
Depending on _torch_module
on an assert was nice too, as it will raise if renamed, but never silently fail due to check failure!
My only question was regarding to __getattr__
's else
which didnt have a return
on it. It called super()
and reached the end of method without a return. Maybe a typo?
…zation of ORTModule
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
803dac6
to
b874490
Compare
Avoid adding unrelated changes to important PRs. |
This pull request introduces auto signaling the module to be re-exported. We need a way for users to disable the auto re-export using skip check. The skip check addition is a way for users control this. |
This pull request:
torch.nn.Module
throughORTModule
by implementing the methods:__getattr__
: Implemented for attribute lookups that could not be found inORTModule
and so a lookup is performed on the user'storch.nn.Module
.__setattr__
: Implemented so that users can set attributes on their original module with ease. All attributes are set on the user provided module instead of onORTModule
. This is also used as a way to signal toORTModule
that the execution graph has changed and therefore a re-export must be done before the next forward call. This is done automatically by this implementation. The auto re-export can be controlled by enabling the skip check for re-building the graph, re creating the execution agent and so on (export ORTMODULE_SKIPCHECK_POLICY="SKIP_CHECK_BUILD_GRADIENT|SKIP_CHECK_EXECUTION_AGENT"
).ORTModule
instance. This is done in order to prevent the problem where user defined methods invoke the forward on the model thereby calling thePyTorch
module implementation of forward as opposed toORTModule
's implementation of forward.Users can now seamlessly use their training script with
ORTModule
without needing to change how they invoke user defined methods on their originaltorch.nn.Module
. Here is an example:In addition,
ORTModule
checks for any attribute name collisions between the user's model andORTModule
.