Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make UNet2DConditionOutput pickle-able #3857

Merged
merged 6 commits into from Jul 6, 2023

Conversation

prathikr
Copy link
Contributor

@prathikr prathikr commented Jun 22, 2023

This PR addresses previous concerns that the output of the UNet's forward pass is not copy-able. The root cause appears to be because copy fails on collections.OrderedDict dataclass with required args. The solution presented sets a default value for sample such that is it no longer a required parameter of the output class while still erroring when missing since the default setting is None (link to similar solution for different model).

Reproduction Instructions:

from diffusers.utils import BaseOutput
from dataclasses import dataclass
import copy

@dataclass
class NetParams(BaseOutput):
    sample: torch.FloatTensor

m = NetParams(sample=torch.randn(1, 10))
n = copy.copy(m)

@prathikr prathikr changed the title add default to unet output to prevent it from being a required arg [WIP] add default to unet output to prevent it from being a required arg Jun 22, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 22, 2023

The documentation is not available anymore as the PR was closed or merged.

@prathikr prathikr changed the title [WIP] add default to unet output to prevent it from being a required arg [WIP] Make UNet2DConditionOutput pickle-able Jun 22, 2023
@prathikr prathikr changed the title [WIP] Make UNet2DConditionOutput pickle-able Make UNet2DConditionOutput pickle-able Jun 22, 2023
@prathikr prathikr marked this pull request as ready for review June 22, 2023 23:15
@prathikr
Copy link
Contributor Author

@patrickvonplaten @anton-l can I please get a review on this?

@prathikr
Copy link
Contributor Author

@patrickvonplaten @anton-l any updates?

@patrickvonplaten
Copy link
Contributor

This change is ok for me! Could we add a test here that shows how we can now pickle the output?

@prathikr
Copy link
Contributor Author

prathikr commented Jun 28, 2023

@patrickvonplaten I gave adding a unit test a try. Let me know if I should change it or put it somewhere else.

@prathikr
Copy link
Contributor Author

@patrickvonplaten any updates?

@prathikr
Copy link
Contributor Author

prathikr commented Jul 3, 2023

@patrickvonplaten this is currently blocking ONNX Runtime integration with Diffusers. Can you please provide an update? Thank you.

@patrickvonplaten
Copy link
Contributor

Ok for me!

@patrickvonplaten
Copy link
Contributor

@sayakpaul @pcuenca could you maybe also quickly check?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! PR looks great to me except for https://github.com/huggingface/diffusers/pull/3857/files#r1252260499.

@prathikr prathikr requested a review from sayakpaul July 5, 2023 18:08
@prathikr
Copy link
Contributor Author

prathikr commented Jul 6, 2023

@sayakpaul can you please review/merge? Thank you.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@sayakpaul sayakpaul merged commit de14261 into huggingface:main Jul 6, 2023
8 checks passed
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add default to unet output to prevent it from being a required arg

* add unit test

* make style

* adjust unit test

* mark as fast test

* adjust assert statement in test

---------

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add default to unet output to prevent it from being a required arg

* add unit test

* make style

* adjust unit test

* mark as fast test

* adjust assert statement in test

---------

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants