Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto forward non method attribute lookups to the user's model and bind custom methods to ORTModule #8798

Merged
merged 10 commits into from
Sep 3, 2021

Conversation

baijumeswani
Copy link
Contributor

@baijumeswani baijumeswani commented Aug 20, 2021

This pull request:

  1. For non methods attributes: auto forwards any call made to the user's original torch.nn.Module through ORTModule by implementing the methods:
    • __getattr__: Implemented for attribute lookups that could not be found in ORTModule and so a lookup is performed on the user's torch.nn.Module.
    • __setattr__: Implemented so that users can set attributes on their original module with ease. All attributes are set on the user provided module instead of on ORTModule. This is also used as a way to signal to ORTModule that the execution graph has changed and therefore a re-export must be done before the next forward call. This is done automatically by this implementation. The auto re-export can be controlled by enabling the skip check for re-building the graph, re creating the execution agent and so on (export ORTMODULE_SKIPCHECK_POLICY="SKIP_CHECK_BUILD_GRADIENT|SKIP_CHECK_EXECUTION_AGENT").
  2. For methods attributes: copies user defined methods and binds them to the ORTModule instance. This is done in order to prevent the problem where user defined methods invoke the forward on the model thereby calling the PyTorch module implementation of forward as opposed to ORTModule's implementation of forward.

Users can now seamlessly use their training script with ORTModule without needing to change how they invoke user defined methods on their original torch.nn.Module. Here is an example:

# User defined torch.nn.Module 
class UserDefinedMethodsNet(torch.nn.Module): 
    def __init__(self): 
        super(UserDefinedMethodsNet, self).__init__() 

    def forward(self, ...): 
        ... 

    def custom_user_method(self, ...): 
        return some_calculation()  

    def training_step(self, ...):
        out = self(...)
        ...

# Instantiation of ORTModule 
model = UserDefinedMethodsNet() 
model = ORTModule(model) 

# Invoke user defined function 
out = model.custom_user_method() # No AttributeError since ORTModule auto forwards the call to the original module 
model.training_step(...) # ORTModule's forward will be executed and not the user defined forward

In addition, ORTModule checks for any attribute name collisions between the user's model and ORTModule.

@SherlockNoMad
Copy link
Contributor

Another thing to consider is that what if there is private variable used in user defined functions?

def custom_functions(): self.state = update

IIUC, current implementation only cover stateless user defined functions?

@baijumeswani
Copy link
Contributor Author

baijumeswani commented Aug 20, 2021

Another thing to consider is that what if there is private variable used in user defined functions?

def custom_functions(): self.state = update

IIUC, current implementation only cover stateless user defined functions?

Do you mean for auto detecting attribute change and marking the model for re-export? Then yes, we cannot always detect that the user has made changes to the model, especially if the changes are being made to the model from a path that cannot be controlled by ORTModule. One way to ensure that the model get's re-exported is by exposing a method (mark_execution_graph_as_stale) to the user that should be explicitly called whenever they made a change to the model.

@thiagocrepaldi
Copy link
Contributor

We are looking into recursive ORTModule soon. Will this PR handle such cases, where one ORTModule can wrap another? in one of its layers?

@baijumeswani
Copy link
Contributor Author

We are looking into recursive ORTModule soon. Will this PR handle such cases, where one ORTModule can wrap another? in one of its layers?

It should work with recursive ORTModules. Don't see anything that prevents that. But I think it would be hard to assess and predict the difficulties this PR brings to the recursive ORTModules without more details/implementation.

Copy link
Contributor

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing to consider is that what if there is private variable used in user defined functions?
def custom_functions(): self.state = update
IIUC, current implementation only cover stateless user defined functions?

Do you mean for auto detecting attribute change and marking the model for re-export? Then yes, we cannot always detect that the user has made changes to the model, especially if the changes are being made to the model from a path that cannot be controlled by ORTModule. One way to ensure that the model get's re-exported is by exposing a method (mark_execution_graph_as_stale) to the user that should be explicitly called whenever they made a change to the model.

Just confirming we are not publicly exposing an API for that purpose. ORTModule should assume attribute changes (method or non-method) mark the model as stale automatically.

Copy link
Contributor

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!
Setting/Retrieving attributes based on module initialization status as opposed to internal names simplified understanding (although code-wise they were almost the same)!
Depending on _torch_module on an assert was nice too, as it will raise if renamed, but never silently fail due to check failure!

My only question was regarding to __getattr__'s else which didnt have a return on it. It called super() and reached the end of method without a return. Maybe a typo?

thiagocrepaldi
thiagocrepaldi previously approved these changes Aug 30, 2021
@baijumeswani baijumeswani changed the title Auto forward attribute lookups to the user's model through ORTModule Auto forward non method attribute lookups to the user's model and bind custom methods to ORTModule Sep 1, 2021
thiagocrepaldi
thiagocrepaldi previously approved these changes Sep 2, 2021
Copy link
Contributor

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thiagocrepaldi
Copy link
Contributor

thiagocrepaldi commented Sep 2, 2021

Avoid adding unrelated changes to important PRs.

thiagocrepaldi
thiagocrepaldi previously approved these changes Sep 2, 2021
@baijumeswani
Copy link
Contributor Author

baijumeswani commented Sep 2, 2021

Avoid adding unrelated changes to important PRs.

This pull request introduces auto signaling the module to be re-exported. We need a way for users to disable the auto re-export using skip check. The skip check addition is a way for users control this.
Having said that, maybe I could have put it in a follow up PR.

@baijumeswani baijumeswani merged commit 0cc2909 into master Sep 3, 2021
@baijumeswani baijumeswani deleted the bmeswani/user-defined-methods branch September 3, 2021 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants