Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion keras_hub/src/models/stable_diffusion_3/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,28 @@ def build(self, inputs_shape, context_shape, timestep_embedding_shape):
self.context_block.build(context_shape, timestep_embedding_shape)

def _compute_attention(self, query, key, value):
batch_size = ops.shape(query)[0]

# Use the fast path when `ops.dot_product_attention` and flash attention
# are available.
if hasattr(ops, "dot_product_attention") and hasattr(
keras.config, "is_flash_attention_enabled"
):
# `ops.dot_product_attention` is slower than the vanilla
# implementation in the tensorflow backend.
encoded = ops.dot_product_attention(
query,
key,
value,
scale=self._inverse_sqrt_key_dim,
flash_attention=keras.config.is_flash_attention_enabled(),
)
return ops.reshape(
encoded, (batch_size, -1, self.num_heads * self.head_dim)
)

# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846
batch_size = ops.shape(query)[0]
logits = ops.einsum("BTNH,BSNH->BNTS", query, key)
logits = ops.multiply(logits, self._inverse_sqrt_key_dim)
probs = self.softmax(logits)
Expand Down
Loading