Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ class AttentionBackendName(str, Enum):
_NATIVE_FLASH = "_native_flash"
_NATIVE_MATH = "_native_math"
_NATIVE_NPU = "_native_npu"
_NATIVE_NEURON = "_native_neuron"
_NATIVE_XLA = "_native_xla"

# `sageattention`
Expand Down Expand Up @@ -576,6 +577,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
)

elif backend == AttentionBackendName._NATIVE_NEURON:
pass # No extra dependency check needed; torch_neuronx overrides the ATen op at import time.

elif backend == AttentionBackendName._NATIVE_XLA:
if not _CAN_USE_XLA_ATTN:
raise RuntimeError(
Expand Down Expand Up @@ -3218,6 +3222,126 @@ def _native_npu_attention(
return out


def _neuron_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask=None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale=None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config=None,
):
"""Forward op for Neuron ring attention using _scaled_dot_product_fused_attention_overrideable.

Saves query, key, value, out, lse, philox_seed, philox_offset for backward.
Follows the same pattern as _cudnn_attention_forward_op.
"""
import math

q, k, v = (x.permute(0, 2, 1, 3) for x in (query, key, value))
if scale is None:
scale = 1.0 / math.sqrt(q.shape[-1])

result = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q, k, v,
attn_bias=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=False,
scale=scale,
)
out_bhsd, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, _ = result

if _save_ctx:
ctx.save_for_backward(q, k, v, out_bhsd, lse, philox_seed, philox_offset)
ctx.attn_mask = attn_mask
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.max_q = max_q
ctx.max_k = max_k

out = out_bhsd.permute(0, 2, 1, 3) # [B, S, H, D]
# [B, H, S] → [B, S, H, 1] for broadcasting in ring accumulation against out [B, S, H, D]
lse_out = lse.permute(0, 2, 1).unsqueeze(-1)
return (out, lse_out) if return_lse else out


def _neuron_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
"""Backward op for Neuron ring attention using _scaled_dot_product_fused_attention_overrideable_backward."""
q, k, v, out_bhsd, lse, philox_seed, philox_offset = ctx.saved_tensors

grad_out_bhsd = grad_out.permute(0, 2, 1, 3) # [B, S, H, D] → [B, H, S, D]
grad_input_mask = [True, True, True, False] # grad for q, k, v; not attn_bias

attn_bias = ctx.attn_mask if ctx.attn_mask is not None else torch.zeros((1,), dtype=q.dtype, device=q.device)
cum_seq_q = cum_seq_k = torch.zeros((1,), dtype=torch.int32, device=q.device)

grad_q, grad_k, grad_v, _ = torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
grad_out_bhsd, q, k, v,
attn_bias,
grad_input_mask,
out_bhsd,
lse,
cum_seq_q, cum_seq_k,
ctx.max_q, ctx.max_k,
ctx.dropout_p,
ctx.is_causal,
philox_seed, philox_offset,
scale=ctx.scale,
)
# [B, H, S, D] → [B, S, H, D]
return grad_q.permute(0, 2, 1, 3), grad_k.permute(0, 2, 1, 3), grad_v.permute(0, 2, 1, 3)


@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_NEURON,
constraints=[],
supports_context_parallel=True,
)
def _native_neuron_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask=None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale=None,
enable_gqa: bool = False,
return_lse: bool = False,
_parallel_config=None,
) -> torch.Tensor:
if _parallel_config is not None:
return _templated_context_parallel_attention(
query, key, value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
return_lse=return_lse,
forward_op=_neuron_attention_forward_op,
backward_op=_neuron_attention_backward_op,
_parallel_config=_parallel_config,
)
# Non-ring path
return _neuron_attention_forward_op(
None, query, key, value,
attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal,
scale=scale, return_lse=return_lse, _save_ctx=False,
)


# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_XLA,
Expand Down
Loading