Skip to content

Commit

Permalink
[kernels] refactor flash attention for continuous batching (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
abenmao authored and Duyi-Wang committed May 15, 2024
1 parent 451ef21 commit 5e98e6d
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 130 deletions.
128 changes: 127 additions & 1 deletion src/kernels/attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -778,4 +778,130 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in
} // end for b
}

} // namespace xft
// scaled dot-product attention: bmm1 + softmax + bmm2
// query key value are all in [*, seqLen, headnum, headsize] order
template <typename T, typename AttnT>
void selfScaledDpAttention(T *output, const T *query, const AttnT *key, const AttnT *value, int qHeadNum, int kvHeadNum,
int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *inputSeqLens,
const int *pastSeqLens, bool causal, const float *alibiSlopes, const float *attnMask, const float scale,
int threadNum) {
// output = softmax(query * trans(key)) * value
// causal = True: llama-family, chatglm2; extra alibiSlopes for baichuan
// causal = False: just chatglm (prefixLLM, 0:startid) need attnMask for now

// get the max seqLen
int maxSrcLen = 0, maxTgtLen = 0;
for (int i = 0; i < batchSize; ++i) {
maxSrcLen = std::max(maxSrcLen, inputSeqLens[i]);
maxTgtLen = std::max(maxTgtLen, inputSeqLens[i] + pastSeqLens[i]);
}
// compute the seqStartLoc
int seqStartLoc[batchSize + 1];
seqStartLoc[0] = 0;
for (int i = 0; i < batchSize; ++i) {
seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i];
}

// closest value of power of 2
int minBlk = (int)std::pow(2, int(std::log2(maxSrcLen / 2)));
// Split sequence to make sure a moderate sync frequency and the intermediate
// result [srcSeq * tgtSeq] in cache. The current block size is derived from practical experience.
int srcBlk = std::min(256, minBlk);
int tgtBlk = std::min(512, maxTgtLen);

int numGroup = qHeadNum / kvHeadNum;

int numArr = 7;
int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk;
float *thrBuf
= (float *)SimpleMemPool::instance().getBuffer("threadBuffers", sizeof(float) * threadNum * arrStride);
float **thrPtrBuf
= (float **)SimpleMemPool::instance().getBuffer("threadPtrBuffers", sizeof(float *) * threadNum * numArr);

float **preSum = thrPtrBuf;
float **sum = thrPtrBuf + threadNum;
float **preMax = thrPtrBuf + threadNum * 2;
float **max = thrPtrBuf + threadNum * 3;
float **qkArr = thrPtrBuf + threadNum * 4;
float **expQkvArr = thrPtrBuf + threadNum * 5;
float **qArr = thrPtrBuf + threadNum * 6;

for (int i = 0; i < threadNum; ++i) {
preSum[i] = thrBuf + srcBlk * i;
sum[i] = thrBuf + srcBlk * threadNum + srcBlk * i;
preMax[i] = thrBuf + srcBlk * threadNum * 2 + srcBlk * i;
max[i] = thrBuf + srcBlk * threadNum * 3 + srcBlk * i;
qkArr[i] = thrBuf + srcBlk * threadNum * 4 + srcBlk * tgtBlk * i;
expQkvArr[i] = thrBuf + srcBlk * threadNum * (4 + tgtBlk) + srcBlk * headSize * i;
qArr[i] = thrBuf + srcBlk * threadNum * (4 + tgtBlk + headSize) + srcBlk * headSize * i;
}

