-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Fix scheduler type mismatch #3041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When doing generation manually and using guidance_scale as a static argument.
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) | ||
|
||
# Ensure model output will be `float32` before going into the scheduler | ||
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative would be to simply cast to the correct type before invoking the step function. I did it this way for symmetry with what happens here: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L369
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me!
When doing generation manually and using guidance_scale as a static argument.
When doing generation manually and using guidance_scale as a static argument.
When doing generation manually and using guidance_scale as a static argument.
When doing generation manually and using guidance_scale as a static argument.
When doing generation manually instead of using
__call__
, and using guidance_scale as a static argument.See #3039 for a discussion.