Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions keras_hub/src/models/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def generate(
self,
inputs,
num_steps,
guidance_scale,
strength,
guidance_scale=None,
seed=None,
):
"""Generate image based on the provided `inputs`.
Expand Down Expand Up @@ -313,30 +313,36 @@ def generate(
- A `tf.data.Dataset` with `"images"`, `"prompts"` and/or
`"negative_prompts"` keys.
num_steps: int. The number of diffusion steps to take.
guidance_scale: float. The classifier free guidance scale defined in
[Classifier-Free Diffusion Guidance](
https://arxiv.org/abs/2207.12598). A higher scale encourages
generating images more closely related to the prompts, typically
at the cost of lower image quality.
strength: float. Indicates the extent to which the reference
`images` are transformed. Must be between `0.0` and `1.0`. When
`strength=1.0`, `images` is essentially ignore and added noise
is maximum and the denoising process runs for the full number of
iterations specified in `num_steps`.
guidance_scale: Optional float. The classifier free guidance scale
defined in [Classifier-Free Diffusion Guidance](
https://arxiv.org/abs/2207.12598). A higher scale encourages
generating images more closely related to the prompts, typically
at the cost of lower image quality. Note that some models don't
utilize classifier-free guidance.
seed: optional int. Used as a random seed.
"""
num_steps = int(num_steps)
guidance_scale = float(guidance_scale)
strength = float(strength)
guidance_scale = (
float(guidance_scale) if guidance_scale is not None else None
)
if strength < 0.0 or strength > 1.0:
raise ValueError(
"`strength` must be between `0.0` and `1.0`. "
f"Received strength={strength}."
)
if guidance_scale is not None and guidance_scale > 1.0:
guidance_scale = ops.convert_to_tensor(float(guidance_scale))
else:
guidance_scale = None
starting_step = int(num_steps * (1.0 - strength))
starting_step = ops.convert_to_tensor(starting_step, "int32")
num_steps = ops.convert_to_tensor(num_steps, "int32")
guidance_scale = ops.convert_to_tensor(guidance_scale)
num_steps = ops.convert_to_tensor(int(num_steps), "int32")

