From f55a2bfa5d505fb7c7a6225c1c6206b5926777ab Mon Sep 17 00:00:00 2001 From: Xuan Liao Date: Fri, 1 Mar 2024 15:05:59 +0800 Subject: [PATCH] [flash attention] calculate logsumexp for backward (#2631) * [flash attention] calculate logsumexp for backward * [flash attention] calculate logsumexp for backward reduced type --------- Co-authored-by: jiayisunx Co-authored-by: WeizhuoZhang-intel --- csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp | 22 ++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp b/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp index 216ff61b9..16b6a179d 100644 --- a/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp +++ b/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp @@ -420,6 +420,9 @@ cpu_flash_attention( int64_t oStrideB = output.stride(0); int64_t oStrideM = output.stride(1); int64_t oStrideH = output.stride(2); + int64_t lStrideB = logsumexp.stride(0); + int64_t lStrideM = logsumexp.stride(1); + int64_t lStrideH = logsumexp.stride(2); int64_t mStrideB = (attention_mask.has_value() && attention_mask.value().size(0) > 1) ? attention_mask.value().stride(0) @@ -459,6 +462,7 @@ cpu_flash_attention( ? attention_mask.value().data_ptr() : nullptr; scalar_t* out_data = output.data_ptr(); + accum_t* lse_data = logsumexp.data_ptr(); accum_t* buf_data = buf.data_ptr(); at::parallel_for( @@ -616,6 +620,13 @@ cpu_flash_attention( dst_data + row * headSize, headSize); } + // Store logsumexp for backward + accum_t* lse_ptr = + lse_data + i * lStrideB + j * lStrideH + m * lStrideM; + for (const auto row : c10::irange(qBlockSize)) { + lse_ptr[row * lStrideM] = + qk_max_data[row] + std::log(qk_sum_data[row]); + } // Move to the next query at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); } @@ -684,6 +695,9 @@ cpu_flash_attention( int64_t oStrideB = output.stride(0); int64_t oStrideM = output.stride(1); int64_t oStrideH = output.stride(2); + int64_t lStrideB = logsumexp.stride(0); + int64_t lStrideM = logsumexp.stride(1); + int64_t lStrideH = logsumexp.stride(2); int64_t mStrideB = (attention_mask.has_value() && attention_mask.value().size(0) > 1) ? attention_mask.value().stride(0) @@ -725,6 +739,7 @@ cpu_flash_attention( ? attention_mask.value().data_ptr() : nullptr; scalar_t* out_data = output.data_ptr(); + accum_t* lse_data = logsumexp.data_ptr(); accum_t* buf_data = buf.data_ptr(); scalar_t* buf_reduced_data = buf_reduced.data_ptr(); @@ -1310,6 +1325,13 @@ cpu_flash_attention( dst_data + row * headSize, headSize); } + // Store logsumexp for backward + accum_t* lse_ptr = + lse_data + i * lStrideB + j * lStrideH + m * lStrideM; + for (const auto row : c10::irange(qBlockSize)) { + lse_ptr[row * lStrideM] = + qk_max_data[row] + std::log(qk_sum_data[row]); + } // Move to the next query at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); }