From cfab63a0332da4559113ffd6988afe56d4eedf59 Mon Sep 17 00:00:00 2001 From: "Meng,Chen" Date: Tue, 30 Apr 2024 14:35:29 +0800 Subject: [PATCH] [models] Add attnMeta for continuous batching (#364) --- src/common/attn_metadata.h | 70 +++++++++++++++++++++++++++++++++++++ src/layers/attention.h | 2 +- src/models/common_decoder.h | 2 ++ 3 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 src/common/attn_metadata.h diff --git a/src/common/attn_metadata.h b/src/common/attn_metadata.h new file mode 100644 index 00000000..bc2e6031 --- /dev/null +++ b/src/common/attn_metadata.h @@ -0,0 +1,70 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +class AttnMetaData { + +public: + AttnMetaData () : batchSize(0), attnMask(nullptr) {} + + AttnMetaData (int batchSize, int *inputTokenSizes, int *pastTokenSizes, bool isPrompt, bool isCausal, float *attnMask = nullptr) + : batchSize(batchSize), isPrompt(isPrompt), isCausal(isCausal), attnMask(attnMask) { + // causal=True, no need mask + assert(isCausal && attnMask == nullptr) + // causal=False, need mask + assert(!isCausal && attnMask) + + // fill inputSeqLens, pastSeqLens, seqStartLoc + inputSeqLens.resize(batchSize); + pastSeqLens.resize(batchSize); + seqStartLoc.resize(batchSize + 1); + + seqStartLoc[0] = 0; + for (int i = 0; i < batchSize; i++) { + inputSeqLens[i] = inputTokenSizes[i]; + pastSeqLens[i] = pastTokenSizes[i]; + seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i]; + } + + AttnMetaData (vector &inputTokens, vector &pastTokens, bool isPrompt, bool isCausal, float *attnMask = nullptr) + : batchSize(inputTokenSizes.size()), isPrompt(isPrompt), isCausal(isCausal), attnMask(attnMask), + inputSeqLens(inputTokenSizes), pastSeqLens(pastTokenSizes){ + // causal=True, no need mask + assert(isCausal && attnMask == nullptr) + // causal=False, need mask + assert(!isCausal && attnMask) + + // fill seqStartLoc + seqStartLoc.resize(batchSize + 1); + + seqStartLoc[0] = 0; + for (int i = 0; i < batchSize; i++) { + seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i]; + } + + } + +private: + bool isPrompt; + bool isCausal; + + int batchSize; + std::vector inputSeqLens; + std::vector pastSeqLens; + std::vector seqStartLoc; + + float *attnMask; + +}; diff --git a/src/layers/attention.h b/src/layers/attention.h index dc730155..485e5975 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -507,7 +507,7 @@ class Attention { void softmax(DecoderContext *ctx, T1 *score, const T2 *mask, int rows, int cols, int lds, int startSeq) { const int keyLen = cols; for (int seq = 0; seq < rows; ++seq) { - DecoderUtil::computeSoftmax(ctx, score + seq * lds, mask + (seq + startSeq) * keyLen, keyLen); + DecoderUtil::computeSoftmax(score + seq * lds, mask + (seq + startSeq) * keyLen, keyLen, ctx->attFactor); } } diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 0d5bd85a..e040a637 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -333,6 +333,7 @@ class CommonDecoder : public AbstractDecoder { // Prepare attention mask this->prepareAttnMask(ids, step + this->prefixSharing); + // prepareAttnMeta // Token position ids, note: different models may have different impl. int *positionIds = this->getPositionIds(ids, batchSize, inputSeqLen, step + this->prefixSharing); @@ -392,6 +393,7 @@ class CommonDecoder : public AbstractDecoder { // Pls be noted: in attention, 'outBuf' is used as imtermediate buffer, 'tmpBuf' is used as output AttnOutT *attnOut = (AttnOutT *)(this->getContext()->tmpBuf.Data()); + // attnMeta (inputSeqLens, pastSeqLens, seqStartLoc, is_prompt(useSelfAttn), causal, attnMask) this->decoders[i]->forwardAttention(getContext(), embBuf, outBuf, attnOut, attnMask, presentKey, // presentKey, presentValue, // presentValue,