#pragma omp parallel for collapse(3) schedule(dynamic)
for (uint64_t b = 0; b < batchSize; ++b) {
for (int h = 0; h < qHeadNum; ++h) {
for (int m = 0; m < maxSrcLen; m += srcBlk) {
int srcLen = inputSeqLens[b];
int tgtLen = inputSeqLens[b] + pastSeqLens[b];
if (m >= srcLen) { continue; }

int tid = omp_get_thread_num();
int qRealBlk = std::min(srcBlk, srcLen - m);
uint64_t srcOff = seqStartLoc[b] * qStride + h * headSize;
uint64_t outOff = seqStartLoc[b] * oStride + h * headSize;
const T *qbuf = query + srcOff + m * qStride;
AttnT *q = (AttnT *)qArr[tid];
T *out = output + outOff + m * oStride;

// reset out
for (int ii = 0; ii < qRealBlk; ++ii) {
#pragma omp simd
for (int jj = 0; jj < headSize; ++jj) {
out[ii * oStride + jj] = 0; // reset output
q[ii * headSize + jj] = (AttnT)(qbuf[ii * qStride + jj]); // reset output
}
}
// reset sum
#pragma omp simd
for (int ii = 0; ii < qRealBlk; ++ii) {
preSum[tid][ii] = 0;
sum[tid][ii] = 0;
preMax[tid][ii] = std::numeric_limits<float>::lowest();
max[tid][ii] = std::numeric_limits<float>::lowest();
}

uint64_t tgtOff = seqStartLoc[b] * kvStride + (h / numGroup) * headSize;
const AttnT *k = key + tgtOff;
const AttnT *v = value + tgtOff;
// split the target len dimension
for (int n = 0; n < tgtLen; n += tgtBlk) {
int kvRealBlk = std::min(tgtBlk, tgtLen - n);
// mask out. TODO: for prefixLLM
if (causal && m + qRealBlk - 1 < n) {
//printf("Skip bs %d head %d src %d tgt %d\n", b, h, m, n);
break;
}

const AttnT *kBlk = k + n * kvStride;
const AttnT *vBlk = v + n * kvStride;

if (causal) {
// causal=True, build-in mask
float headSlope = alibiSlopes != nullptr ? alibiSlopes[h] : 0.0f;
DecoderUtil::incrementalTileAttentionCausal(q, kBlk, vBlk, headSlope, m, n, qRealBlk, headSize,
kvRealBlk, preSum[tid], sum[tid], preMax[tid], max[tid], scale, qkArr[tid],
expQkvArr[tid], out, headSize, kvStride, kvStride, oStride);
} else {
// causal=False, need mask matrix for now
const float *attnMsk = attnMask + seqStartLoc[b] * tgtLen + m * tgtLen + n;
DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk, qRealBlk, headSize, kvRealBlk,
tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], scale, qkArr[tid], expQkvArr[tid],
out, headSize, kvStride, kvStride, oStride);
}
}
}
}
}
return;
}

} // namespace xft
145 changes: 30 additions & 115 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class Attention {
printf("Not supported yet: QHeads=%d, KVHeads=%d\n", ctx->attHeadNum, ctx->kvHeadNum);
exit(-1);
}

alibiSlopes = nullptr;
}

// The inerface is for PyTorch, thus the weights are already transposed
Expand Down Expand Up @@ -701,8 +703,14 @@ class Attention {
int kvCols = respKVHeads * headSize;
int qkvCols = qCols + kvCols * 2;
float scale = ctx->attFactor;
int srcLen = ctx->inputSeqLen;
int tgtLen = pastSeqLen + srcLen;

int totalTokenSize = 0;
int inputSeqLens[batchSize], pastSeqLens[batchSize];
for (int i = 0; i < batchSize; ++i) {
inputSeqLens[i] = ctx->inputSeqLen;
pastSeqLens[i] = pastSeqLen;
totalTokenSize += inputSeqLens[i];
}

// TODO: kv dtype conversion for prefixSharing
AttnT *k, *v;
Expand All @@ -712,22 +720,21 @@ class Attention {
//Timer tmc(true, "convert KV matrix into bf16");
kvStride = kvCols * 2;
AttnT *kvBuf = (AttnT *)SimpleMemPool::instance().getBuffer(
"flashKVBuf", batchSize * srcLen * kvStride * sizeof(AttnT));
#pragma omp parallel for collapse(3)
for (uint64_t b = 0; b < batchSize; ++b)
for (uint64_t seq = 0; seq < srcLen; ++seq)
for (uint64_t i = 0; i < kvCols * 2; i += headSize) {
const ImT *srcPtr = key.Data() + b * srcLen * qkvCols + seq * qkvCols + i;
AttnT *dstPtr = kvBuf + b * srcLen * kvStride + seq * kvStride + i;
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
} else {
printf("Not supported Type in Flash Attention yet\n");
exit(-1);
}
"flashKVBuf", totalTokenSize * kvStride * sizeof(AttnT));
#pragma omp parallel for collapse(2)
for (uint64_t seq = 0; seq < totalTokenSize; ++seq)
for (uint64_t i = 0; i < kvCols * 2; i += headSize) {
const ImT *srcPtr = key.Data() + seq * qkvCols + i;
AttnT *dstPtr = kvBuf + seq * kvStride + i;
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
} else {
printf("Not supported Type in Flash Attention yet\n");
exit(-1);
}
}