# Check `inputs` format.
required_keys = ["images", "prompts"]
Expand Down
21 changes: 14 additions & 7 deletions keras_hub/src/models/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ def generate(
self,
inputs,
num_steps,
guidance_scale,
strength,
guidance_scale=None,
seed=None,
):
"""Generate image based on the provided `inputs`.
Expand Down Expand Up @@ -406,26 +406,33 @@ def generate(
- A `tf.data.Dataset` with `"images"`, `"masks"`, `"prompts"`
and/or `"negative_prompts"` keys.
num_steps: int. The number of diffusion steps to take.
guidance_scale: float. The classifier free guidance scale defined in
[Classifier-Free Diffusion Guidance](
https://arxiv.org/abs/2207.12598). A higher scale encourages
generating images more closely related to the prompts, typically
at the cost of lower image quality.
strength: float. Indicates the extent to which the reference
`images` are transformed. Must be between `0.0` and `1.0`. When
`strength=1.0`, `images` is essentially ignore and added noise
is maximum and the denoising process runs for the full number of
iterations specified in `num_steps`.
guidance_scale: Optional float. The classifier free guidance scale
defined in [Classifier-Free Diffusion Guidance](
https://arxiv.org/abs/2207.12598). A higher scale encourages
generating images more closely related to the prompts, typically
at the cost of lower image quality. Note that some models don't
utilize classifier-free guidance.
seed: optional int. Used as a random seed.
"""
num_steps = int(num_steps)
guidance_scale = float(guidance_scale)
strength = float(strength)
guidance_scale = (
float(guidance_scale) if guidance_scale is not None else None
)
if strength < 0.0 or strength > 1.0:
raise ValueError(
"`strength` must be between `0.0` and `1.0`. "
f"Received strength={strength}."
)
if guidance_scale is not None and guidance_scale > 1.0:
guidance_scale = ops.convert_to_tensor(guidance_scale)
else:
guidance_scale = None
starting_step = int(num_steps * (1.0 - strength))
starting_step = ops.convert_to_tensor(starting_step, "int32")
num_steps = ops.convert_to_tensor(num_steps, "int32")
Expand Down
41 changes: 41 additions & 0 deletions keras_hub/src/models/stable_diffusion_3/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ class DismantledBlock(layers.Layer):
mlp_ratio: float. The expansion ratio of `MLP`.
use_projection: bool. Whether to use an attention projection layer at
the end of the block.
qk_norm: Optional str. Whether to normalize the query and key tensors.
Available options are `None` and `"rms_norm"`. Defaults to `None`.
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
including `name`, `dtype` etc.
"""
Expand All @@ -364,13 +366,15 @@ def __init__(
hidden_dim,
mlp_ratio=4.0,
use_projection=True,
qk_norm=None,
**kwargs,
):
super().__init__(**kwargs)
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.mlp_ratio = mlp_ratio
self.use_projection = use_projection
self.qk_norm = qk_norm

head_dim = hidden_dim // num_heads
self.head_dim = head_dim
Expand All @@ -391,6 +395,18 @@ def __init__(
self.attention_qkv = layers.Dense(
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
)
if qk_norm is not None and qk_norm == "rms_norm":
self.q_norm = layers.LayerNormalization(
epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
)
self.k_norm = layers.LayerNormalization(
epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
)
elif qk_norm is not None:
raise NotImplementedError(
"Supported `qk_norm` are `'rms_norm'` and `None`. "
f"Received: qk_norm={qk_norm}."
)
if use_projection:
self.attention_proj = layers.Dense(
hidden_dim, dtype=self.dtype_policy, name="attention_proj"
Expand All @@ -413,6 +429,10 @@ def __init__(
def build(self, inputs_shape, timestep_embedding):
self.ada_layer_norm.build(inputs_shape, timestep_embedding)
self.attention_qkv.build(inputs_shape)
if self.qk_norm is not None:
# [batch_size, sequence_length, num_heads, head_dim]
self.q_norm.build([None, None, self.num_heads, self.head_dim])
self.k_norm.build([None, None, self.num_heads, self.head_dim])
if self.use_projection:
self.attention_proj.build(inputs_shape)
self.norm2.build(inputs_shape)
Expand All @@ -435,6 +455,9 @@ def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
)
q, k, v = ops.unstack(qkv, 3, axis=2)
if self.qk_norm is not None:
q = self.q_norm(q, training=training)
k = self.k_norm(k, training=training)
return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
else:
x = self.ada_layer_norm(
Expand All @@ -445,6 +468,9 @@ def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
)
q, k, v = ops.unstack(qkv, 3, axis=2)
if self.qk_norm is not None:
q = self.q_norm(q, training=training)
k = self.k_norm(k, training=training)
return (q, k, v)

def _compute_post_attention(
Expand Down Expand Up @@ -494,6 +520,7 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"mlp_ratio": self.mlp_ratio,
"use_projection": self.use_projection,
"qk_norm": self.qk_norm,
}
)
return config
Expand All @@ -513,6 +540,8 @@ class MMDiTBlock(layers.Layer):
mlp_ratio: float. The expansion ratio of `MLP`.
use_context_projection: bool. Whether to use an attention projection
layer at the end of the context block.
qk_norm: Optional str. Whether to normalize the query and key tensors.
Available options are `None` and `"rms_norm"`. Defaults to `None`.
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
including `name`, `dtype` etc.

Expand All @@ -527,13 +556,15 @@ def __init__(
hidden_dim,
mlp_ratio=4.0,
use_context_projection=True,
qk_norm=None,
**kwargs,
):
super().__init__(**kwargs)
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.mlp_ratio = mlp_ratio
self.use_context_projection = use_context_projection
self.qk_norm = qk_norm

head_dim = hidden_dim // num_heads
self.head_dim = head_dim
Expand All @@ -544,6 +575,7 @@ def __init__(
hidden_dim=hidden_dim,
mlp_ratio=mlp_ratio,
use_projection=True,
qk_norm=qk_norm,
dtype=self.dtype_policy,
name="x_block",
)
Expand All @@ -552,6 +584,7 @@ def __init__(
hidden_dim=hidden_dim,
mlp_ratio=mlp_ratio,
use_projection=use_context_projection,
qk_norm=qk_norm,
dtype=self.dtype_policy,
name="context_block",
)
Expand Down Expand Up @@ -629,6 +662,7 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"mlp_ratio": self.mlp_ratio,
"use_context_projection": self.use_context_projection,
"qk_norm": self.qk_norm,
}
)
return config
Expand Down Expand Up @@ -705,6 +739,9 @@ class MMDiT(Backbone):
latent_shape: tuple. The shape of the latent image.
context_shape: tuple. The shape of the context.
pooled_projection_shape: tuple. The shape of the pooled projection.
qk_norm: Optional str. Whether to normalize the query and key tensors in
the intermediate blocks. Available options are `None` and
`"rms_norm"`. Defaults to `None`.
data_format: `None` or str. If specified, either `"channels_last"` or
`"channels_first"`. The ordering of the dimensions in the
inputs. `"channels_last"` corresponds to inputs with shape
Expand All @@ -729,6 +766,7 @@ def __init__(
latent_shape=(64, 64, 16),
context_shape=(None, 4096),
pooled_projection_shape=(2048,),
qk_norm=None,
data_format=None,
dtype=None,
**kwargs,
Expand Down Expand Up @@ -782,6 +820,7 @@ def __init__(
hidden_dim,
mlp_ratio,
use_context_projection=not (i == num_layers - 1),
qk_norm=qk_norm,
dtype=dtype,
name=f"joint_block_{i}",
)
Expand Down Expand Up @@ -851,6 +890,7 @@ def __init__(
self.latent_shape = latent_shape
self.context_shape = context_shape
self.pooled_projection_shape = pooled_projection_shape
self.qk_norm = qk_norm

def get_config(self):
config = super().get_config()
Expand All @@ -865,6 +905,7 @@ def get_config(self):
"latent_shape": self.latent_shape,
"context_shape": self.context_shape,
"pooled_projection_shape": self.pooled_projection_shape,
"qk_norm": self.qk_norm,
}
)
return config
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class StableDiffusion3Backbone(Backbone):
transformer in MMDiT.
mmdit_position_size: int. The size of the height and width for the
position embedding in MMDiT.
mmdit_qk_norm: Optional str. Whether to normalize the query and key
tensors for each transformer in MMDiT. Available options are `None`
and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
and to `"rms_norm" for 3.5 version.
vae: The VAE used for transformations between pixel space and latent
space.
clip_l: The CLIP text encoder for encoding the inputs.
Expand Down Expand Up @@ -248,6 +252,7 @@ class StableDiffusion3Backbone(Backbone):
mmdit_hidden_dim=256,
mmdit_depth=4,
mmdit_position_size=192,
mmdit_qk_norm=None,
vae=vae,
clip_l=clip_l,
clip_g=clip_g,
Expand All @@ -262,6 +267,7 @@ def __init__(
mmdit_num_layers,
mmdit_num_heads,
mmdit_position_size,
mmdit_qk_norm,
vae,
clip_l,
clip_g,
Expand Down Expand Up @@ -312,6 +318,7 @@ def __init__(
latent_shape=latent_shape,
context_shape=context_shape,
pooled_projection_shape=pooled_projection_shape,
qk_norm=mmdit_qk_norm,
data_format=data_format,
dtype=dtype,
name="diffuser",
Expand Down Expand Up @@ -446,6 +453,7 @@ def __init__(
self.mmdit_num_layers = mmdit_num_layers
self.mmdit_num_heads = mmdit_num_heads
self.mmdit_position_size = mmdit_position_size
self.mmdit_qk_norm = mmdit_qk_norm
self.latent_channels = latent_channels
self.output_channels = output_channels
self.num_train_timesteps = num_train_timesteps
Expand Down Expand Up @@ -532,17 +540,23 @@ def denoise_step(
embeddings,
step,
num_steps,
guidance_scale,
guidance_scale=None,
):
step = ops.convert_to_tensor(step)
next_step = ops.add(step, 1)
sigma, timestep = self.scheduler(step, num_steps)
next_sigma, _ = self.scheduler(next_step, num_steps)

# Concatenation for classifier-free guidance.
concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
latents, *embeddings, timestep
)
if guidance_scale is not None:
concated_latents, contexts, pooled_projs, timesteps = (
self.cfg_concat(latents, *embeddings, timestep)
)
else:
timesteps = ops.broadcast_to(timestep, ops.shape(latents)[:1])
concated_latents = latents
contexts = embeddings[0]
pooled_projs = embeddings[2]

# Diffusion.
predicted_noise = self.diffuser(
Expand All @@ -556,7 +570,8 @@ def denoise_step(
)

# Classifier-free guidance.
predicted_noise = self.cfg(predicted_noise, guidance_scale)
if guidance_scale is not None:
predicted_noise = self.cfg(predicted_noise, guidance_scale)

# Euler step.
return self.euler_step(latents, predicted_noise, sigma, next_sigma)
Expand All @@ -574,6 +589,7 @@ def get_config(self):
"mmdit_num_layers": self.mmdit_num_layers,
"mmdit_num_heads": self.mmdit_num_heads,
"mmdit_position_size": self.mmdit_position_size,
"mmdit_qk_norm": self.mmdit_qk_norm,
"vae": layers.serialize(self.vae),
"clip_l": layers.serialize(self.clip_l),
"clip_g": layers.serialize(self.clip_g),
Expand Down Expand Up @@ -620,4 +636,9 @@ def from_config(cls, config, custom_objects=None):
config["t5"] = layers.deserialize(
config["t5"], custom_objects=custom_objects
)

# To maintain backward compatibility, we need to ensure that
# `mmdit_qk_norm` is included in the config.
if "mmdit_qk_norm" not in config:
config["mmdit_qk_norm"] = None
return cls(**config)
Loading
Loading