diff --git a/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp b/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp index 216ff61b9..16b6a179d 100644 --- a/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp +++ b/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp @@ -420,6 +420,9 @@ cpu_flash_attention( int64_t oStrideB = output.stride(0); int64_t oStrideM = output.stride(1); int64_t oStrideH = output.stride(2); + int64_t lStrideB = logsumexp.stride(0); + int64_t lStrideM = logsumexp.stride(1); + int64_t lStrideH = logsumexp.stride(2); int64_t mStrideB = (attention_mask.has_value() && attention_mask.value().size(0) > 1) ? attention_mask.value().stride(0) @@ -459,6 +462,7 @@ cpu_flash_attention( ? attention_mask.value().data_ptr() : nullptr; scalar_t* out_data = output.data_ptr(); + accum_t* lse_data = logsumexp.data_ptr(); accum_t* buf_data = buf.data_ptr(); at::parallel_for( @@ -616,6 +620,13 @@ cpu_flash_attention( dst_data + row * headSize, headSize); } + // Store logsumexp for backward + accum_t* lse_ptr = + lse_data + i * lStrideB + j * lStrideH + m * lStrideM; + for (const auto row : c10::irange(qBlockSize)) { + lse_ptr[row * lStrideM] = + qk_max_data[row] + std::log(qk_sum_data[row]); + } // Move to the next query at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); } @@ -684,6 +695,9 @@ cpu_flash_attention( int64_t oStrideB = output.stride(0); int64_t oStrideM = output.stride(1); int64_t oStrideH = output.stride(2); + int64_t lStrideB = logsumexp.stride(0); + int64_t lStrideM = logsumexp.stride(1); + int64_t lStrideH = logsumexp.stride(2); int64_t mStrideB = (attention_mask.has_value() && attention_mask.value().size(0) > 1) ? attention_mask.value().stride(0) @@ -725,6 +739,7 @@ cpu_flash_attention( ? attention_mask.value().data_ptr() : nullptr; scalar_t* out_data = output.data_ptr(); + accum_t* lse_data = logsumexp.data_ptr(); accum_t* buf_data = buf.data_ptr(); scalar_t* buf_reduced_data = buf_reduced.data_ptr(); @@ -1310,6 +1325,13 @@ cpu_flash_attention( dst_data + row * headSize, headSize); } + // Store logsumexp for backward + accum_t* lse_ptr = + lse_data + i * lStrideB + j * lStrideH + m * lStrideM; + for (const auto row : c10::irange(qBlockSize)) { + lse_ptr[row * lStrideM] = + qk_max_data[row] + std::log(qk_sum_data[row]); + } // Move to the next query at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); }