Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,18 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:

added_cond_kwargs = None
if self.addition_embed_type == "text_time":
# TODO: how to get this from the config? It's no longer cross_attention_dim
text_embeds_dim = 1280
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
is_refiner = (
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
== self.config.projection_class_embeddings_input_dim
)
num_micro_conditions = 5 if is_refiner else 6

text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
num_micro_conditions * self.config.addition_time_embed_dim
)

time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
added_cond_kwargs = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,15 @@ def _generate(
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * params["scheduler"].init_noise_sigma

# Prepare scheduler state
scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler_state.init_noise_sigma
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to set the init_noise_sigma atfter scaling

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, great catch!


added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

# Denoising loop
Expand Down