-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Make dynamo wrapped modules work with save_pretrained #2726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
I don't see other way around either. |
|
||
# Dynamo wraps the original mode and changes the class. | ||
# Is there a principled way to obtain the original class? | ||
if "_dynamo.eval_frame" in str(model_cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think that's a bit cleaner
if "_dynamo.eval_frame" in str(model_cls): | |
if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcuenca quick question - why don't we go for this API? Think isinstance
is a better check here than the string checking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because torch._dynamo
is not available in all platforms, just on cuda builds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this then:
if "_dynamo.eval_frame" in str(model_cls): | |
if is_torch_version(">=", "2.0.0") and hasattr(torch, "_dynamo") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit too verbose and distracting for my taste, specially if we are going to use it in multiple places. We could create a helper function somewhere. The other issue is that using a private symbol makes me uneasy about the class hierarchy changing in future revisions, making the code crash. Which is not necessarily bad, because if we use a string then we'd silently fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! Tbh, both ways are fine by me - I do prefer to avoid silent errors and generally don't like string "in" testing for classes as this is a very brittle test in my opinion.
Would prefer a helper function here, but also ok to leave as is as it's an edge case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice - thanks for the fix! Let's maybe do the same in modeling_utils.py
. Would just recommend to use the isinstance(...)
check instead of string matching check - think that's a bit less brittle.
The way you retrieve the original class here is good IMO
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Instead of the dynamo class.
I created a simple test that unsurfaced additional problems in the configuration (the wrong class was saved to config.json). |
PyTorch CPU does not have _dynamo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
register_dict = {name: (None, None)} | ||
else: | ||
# register the original module, not the dynamo compiled one | ||
if "_dynamo.eval_frame" in str(module.__class__): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this check is better:
if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As commented above, torch._dynamo
does not exist in all platforms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry, now I get it
# I didn't find a public API to get the original class. | ||
sub_model = passed_class_obj[name] | ||
model_cls = sub_model.__class__ | ||
if "_dynamo.eval_frame" in str(model_cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this check is better:
if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we apply the same changes to the model save_pretrained function?
Thanks for clarifying the design here. Could we then maybe go with: #2726 (comment) or do you prefer how it is now? |
It works, actually; at first I thought it wouldn't but I was wrong. When |
@patrickvonplaten so the only thing missing here would be to decide whether to use a helper function to verify the class. I'm undecided because I think that testing for a private class that might change is not very much better than testing for a name, but we might want to make it fail so we can fix it when it changes. What do you think? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Summarized my thoughts here: https://github.com/huggingface/diffusers/pull/2726/files#r1146130475
Good to go whenever you want :-)
I used the helper function and I think it's ok, can one of you give it a final look before merge? @sayakpaul @patrickvonplaten |
Nice! Looks clean to me :) |
* Workaround for saving dynamo-wrapped models. * Accept suggestion from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply workaround when overriding pipeline components. * Ensure the correct config.json is saved to disk. Instead of the dynamo class. * Save correct module (not compiled one) * Add test * style * fix docstrings * Go back to using string comparisons. PyTorch CPU does not have _dynamo. * Simple test for save_pretrained of compiled models. * Helper function to test whether module is compiled. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Workaround for saving dynamo-wrapped models. * Accept suggestion from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply workaround when overriding pipeline components. * Ensure the correct config.json is saved to disk. Instead of the dynamo class. * Save correct module (not compiled one) * Add test * style * fix docstrings * Go back to using string comparisons. PyTorch CPU does not have _dynamo. * Simple test for save_pretrained of compiled models. * Helper function to test whether module is compiled. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Workaround for saving dynamo-wrapped models. * Accept suggestion from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply workaround when overriding pipeline components. * Ensure the correct config.json is saved to disk. Instead of the dynamo class. * Save correct module (not compiled one) * Add test * style * fix docstrings * Go back to using string comparisons. PyTorch CPU does not have _dynamo. * Simple test for save_pretrained of compiled models. * Helper function to test whether module is compiled. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Experimental, not sure if there's a better way to do it. We'd also need to do something similar in
modeling_utils.py
.Addresses #2709.