Skip to content

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Apr 10, 2023

When doing generation manually instead of using __call__, and using guidance_scale as a static argument.

See #3039 for a discussion.

When doing generation manually and using guidance_scale as a static
argument.
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])

# Ensure model output will be `float32` before going into the scheduler
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to simply cast to the correct type before invoking the step function. I did it this way for symmetry with what happens here: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L369

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 10, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me!

@pcuenca pcuenca merged commit 526827c into main Apr 11, 2023
@pcuenca pcuenca deleted the 3039-flax-type-mismatch branch April 11, 2023 21:20
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
When doing generation manually and using guidance_scale as a static
argument.
dg845 pushed a commit to dg845/diffusers that referenced this pull request May 6, 2023
When doing generation manually and using guidance_scale as a static
argument.
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
When doing generation manually and using guidance_scale as a static
argument.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
When doing generation manually and using guidance_scale as a static
argument.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants