Skip to content

[SDXL Flax] fix SDXL flax init#5187

Merged
patrickvonplaten merged 3 commits into
mainfrom
fix_sdxl_flax_init
Sep 26, 2023
Merged

[SDXL Flax] fix SDXL flax init#5187
patrickvonplaten merged 3 commits into
mainfrom
fix_sdxl_flax_init

Conversation

@patrickvonplaten
Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten commented Sep 26, 2023

Make sure Flax init is correct for different model sizes

is_refiner = 5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim == self.config.projection_class_embeddings_input_dim
num_micro_conditions = 5 if is_refiner else 6

text_embeds_dim = self.config.projection_class_embeddings_input_dim - (num_micro_conditions * self.config.addition_time_embed_dim)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2816 - 6 * 256 = 1280 for base and 2560 - 5 * 256 = 1280 for refiner

Copy link
Copy Markdown
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks a lot! 🙌

)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler_state.init_noise_sigma
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to set the init_noise_sigma atfter scaling

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, great catch!

@patrickvonplaten
Copy link
Copy Markdown
Contributor Author

@pcuenca I'm now getting almost identical results on CPU with PyTorch vs. Flax for dummy inputs:

Flax

from diffusers import FlaxStableDiffusionXLPipeline
import numpy as np
import jax.numpy as jnp
import jax

path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

pipe, params = FlaxStableDiffusionXLPipeline.from_pretrained(path)

prompt = "An astronaut riding a green horse on Mars"
negative_prompt = "ugly"
steps = 3

batch_size, height, width, ch = 1, 32, 32, 4
num_elems = batch_size * height * width * ch
rng = jax.random.PRNGKey(0)
latents = (jnp.arange(num_elems) / num_elems)[:, None, None, None].reshape(batch_size, ch, width, height)

print("latents", np.abs(np.asarray(latents)).sum())

prompt_embeds = pipe.prepare_inputs(prompt)
neg_prompt_ids = pipe.prepare_inputs(negative_prompt)

image = pipe(prompt_embeds, params, rng, neg_prompt_ids=neg_prompt_ids, latents=latents, num_inference_steps=3, output_type="np").images[0]

print(np.abs(np.asarray(image)).sum())

PT

import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline

path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

pipe = StableDiffusionXLPipeline.from_pretrained(path)
pipe.unet.set_default_attn_processor()

prompt = "An astronaut riding a green horse on Mars"
neg_prompt = "ugly"
steps = 3

batch_size, height, width, ch = 1, 32, 32, 4
num_elems = batch_size * height * width * ch
latents = (torch.arange(num_elems) / num_elems)[:, None, None, None].reshape(batch_size, ch, width, height)
print("latents", latents.abs().sum())

image = pipe(prompt, negative_prompt=neg_prompt, latents=latents, num_inference_steps=3, output_type="np", guidance_scale=7.5).images[0]

print(np.abs(image).sum())

Getting:
PT: 6237.967
Flax: 6237.9585

@patrickvonplaten patrickvonplaten merged commit c82f7ba into main Sep 26, 2023
@patrickvonplaten patrickvonplaten deleted the fix_sdxl_flax_init branch September 26, 2023 17:55
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* fix SDXL flax init

* finish

* Fix
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* fix SDXL flax init

* finish

* Fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants