From 99687396f33379b84d2303d0ce43ab33c21c6118 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 10 Apr 2023 18:39:23 +0000 Subject: [PATCH] Fix scheduler type mismatch When doing generation manually and using guidance_scale as a static argument. --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 066d1e99acaa..31a62c6ccc16 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -245,6 +245,9 @@ def _generate( negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + latents_shape = ( batch_size, self.unet.in_channels,