From 109d2dd64afc672e94411bc336872fa317a93ca0 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 29 Oct 2025 23:10:07 -0700 Subject: [PATCH 1/4] ulysses enabling in native attention path Signed-off-by: Wang, Yi A --- src/diffusers/models/attention_dispatch.py | 53 +++++++++++++++++----- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ab0d7102ee83..1f2993592c4a 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1538,18 +1538,47 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + if _parallel_config is None: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + elif _parallel_config.context_parallel_config.ring_degree == 1: + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + world_size = _parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + B, S_Q_LOCAL, H, D = query.shape + _, S_KV_LOCAL, _, _ = key.shape + H_LOCAL = H // world_size + query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) + query, key, value = (x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.reshape(B, H_LOCAL, world_size, S_Q_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + out = _all_to_all_single(out, group) + out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + else: + raise ValueError("Native attention backend does not support context parallelism with ring_degree > 1, you could try to use ulysses Attention instead") return out From 146cd6dd773b27acbbbc85d7ebb27e5bdfda0ca7 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 30 Oct 2025 19:31:02 -0700 Subject: [PATCH 2/4] address review comment Signed-off-by: Wang, Yi A --- src/diffusers/models/attention_dispatch.py | 34 +++++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 1f2993592c4a..f82e138b5cba 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1556,12 +1556,24 @@ def _native_attention( world_size = _parallel_config.context_parallel_config.ulysses_degree group = ulysses_mesh.get_group() - B, S_Q_LOCAL, H, D = query.shape - _, S_KV_LOCAL, _, _ = key.shape - H_LOCAL = H // world_size - query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() - key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() - value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + batch_size, seq_len_q_local, num_heads, head_dim = query.shape + _, seq_len_kv_local, _, _ = key.shape + num_heads_local = num_heads // world_size + query = ( + query.reshape(batch_size, seq_len_q_local, world_size, num_heads_local, head_dim) + .permute(2, 1, 0, 3, 4) + .contiguous() + ) + key = ( + key.reshape(batch_size, seq_len_kv_local, world_size, num_heads_local, head_dim) + .permute(2, 1, 0, 3, 4) + .contiguous() + ) + value = ( + value.reshape(batch_size, seq_len_kv_local, world_size, num_heads_local, head_dim) + .permute(2, 1, 0, 3, 4) + .contiguous() + ) query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) query, key, value = (x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( @@ -1574,11 +1586,17 @@ def _native_attention( scale=scale, enable_gqa=enable_gqa, ) - out = out.reshape(B, H_LOCAL, world_size, S_Q_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + out = ( + out.reshape(batch_size, num_heads_local, world_size, seq_len_q_local, head_dim) + .permute(2, 1, 0, 3, 4) + .contiguous() + ) out = _all_to_all_single(out, group) out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() else: - raise ValueError("Native attention backend does not support context parallelism with ring_degree > 1, you could try to use ulysses Attention instead") + raise ValueError( + "Native attention backend does not support context parallelism with `ring_degree` > 1, try Ulysses Attention instead by specifying `ulysses_degree` > 1." + ) return out From 0894b52ac15dfb74d2d881ea7db0c9af772bf683 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 31 Oct 2025 00:22:41 -0700 Subject: [PATCH 3/4] add supports_context_parallel for native attention Signed-off-by: Wang, Yi A --- src/diffusers/models/attention_dispatch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f82e138b5cba..62cc6c7ffed1 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1523,6 +1523,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], + supports_context_parallel=True, ) def _native_attention( query: torch.Tensor, From 31b2c545422acbc4d567aec33791446285eee2f7 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 2 Nov 2025 20:54:03 -0800 Subject: [PATCH 4/4] update templated attention Signed-off-by: Wang, Yi A --- src/diffusers/models/attention_dispatch.py | 138 ++++++++++++++------- 1 file changed, 94 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 62cc6c7ffed1..c17a3d0ed6ba 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -649,6 +649,86 @@ def _( # ===== Helper functions to use attention backends with templated CP autograd functions ===== +def _native_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # Native attention does not return_lse + if return_lse: + raise ValueError("Native attention does not support return_lse=True") + + # used for backward pass + if _save_ctx: + ctx.save_for_backward(query, key, value) + ctx.attn_mask = attn_mask + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.enable_gqa = enable_gqa + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + return out + + +def _native_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + query, key, value = ctx.saved_tensors + + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query_t, + key=key_t, + value=value_t, + attn_mask=ctx.attn_mask, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + enable_gqa=ctx.enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + grad_out_t = grad_out.permute(0, 2, 1, 3) + grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( + outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False + ) + + grad_query = grad_query_t.permute(0, 2, 1, 3) + grad_key = grad_key_t.permute(0, 2, 1, 3) + grad_value = grad_value_t.permute(0, 2, 1, 3) + + return grad_query, grad_key, grad_value + + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # forward declaration: # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) @@ -1552,52 +1632,22 @@ def _native_attention( enable_gqa=enable_gqa, ) out = out.permute(0, 2, 1, 3) - elif _parallel_config.context_parallel_config.ring_degree == 1: - ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh - world_size = _parallel_config.context_parallel_config.ulysses_degree - group = ulysses_mesh.get_group() - - batch_size, seq_len_q_local, num_heads, head_dim = query.shape - _, seq_len_kv_local, _, _ = key.shape - num_heads_local = num_heads // world_size - query = ( - query.reshape(batch_size, seq_len_q_local, world_size, num_heads_local, head_dim) - .permute(2, 1, 0, 3, 4) - .contiguous() - ) - key = ( - key.reshape(batch_size, seq_len_kv_local, world_size, num_heads_local, head_dim) - .permute(2, 1, 0, 3, 4) - .contiguous() - ) - value = ( - value.reshape(batch_size, seq_len_kv_local, world_size, num_heads_local, head_dim) - .permute(2, 1, 0, 3, 4) - .contiguous() - ) - query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) - query, key, value = (x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = ( - out.reshape(batch_size, num_heads_local, world_size, seq_len_q_local, head_dim) - .permute(2, 1, 0, 3, 4) - .contiguous() - ) - out = _all_to_all_single(out, group) - out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() else: - raise ValueError( - "Native attention backend does not support context parallelism with `ring_degree` > 1, try Ulysses Attention instead by specifying `ulysses_degree` > 1." + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op=_native_attention_forward_op, + backward_op=_native_attention_backward_op, + _parallel_config=_parallel_config, ) + return out