New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make UNet2DConditionOutput
pickle-able
#3857
Make UNet2DConditionOutput
pickle-able
#3857
Conversation
The documentation is not available anymore as the PR was closed or merged. |
UNet2DConditionOutput
pickle-able
UNet2DConditionOutput
pickle-ableUNet2DConditionOutput
pickle-able
@patrickvonplaten @anton-l can I please get a review on this? |
@patrickvonplaten @anton-l any updates? |
This change is ok for me! Could we add a test here that shows how we can now pickle the output? |
@patrickvonplaten I gave adding a unit test a try. Let me know if I should change it or put it somewhere else. |
@patrickvonplaten any updates? |
@patrickvonplaten this is currently blocking ONNX Runtime integration with Diffusers. Can you please provide an update? Thank you. |
Ok for me! |
@sayakpaul @pcuenca could you maybe also quickly check? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! PR looks great to me except for https://github.com/huggingface/diffusers/pull/3857/files#r1252260499.
@sayakpaul can you please review/merge? Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
* add default to unet output to prevent it from being a required arg * add unit test * make style * adjust unit test * mark as fast test * adjust assert statement in test --------- Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
* add default to unet output to prevent it from being a required arg * add unit test * make style * adjust unit test * mark as fast test * adjust assert statement in test --------- Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This PR addresses previous concerns that the output of the UNet's forward pass is not copy-able. The root cause appears to be because copy fails on collections.OrderedDict dataclass with required args. The solution presented sets a default value for
sample
such that is it no longer a required parameter of the output class while still erroring when missing since the default setting isNone
(link to similar solution for different model).Reproduction Instructions: