Skip to content

Commit

Permalink
[Model/Layer] New forward to support CB (CommonDecoder->DecoderBlock-…
Browse files Browse the repository at this point in the history
…>DecoderLayer->Attention/MLP) (#375)
  • Loading branch information
pujiang2018 committed May 9, 2024
1 parent 4a90500 commit 77dd36b
Show file tree
Hide file tree
Showing 4 changed files with 594 additions and 35 deletions.
214 changes: 214 additions & 0 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "gemm_kernel_ext.h"
#include "kvcache_tensor.h"
#include "matmul_helper.h"
#include "sequence.h"
#include "simple_mem_pool.h"
#include "transformer_ctx.h"
#include "transformer_util.h"
Expand Down Expand Up @@ -378,6 +379,191 @@ class Attention {
}
t5.release();

#ifdef DEBUG
dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(),
outBuffer.Stride());
dbg.dumpMatrix(outBuffer);
#endif

if (!doLnBefore) {
TimeLine t6("result.layer_norm");
norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride());
#ifdef DEBUG
dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(),
outBuffer.Stride());
dbg.dumpMatrix(outBuffer);
#endif
}
}

/**
* Forward computing for the whole Attention layer (QKV MatMul + MHA/GQA + Output MatMul)
*/
template <typename KVCacheT>
void forward(DecoderContext *ctx, std::vector<xft::SequenceMeta *> &seqs, InT *input, OutT *output,
size_t totInSeqLen, std::vector<KVCacheTensor<KVCacheT> *> &keyCaches,
std::vector<KVCacheTensor<KVCacheT> *> &valueCaches, bool doLnBefore = true) {

auto hiddenSize = ctx->hiddenSize;
xft::Matrix<InT> inputBuffer(input, totInSeqLen, hiddenSize, hiddenSize);
ImT *imBuf = (ImT *)ctx->getBuffer<ImT>("tmp", totInSeqLen * hiddenSize);
xft::Matrix<ImT> imBuffer(imBuf, totInSeqLen, hiddenSize, hiddenSize);
xft::Matrix<OutT> outBuffer(output, totInSeqLen, hiddenSize, hiddenSize);

float epsilon = ctx->epsilon;
int headSize = ctx->attHeadSize;
auto qkvRows = totInSeqLen;
int qCols = (this->endQHead - this->startQHead) * headSize;
int kvCols = (this->endKVHead - this->startKVHead) * headSize;
int qkCols = qCols + kvCols;
int qkvCols = qkCols + kvCols;

int qkvStride = qkvCols;
auto &qkvMatMul = ctx->qkvMatMul;
xft::Matrix<ImT> qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride);

#ifdef DEBUG
dbg.debugPrint("---- DecoderLayer.forward (useSelfAttn=%d) ----\n", useSelfAttn);
dbg.debugPrint("input:\n");
dbg.dumpMatrix(inputBuffer);
#endif

if (doLnBefore) {
TimeLine t1("input.layer_norm");
norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(),
imBuffer.Stride(), epsilon);
}
#ifdef DEBUG
dbg.debugPrint("layer norm:\n");
dbg.dumpMatrix(imBuffer);
dbg.debugPrint("qkvWeight [%d, %d]:\n", this->qkvWeight.Rows(), this->qkvWeight.Cols());
dbg.dumpMatrix(this->qkvWeight);
#endif

// Query, Key, Value computed together
TimeLine t2("QKV.linear");
if (qkvBias.Size() == 0) {
ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(),
imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride());
} else {
ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f,
imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data());
}
t2.release();

xft::Matrix<ImT> query(qkvGroupMatMul, 0, inputBuffer.Rows(), 0, qCols);
xft::Matrix<ImT> key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols);
xft::Matrix<ImT> value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols);

#ifdef DEBUG
dbg.debugPrint("Q[%d,%d](%d):\n", query.Rows(), query.Cols(), query.Stride());
dbg.dumpMatrix(query);
dbg.debugPrint("K[%d,%d](%d):\n", key.Rows(), key.Cols(), key.Stride());
dbg.dumpMatrix(key);
dbg.debugPrint("V[%d,%d](%d):\n", value.Rows(), value.Cols(), value.Stride());
dbg.dumpMatrix(value);
#endif

// Apply post operations on query and key
TimeLine t3("QKPO");
// TODO: call into rotary embedding
// int qheads = this->endQHead - this->startQHead;
// int kheads = this->endKVHead - this->startKVHead;
// int qkShape[7] = {ctx->batchSize, ctx->inputSeqLen, qheads, headSize, kheads, ctx->maxSeqLength, pastSeqLen};
// if (positionIds != nullptr) {
// qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, positionIds);
// } else if (ctx->maxPosEmbed > 0) {
// // Use the default position ids
// std::vector<int> posIds(ctx->inputSeqLen);
// if (inputSeqLen == 1) {
// posIds[0] = pastSeqLen;
// } else {
// std::iota(posIds.begin(), posIds.end(), pastSeqLen);
// }
// qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, posIds.data());
// }
t3.release();

