-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the issue (#2102): Add cast to float32 from float16 into Stable D… #2124
Fix the issue (#2102): Add cast to float32 from float16 into Stable D… #2124
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
@@ -235,6 +235,8 @@ def generate_image( | |||
+ unconditional_guidance_scale * (latent - unconditional_latent) | |||
) | |||
a_t, a_prev = alphas[index], alphas_prev[index] | |||
if latent.dtype == "float16": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be able to unconditionally cast latent to the dtype of latent_prev, e.g. ops.cast(latent, latent_prev.dtype)
.
Is there a reason that won't work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ianstenbit
Thank you very much for your comment.
I think you are right!
I will try your advice and modify the code!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ianstenbit
I tried your suggestion and modified the code!
Thank you!!
If there is no problem, please approve my modification.
…o Stable Diffusion
cc214e6
to
3c2f897
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good -- thanks for the PR!
/gcbrun |
…o Stable Diffusion (keras-team#2124)
…o Stable Diffusion (keras-team#2124)
What does this PR do?
This PR fixed Issue #2102 (InvalidArgumentError occurred when I implemeted image generation using mixed float16) .
Fixes # (issue)
#2102
Who can review?
@ianstenbit
@jbischof
Someone who know this problem.