From 0a1f8fcda2ee0f73a595a3f4e1672ffb95eacc76 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Nov 2025 00:39:42 +0100 Subject: [PATCH 1/2] fix --- src/diffusers/models/attention_dispatch.py | 16 +++++++++++----- src/diffusers/utils/constants.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 6ecf97701fe8..ad0190d0f19c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -383,12 +383,18 @@ def _check_shape( attn_mask: Optional[torch.Tensor] = None, **kwargs, ) -> None: + # Expected shapes: + # query: (batch_size, seq_len_q, num_heads, head_dim) + # key: (batch_size, seq_len_kv, num_heads, head_dim) + # value: (batch_size, seq_len_kv, num_heads, head_dim) + # attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv) + # or (batch_size, num_heads, seq_len_q, seq_len_kv) if query.shape[-1] != key.shape[-1]: - raise ValueError("Query and key must have the same last dimension.") - if query.shape[-2] != value.shape[-2]: - raise ValueError("Query and value must have the same second to last dimension.") - if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: - raise ValueError("Attention mask must match the key's second to last dimension.") + raise ValueError("Query and key must have the same head dimension.") + if key.shape[-3] != value.shape[-3]: + raise ValueError("Key and value must have the same sequence length.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]: + raise ValueError("Attention mask must match the key's sequence length.") # ===== Helper functions ===== diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 42a53e181034..a18f28606b3e 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -42,7 +42,7 @@ DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") -DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES +DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_VARS_TRUE_VALUES DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES From 0d707c4dfde5f674442d89a1e4d09e621484f164 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Nov 2025 00:41:25 +0100 Subject: [PATCH 2/2] fix --- src/diffusers/models/attention_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ad0190d0f19c..92a4a6a59936 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -387,7 +387,7 @@ def _check_shape( # query: (batch_size, seq_len_q, num_heads, head_dim) # key: (batch_size, seq_len_kv, num_heads, head_dim) # value: (batch_size, seq_len_kv, num_heads, head_dim) - # attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv) + # attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv) # or (batch_size, num_heads, seq_len_q, seq_len_kv) if query.shape[-1] != key.shape[-1]: raise ValueError("Query and key must have the same head dimension.")