Skip to content

Commit

Permalink
[models] Add attnMeta for continuous batching (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
abenmao committed Apr 30, 2024
1 parent 9e7bdca commit cfab63a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
70 changes: 70 additions & 0 deletions src/common/attn_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#pragma once

class AttnMetaData {

public:
AttnMetaData () : batchSize(0), attnMask(nullptr) {}

AttnMetaData (int batchSize, int *inputTokenSizes, int *pastTokenSizes, bool isPrompt, bool isCausal, float *attnMask = nullptr)
: batchSize(batchSize), isPrompt(isPrompt), isCausal(isCausal), attnMask(attnMask) {
// causal=True, no need mask
assert(isCausal && attnMask == nullptr)
// causal=False, need mask
assert(!isCausal && attnMask)

// fill inputSeqLens, pastSeqLens, seqStartLoc
inputSeqLens.resize(batchSize);
pastSeqLens.resize(batchSize);
seqStartLoc.resize(batchSize + 1);

seqStartLoc[0] = 0;
for (int i = 0; i < batchSize; i++) {
inputSeqLens[i] = inputTokenSizes[i];
pastSeqLens[i] = pastTokenSizes[i];
seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i];
}

AttnMetaData (vector<int> &inputTokens, vector<int> &pastTokens, bool isPrompt, bool isCausal, float *attnMask = nullptr)
: batchSize(inputTokenSizes.size()), isPrompt(isPrompt), isCausal(isCausal), attnMask(attnMask),
inputSeqLens(inputTokenSizes), pastSeqLens(pastTokenSizes){
// causal=True, no need mask
assert(isCausal && attnMask == nullptr)
// causal=False, need mask
assert(!isCausal && attnMask)

// fill seqStartLoc
seqStartLoc.resize(batchSize + 1);

seqStartLoc[0] = 0;
for (int i = 0; i < batchSize; i++) {
seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i];
}

}

private:
bool isPrompt;
bool isCausal;

int batchSize;
std::vector<int> inputSeqLens;
std::vector<int> pastSeqLens;
std::vector<int> seqStartLoc;

float *attnMask;

};
2 changes: 1 addition & 1 deletion src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ class Attention {
void softmax(DecoderContext *ctx, T1 *score, const T2 *mask, int rows, int cols, int lds, int startSeq) {
const int keyLen = cols;
for (int seq = 0; seq < rows; ++seq) {
DecoderUtil::computeSoftmax(ctx, score + seq * lds, mask + (seq + startSeq) * keyLen, keyLen);
DecoderUtil::computeSoftmax(score + seq * lds, mask + (seq + startSeq) * keyLen, keyLen, ctx->attFactor);
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/models/common_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ class CommonDecoder : public AbstractDecoder {

// Prepare attention mask
this->prepareAttnMask(ids, step + this->prefixSharing);
// prepareAttnMeta

// Token position ids, note: different models may have different impl.
int *positionIds = this->getPositionIds(ids, batchSize, inputSeqLen, step + this->prefixSharing);
Expand Down Expand Up @@ -392,6 +393,7 @@ class CommonDecoder : public AbstractDecoder {

// Pls be noted: in attention, 'outBuf' is used as imtermediate buffer, 'tmpBuf' is used as output
AttnOutT *attnOut = (AttnOutT *)(this->getContext()->tmpBuf.Data());
// attnMeta (inputSeqLens, pastSeqLens, seqStartLoc, is_prompt(useSelfAttn), causal, attnMask)
this->decoders[i]->forwardAttention(getContext(), embBuf, outBuf, attnOut, attnMask,
presentKey, // presentKey,
presentValue, // presentValue,
Expand Down

0 comments on commit cfab63a

Please sign in to comment.