k = kvBuf;
v = kvBuf + kvCols;
Expand All @@ -738,109 +745,14 @@ class Attention {
}

// [batch, src, head, headsize]
scaledDpAttention<AttnT>(query.Data(), k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads,
respKVHeads, headSize, result.Data(), query.Stride(), kvStride, result.Stride());
xft::selfScaledDpAttention<ImT, AttnT>(result.Data(), query.Data(), k, v, respQHeads, respKVHeads, headSize,
result.Stride(), query.Stride(), kvStride, batchSize, inputSeqLens, pastSeqLens, true, alibiSlopes,
attnMask, scale, ctx->numThreads);

// copy current key/values to cache
copyKVCache(ctx, key, value, presentKey, presentValue, pastSeqLen);
}

// scaled dot-product attention: bmm1 + softmax + bmm2
template <typename AttnT>
void scaledDpAttention(const ImT *query, const AttnT *key, const AttnT *value, const float *attnMask, float scale,
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, ImT *output, int qStride,
int kvStride, int stride) {
// output = trans(softmax(query * trans(key)) * value)
int nth = omp_get_max_threads();
// closest value of power of 2
int minBlk = (int)std::pow(2, int(std::log2(srcLen / 2)));
// Split sequence to make sure a moderate sync frequency and the intermediate
// result [srcSeq * tgtSeq] in cache. The current block size is derived from practical experience.
int srcBlk = std::min(256, minBlk);
int tgtBlk = std::min(512, tgtLen);
float refac = scale;
int numGroup = numQHead / numKVHead;

int numArr = 7;
int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk;
float *thrBuf = (float *)SimpleMemPool::instance().getBuffer("threadBuffers", sizeof(float) * nth * arrStride);
float **thrPtrBuf
= (float **)SimpleMemPool::instance().getBuffer("threadPtrBuffers", sizeof(float *) * nth * numArr);

float **preSum = thrPtrBuf;
float **sum = thrPtrBuf + nth;
float **preMax = thrPtrBuf + nth * 2;
float **max = thrPtrBuf + nth * 3;
float **qkArr = thrPtrBuf + nth * 4;
float **expQkvArr = thrPtrBuf + nth * 5;
float **qArr = thrPtrBuf + nth * 6;

for (int i = 0; i < nth; ++i) {
preSum[i] = thrBuf + srcBlk * i;
sum[i] = thrBuf + srcBlk * nth + srcBlk * i;
preMax[i] = thrBuf + srcBlk * nth * 2 + srcBlk * i;
max[i] = thrBuf + srcBlk * nth * 3 + srcBlk * i;
qkArr[i] = thrBuf + srcBlk * nth * 4 + srcBlk * tgtBlk * i;
expQkvArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk) + srcBlk * headSize * i;
qArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk + headSize) + srcBlk * headSize * i;
}

