diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index e6a8b7a..03068ab 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -979,13 +979,10 @@ mha_bwd( dbias_expanded = has_bias ? ( (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q - ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) + ? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) : dbias ) : torch::empty({0}, opts); - if (has_bias) { - dbias_expanded.zero_(); - } Flash_bwd_params params; @@ -1050,7 +1047,7 @@ mha_bwd( } } - return { dq, dk, dv, dbias, softmax_d }; + return {dq, dk, dv, dbias, softmax_d}; } std::vector diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 663dc2e..4ee10eb 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -409,6 +409,7 @@ def forward( deterministic = False if return_softmax is None: return_softmax = False + seqlen_k_bias_og = bias.shape[-1] if bias is not None else 0 # Padding to multiple of 8 for 16-bit memory allocations head_size_og = q.size(3) @@ -446,6 +447,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic + ctx.seqlen_k_bias_og = seqlen_k_bias_og out = out_padded[..., :head_size_og] @@ -491,7 +493,7 @@ def backward( dv = dv[..., : dout.shape[-1]] if dbias is not None: - dbias = dbias[..., : k.shape[1]] + dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]] return dq, dk, dv, None, dbias, None, None, None, None, None, None