-
Notifications
You must be signed in to change notification settings - Fork 39
Adds bias gradient computation to backward kernel #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -336,8 +336,325 @@ def _bwd_store_dk_dv( | |
| tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) | ||
|
|
||
|
|
||
| def init_to_zero(name): | ||
| return lambda nargs: nargs[name].zero_() | ||
| @triton.jit | ||
| def _bwd_kernel_one_col_block( | ||
| start_n, | ||
| Q, | ||
| K, | ||
| V, | ||
| Mask, | ||
| Bias, | ||
| DO, | ||
| DQ, | ||
| DK, | ||
| DV, | ||
| DBias, | ||
| LSE, | ||
| D, | ||
| softmax_scale, | ||
| stride_qm, | ||
| stride_kn, | ||
| stride_vn, | ||
| stride_mm, | ||
| stride_bm, | ||
| stride_dom, | ||
| stride_dqm, | ||
| stride_dkn, | ||
| stride_dvn, | ||
| stride_dbm, | ||
| seqlen_q, | ||
| seqlen_k, | ||
| headdim, | ||
| ATOMIC_ADD: tl.constexpr, | ||
| IS_CAUSAL: tl.constexpr, | ||
| BLOCK_HEADDIM: tl.constexpr, | ||
| EVEN_M: tl.constexpr, | ||
| EVEN_N: tl.constexpr, | ||
| EVEN_HEADDIM: tl.constexpr, | ||
| BLOCK_M: tl.constexpr, | ||
| BLOCK_N: tl.constexpr, | ||
| ): | ||
| # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) | ||
| begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M | ||
| # initialize row/col offsets | ||
| offs_qm = begin_m + tl.arange(0, BLOCK_M) | ||
| offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) | ||
| offs_m = tl.arange(0, BLOCK_M) | ||
| offs_d = tl.arange(0, BLOCK_HEADDIM) | ||
| # initialize pointers to value-like data | ||
| q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) | ||
| k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) | ||
| v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) | ||
| m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :]) | ||
| b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) | ||
| do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) | ||
| dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) | ||
| db_ptrs = DBias + (offs_qm[:, None] * stride_dbm + offs_n[None, :]) | ||
| # initialize dv and dk | ||
| dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) | ||
| dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) | ||
| # There seems to be some problem with Triton pipelining that makes results wrong for | ||
| # headdim=64, seqlen=(113, 255). In this case the for loop may have zero step, | ||
| # and pipelining with the bias matrix could screw it up. So we just exit early. | ||
| if begin_m >= seqlen_q: | ||
| dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) | ||
| dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) | ||
| _bwd_store_dk_dv( | ||
| dk_ptrs, | ||
| dv_ptrs, | ||
| dk, | ||
| dv, | ||
| offs_n, | ||
| offs_d, | ||
| seqlen_k, | ||
| headdim, | ||
| EVEN_M=EVEN_M, | ||
| EVEN_N=EVEN_N, | ||
| EVEN_HEADDIM=EVEN_HEADDIM, | ||
| ) | ||
| return | ||
| # k and v stay in SRAM throughout | ||
| # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, | ||
| # if we just call tl.load(k_ptrs), we get the wrong output! | ||
| if EVEN_N & EVEN_M: | ||
| if EVEN_HEADDIM: | ||
| k = tl.load(k_ptrs) | ||
| v = tl.load(v_ptrs) | ||
| else: | ||
| k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) | ||
| v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) | ||
| else: | ||
| if EVEN_HEADDIM: | ||
| k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) | ||
| v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) | ||
| else: | ||
| k = tl.load( | ||
| k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 | ||
| ) | ||
| v = tl.load( | ||
| v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 | ||
| ) | ||
| # loop over rows | ||
| num_block_m = tl.cdiv(seqlen_q, BLOCK_M) | ||
| for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): | ||
| start_m = tl.multiple_of(start_m, BLOCK_M) | ||
| offs_m_curr = start_m + offs_m | ||
| # load q, k, v, do on-chip | ||
| # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) | ||
| if EVEN_M & EVEN_HEADDIM: | ||
| q = tl.load(q_ptrs) | ||
| else: | ||
| if EVEN_HEADDIM: | ||
| q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) | ||
| else: | ||
| q = tl.load( | ||
| q_ptrs, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), | ||
| other=0.0, | ||
| ) | ||
| # recompute p = softmax(acc_s, dim=-1).T | ||
| acc_s = tl.dot(q, tl.trans(k)) | ||
|
|
||
| tl.debug_barrier() | ||
| # Load mask | ||
| if EVEN_M & EVEN_N: | ||
| mask = tl.load(m_ptrs) | ||
| else: | ||
| mask = tl.load( | ||
| m_ptrs, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), | ||
| other=0.0, | ||
| ) | ||
|
|
||
| # Trying to combine the two masks seem to make the result wrong | ||
| # Apply sequence length mask | ||
| if not EVEN_N: # Need to mask out otherwise the softmax is wrong | ||
| acc_s = tl.where(offs_n[None, :] < seqlen_k, acc_s, float("-inf")) | ||
| # Apply causal mask | ||
| if IS_CAUSAL: | ||
| acc_s = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), acc_s, float("-inf")) | ||
| # Apply dynamic mask | ||
| acc_s = tl.where(mask > 0.0, acc_s, float("-inf")) | ||
|
|
||
| tl.debug_barrier() # Race condition otherwise | ||
| # Load bias | ||
| if EVEN_M & EVEN_N: | ||
| bias = tl.load( | ||
| b_ptrs, | ||
| mask=(mask > 0.0), | ||
| other=0.0, | ||
| ).to(tl.float32) | ||
| else: | ||
| bias = tl.load( | ||
| b_ptrs, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) & (mask > 0.0), | ||
| other=0.0, | ||
| ).to(tl.float32) | ||
| acc_s = acc_s * softmax_scale + bias | ||
| # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. | ||
| # Also wrong for headdim=64. | ||
| if not (EVEN_M & EVEN_HEADDIM): | ||
| tl.debug_barrier() | ||
| lse_i = tl.load(LSE + offs_m_curr) | ||
| p = tl.exp(acc_s - lse_i[:, None]) | ||
| # compute dv | ||
| # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call | ||
| # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs | ||
| # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, | ||
| # the output is correct. | ||
| if EVEN_M & EVEN_HEADDIM: | ||
| do = tl.load(do_ptrs) | ||
| else: | ||
| # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. | ||
| do = tl.load( | ||
| do_ptrs, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), | ||
| other=0.0, | ||
| ) | ||
| # if EVEN_M: | ||
| # if EVEN_HEADDIM: | ||
| # do = tl.load(do_ptrs) | ||
| # else: | ||
| # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) | ||
| # else: | ||
| # if EVEN_HEADDIM: | ||
| # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) | ||
| # else: | ||
| # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) | ||
| # & (offs_d[None, :] < headdim), other=0.0) | ||
| dv += tl.dot(tl.trans(p.to(do.dtype)), do) | ||
| # compute dp = dot(v, do) | ||
| # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. | ||
| # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True | ||
| # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False | ||
| if not (EVEN_M & EVEN_HEADDIM): | ||
| tl.debug_barrier() | ||
| dp = tl.dot(do, tl.trans(v)) | ||
| # There's a race condition for headdim=48 | ||
| if not EVEN_HEADDIM: | ||
| tl.debug_barrier() | ||
| # compute dbias = p * (dp - delta[:, None]) and ds = dbias * softmax_scale | ||
| # Putting the subtraction after the dp matmul (instead of before) is slightly faster | ||
| Di = tl.load(D + offs_m_curr) | ||
| # Converting ds to q.dtype here reduces register pressure and makes it much faster | ||
| # for BLOCK_HEADDIM=128 | ||
| dbias = (p * (dp - Di[:, None])) | ||
| ds = (dbias * softmax_scale).to(q.dtype) | ||
| # dbias = tl.where(mask > 0.0, dbias, 0.0) | ||
| # ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) | ||
| if not (EVEN_M & EVEN_N): | ||
| tl.debug_barrier() | ||
| if not ATOMIC_ADD: | ||
| if EVEN_M & EVEN_N: | ||
| tl.store( | ||
| db_ptrs, | ||
| dbias | ||
| ) | ||
| else: | ||
| tl.store( | ||
| db_ptrs, | ||
| dbias, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) | ||
| ) | ||
| else: | ||
| if EVEN_M & EVEN_N: | ||
| tl.atomic_add( | ||
| db_ptrs, | ||
| dbias | ||
| ) | ||
| else: | ||
| tl.atomic_add( | ||
| db_ptrs, | ||
| dbias, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) | ||
| ) | ||
| # compute dk = dot(ds.T, q) | ||
| dk += tl.dot(tl.trans(ds), q) | ||
| # compute dq | ||
| if not ( | ||
| EVEN_M & EVEN_HEADDIM | ||
| ): # Otherewise there's a race condition | ||
| tl.debug_barrier() | ||
| if not ATOMIC_ADD: | ||
| if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M | ||
| dq = tl.load(dq_ptrs, eviction_policy="evict_last") | ||
| dq += tl.dot(ds, k) | ||
| tl.store(dq_ptrs, dq, eviction_policy="evict_last") | ||
| else: | ||
| if EVEN_HEADDIM: | ||
| dq = tl.load( | ||
| dq_ptrs, | ||
| mask=offs_m_curr[:, None] < seqlen_q, | ||
| other=0.0, | ||
| eviction_policy="evict_last", | ||
| ) | ||
| dq += tl.dot(ds, k) | ||
| tl.store( | ||
| dq_ptrs, | ||
| dq, | ||
| mask=offs_m_curr[:, None] < seqlen_q, | ||
| eviction_policy="evict_last", | ||
| ) | ||
| else: | ||
| dq = tl.load( | ||
| dq_ptrs, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), | ||
| other=0.0, | ||
| eviction_policy="evict_last", | ||
| ) | ||
| dq += tl.dot(ds, k) | ||
| tl.store( | ||
| dq_ptrs, | ||
| dq, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), | ||
| eviction_policy="evict_last", | ||
| ) | ||
| else: # If we're parallelizing across the seqlen_k dimension | ||
| dq = tl.dot(ds, k) | ||
| if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M | ||
| tl.atomic_add(dq_ptrs, dq) | ||
| else: | ||
| if EVEN_HEADDIM: | ||
| tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) | ||
| else: | ||
| tl.atomic_add( | ||
| dq_ptrs, | ||
| dq, | ||
| mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), | ||
| ) | ||
| # increment pointers | ||
| do_ptrs += BLOCK_M * stride_dom | ||
| dq_ptrs += BLOCK_M * stride_dqm | ||
| db_ptrs += BLOCK_M * stride_dbm | ||
| q_ptrs += BLOCK_M * stride_qm | ||
| m_ptrs += BLOCK_M * stride_mm | ||
| b_ptrs += BLOCK_M * stride_bm | ||
|
|
||
| # write-back | ||
| dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) | ||
| dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) | ||
| _bwd_store_dk_dv( | ||
| dk_ptrs, | ||
| dv_ptrs, | ||
| dk, | ||
| dv, | ||
| offs_n, | ||
| offs_d, | ||
| seqlen_k, | ||
| headdim, | ||
| EVEN_M=EVEN_M, | ||
| EVEN_N=EVEN_N, | ||
| EVEN_HEADDIM=EVEN_HEADDIM, | ||
| ) | ||
|
|
||
|
|
||
| def init_to_zero(names): | ||
| if isinstance(names, str): | ||
| names = [names] | ||
| def init_func(nargs): | ||
| for name in names: | ||
| nargs[name].zero_() | ||
| return init_func | ||
|
|
||
|
|
||
| @triton.autotune( | ||
|
|
@@ -346,20 +663,20 @@ def init_to_zero(name): | |
| {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, | ||
| num_warps=8, | ||
| num_stages=1, | ||
| pre_hook=init_to_zero("DQ"), | ||
| pre_hook=init_to_zero(["DQ", "DBias"]), | ||
| ), | ||
| triton.Config( | ||
| {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, | ||
| num_warps=8, | ||
| num_stages=1, | ||
| pre_hook=init_to_zero("DQ"), | ||
| pre_hook=init_to_zero(["DQ", "DBias"]), | ||
| ), | ||
| # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now | ||
| # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* | ||
| # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), | ||
| # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), | ||
| # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), | ||
| # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), | ||
| # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), | ||
| # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), | ||
| # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), | ||
| # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), | ||
|
Comment on lines
675
to
+679
|
||
| ], | ||
| key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"], | ||
| ) | ||
|
|
@@ -793,4 +1110,4 @@ def backward(ctx, do): | |
| return dq, dk, dv, None, dbias, None, None | ||
|
|
||
|
|
||
| flash_dmattn_func = FlashDMAttnFunc.apply | ||
| triton_dmattn_func = FlashDMAttnFunc.apply | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This kernel function spans 300+ lines and handles multiple responsibilities; consider breaking it into smaller helper functions (e.g., pointer setup, mask application, gradient computation) to improve readability and ease future maintenance.