diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index f273451..41bfd28 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -12,7 +12,7 @@ # # This config has a race condition when EVEN_M == False, disabling it for now. # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), # ], -# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] # ) @triton.heuristics( { @@ -120,6 +120,38 @@ def _fwd_kernel( end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) + + # Load k + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # compute acc_s + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + acc_s += tl.dot(q, tl.trans(k)) + + # Trying to combine the two masks seem to make the result wrong + # Apply sequence length mask + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + # Apply causal mask + if IS_CAUSAL: + acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) # Load mask if EVEN_M & EVEN_N: @@ -132,41 +164,11 @@ def _fwd_kernel( ) # Check if any element in mask is non-zero - any_active = tl.sum(mask) > 0 - - # compute acc_s - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if any_active: - # Load k - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - acc_s += tl.dot(q, tl.trans(k)) - - # Trying to combine the two masks seem to make the result wrong - # Apply sequence length mask - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - # Apply causal mask - if IS_CAUSAL: - acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + # BUG: Triton needs to determine the control flow at compile time. + # Dynamic conditions at runtime can undermine this optimization. + # any_active = tl.sum(mask) != 0 # Apply dynamic mask - acc_s += tl.where(mask > 0, 0, float("-inf")) + acc_s += tl.where(mask > 0.0, 0.0, float("-inf")) # Load bias if EVEN_M & EVEN_N: @@ -184,7 +186,7 @@ def _fwd_kernel( # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. acc_s = acc_s * softmax_scale + bias - acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s) + # acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s) m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) p = tl.exp(acc_s - m_ij[:, None]) l_ij = tl.sum(p, 1) @@ -198,27 +200,26 @@ def _fwd_kernel( acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] - if any_active: - # load v - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) - else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + # load v + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) else: - if EVEN_HEADDIM: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - acc_o += tl.dot(p.to(v.dtype), v) + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + acc_o += tl.dot(p.to(v.dtype), v) # update statistics m_i = m_ij @@ -258,7 +259,277 @@ def _fwd_kernel( ) -def _flash_attn_forward(q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + + off_b * stride_dob + + off_h * stride_doh + + offs_m[:, None] * stride_dom + + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_mb, + stride_mh, + stride_mm, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dbb, + stride_dbh, + stride_dbm, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + Mask += off_b * stride_mb + off_h * stride_mh + Bias += off_b * stride_bb + off_h * stride_bh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + DBias += off_b * stride_dbb + off_h * stride_dbh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_mm, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + stride_dbm, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=False, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_mm, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + stride_dbm, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=True, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + +def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None): # shape constraints batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape @@ -269,26 +540,18 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, causal=False, softmax_sca assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda - if mask is not None: - assert mask.shape == (batch, nheads, seqlen_q, seqlen_k), f"mask shape {mask.shape} does not match expected shape {(batch, nheads, seqlen_q, seqlen_k)}" - assert mask.dtype in [torch.float16, torch.bfloat16, torch.float32], "mask must be fp16, bf16, or fp32" - assert mask.is_cuda, "mask must be on CUDA" - if mask.stride(-1) != 1: - mask = mask.contiguous() - else: - # Create a default mask of all ones - mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) - - if bias is not None: - assert bias.dtype in [q.dtype, torch.float], f"bias dtype {bias.dtype} must match q dtype {q.dtype} or be float" - assert bias.is_cuda, "bias must be on CUDA" - assert bias.dim() == 4, f"bias must be 4D, got {bias.dim()}D" - assert bias.shape == (batch, nheads, seqlen_q, seqlen_k), f"bias shape {bias.shape} must be (batch={batch}, nheads={nheads}, seqlen_q={seqlen_q}, seqlen_k={seqlen_k})" - if bias.stride(-1) != 1: - bias = bias.contiguous() - else: - # Create zero bias if none provided - bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) + assert mask.shape == (batch, nheads, seqlen_q, seqlen_k), f"mask shape {mask.shape} does not match expected shape {(batch, nheads, seqlen_q, seqlen_k)}" + assert mask.dtype in [torch.float16, torch.bfloat16, torch.float32], "mask must be fp16, bf16, or fp32" + assert mask.is_cuda, "mask must be on CUDA" + if mask.stride(-1) != 1: + mask = mask.contiguous() + + assert bias.dtype in [q.dtype, torch.float], f"bias dtype {bias.dtype} must match q dtype {q.dtype} or be float" + assert bias.is_cuda, "bias must be on CUDA" + assert bias.dim() == 4, f"bias must be 4D, got {bias.dim()}D" + assert bias.shape == (batch, nheads, seqlen_q, seqlen_k), f"bias shape {bias.shape} must be (batch={batch}, nheads={nheads}, seqlen_q={seqlen_q}, seqlen_k={seqlen_k})" + if bias.stride(-1) != 1: + bias = bias.contiguous() softmax_scale = softmax_scale or 1.0 / math.sqrt(d) @@ -348,52 +611,186 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, causal=False, softmax_sca return o, lse, softmax_scale # softmax_scale could have been updated -class FlashAttnFunc(torch.autograd.Function): +def _flash_attn_backward( + do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, causal=False, softmax_scale=None +): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + + assert mask.dtype in [q.dtype, torch.float] + assert mask.is_cuda + assert mask.dim() == 4 + assert mask.stride(-1) == 1 + + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=64, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, + k, + v, + mask, + bias, + do, + dq_accum, + dk, + dv, + dbias, + lse, + delta, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + mask.stride(0), + mask.stride(1), + mask.stride(2), + bias.stride(0), + bias.stride(1), + bias.stride(2), + do.stride(0), + do.stride(2), + do.stride(1), + dq_accum.stride(0), + dq_accum.stride(2), + dq_accum.stride(1), + dk.stride(0), + dk.stride(2), + dk.stride(1), + dv.stride(0), + dv.stride(2), + dv.stride(1), + dbias.stride(0), + dbias.stride(1), + dbias.stride(2), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + causal, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + +class FlashDMAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): """ q: (batch_size, seqlen_q, nheads, headdim) - k, v: (batch_size, seqlen_k, nheads, headdim) - mask: optional, shape (batch, nheads, seqlen_q, seqlen_k), dynamic attention mask - bias: optional, shape must be exactly (batch, nheads, seqlen_q, seqlen_k), attention bias matrix + k: (batch_size, seqlen_k, nheads, headdim) + v: (batch_size, seqlen_k, nheads, headdim) + mask: optional, (batch, nheads, seqlen_q, seqlen_k) + bias: optional, (batch, nheads, seqlen_q, seqlen_k) causal: bool, whether to apply causal masking softmax_scale: float, scaling factor for attention scores """ + batch, seqlen_q, nheads, _ = q.shape + _, seqlen_k, _, _ = k.shape + if mask is not None: + if mask.dtype == torch.bool: + mask = torch.where(mask, 1.0, 0.0) + else: + mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) + if bias is None: + bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) + # Make sure that the last dimension is contiguous - q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + q, k, v, mask, bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, mask, bias]] o, lse, ctx.softmax_scale = _flash_attn_forward( - q, k, v, mask=mask, bias=bias, causal=causal, softmax_scale=softmax_scale + q, k, v, mask, bias, causal=causal, softmax_scale=softmax_scale ) ctx.save_for_backward(q, k, v, o, lse, mask, bias) ctx.causal = causal return o - # @staticmethod - # def backward(ctx, do): - # q, k, v, o, lse, mask, bias = ctx.saved_tensors - # assert not ctx.needs_input_grad[3], "FlashAttention does not support mask gradient yet" - # assert not ctx.needs_input_grad[4], "FlashAttention does not support bias gradient yet" - # # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - # with torch.inference_mode(): - # dq = torch.empty_like(q) - # dk = torch.empty_like(k) - # dv = torch.empty_like(v) - # _flash_attn_backward( - # do, - # q, - # k, - # v, - # o, - # lse, - # dq, - # dk, - # dv, - # bias=bias, - # causal=ctx.causal, - # softmax_scale=ctx.softmax_scale, - # ) - # return dq, dk, dv, None, None, None, None - - -flash_dmattn_func = FlashAttnFunc.apply + @staticmethod + def backward(ctx, do): + q, k, v, o, lse, mask, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbias = torch.empty_like(bias) + _flash_attn_backward( + do, + q, + k, + v, + mask, + bias, + o, + lse, + dq, + dk, + dv, + dbias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dq, dk, dv, None, dbias, None, None + + +flash_dmattn_func = FlashDMAttnFunc.apply