#pragma omp parallel for collapse(3) schedule(dynamic)
for (uint64_t i = 0; i < batchSize; ++i) {
for (int j = 0; j < numQHead; ++j) {
for (int m = 0; m < srcLen; m += srcBlk) {
int tid = omp_get_thread_num();

int qRealBlk = std::min(srcBlk, srcLen - m);
uint64_t srcOff = i * srcLen * qStride + j * headSize;
uint64_t outOff = i * srcLen * stride + j * headSize;
const ImT *qbuf = query + srcOff + m * qStride;
AttnT *q = (AttnT *)qArr[tid];
ImT *out = output + outOff + m * stride;

// reset out
for (int ii = 0; ii < qRealBlk; ++ii) {
#pragma omp simd
for (int jj = 0; jj < headSize; ++jj) {
out[ii * stride + jj] = 0; // reset output
q[ii * headSize + jj] = (AttnT)(qbuf[ii * qStride + jj]); // reset output
}
}
// reset sum
#pragma omp simd
for (int ii = 0; ii < qRealBlk; ++ii) {
preSum[tid][ii] = 0;
sum[tid][ii] = 0;
preMax[tid][ii] = std::numeric_limits<float>::lowest();
max[tid][ii] = std::numeric_limits<float>::lowest();
}

uint64_t tgtOff = i * tgtLen * kvStride + (j / numGroup) * headSize;
const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;
const AttnT *k = key + tgtOff;
const AttnT *v = value + tgtOff;
// split the target len dimension
for (int b = 0; b < tgtLen; b += tgtBlk) {
int kvRealBlk = std::min(tgtBlk, tgtLen - b);
// TODO: mask out
if (enableSkipMsk() && DecoderUtil::skipMskAttn(attnMsk + b, qRealBlk, kvRealBlk, tgtLen)) {
// printf("Skip bs %d head %d src %d tgt %d\n", i, j, m, b);
break;
}

const AttnT *kBlk = k + b * kvStride;
const AttnT *vBlk = v + b * kvStride;

DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk + b, qRealBlk, headSize, kvRealBlk,
tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid], expQkvArr[tid],
out, headSize, kvStride, kvStride, stride);
}
}
}
}
return;
}

private:
std::pair<int, int> getTaskRange(int N, int splits, int splitIdx) {
int startId, endId;
Expand Down Expand Up @@ -906,6 +818,9 @@ class Attention {
NORM_CLS norm;
int layerId;

// Alibi Slopes
float *alibiSlopes;

// The responsible head in the global view
// If in single instance, startQHead=startKVHead=0, and endQHead-startQHead=qHeadNum
int startQHead;
Expand Down
17 changes: 9 additions & 8 deletions src/layers/attn_baichuan.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,21 @@ template <typename WeiT, typename QKPO_CLS = QKPO_Dummy, typename NORM_CLS = Rms
class BaichuanAttention : public Attention<WeiT, QKPO_CLS, NORM_CLS> {
public:
BaichuanAttention(int layerId, DecoderContext *ctx) : Attention<WeiT, QKPO_CLS, NORM_CLS>(layerId, ctx) {
if (ctx->maxPosEmbed <= 0 && alibiSlopes == nullptr) {
if (ctx->maxPosEmbed <= 0 && this->alibiSlopes == nullptr) {
respBaichuanHeads = this->endQHead - this->startQHead;
alibiSlopes = new float[respBaichuanHeads];
this->alibiSlopes = new float[respBaichuanHeads];
// alibi mask element
float ratio = std::pow(2, 8);
int closestPowerOf2 = std::pow(2, int(std::log2(ctx->attHeadNum)));
float x0 = std::pow(ratio, 1.0 / closestPowerOf2);
float x1 = std::pow(ratio, 1.0 / (closestPowerOf2 * 2));
for (int i = 0, h = this->startQHead; i < respBaichuanHeads; ++i, ++h) {
if (h < closestPowerOf2)
alibiSlopes[i] = 1 / std::pow(x0, h + 1);
this->alibiSlopes[i] = 1 / std::pow(x0, h + 1);
else
alibiSlopes[i] = 1 / std::pow(x1, 2 * (h - closestPowerOf2) + 1);
this->alibiSlopes[i] = 1 / std::pow(x1, 2 * (h - closestPowerOf2) + 1);
}
alibiSlopes = this->alibiSlopes;
}
}

Expand All @@ -50,15 +51,15 @@ class BaichuanAttention : public Attention<WeiT, QKPO_CLS, NORM_CLS> {
const static int getResponsibleHeads() { return respBaichuanHeads; }

virtual ~BaichuanAttention() {
if (alibiSlopes != nullptr) {
delete[] alibiSlopes;
alibiSlopes = nullptr;
if (this->alibiSlopes != nullptr) {
delete[] this->alibiSlopes;
this->alibiSlopes = nullptr;
}
}

protected:
const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) override {
if (alibiSlopes != nullptr)
if (this->alibiSlopes != nullptr)
return attnMask + hId * srcLen * tgtLen;
else
return attnMask + bId * srcLen * tgtLen;
Expand Down
Loading

0 comments on commit 5e98e6d

Please sign in to comment.