Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model/Layer] New forward to support CB (CommonDecoder->DecoderBlock->DecoderLayer->Attention/MLP) #375

Merged
merged 40 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8f351ba
[Common] Add sequenceMeta, sequenceGroup and sequenecePool. (#343)
changqi1 Apr 25, 2024
5ab63df
Merge commit 'd2b8df0c85ba57b62169d74c88192de3bf6e4820' into cb_dev
pujiang2018 Apr 26, 2024
dbcb267
merge batchSize and seqLen into one in TokenEembedding
pujiang2018 Apr 26, 2024
3949abd
merge batchSize and seqLen into one in TokenEembedding (#350)
pujiang2018 Apr 26, 2024
9a53fb2
[Common] Move Martix into xft namespace. (#351)
Duyi-Wang Apr 26, 2024
25ee312
Merge commit '9a53fb2ea6b9141ba7c045bc0d135c1809e8f22c' into pujiang/…
pujiang2018 Apr 26, 2024
376b2bc
remove unsed function in DecoderLayer
pujiang2018 Apr 26, 2024
7112b33
[Layer] Remove unused functions in Decoder layer (#353)
pujiang2018 Apr 26, 2024
4ff4707
Merge commit '819ecccfa06662bddb16e9e1cd1ec8775d1c0180' into cb_dev
pujiang2018 Apr 29, 2024
d281a54
Merge commit '4ff47074fc85a27e13251c3fb618f36e338c456f' into pujiang/…
pujiang2018 Apr 29, 2024
b5b225a
fix compile error of embeddingForward
pujiang2018 Apr 29, 2024
f8f8571
[Model] Fix compile error of embeddingForward in YaRNLlama (#358)
pujiang2018 Apr 29, 2024
b95dac1
[Common] Add sampling params into group seq. (#356)
Duyi-Wang Apr 29, 2024
d5c9407
remove DecoderContext in computeSoftmax
pujiang2018 Apr 29, 2024
be615b2
Merge commit 'f8f85714331c0df2ce4a8344e06972316770ec11' into pujiang/…
pujiang2018 Apr 29, 2024
e48ea1f
[Util] Remove DecoderContext in computeSoftmax (#362)
pujiang2018 Apr 29, 2024
7ad311e
[Common] Refactor sequence.h. (#363)
Duyi-Wang Apr 30, 2024
9e7bdca
[kernels] refactor flash attention for continuous batching (#361)
abenmao Apr 30, 2024
cfab63a
[models] Add attnMeta for continuous batching (#364)
abenmao Apr 30, 2024
deabd33
[Layers] fix build error (#365)
abenmao Apr 30, 2024
2499f60
[Model] add interface for seq meta. (#366)
Duyi-Wang Apr 30, 2024
5833d41
Merge commit '2499f602c22184ca5afaa2f013ae0ff4e3bd4263' into pujiang/…
pujiang2018 May 3, 2024
0514833
refactor resize function in DecoderContext to support CB, and qkScore…
pujiang2018 May 6, 2024
c792aff
[Common] Modify resize() in DecoderContext to support (#367)
pujiang2018 May 6, 2024
5315a5a
Merge commit 'c792aff5f7cc8554afa0399f4b6d241333b0b56c' into pujiang/…
pujiang2018 May 6, 2024
63e895a
add some code to CommonDecoder::forward()
pujiang2018 May 6, 2024
9cfc7c6
SequenceMeta refactor
pujiang2018 May 7, 2024
2e05716
[Model] New CommonDecoder::forward impl. skeleton (#369)
pujiang2018 May 7, 2024
41e692c
new KVCacheMgr supporting CB
pujiang2018 May 7, 2024
32c845c
Merge commit '2e057165b456ef9b88591a880403ffe47e7500c3' into pujiang/…
pujiang2018 May 7, 2024
c2ac8d2
fix typo & set default prefixId to -1 in addSequence()
pujiang2018 May 7, 2024
dfa1d0e
[Common] New KVCacheMgr to support CB (#371)
pujiang2018 May 7, 2024
e79c9c2
[Sampling] Add repetition penalty for new seq type. (#373)
Duyi-Wang May 8, 2024
1e93ef0
Merge commit 'e79c9c21825ad51116f49a290baa7eea80e25b07' into pujiang/…
pujiang2018 May 8, 2024
689b41a
New foward to support CB (CommonDecoder->DecoderBlock->DecoderLayer->…
pujiang2018 May 8, 2024
3ce8d06
add todo
pujiang2018 May 8, 2024
9377f83
[Sampling] Add greedy search for cb path. (#376)
Duyi-Wang May 8, 2024
b71ec05
Merge commit '9377f8371ebfe4999c2171eceab106a52ff2618a' into pujiang/…
pujiang2018 May 8, 2024
3144675
logic issue fix
pujiang2018 May 8, 2024
fcee295
Merge branch 'cb_dev' into pujiang/feature/cb_dev
pujiang2018 May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "gemm_kernel_ext.h"
#include "kvcache_tensor.h"
#include "matmul_helper.h"
#include "sequence.h"
#include "simple_mem_pool.h"
#include "transformer_ctx.h"
#include "transformer_util.h"
Expand Down Expand Up @@ -378,6 +379,191 @@ class Attention {
}
t5.release();

#ifdef DEBUG
dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(),
outBuffer.Stride());
dbg.dumpMatrix(outBuffer);
#endif

if (!doLnBefore) {
TimeLine t6("result.layer_norm");
norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride());
#ifdef DEBUG
dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(),
outBuffer.Stride());
dbg.dumpMatrix(outBuffer);
#endif
}
}

/**
* Forward computing for the whole Attention layer (QKV MatMul + MHA/GQA + Output MatMul)
*/
template <typename KVCacheT>
void forward(DecoderContext *ctx, std::vector<xft::SequenceMeta *> &seqs, InT *input, OutT *output,
size_t totInSeqLen, std::vector<KVCacheTensor<KVCacheT> *> &keyCaches,
std::vector<KVCacheTensor<KVCacheT> *> &valueCaches, bool doLnBefore = true) {

auto hiddenSize = ctx->hiddenSize;
xft::Matrix<InT> inputBuffer(input, totInSeqLen, hiddenSize, hiddenSize);
ImT *imBuf = (ImT *)ctx->getBuffer<ImT>("tmp", totInSeqLen * hiddenSize);
xft::Matrix<ImT> imBuffer(imBuf, totInSeqLen, hiddenSize, hiddenSize);
xft::Matrix<OutT> outBuffer(output, totInSeqLen, hiddenSize, hiddenSize);

float epsilon = ctx->epsilon;
int headSize = ctx->attHeadSize;
auto qkvRows = totInSeqLen;
int qCols = (this->endQHead - this->startQHead) * headSize;
int kvCols = (this->endKVHead - this->startKVHead) * headSize;
int qkCols = qCols + kvCols;
int qkvCols = qkCols + kvCols;

int qkvStride = qkvCols;
auto &qkvMatMul = ctx->qkvMatMul;
xft::Matrix<ImT> qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride);

#ifdef DEBUG
dbg.debugPrint("---- DecoderLayer.forward (useSelfAttn=%d) ----\n", useSelfAttn);
dbg.debugPrint("input:\n");
dbg.dumpMatrix(inputBuffer);
#endif

if (doLnBefore) {
TimeLine t1("input.layer_norm");
norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(),
imBuffer.Stride(), epsilon);
}
#ifdef DEBUG
dbg.debugPrint("layer norm:\n");
dbg.dumpMatrix(imBuffer);
dbg.debugPrint("qkvWeight [%d, %d]:\n", this->qkvWeight.Rows(), this->qkvWeight.Cols());
dbg.dumpMatrix(this->qkvWeight);
#endif

// Query, Key, Value computed together
TimeLine t2("QKV.linear");
if (qkvBias.Size() == 0) {
ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(),
imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride());
} else {
ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f,
imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data());
}
t2.release();

xft::Matrix<ImT> query(qkvGroupMatMul, 0, inputBuffer.Rows(), 0, qCols);
xft::Matrix<ImT> key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols);
xft::Matrix<ImT> value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols);

#ifdef DEBUG
dbg.debugPrint("Q[%d,%d](%d):\n", query.Rows(), query.Cols(), query.Stride());
dbg.dumpMatrix(query);
dbg.debugPrint("K[%d,%d](%d):\n", key.Rows(), key.Cols(), key.Stride());
dbg.dumpMatrix(key);
dbg.debugPrint("V[%d,%d](%d):\n", value.Rows(), value.Cols(), value.Stride());
dbg.dumpMatrix(value);
#endif

// Apply post operations on query and key
TimeLine t3("QKPO");
// TODO: call into rotary embedding
// int qheads = this->endQHead - this->startQHead;
// int kheads = this->endKVHead - this->startKVHead;
// int qkShape[7] = {ctx->batchSize, ctx->inputSeqLen, qheads, headSize, kheads, ctx->maxSeqLength, pastSeqLen};
// if (positionIds != nullptr) {
// qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, positionIds);
// } else if (ctx->maxPosEmbed > 0) {
// // Use the default position ids
// std::vector<int> posIds(ctx->inputSeqLen);
// if (inputSeqLen == 1) {
// posIds[0] = pastSeqLen;
// } else {
// std::iota(posIds.begin(), posIds.end(), pastSeqLen);
// }
// qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, posIds.data());
// }
t3.release();

#ifdef DEBUG
dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride());
dbg.dumpMatrix(query);
dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride());
dbg.dumpMatrix(key);
#endif

// Revise attnFactor before softmax (for some models, attnFactor may be not the default value)
// We initially introduced the code for ChatGLM, but eventually found it has no difference and was unnecessary.
// However, we have chosen to keep it in the codebase in case it becomes useful for future models.
if (getScalingCoeff() != 0) { ctx->attFactor = getScalingCoeff(); }

TimeLine t4("MHA");
if constexpr (!INPUT_AS_RESID) { // Swap inputBuffer and imBuffer
auto tmp = imBuffer.Data();
int rows = imBuffer.Rows(), cols = imBuffer.Cols(), stride = imBuffer.Stride();
imBuffer.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride());
inputBuffer.Assign(tmp, rows, cols, stride);
}

// For multiple nodes inference, not the whole result buffer
xft::Matrix<ImT> attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols);

if (seqs[0]->getStep() == 0) { // First token generation
// TODO: add flashAttention
if constexpr (std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>) {
selfAttentionBF16(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs);
} else {
fusedAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs);
}
} else {
fusedAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs);
}
t4.release();

#ifdef DEBUG
dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(),
attnSplit.Cols(), attnSplit.Stride());
dbg.dumpMatrix(attnSplit);
#endif

TimeLine t5("Output");
// Output/projection in attention, only add the input in the first split
if (ctx->splitIdx == 0) {
float gamma = getResidentialScale();

// denseWithScaledSum should be enough, but as the performance of denseWithScaledSum is not verified,
// So here still use denseWithSum
if (gamma == 1) {
float *pbias = attnOutputBias.Data();
if (attnOutputBias.Size() == 0) { pbias = nullptr; }
ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(),
1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(),
attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f,
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride());
} else {
float *pbias = attnOutputBias.Data();
if (attnOutputBias.Size() == 0) { pbias = nullptr; }
ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride());
}
} else {
if (attnOutputBias.Size() == 0) {
ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride());
} else {
ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), attnOutputBias.Data());
}
}
t5.release();

#ifdef DEBUG
dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(),
outBuffer.Stride());
Expand Down Expand Up @@ -416,6 +602,34 @@ class Attention {
[&](int b, int headIdx, int seqIdx) { return presentValue.getSequence(seqIdx, b, headIdx); });
}

template <typename KVCacheT>
void selfAttentionBF16(DecoderContext *ctx, xft::Matrix<bfloat16_t> &query, xft::Matrix<bfloat16_t> &key,
xft::Matrix<bfloat16_t> &value, xft::Matrix<bfloat16_t> &result,
std::vector<KVCacheTensor<KVCacheT> *> &keyCaches, std::vector<KVCacheTensor<KVCacheT> *> &valueCaches,
std::vector<xft::SequenceMeta *> &seqs) {
int responsibleQHeads = this->endQHead - this->startQHead;
int responsibleKVHeads = this->endKVHead - this->startKVHead;

int tokenSizes[ctx->batchSize];
for (int i = 0; i < ctx->batchSize; ++i) {
tokenSizes[i] = seqs[i]->getInputSeqLen();
}

xft::selfAttention(
result.Data(), query.Data(), key.Data(), value.Data(), responsibleQHeads, responsibleKVHeads,
ctx->attHeadSize, result.Stride(), query.Stride(), key.Stride(), ctx->batchSize, tokenSizes,
ctx->attFactor, ctx->numThreads,
[&](int b, int headIdx, int seqIdx) { return keyCaches[b]->getSequence(seqIdx, 0, headIdx); },
[&](int b, int headIdx, int seqIdx) { return valueCaches[b]->getSequence(seqIdx, 0, headIdx); });
}

template <typename T, typename KVCacheT>
void fusedAttention(DecoderContext *ctx, xft::Matrix<T> &query, xft::Matrix<T> &key, xft::Matrix<T> &value,
xft::Matrix<T> &result, std::vector<KVCacheTensor<KVCacheT> *> &keyCaches,
std::vector<KVCacheTensor<KVCacheT> *> &valueCaches, std::vector<xft::SequenceMeta *> &seqs) {
// TODO: implement fusedAttention
}

int getMBlockSize(int inputSeqLen, int headSize, int minVal = 6) {
// Special case
if (inputSeqLen == 1) { return 1; }
Expand Down
Loading