Skip to content

Commit

Permalink
[flash attention] calculate logsumexp for backward (#2631)
Browse files Browse the repository at this point in the history
* [flash attention] calculate logsumexp for backward

* [flash attention] calculate logsumexp for backward reduced type

---------

Co-authored-by: jiayisunx <jiayi.sun@intel.com>
Co-authored-by: WeizhuoZhang-intel <weizhuo.zhang@intel.com>
  • Loading branch information
3 people committed Mar 1, 2024
1 parent c7b1a8b commit f55a2bf
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ cpu_flash_attention(
int64_t oStrideB = output.stride(0);
int64_t oStrideM = output.stride(1);
int64_t oStrideH = output.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t mStrideB =
(attention_mask.has_value() && attention_mask.value().size(0) > 1)
? attention_mask.value().stride(0)
Expand Down Expand Up @@ -459,6 +462,7 @@ cpu_flash_attention(
? attention_mask.value().data_ptr<accum_t>()
: nullptr;
scalar_t* out_data = output.data_ptr<scalar_t>();
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();

at::parallel_for(
Expand Down Expand Up @@ -616,6 +620,13 @@ cpu_flash_attention(
dst_data + row * headSize,
headSize);
}
// Store logsumexp for backward
accum_t* lse_ptr =
lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
for (const auto row : c10::irange(qBlockSize)) {
lse_ptr[row * lStrideM] =
qk_max_data[row] + std::log(qk_sum_data[row]);
}
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, k, qSlice);
}
Expand Down Expand Up @@ -684,6 +695,9 @@ cpu_flash_attention(
int64_t oStrideB = output.stride(0);
int64_t oStrideM = output.stride(1);
int64_t oStrideH = output.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t mStrideB =
(attention_mask.has_value() && attention_mask.value().size(0) > 1)
? attention_mask.value().stride(0)
Expand Down Expand Up @@ -725,6 +739,7 @@ cpu_flash_attention(
? attention_mask.value().data_ptr<accum_t>()
: nullptr;
scalar_t* out_data = output.data_ptr<scalar_t>();
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = buf_reduced.data_ptr<scalar_t>();

Expand Down Expand Up @@ -1310,6 +1325,13 @@ cpu_flash_attention(
dst_data + row * headSize,
headSize);
}
// Store logsumexp for backward
accum_t* lse_ptr =
lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
for (const auto row : c10::irange(qBlockSize)) {
lse_ptr[row * lStrideM] =
qk_max_data[row] + std::log(qk_sum_data[row]);
}
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, k, qSlice);
}
Expand Down

0 comments on commit f55a2bf

Please sign in to comment.