#ifdef DEBUG
dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride());
dbg.dumpMatrix(query);
dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride());
dbg.dumpMatrix(key);
#endif

// Revise attnFactor before softmax (for some models, attnFactor may be not the default value)
// We initially introduced the code for ChatGLM, but eventually found it has no difference and was unnecessary.
// However, we have chosen to keep it in the codebase in case it becomes useful for future models.
if (getScalingCoeff() != 0) { ctx->attFactor = getScalingCoeff(); }

TimeLine t4("MHA");
if constexpr (!INPUT_AS_RESID) { // Swap inputBuffer and imBuffer
auto tmp = imBuffer.Data();
int rows = imBuffer.Rows(), cols = imBuffer.Cols(), stride = imBuffer.Stride();
imBuffer.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride());
inputBuffer.Assign(tmp, rows, cols, stride);
}

// For multiple nodes inference, not the whole result buffer
xft::Matrix<ImT> attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols);

if (seqs[0]->getStep() == 0) { // First token generation
// TODO: add flashAttention
if constexpr (std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>) {
selfAttentionBF16(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs);
} else {
fusedAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs);
}
} else {
fusedAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs);
}
t4.release();

#ifdef DEBUG
dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(),
attnSplit.Cols(), attnSplit.Stride());
dbg.dumpMatrix(attnSplit);
#endif

TimeLine t5("Output");
// Output/projection in attention, only add the input in the first split
if (ctx->splitIdx == 0) {
float gamma = getResidentialScale();

// denseWithScaledSum should be enough, but as the performance of denseWithScaledSum is not verified,
// So here still use denseWithSum
if (gamma == 1) {
float *pbias = attnOutputBias.Data();
if (attnOutputBias.Size() == 0) { pbias = nullptr; }
ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(),
1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(),
attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f,
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride());
} else {
float *pbias = attnOutputBias.Data();
if (attnOutputBias.Size() == 0) { pbias = nullptr; }
ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride());
}
} else {
if (attnOutputBias.Size() == 0) {
ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride());
} else {
ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), attnOutputBias.Data());
}
}
t5.release();

#ifdef DEBUG
dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(),
outBuffer.Stride());
Expand Down Expand Up @@ -416,6 +602,34 @@ class Attention {
[&](int b, int headIdx, int seqIdx) { return presentValue.getSequence(seqIdx, b, headIdx); });
}

template <typename KVCacheT>
void selfAttentionBF16(DecoderContext *ctx, xft::Matrix<bfloat16_t> &query, xft::Matrix<bfloat16_t> &key,
xft::Matrix<bfloat16_t> &value, xft::Matrix<bfloat16_t> &result,
std::vector<KVCacheTensor<KVCacheT> *> &keyCaches, std::vector<KVCacheTensor<KVCacheT> *> &valueCaches,
std::vector<xft::SequenceMeta *> &seqs) {
int responsibleQHeads = this->endQHead - this->startQHead;
int responsibleKVHeads = this->endKVHead - this->startKVHead;

int tokenSizes[ctx->batchSize];
for (int i = 0; i < ctx->batchSize; ++i) {
tokenSizes[i] = seqs[i]->getInputSeqLen();
}

xft::selfAttention(
result.Data(), query.Data(), key.Data(), value.Data(), responsibleQHeads, responsibleKVHeads,
ctx->attHeadSize, result.Stride(), query.Stride(), key.Stride(), ctx->batchSize, tokenSizes,
ctx->attFactor, ctx->numThreads,
[&](int b, int headIdx, int seqIdx) { return keyCaches[b]->getSequence(seqIdx, 0, headIdx); },
[&](int b, int headIdx, int seqIdx) { return valueCaches[b]->getSequence(seqIdx, 0, headIdx); });
}

template <typename T, typename KVCacheT>
void fusedAttention(DecoderContext *ctx, xft::Matrix<T> &query, xft::Matrix<T> &key, xft::Matrix<T> &value,
xft::Matrix<T> &result, std::vector<KVCacheTensor<KVCacheT> *> &keyCaches,
std::vector<KVCacheTensor<KVCacheT> *> &valueCaches, std::vector<xft::SequenceMeta *> &seqs) {
// TODO: implement fusedAttention
}

int getMBlockSize(int inputSeqLen, int headSize, int minVal = 6) {
// Special case
if (inputSeqLen == 1) { return 1; }
Expand Down
Loading

0 comments on commit 77dd36b

Please sign in to comment.