diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ab0d7102ee83..c17a3d0ed6ba 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -649,6 +649,86 @@ def _( # ===== Helper functions to use attention backends with templated CP autograd functions ===== +def _native_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # Native attention does not return_lse + if return_lse: + raise ValueError("Native attention does not support return_lse=True") + + # used for backward pass + if _save_ctx: + ctx.save_for_backward(query, key, value) + ctx.attn_mask = attn_mask + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.enable_gqa = enable_gqa + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + return out + + +def _native_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + query, key, value = ctx.saved_tensors + + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query_t, + key=key_t, + value=value_t, + attn_mask=ctx.attn_mask, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + enable_gqa=ctx.enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + grad_out_t = grad_out.permute(0, 2, 1, 3) + grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( + outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False + ) + + grad_query = grad_query_t.permute(0, 2, 1, 3) + grad_key = grad_key_t.permute(0, 2, 1, 3) + grad_value = grad_value_t.permute(0, 2, 1, 3) + + return grad_query, grad_key, grad_value + + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # forward declaration: # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) @@ -1523,6 +1603,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], + supports_context_parallel=True, ) def _native_attention( query: torch.Tensor, @@ -1538,18 +1619,35 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + if _parallel_config is None: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + else: + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op=_native_attention_forward_op, + backward_op=_native_attention_backward_op, + _parallel_config=_parallel_config, + ) + return out