Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flash_dmattn/flash_dmattn_flex.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,6 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
return attn_output, attention_weights

flex_dmattn_func = flex_attention_forward
335 changes: 326 additions & 9 deletions flash_dmattn/flash_dmattn_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,325 @@ def _bwd_store_dk_dv(
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))


def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
@triton.jit
def _bwd_kernel_one_col_block(
Copy link

Copilot AI Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This kernel function spans 300+ lines and handles multiple responsibilities; consider breaking it into smaller helper functions (e.g., pointer setup, mask application, gradient computation) to improve readability and ease future maintenance.

Copilot uses AI. Check for mistakes.
start_n,
Q,
K,
V,
Mask,
Bias,
DO,
DQ,
DK,
DV,
DBias,
LSE,
D,
softmax_scale,
stride_qm,
stride_kn,
stride_vn,
stride_mm,
stride_bm,
stride_dom,
stride_dqm,
stride_dkn,
stride_dvn,
stride_dbm,
seqlen_q,
seqlen_k,
headdim,
ATOMIC_ADD: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
# initialize row/col offsets
offs_qm = begin_m + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :])
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
db_ptrs = DBias + (offs_qm[:, None] * stride_dbm + offs_n[None, :])
# initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# There seems to be some problem with Triton pipelining that makes results wrong for
# headdim=64, seqlen=(113, 255). In this case the for loop may have zero step,
# and pipelining with the bias matrix could screw it up. So we just exit early.
if begin_m >= seqlen_q:
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
_bwd_store_dk_dv(
dk_ptrs,
dv_ptrs,
dk,
dv,
offs_n,
offs_d,
seqlen_k,
headdim,
EVEN_M=EVEN_M,
EVEN_N=EVEN_N,
EVEN_HEADDIM=EVEN_HEADDIM,
)
return
# k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.load(k_ptrs), we get the wrong output!
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
else:
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
else:
k = tl.load(
k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
)
v = tl.load(
v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
)
# loop over rows
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
# Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
if EVEN_M & EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(
q_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0,
)
# recompute p = softmax(acc_s, dim=-1).T
acc_s = tl.dot(q, tl.trans(k))

tl.debug_barrier()
# Load mask
if EVEN_M & EVEN_N:
mask = tl.load(m_ptrs)
else:
mask = tl.load(
m_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
other=0.0,
)

# Trying to combine the two masks seem to make the result wrong
# Apply sequence length mask
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
acc_s = tl.where(offs_n[None, :] < seqlen_k, acc_s, float("-inf"))
# Apply causal mask
if IS_CAUSAL:
acc_s = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), acc_s, float("-inf"))
# Apply dynamic mask
acc_s = tl.where(mask > 0.0, acc_s, float("-inf"))

tl.debug_barrier() # Race condition otherwise
# Load bias
if EVEN_M & EVEN_N:
bias = tl.load(
b_ptrs,
mask=(mask > 0.0),
other=0.0,
).to(tl.float32)
else:
bias = tl.load(
b_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) & (mask > 0.0),
other=0.0,
).to(tl.float32)
acc_s = acc_s * softmax_scale + bias
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64.
if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr)
p = tl.exp(acc_s - lse_i[:, None])
# compute dv
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
# in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
# the output is correct.
if EVEN_M & EVEN_HEADDIM:
do = tl.load(do_ptrs)
else:
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
do = tl.load(
do_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0,
)
# if EVEN_M:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs)
# else:
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
# else:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
# else:
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
# & (offs_d[None, :] < headdim), other=0.0)
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier()
dp = tl.dot(do, tl.trans(v))
# There's a race condition for headdim=48
if not EVEN_HEADDIM:
tl.debug_barrier()
# compute dbias = p * (dp - delta[:, None]) and ds = dbias * softmax_scale
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
Di = tl.load(D + offs_m_curr)
# Converting ds to q.dtype here reduces register pressure and makes it much faster
# for BLOCK_HEADDIM=128
dbias = (p * (dp - Di[:, None]))
ds = (dbias * softmax_scale).to(q.dtype)
# dbias = tl.where(mask > 0.0, dbias, 0.0)
# ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
if not (EVEN_M & EVEN_N):
tl.debug_barrier()
if not ATOMIC_ADD:
if EVEN_M & EVEN_N:
tl.store(
db_ptrs,
dbias
)
else:
tl.store(
db_ptrs,
dbias,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k)
)
else:
if EVEN_M & EVEN_N:
tl.atomic_add(
db_ptrs,
dbias
)
else:
tl.atomic_add(
db_ptrs,
dbias,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k)
)
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds), q)
# compute dq
if not (
EVEN_M & EVEN_HEADDIM
): # Otherewise there's a race condition
tl.debug_barrier()
if not ATOMIC_ADD:
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
else:
if EVEN_HEADDIM:
dq = tl.load(
dq_ptrs,
mask=offs_m_curr[:, None] < seqlen_q,
other=0.0,
eviction_policy="evict_last",
)
dq += tl.dot(ds, k)
tl.store(
dq_ptrs,
dq,
mask=offs_m_curr[:, None] < seqlen_q,
eviction_policy="evict_last",
)
else:
dq = tl.load(
dq_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0,
eviction_policy="evict_last",
)
dq += tl.dot(ds, k)
tl.store(
dq_ptrs,
dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
eviction_policy="evict_last",
)
else: # If we're parallelizing across the seqlen_k dimension
dq = tl.dot(ds, k)
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
tl.atomic_add(dq_ptrs, dq)
else:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
else:
tl.atomic_add(
dq_ptrs,
dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
)
# increment pointers
do_ptrs += BLOCK_M * stride_dom
dq_ptrs += BLOCK_M * stride_dqm
db_ptrs += BLOCK_M * stride_dbm
q_ptrs += BLOCK_M * stride_qm
m_ptrs += BLOCK_M * stride_mm
b_ptrs += BLOCK_M * stride_bm

# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
_bwd_store_dk_dv(
dk_ptrs,
dv_ptrs,
dk,
dv,
offs_n,
offs_d,
seqlen_k,
headdim,
EVEN_M=EVEN_M,
EVEN_N=EVEN_N,
EVEN_HEADDIM=EVEN_HEADDIM,
)


def init_to_zero(names):
if isinstance(names, str):
names = [names]
def init_func(nargs):
for name in names:
nargs[name].zero_()
return init_func


@triton.autotune(
Expand All @@ -346,20 +663,20 @@ def init_to_zero(name):
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
num_warps=8,
num_stages=1,
pre_hook=init_to_zero("DQ"),
pre_hook=init_to_zero(["DQ", "DBias"]),
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True},
num_warps=8,
num_stages=1,
pre_hook=init_to_zero("DQ"),
pre_hook=init_to_zero(["DQ", "DBias"]),
),
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
Comment on lines 675 to +679
Copy link

Copilot AI Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] There are several commented-out Triton configuration lines that include DBias; once DBias support is verified, consider removing or updating these stale lines to reduce clutter.

Copilot uses AI. Check for mistakes.
],
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"],
)
Expand Down Expand Up @@ -793,4 +1110,4 @@ def backward(ctx, do):
return dq, dk, dv, None, dbias, None, None


flash_dmattn_func = FlashDMAttnFunc.apply
triton_dmattn_func = FlashDMAttnFunc.apply