From 1cc9c1d9d8c97e09b6de22350f42ec9c07aea497 Mon Sep 17 00:00:00 2001 From: Changqing Li Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 01/35] [Common] Add sequenceMeta, sequenceGroup and sequenecePool. (#343) --- src/common/sequence.h | 330 ++++++++++++++++++++++++++++++++ src/common/transformer_ctx.h | 6 +- src/models/common_decoder.h | 48 ++++- src/models/models.cpp | 1 + src/searchers/greedy_search.cpp | 30 ++- src/searchers/greedy_search.h | 1 + src/utils/environment.h | 21 ++ src/utils/thread_util.h | 79 +++++++- 8 files changed, 502 insertions(+), 14 deletions(-) create mode 100644 src/common/sequence.h diff --git a/src/common/sequence.h b/src/common/sequence.h new file mode 100644 index 00000000..d4d82db1 --- /dev/null +++ b/src/common/sequence.h @@ -0,0 +1,330 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +#include +#include +#include + +/* + SequencePool + ┌──────┬──────┬──────┐ + │ │ │ ◄───┼──┬─ SequenceGroupMeta + ├──────┼──────┼──────┤ │ + BatchInputs │ │ │ ◄───┼──┘ + │ └▲─┬─▲─┴──────┴──────┘ + │ │ │ └───────────────────────────────────┐ + ▼ ┌──┬──┬──┬──┐ │ │ ┌──┬──┬──┬──┬──┬──┬──┬──┬──┐ │ + Input ─►│ │ │ │ ├──┘ └─────►│ │ │ │ │ │ │ │ │ ├─┐ │ + └──┴──┴──┴──┘ └──┴──┴──┴──┴──┴──┴──┴──┴──┘ │ │ + InputQueue TaskWaitingQueue0 │ │ + ┌───────────────────────────────┘ │ + │ ┌──┬──┬──┬──┬──┬──┬──┬──┬──┐ │ + └─►│ │ │ │ │ │ │ │ │ ├───┘ + └──┴──┴──┴──┴──┴──┴──┴──┴──┘ + TaskWaitingQueue1 +*/ + +namespace xft { + +// The SequenceMeta is one sequence of batch inputs and includes the generated tokens. +class SequenceMeta { +public: + SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen, std::vector &_inputTokens) + : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0), step(0) { + inputTokens.reserve(_inputSeqLen); + inputTokens.assign(_inputTokens.begin(), _inputTokens.end()); + nextTokens.reserve(_inputSeqLen); + setPastSeqLen(getPastSeqLen()); + } + + SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen) + : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), inputTokens(_inputSeqLen, 0), pastSeqLen(0), step(0) { + nextTokens.reserve(_inputSeqLen); + } + + ~SequenceMeta() {} + + int32_t getSequenceID() const { return sequenceID; } + + // For first tokens + void stepForward() { + if (getStep() == 0) { + setPastSeqLen(inputTokens.size()); + setStep(getStep() + 1); + } + } + + // For next token + void stepForward(int32_t token) { + addNextToken(token); + setPastSeqLen(getPastSeqLen() + 1); + setStep(getStep() + 1); + } + + // Get the input tokens in sequence + int32_t getInputSeqLen() const { return inputSeqLen; } + + const int32_t *getInputTokens() const { return inputTokens.data(); } + + int32_t getPastSeqLen() const { return pastSeqLen; } + + void setPastSeqLen(int32_t _pastSeqLen) { pastSeqLen = _pastSeqLen; } + + // For next tokens + void addNextToken(int32_t token) { + nextTokens.clear(); + nextTokens.push_back(token); + inputTokens.push_back(token); + } + + int32_t getLatestToken() const { return nextTokens.back(); } + + const int32_t *getTotalTokens() const { return getInputTokens(); } + + int32_t getStep() const { return step; } + + void setStep(int32_t _step) { step = _step; } + +private: + int32_t sequenceID; + int32_t inputSeqLen; + int32_t pastSeqLen; + std::vector inputTokens; // input tokens + next tokens + std::vector nextTokens; // next tokens + int32_t step; + +#ifdef PIPELINE_PARALLEL +public: + template + void allocBuffer(int32_t hiddenSize, void *_hiddenStates) { + hiddenStates = xft::alloc(sizeof(T) * getInputSeqLen() * hiddenSize); + memcpy(hiddenStates, _hiddenStates, sizeof(T) * getInputSeqLen() * hiddenSize); + } + +private: + int32_t hiddenSize; + void *hiddenStates; +#endif +}; + +// For beam searcher +class SequenceGroupMeta { +public: + SequenceGroupMeta(std::vector &seq) { + size_per_group = seq.size(); + sequences.reserve(size_per_group); + sequences.assign(seq.begin(), seq.end()); + } + + int32_t getGroupSize() { return size_per_group; } + + SequenceMeta *get() { return sequences.data(); } + + SequenceMeta *get(int index) { return &sequences[index]; } + + SequenceMeta &operator[](int index) { + return sequences[index]; + } + +private: + int32_t size_per_group; + std::vector sequences; +}; + +// SequencePool +// ┌──────┬──────┬──────┐ +// │ │ │ ◄───┼──┬─ SequenceGroupMeta +// ├──────┼──────┼──────┤ │ +// │ │ │ ◄───┼──┘ +// └──────┴──────┴──────┘ +class SequencePool { +public: + static SequencePool &getInstance() { + static SequencePool instance; + return instance; + } + + int32_t createSequenceID() { + int32_t id = globalSequenceID++; + if (id >= 10 * 1024) { + globalSequenceID = 0; + id = globalSequenceID++; + } + return id; + } + + SequenceGroupMeta *newMeta(int32_t sequenceID, int32_t inputSeqLen, std::vector &inputTokens) { + std::vector sequence; + sequence.emplace_back(SequenceMeta(sequenceID, inputSeqLen, inputTokens)); + + auto *group = new SequenceGroupMeta(sequence); + return group; + } + + SequenceGroupMeta *newMeta(int32_t sequenceID, int32_t inputSeqLen) { + std::vector sequence; + sequence.emplace_back(SequenceMeta(sequenceID, inputSeqLen)); + + auto *group = new SequenceGroupMeta(sequence); + return group; + } + + SequenceGroupMeta *newGroupMeta(std::vector &sequenceIDs, std::vector &inputSeqLens, + std::vector> &inputTokens) { + assert(sequenceIDs.size() == inputSeqLens.size()); + assert(sequenceIDs.size() == inputTokens.size()); + + std::vector sequences; + for (int i = 0; i < sequenceIDs.size(); ++i) { + sequences.emplace_back(SequenceMeta(sequenceIDs[i], inputSeqLens[i], inputTokens[i])); + } + + auto *group = new SequenceGroupMeta(sequences); + return group; + } + + SequenceGroupMeta *newGroupMeta(std::vector &sequenceIDs, std::vector &inputSeqLens) { + assert(sequenceIDs.size() == inputSeqLens.size()); + + std::vector sequences; + for (int i = 0; i < sequenceIDs.size(); ++i) { + sequences.emplace_back(SequenceMeta(sequenceIDs[i], inputSeqLens[i])); + } + + auto *group = new SequenceGroupMeta(sequences); + return group; + } + + // Use first sequenceID if num_beam = 4 + bool add(int32_t sequenceID, SequenceGroupMeta *sequence, bool force = false) { + bool isSuccess = false; + if (force) { + auto it = hub.find(sequenceID); + if (it != hub.end()) { remove(it->first, true); } + + hub[sequenceID] = sequence; + isSuccess = true; + } else { + bool exist = has(sequenceID); + if (!exist) { + hub[sequenceID] = sequence; + isSuccess = true; + } + } + + return isSuccess; + } + + bool has(int32_t sequenceID) const { return hub.find(sequenceID) != hub.end(); } + + SequenceGroupMeta *get(int32_t sequenceID) const { + auto it = hub.find(sequenceID); + if (it != hub.end()) { + return it->second; + } else { + return nullptr; + } + } + + bool remove(int32_t sequenceID, bool deep = false) { + bool isSuccess = false; + if (has(sequenceID)) { + if (deep == true) { + auto it = hub.find(sequenceID); + if (it != hub.end()) { delete it->second; } + } + isSuccess = hub.erase(sequenceID); + } + + return isSuccess; + } + + bool replace(int32_t sequenceID, SequenceGroupMeta *sequences) { + bool isSuccess = false; + auto it = hub.find(sequenceID); + if (it != hub.end()) { + remove(it->first, true); + hub[sequenceID] = sequences; + isSuccess = true; + } + + return isSuccess; + } + +private: + SequencePool() {} + + int32_t globalSequenceID = 0; + std::unordered_map hub; +}; + +// Manage input sequenceMeta +class InputQueue { +public: + static InputQueue &getInstance() { + static InputQueue instance; + return instance; + } + + bool empty() { return queue.empty(); } + + SequenceGroupMeta *pop() { + auto seq = queue.front(); + queue.pop(); + return seq; + } + + void push(SequenceGroupMeta *seq) { queue.push(seq); } + +private: + InputQueue() {} + + std::queue queue; +}; + +// Manage executive sequenceMeta +class TaskWaitingQueue { +public: + static TaskWaitingQueue &getInstance() { + static TaskWaitingQueue instance; + return instance; + } + + bool empty() { return queue.empty(); } + + int32_t size() { return queue.size(); } + + bool isFull() { + bool full = false; + if (this->size() >= Env::getInstance().getMaxRequestNum()) { full = true; } + return full; + } + + SequenceGroupMeta *pop() { + auto seq = queue.front(); + queue.pop(); + return seq; + } + + void push(SequenceGroupMeta *seq) { queue.push(seq); } + +private: + TaskWaitingQueue() {} + + std::queue queue; +}; + +} // namespace xft \ No newline at end of file diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 3685baae..0a48c88b 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -65,6 +65,10 @@ struct DecoderContext { // For custom usage int reserved1; +#ifdef PIPELINE_PARALLEL + int sequenceID; +#endif + // Model structure configuration int vocabSize; int embeddingSize; @@ -319,4 +323,4 @@ struct DecoderContext { } ~DecoderContext() { free(this->rawBuffer); } -}; +}; \ No newline at end of file diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index ab289027..a752c7eb 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -35,6 +35,7 @@ #include "transformer_ctx.h" #include "transpose_util.h" #include "weight_util.h" +#include "sequence.h" using namespace xft; @@ -278,7 +279,7 @@ class CommonDecoder : public AbstractDecoder { int userSideBS = dims[0]; int beamSize = dims[1]; - int batchSize = (step == 0 ? userSideBS : userSideBS * beamSize); // as samples are duplicated at step 0 + int batchSize = (step == 0 ? userSideBS : userSideBS * beamSize); // as sequence are duplicated at step 0 int seqLen = dims[2]; int pastSeqLen = step == 0 ? 0 : this->accSeqLen; int inputSeqLen = seqLen; @@ -286,6 +287,7 @@ class CommonDecoder : public AbstractDecoder { // Prepare context DecoderContext *ctx = this->getContext(); ctx->resize(batchSize, seqLen, pastSeqLen); + int hiddenSize = ctx->hiddenSize; if (step == 0) { // Reset initial and accumulated sequence length at the first step @@ -314,7 +316,7 @@ class CommonDecoder : public AbstractDecoder { } AttnInT *embBuf = (AttnInT *)actBuffers->Data(); - MlpOutT *outBuf = (MlpOutT *)(embBuf + batchSize * inputSeqLen * ctx->hiddenSize); + MlpOutT *outBuf = (MlpOutT *)(embBuf + batchSize * inputSeqLen * hiddenSize); // Embedding this->embeddingForward(ids, embBuf, batchSize, inputSeqLen); @@ -325,8 +327,8 @@ class CommonDecoder : public AbstractDecoder { dbg.debugPrint("ids:\n"); dbg.dumpMatrix(ids, batchSize, inputSeqLen, inputSeqLen); dbg.debugPrint( - "embBuf(rows: %d, cols: %d, stride: %d):\n", batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize); - dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize); + "embBuf(rows: %d, cols: %d, stride: %d):\n", batchSize * inputSeqLen, hiddenSize, hiddenSize); + dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, hiddenSize, hiddenSize); #endif // Prepare attention mask @@ -341,15 +343,43 @@ class CommonDecoder : public AbstractDecoder { if (ctx->ppSize > 1 && ctx->ppRank > 0) { int curr_world_rank = ctx->ppRank * ctx->tpSize + ctx->tpRank; int prev_world_rank = (ctx->ppRank - 1) * ctx->tpSize + ctx->tpRank; - int count = batchSize * inputSeqLen * ctx->hiddenSize; + int count = batchSize * inputSeqLen * hiddenSize; + int32_t sequenceID; + MPI_Recv(&sequenceID, 1, MPI_INT32_T, prev_world_rank, curr_world_rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".MPI_Recv"); MPI_Recv(embBuf, count, MPI_FLOAT, prev_world_rank, curr_world_rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE); // TODO: Error: different scope when dynamic loading so file // this->messenger.worldRecvFP32(embBuf, count, prev_world_rank, curr_world_rank); + if (!SequencePool::getInstance().has(sequenceID)) { + auto *seqs = SequencePool::getInstance().newMeta(sequenceID, seqLen); + seqs->get(0)->setPastSeqLen(pastSeqLen); + seqs->get(0)->allocBuffer(hiddenSize, embBuf); + SequencePool::getInstance().add(seqs->get(0)->getSequenceID(), seqs); + } + TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(sequenceID)); + } + + if (!InputQueue::getInstance().empty()) { + if (!TaskWaitingQueue::getInstance().isFull()) { + auto *seqs = InputQueue::getInstance().pop(); + seqs->get(0)->setPastSeqLen(pastSeqLen); + seqs->get(0)->allocBuffer(hiddenSize, embBuf); + SequencePool::getInstance().add(seqs->get(0)->getSequenceID(), seqs); + TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(seqs->get(0)->getSequenceID())); + } } + + while(TaskWaitingQueue::getInstance().empty()); + + SequenceGroupMeta *runningTask = nullptr; + int32_t sequenceID = -1; + if (!TaskWaitingQueue::getInstance().empty()) { + runningTask = TaskWaitingQueue::getInstance().pop(); + sequenceID = runningTask->get(0)->getSequenceID(); + TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".Step"); #endif // Decoder: forward - int hiddenSize = ctx->hiddenSize; int layers_per_pp_stage = this->decoders.size(); for (int i = 0; i < layers_per_pp_stage; ++i) { int workers = this->messenger.getSize(); @@ -402,10 +432,14 @@ class CommonDecoder : public AbstractDecoder { } #ifdef PIPELINE_PARALLEL + } + // If current pipeline stage isn't the end of stage, should send data to next stage and return nullptr if (ctx->ppSize > 1 && ctx->ppRank < ctx->ppSize - 1) { + TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".MPI_Send"); int next_world_rank = (ctx->ppRank + 1) * ctx->tpSize + ctx->tpRank; - int count = batchSize * inputSeqLen * ctx->hiddenSize; + int count = batchSize * inputSeqLen * hiddenSize; + MPI_Send(&sequenceID, 1, MPI_INT32_T, next_world_rank, next_world_rank, MPI_COMM_WORLD); MPI_Send(embBuf, count, MPI_FLOAT, next_world_rank, next_world_rank, MPI_COMM_WORLD); // TODO: Error: different scope when dynamic loading so file // this->messenger.worldSendFP32(embBuf, count, next_world_rank, next_world_rank); diff --git a/src/models/models.cpp b/src/models/models.cpp index 81b1b685..dee1ab4e 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -33,6 +33,7 @@ #include "searcher.h" #include "timeline.h" #include "yarn_llama.h" +#include "sequence.h" namespace xft { enum class GenerationMode { GREEDY_SEARCH, BEAM_SEARCH, SAMPLE }; diff --git a/src/searchers/greedy_search.cpp b/src/searchers/greedy_search.cpp index 0e55648e..19add295 100644 --- a/src/searchers/greedy_search.cpp +++ b/src/searchers/greedy_search.cpp @@ -14,7 +14,11 @@ // ============================================================================ #include "greedy_search.h" #include "messenger.h" +#include "sequence.h" #include "search_utils.h" +#include "thread_util.h" + +using namespace xft; GreedySearch::GreedySearch(AbstractDecoder &dec, const SearcherConfig &config) : decoder(dec), maxLen(config.maxLen), step(0), repetitionPenalty(config.repetitionPenalty) { @@ -36,18 +40,34 @@ std::vector GreedySearch::syncToken(std::tuple &result) if (std::get<0>(result) == nullptr) { // The first embedding pipeline parallel stage this->nextTokens = std::vector(batchSize, 0); - if (ctx->ppSize > 1 && ctx->ppRank == 0) { + if (ctx->ppSize > 1 && ctx->ppRank == 0 && enabledBackgroundSync == false) { + enabledBackgroundSync = true; int predictor_world_rank = (ctx->ppSize - 1) * ctx->tpSize + ctx->tpRank; - MPI_Recv(this->nextTokens.data(), batchSize, MPI_INT32_T, predictor_world_rank, predictor_world_rank, - MPI_COMM_WORLD, MPI_STATUS_IGNORE); - // TODO: Error: different scope when dynamic loading so file - // messenger.worldRecvINT32(this->nextTokens.data(), batchSize, predictor_world_rank, predictor_world_rank); + ThreadPool::getInstance().addTask([predictor_world_rank, this] { + while (true) { + int32_t sequenceID; + MPI_Recv(&sequenceID, 1, MPI_INT32_T, predictor_world_rank, predictor_world_rank, MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + TimeLine t("GreedySearch.Seq" + std::to_string(sequenceID) + ".MPI_Recv"); + MPI_Recv(this->nextTokens.data(), this->batchSize, MPI_INT32_T, predictor_world_rank, + predictor_world_rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + if (SequencePool::getInstance().has(sequenceID)) { + auto seq = SequencePool::getInstance().get(sequenceID); + TaskWaitingQueue::getInstance().push(seq); + } else { + printf("Error: should have sequenceID\n"); + fflush(stdout); + } + } + }); } } else { // The last predictor pipeline parallel stage this->nextTokens = this->search(result); if (ctx->ppSize > 1 && ctx->ppRank == ctx->ppSize - 1) { + TimeLine t("GreedySearch.Seq" + std::to_string(ctx->sequenceID) + ".MPI_Send"); int embedding_world_rank = 0 * ctx->tpSize + ctx->tpRank; int predictor_world_rank = (ctx->ppSize - 1) * ctx->tpSize + ctx->tpRank; + MPI_Send(&ctx->sequenceID, 1, MPI_INT32_T, embedding_world_rank, predictor_world_rank, MPI_COMM_WORLD); MPI_Send(this->nextTokens.data(), batchSize, MPI_INT32_T, embedding_world_rank, predictor_world_rank, MPI_COMM_WORLD); // TODO: Error: different scope when dynamic loading so file diff --git a/src/searchers/greedy_search.h b/src/searchers/greedy_search.h index 607d9737..5b4ec164 100644 --- a/src/searchers/greedy_search.h +++ b/src/searchers/greedy_search.h @@ -47,6 +47,7 @@ class GreedySearch : public AbstractSearcher { std::vector> cachedRepetVec; std::vector doneBatch; + bool enabledBackgroundSync; int batchSize; int step; int curLen; diff --git a/src/utils/environment.h b/src/utils/environment.h index e6d94338..2630ff74 100644 --- a/src/utils/environment.h +++ b/src/utils/environment.h @@ -41,6 +41,9 @@ class Env { // get Engine Kind and Index int getPipelineStage() { return pipelineStageValue; } + // get Engine Kind and Index + int getMaxRequestNum() { return maxRequestNumValue; } + // get AMX Threshold M int getAMXThresholdM() { return AMXThresholdMValue; } @@ -73,6 +76,9 @@ class Env { // init Pipeline Parallel initPipelineStage(); + // init Max request number + initMaxRequestNum(); + // init Engine Kind and Index initEngineKindIndex(); @@ -173,6 +179,21 @@ class Env { } } + // Max request number + int maxRequestNumValue = 1; + void initMaxRequestNum() { + char *xft_max_request_num_value = getenv("XFT_MAX_REQUEST_NUM"); + if (xft_max_request_num_value != NULL) { + int value = atoi(xft_max_request_num_value); + if (value >= 1) + maxRequestNumValue = value; + else + printf("[ERROR] XFT_MAX_REQUEST_NUM value need to be greater than 0.\n"); + } else { + maxRequestNumValue = 1; + } + } + // AMX Threshold M int AMXThresholdMValue = 1; void initAMXThresholdM() { diff --git a/src/utils/thread_util.h b/src/utils/thread_util.h index c6826051..a44b08d7 100644 --- a/src/utils/thread_util.h +++ b/src/utils/thread_util.h @@ -1,6 +1,29 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ #pragma once #include +#include +#include +#include +#include +#include +#include + +namespace xft { + template void parallel_for(int tasks, const Lambda &fn) { #pragma omp parallel for @@ -15,4 +38,58 @@ void parallel_for_dschedule(int tasks, const Lambda &fn) { for (int i = 0; i < tasks; i++) { fn(i); } -} \ No newline at end of file +} + +class ThreadPool { +public: + static ThreadPool &getInstance() { + static ThreadPool instance; + return instance; + } + + template + void addTask(F &&f, Args &&...args) { + { + std::unique_lock lock(queueMutex); + tasks.emplace(std::bind(std::forward(f), std::forward(args)...)); + } + condition.notify_one(); + } + + ~ThreadPool() { + stop = true; + condition.notify_all(); + for (std::thread &worker : workers) { + worker.join(); + } + } + +private: + ThreadPool() : stop(false) { + for (size_t i = 0; i < numThreads; ++i) { + workers.emplace_back([this] { + while (true) { + std::function task; + { + std::unique_lock lock(queueMutex); + condition.wait(lock, [this] { return stop || !tasks.empty(); }); + if (stop && tasks.empty()) { return; } + task = std::move(tasks.front()); + tasks.pop(); + } + task(); + } + }); + } + } + + static constexpr size_t numThreads = 1; + std::vector workers; + std::queue> tasks; + + std::mutex queueMutex; + std::condition_variable condition; + bool stop; +}; + +} // namespace xft \ No newline at end of file From d01be1a2fcb3a45bb79643ea68fd2588b4ff27f3 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 02/35] merge batchSize and seqLen into one in TokenEembedding (#350) --- src/kernels/token_embedding_kernels.cpp | 33 ++++++++++--------- src/kernels/token_embedding_kernels.h | 6 ++-- src/layers/opt_embedding.h | 32 ++++++++----------- src/layers/token_embedding.h | 4 +-- src/layers/token_embedding_gemma.h | 4 +-- src/models/baichuan.cpp | 4 +-- src/models/baichuan.h | 2 +- src/models/chatglm.cpp | 4 +-- src/models/chatglm.h | 2 +- src/models/chatglm2.cpp | 8 ++--- src/models/chatglm2.h | 4 +-- src/models/common_decoder.h | 8 ++--- src/models/gemma.cpp | 8 ++--- src/models/gemma.h | 4 +-- src/models/llama.cpp | 8 ++--- src/models/llama.h | 4 +-- src/models/opt_decoder.cpp | 42 +++++++++++++++++++++++-- src/models/opt_decoder.h | 3 +- src/models/qwen.cpp | 4 +-- src/models/qwen.h | 2 +- src/models/qwen2.cpp | 8 ++--- src/models/qwen2.h | 4 +-- src/models/yarn_llama.cpp | 4 +-- src/models/yarn_llama.h | 4 +-- tests/ut/token_embedding_test.cpp | 18 +++++------ 25 files changed, 127 insertions(+), 97 deletions(-) diff --git a/src/kernels/token_embedding_kernels.cpp b/src/kernels/token_embedding_kernels.cpp index b2bbdd02..af9b1acd 100644 --- a/src/kernels/token_embedding_kernels.cpp +++ b/src/kernels/token_embedding_kernels.cpp @@ -22,27 +22,26 @@ namespace xft { template -void tokenEmbedding(OutT *output, const int *tokenId, const WeiT *embTable, const int batchSize, const int seqLen, - const int hiddenSize) { - for (int i = 0; i < batchSize * seqLen; ++i) { +void tokenEmbedding(OutT *output, const int *tokenId, const WeiT *embTable, const int tokenSize, const int hiddenSize) { + for (int i = 0; i < tokenSize; ++i) { int id = tokenId[i]; xft::copy(output + i * hiddenSize, embTable + id * hiddenSize, hiddenSize); } } -template void tokenEmbedding(float *output, const int *tokenId, const float *weight, const int batchSize, - const int seqLen, const int hiddenSize); -template void tokenEmbedding(float16_t *output, const int *tokenId, const float16_t *weight, - const int batchSize, const int seqLen, const int hiddenSize); -template void tokenEmbedding(bfloat16_t *output, const int *tokenId, const bfloat16_t *weight, - const int batchSize, const int seqLen, const int hiddenSize); +template void tokenEmbedding( + float *output, const int *tokenId, const float *weight, const int tokenSize, const int hiddenSize); +template void tokenEmbedding( + float16_t *output, const int *tokenId, const float16_t *weight, const int tokenSize, const int hiddenSize); +template void tokenEmbedding( + bfloat16_t *output, const int *tokenId, const bfloat16_t *weight, const int tokenSize, const int hiddenSize); -template void tokenEmbedding(float *output, const int *tokenId, const float16_t *weight, - const int batchSize, const int seqLen, const int hiddenSize); -template void tokenEmbedding(float *output, const int *tokenId, const bfloat16_t *weight, - const int batchSize, const int seqLen, const int hiddenSize); -template void tokenEmbedding(bfloat16_t *output, const int *tokenId, const float16_t *weight, - const int batchSize, const int seqLen, const int hiddenSize); -template void tokenEmbedding(float16_t *output, const int *tokenId, const bfloat16_t *weight, - const int batchSize, const int seqLen, const int hiddenSize); +template void tokenEmbedding( + float *output, const int *tokenId, const float16_t *weight, const int tokenSize, const int hiddenSize); +template void tokenEmbedding( + float *output, const int *tokenId, const bfloat16_t *weight, const int tokenSize, const int hiddenSize); +template void tokenEmbedding( + bfloat16_t *output, const int *tokenId, const float16_t *weight, const int tokenSize, const int hiddenSize); +template void tokenEmbedding( + float16_t *output, const int *tokenId, const bfloat16_t *weight, const int tokenSize, const int hiddenSize); } // namespace xft \ No newline at end of file diff --git a/src/kernels/token_embedding_kernels.h b/src/kernels/token_embedding_kernels.h index 3062ec79..9e4e768e 100644 --- a/src/kernels/token_embedding_kernels.h +++ b/src/kernels/token_embedding_kernels.h @@ -31,13 +31,11 @@ namespace xft { * @param output Pointer to the output array where embeddings will be stored. * @param tokenId Pointer to the array containing token IDs. * @param weight Pointer to the array containing token weights for embedding lookup. - * @param batchSize Number of sequences in the batch. + * @param tokenSize Total number of tokens in the input array. * @param seqLen Length of each sequence (number of tokens). * @param hiddenSize Size of the hidden dimension for each token embedding. */ template -void tokenEmbedding(OutT *output, const int *tokenId, const weiT *weight, const int batchSize, const int seqLen, - const int hiddenSize); - +void tokenEmbedding(OutT *output, const int *tokenId, const weiT *weight, const int tokenSize, const int hiddenSize); } // namespace xft \ No newline at end of file diff --git a/src/layers/opt_embedding.h b/src/layers/opt_embedding.h index 7669a81d..9e715a6d 100644 --- a/src/layers/opt_embedding.h +++ b/src/layers/opt_embedding.h @@ -47,8 +47,7 @@ class OptEmbedding { } // TODO: mask is not considered - // tokenIds and positions are 2-dimension array with batchSize rows, and seqLen cols - void forward(int *tokenIds, int *positions, float *output, int batchSize, int seqLen) { + void forward(int *tokenIds, int *positions, float *output, int tokenSize) { if (embeddingSize != hiddenSize) { printf("Not supported yet: embeddingSize != hiddenSize\n"); exit(-1); @@ -57,24 +56,19 @@ class OptEmbedding { int row = 0; if constexpr (std::is_same_v) { - for (int i = 0; i < batchSize; ++i) { - for (int j = 0; j < seqLen; ++j) { - // Embedding - int id = tokenIds[i * seqLen + j]; - float16_t::cvt_float16_to_float( - embTable + id * embeddingSize, output + row * hiddenSize, embeddingSize); + for (int i = 0; i < tokenSize; ++i) { + // Embedding + int id = tokenIds[i]; + float16_t::cvt_float16_to_float(embTable + id * embeddingSize, output + i * hiddenSize, embeddingSize); - // Positional embedding - int pos = positions[i * seqLen + j]; - // # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 - // # and adjust num_embeddings appropriately. Other models don't have this hack - // Do not add the offset if the embedding table is already handled it (like FasterTransformer) - //pos += 2; - float16_t::float_add_float16(output + row * hiddenSize, positionalTable + pos * hiddenSize, - output + row * hiddenSize, hiddenSize); - - row += 1; - } + // Positional embedding + int pos = positions[i]; + // # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + // # and adjust num_embeddings appropriately. Other models don't have this hack + // Do not add the offset if the embedding table is already handled it (like FasterTransformer) + //pos += 2; + float16_t::float_add_float16(output + i * hiddenSize, positionalTable + pos * hiddenSize, + output + i * hiddenSize, hiddenSize); } } else { printf("Type %s not supported!\n", typeid(T).name()); diff --git a/src/layers/token_embedding.h b/src/layers/token_embedding.h index 9ebe7e0c..f49a135e 100644 --- a/src/layers/token_embedding.h +++ b/src/layers/token_embedding.h @@ -44,8 +44,8 @@ class TokenEmbedding { // tokenIds ia a 2-dimension array with batchSize rows, and seqLen cols template - void forward(int *tokenIds, OutT *output, int batchSize, int seqLen) { - xft::tokenEmbedding(output, tokenIds, embTable, batchSize, seqLen, hiddenSize); + void forward(int *tokenIds, OutT *output, int tokenSize) { + xft::tokenEmbedding(output, tokenIds, embTable, tokenSize, hiddenSize); } int getVocabSize() { return vocabSize; } diff --git a/src/layers/token_embedding_gemma.h b/src/layers/token_embedding_gemma.h index 2d6a90ce..519f691a 100644 --- a/src/layers/token_embedding_gemma.h +++ b/src/layers/token_embedding_gemma.h @@ -43,14 +43,14 @@ class GemmaTokenEmbedding { // tokenIds ia a 2-dimension array with batchSize rows, and seqLen cols template - void forward(int *tokenIds, OutT *output, int batchSize, int seqLen) { + void forward(int *tokenIds, OutT *output, int tokenSize) { __m512 vdim = _mm512_set1_ps(sqrtf(this->hiddenSize)); constexpr int kStep = 16; int blockSize = hiddenSize / kStep; int remainder = hiddenSize % kStep; #pragma omp parallel for - for (int i = 0; i < batchSize * seqLen; ++i) { + for (int i = 0; i < tokenSize; ++i) { int id = tokenIds[i]; auto src = this->embTable + id * hiddenSize; auto dst = output + i * hiddenSize; diff --git a/src/models/baichuan.cpp b/src/models/baichuan.cpp index 30d1e379..2cfb607e 100644 --- a/src/models/baichuan.cpp +++ b/src/models/baichuan.cpp @@ -165,8 +165,8 @@ void Baichuan::prepareAttnMask(int *ids, int step) { } template -void Baichuan::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void Baichuan::embeddingForward(int *ids, float *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template diff --git a/src/models/baichuan.h b/src/models/baichuan.h index 8b03c206..ac7857d4 100644 --- a/src/models/baichuan.h +++ b/src/models/baichuan.h @@ -30,7 +30,7 @@ class Baichuan void prepareAttnMaskBase(int *ids, int step); void prepareAttnMask(int *ids, int step); - void embeddingForward(int *ids, float *output, int batchSize, int seqLen); + void embeddingForward(int *ids, float *output, int tokenSize); void lastLayerNormForward(float *input, float *output, int rows); private: diff --git a/src/models/chatglm.cpp b/src/models/chatglm.cpp index da8da06a..9df467db 100644 --- a/src/models/chatglm.cpp +++ b/src/models/chatglm.cpp @@ -104,8 +104,8 @@ void ChatGLM::prepareAttnMask(int *ids, int step) { } template -void ChatGLM::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void ChatGLM::embeddingForward(int *ids, float *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template diff --git a/src/models/chatglm.h b/src/models/chatglm.h index 3315afe1..3e5fe204 100644 --- a/src/models/chatglm.h +++ b/src/models/chatglm.h @@ -29,7 +29,7 @@ class ChatGLM : public CommonDecoder::prepareAttnMask(int *ids, int step) { } template -void ChatGLM2::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void ChatGLM2::embeddingForward(int *ids, float *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template -void ChatGLM2::embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void ChatGLM2::embeddingForward(int *ids, bfloat16_t *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template diff --git a/src/models/chatglm2.h b/src/models/chatglm2.h index ce378e82..ca430675 100644 --- a/src/models/chatglm2.h +++ b/src/models/chatglm2.h @@ -34,8 +34,8 @@ class ChatGLM2 ~ChatGLM2(); virtual void prepareAttnMask(int *ids, int step); - virtual void embeddingForward(int *ids, float *output, int batchSize, int seqLen); - virtual void embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen); + virtual void embeddingForward(int *ids, float *output, int tokenSize); + virtual void embeddingForward(int *ids, bfloat16_t *output, int tokenSize); virtual void lastLayerNormForward(float *input, float *output, int rows); virtual void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows); virtual int *getPositionIds(int *ids, int batchSize, int seqLen, int step) override; diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index a752c7eb..1b8f1179 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -319,7 +319,7 @@ class CommonDecoder : public AbstractDecoder { MlpOutT *outBuf = (MlpOutT *)(embBuf + batchSize * inputSeqLen * hiddenSize); // Embedding - this->embeddingForward(ids, embBuf, batchSize, inputSeqLen); + this->embeddingForward(ids, embBuf, batchSize * inputSeqLen); this->accSeqLen += seqLen; #ifdef DEBUG @@ -546,7 +546,7 @@ class CommonDecoder : public AbstractDecoder { MlpOutT *outBuf = (MlpOutT *)(embBuf + 1 * seqLen * ctx->hiddenSize); // Embedding - this->embeddingForward(ids, embBuf, 1, seqLen); + this->embeddingForward(ids, embBuf, 1 * seqLen); // Prepare attention mask this->prepareAttnMask(ids, 0); @@ -949,11 +949,11 @@ class CommonDecoder : public AbstractDecoder { int getStartId() { return startId; } - virtual void embeddingForward(int *ids, float *output, int batchSize, int seqLen) { + virtual void embeddingForward(int *ids, float *output, int tokenSize) { printf("embeddingForward(float) must be implemented.\n"); exit(-1); } - virtual void embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen) { + virtual void embeddingForward(int *ids, bfloat16_t *output, int tokenSize) { printf("embeddingForward(bfloat16_t) must be implemented.\n"); exit(-1); } diff --git a/src/models/gemma.cpp b/src/models/gemma.cpp index 64a81621..9a2266a3 100644 --- a/src/models/gemma.cpp +++ b/src/models/gemma.cpp @@ -105,13 +105,13 @@ void GemmaLLM::prepareAttnMask(int *ids, int step) { } template -void GemmaLLM::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void GemmaLLM::embeddingForward(int *ids, float *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template -void GemmaLLM::embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void GemmaLLM::embeddingForward(int *ids, bfloat16_t *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template diff --git a/src/models/gemma.h b/src/models/gemma.h index 9a0033fb..4e7f7fbc 100644 --- a/src/models/gemma.h +++ b/src/models/gemma.h @@ -34,8 +34,8 @@ class GemmaLLM void prepareAttnMask(int *ids, int step); - void embeddingForward(int *ids, float *output, int batchSize, int seqLen); - void embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen); + void embeddingForward(int *ids, float *output, int tokenSize); + void embeddingForward(int *ids, bfloat16_t *output, int tokenSize); void lastLayerNormForward(float *input, float *output, int rows); void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows); diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 75c520cd..37d429d4 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -105,13 +105,13 @@ void LlamaLLM::prepareAttnMask(int *ids, int step) { } template -void LlamaLLM::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void LlamaLLM::embeddingForward(int *ids, float *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template -void LlamaLLM::embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void LlamaLLM::embeddingForward(int *ids, bfloat16_t *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template diff --git a/src/models/llama.h b/src/models/llama.h index 5fad1e24..d117c219 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -34,8 +34,8 @@ class LlamaLLM void prepareAttnMask(int *ids, int step); - void embeddingForward(int *ids, float *output, int batchSize, int seqLen); - void embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen); + void embeddingForward(int *ids, float *output, int tokenSize); + void embeddingForward(int *ids, bfloat16_t *output, int tokenSize); void lastLayerNormForward(float *input, float *output, int rows); void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows); diff --git a/src/models/opt_decoder.cpp b/src/models/opt_decoder.cpp index 9d7eb6b3..929a25de 100644 --- a/src/models/opt_decoder.cpp +++ b/src/models/opt_decoder.cpp @@ -103,10 +103,13 @@ void OptDecoder::prepareAttnMask(int *ids, int step) { } template -void OptDecoder::embeddingForward(int *ids, float *buf, int batchSize, int seqLen) { +void OptDecoder::embeddingForward(int *ids, float *buf, int tokenSize) { int pastSeqLen = this->accSeqLen; if (pastSeqLen == 0 && this->prefixSharing) { pastSeqLen += this->prefixSeqLen; } + // Prepare position data for positional embedding + int batchSize = 1; + int seqLen = tokenSize; int positions[batchSize * seqLen]; for (int b = 0; b < batchSize; ++b) { for (int i = 0; i < seqLen; ++i) { @@ -115,7 +118,42 @@ void OptDecoder::embeddingForward(int *ids, float *buf, int batc } // Embedding - embedding->forward(ids, positions, buf, batchSize, seqLen); + embedding->forward(ids, positions, buf, tokenSize); +} + +template +void OptDecoder::embeddingForward(float *output, const std::vector &sequences) { + // Calculate the total number of input tokens + int inputTokens = 0; + for (int i = 0; i < sequences.size(); ++i) { + inputTokens += sequences[i]->getInputSeqLen(); + } + + // Prepare position data for positional embedding + int idBuf[256]; + int posBuf[256]; + + int *ids = inputTokens <= 256 ? idBuf : (int *)malloc(inputTokens * sizeof(int)); + int *positions = inputTokens <= 256 ? posBuf : (int *)malloc(inputTokens * sizeof(int)); + + int idx = 0; + for (int i = 0; i < sequences.size(); ++i) { + auto pastSeqLen = sequences[i]->getPastSeqLen(); + auto inputTokens = sequences[i]->getInputTokens(); + for (int j = 0; j < sequences[i]->getInputSeqLen(); ++j) { + ids[idx] = inputTokens[pastSeqLen + j]; + positions[idx] = pastSeqLen + j; + idx += 1; + } + } + + // Embedding + embedding->forward(ids, positions, output, inputTokens); + + if (inputTokens > 256) { + free(ids); + free(positions); + } } template diff --git a/src/models/opt_decoder.h b/src/models/opt_decoder.h index 98482e0e..f4baac34 100644 --- a/src/models/opt_decoder.h +++ b/src/models/opt_decoder.h @@ -35,7 +35,8 @@ class OptDecoder : public CommonDecoder, ~OptDecoder(); void prepareAttnMask(int *ids, int step); - void embeddingForward(int *ids, float *output, int batchSize, int seqLen); + void embeddingForward(int *ids, float *output, int tokenSize); + void embeddingForward(float *output, const std::vector &sequences); void lastLayerNormForward(float *input, float *output, int rows); private: diff --git a/src/models/qwen.cpp b/src/models/qwen.cpp index dbd03c3d..49145062 100644 --- a/src/models/qwen.cpp +++ b/src/models/qwen.cpp @@ -101,8 +101,8 @@ void QwenLLM::prepareAttnMask(int *ids, int step) { } template -void QwenLLM::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void QwenLLM::embeddingForward(int *ids, float *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template diff --git a/src/models/qwen.h b/src/models/qwen.h index 712f19f3..2f3bab6e 100644 --- a/src/models/qwen.h +++ b/src/models/qwen.h @@ -28,7 +28,7 @@ class QwenLLM : public CommonDecoder::prepareAttnMask(int *ids, int step) { } template -void Qwen2LLM::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void Qwen2LLM::embeddingForward(int *ids, float *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template -void Qwen2LLM::embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void Qwen2LLM::embeddingForward(int *ids, bfloat16_t *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template diff --git a/src/models/qwen2.h b/src/models/qwen2.h index 476784d9..b1d99771 100644 --- a/src/models/qwen2.h +++ b/src/models/qwen2.h @@ -34,8 +34,8 @@ class Qwen2LLM void prepareAttnMask(int *ids, int step); - void embeddingForward(int *ids, float *output, int batchSize, int seqLen); - void embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen); + void embeddingForward(int *ids, float *output, int tokenSize); + void embeddingForward(int *ids, bfloat16_t *output, int tokenSize); void lastLayerNormForward(float *input, float *output, int rows); void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows); diff --git a/src/models/yarn_llama.cpp b/src/models/yarn_llama.cpp index 85e58136..813bf706 100644 --- a/src/models/yarn_llama.cpp +++ b/src/models/yarn_llama.cpp @@ -89,8 +89,8 @@ void YaRNLlama::prepareAttnMask(int *ids, int step) { } template -void YaRNLlama::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void YaRNLlama::embeddingForward(int *ids, float *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template diff --git a/src/models/yarn_llama.h b/src/models/yarn_llama.h index 8c03425a..695c00f2 100644 --- a/src/models/yarn_llama.h +++ b/src/models/yarn_llama.h @@ -36,8 +36,8 @@ class YaRNLlama void prepareAttnMask(int *ids, int step); - void embeddingForward(int *ids, float *output, int batchSize, int seqLen); - void embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen); + void embeddingForward(int *ids, float *output, int tokenSize); + void embeddingForward(int *ids, bfloat16_t *output, int tokenSize); void lastLayerNormForward(float *input, float *output, int rows); void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows); diff --git a/tests/ut/token_embedding_test.cpp b/tests/ut/token_embedding_test.cpp index acf6dd6e..981b7607 100644 --- a/tests/ut/token_embedding_test.cpp +++ b/tests/ut/token_embedding_test.cpp @@ -23,10 +23,10 @@ #include "token_embedding_kernels.h" template -static void TestTokenEmbeddingKernel(const int vocabSize, const int hiddenSize, const int batchSize, const int seqLen) { +static void TestTokenEmbeddingKernel(const int vocabSize, const int hiddenSize, const int tokenSize) { WeiT *embTable = (WeiT *)aligned_alloc(64, vocabSize * hiddenSize * sizeof(WeiT)); - int *tokenId = (int *)aligned_alloc(64, batchSize * seqLen * sizeof(int)); - OutT *output = (OutT *)aligned_alloc(64, batchSize * seqLen * hiddenSize * sizeof(OutT)); + int *tokenId = (int *)aligned_alloc(64, tokenSize * sizeof(int)); + OutT *output = (OutT *)aligned_alloc(64, tokenSize * hiddenSize * sizeof(OutT)); for (int i = 0; i < vocabSize; i++) { for (int j = 0; j < hiddenSize; j++) { @@ -34,13 +34,13 @@ static void TestTokenEmbeddingKernel(const int vocabSize, const int hiddenSize, } } - for (int i = 0; i < batchSize * seqLen; i++) { + for (int i = 0; i < tokenSize; i++) { tokenId[i] = rand() % vocabSize; } - xft::tokenEmbedding(output, tokenId, embTable, batchSize, seqLen, hiddenSize); + xft::tokenEmbedding(output, tokenId, embTable, tokenSize, hiddenSize); - for (int i = 0; i < batchSize * seqLen; i++) { + for (int i = 0; i < tokenSize; i++) { int id = tokenId[i]; for (int j = 0; j < hiddenSize; j++) { EXPECT_FLOAT_EQ(float(output[i * hiddenSize + j]), float(embTable[id * hiddenSize + j])); @@ -54,18 +54,18 @@ static void TestTokenEmbeddingKernel(const int vocabSize, const int hiddenSize, #define UT_EMBEDDING(MN, OutT, WeiT, VS, HS, BS, SL) \ TEST(TokenEmbeddingKernel, MN##_BS##BS##_Lens##SL##_OutT##OutT##_WeiT##WeiT) { \ - TestTokenEmbeddingKernel(VS, HS, BS, SL); \ + TestTokenEmbeddingKernel((VS), (HS), (BS) * (SL)); \ } UT_EMBEDDING(Llama2_7B, float_t, float16_t, 32000, 4096, 1, 64); UT_EMBEDDING(Llama2_7B, float_t, float16_t, 32000, 4096, 1, 256); UT_EMBEDDING(Llama2_7B, float_t, float16_t, 32000, 4096, 1, 512); -UT_EMBEDDING(Llama2_7B, float_t, float16_t, 32000, 4096, 512, 512); +UT_EMBEDDING(Llama2_7B, float_t, float16_t, 32000, 4096, 8, 512); UT_EMBEDDING(Llama2_7B, float16_t, float16_t, 32000, 4096, 1, 64); UT_EMBEDDING(Llama2_7B, float16_t, float16_t, 32000, 4096, 1, 256); UT_EMBEDDING(Llama2_7B, float16_t, float16_t, 32000, 4096, 1, 512); -UT_EMBEDDING(Llama2_7B, float16_t, float16_t, 32000, 4096, 512, 512); +UT_EMBEDDING(Llama2_7B, float16_t, float16_t, 32000, 4096, 8, 512); // UT_EMBEDDING(Llama2_7B, bfloat16_t, bfloat16_t, 32000, 4096, 1, 64); // UT_EMBEDDING(Llama2_7B, bfloat16_t, bfloat16_t, 32000, 4096, 1, 256); From db0c4e9004a751b675303f3edd58ec52c9428e27 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 03/35] [Common] Move Martix into xft namespace. (#351) --- src/common/my_types.h | 4 +-- src/common/transformer_ctx.h | 8 ++--- src/layers/attention.h | 64 ++++++++++++++++----------------- src/layers/decoder_layer.h | 10 +++--- src/layers/dist_linear.h | 10 +++--- src/layers/mlp_chatglm2.h | 4 +-- src/layers/mlp_llama.h | 70 ++++++++++++++++++------------------ src/layers/mlp_standard.h | 34 +++++++++--------- src/models/common_decoder.h | 8 ++--- src/utils/debugger.h | 2 +- src/utils/decoder_util.h | 14 ++++---- src/utils/matmul_helper.h | 18 +++++----- 12 files changed, 123 insertions(+), 123 deletions(-) diff --git a/src/common/my_types.h b/src/common/my_types.h index 05861e74..0896c73f 100644 --- a/src/common/my_types.h +++ b/src/common/my_types.h @@ -36,7 +36,7 @@ void *xft_numa_alloc(size_t size); void xft_numa_free(void *start, size_t size); } -namespace hpj { +namespace xft { template struct is_quantization_type { @@ -366,4 +366,4 @@ class Vector { } uint64_t Size() { return size; } }; -} // namespace hpj +} // namespace xft diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 0a48c88b..0c1d77cf 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -109,10 +109,10 @@ struct DecoderContext { float *qkScores; // attention score // Please look into the comments in resize function to see how buffers are arranged - hpj::Matrix normBuf; // buf for the first layer norm - hpj::Matrix tmpBuf; // tmp buffer, same size as output - hpj::Matrix qkvMatMul; // query, key, value - hpj::Matrix imOut; // intermediate output + xft::Matrix normBuf; // buf for the first layer norm + xft::Matrix tmpBuf; // tmp buffer, same size as output + xft::Matrix qkvMatMul; // query, key, value + xft::Matrix imOut; // intermediate output MMHelper *mmHelper; diff --git a/src/layers/attention.h b/src/layers/attention.h index 0e3977eb..c69ddba5 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -132,7 +132,7 @@ class Attention { kvResponsibleCols * sizeof(float)); } - hpj::Matrix convertedqkvWeight; + xft::Matrix convertedqkvWeight; ctx->mmHelper->convertWeight(trans, hiddenSize, responsibleCols, concatBuf, concatScale, concatZero, convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum); ctx->mmHelper->packWeight(trans, convertedqkvWeight, qkvWeight); @@ -162,7 +162,7 @@ class Attention { // Weights for attention output // Horizontally split the weight, as the source (PyTorch weight) is transposed, thus looks like vertically - hpj::Matrix convertedWeight; + xft::Matrix convertedWeight; ctx->mmHelper->convertWeight(trans, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, attnOutWeight, attnOutScale, attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedWeight, attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, true); @@ -220,9 +220,9 @@ class Attention { bool useSelfAttn, bool doLnBefore, bool doLnAfter, int *positionIds = nullptr) { auto hiddenSize = ctx->hiddenSize; - hpj::Matrix inputBuffer(input, ctx->batchSize * inputSeqLen, hiddenSize, hiddenSize); - hpj::Matrix imBuffer(imBuf, ctx->batchSize * inputSeqLen, hiddenSize, hiddenSize); - hpj::Matrix outBuffer(output, ctx->batchSize * inputSeqLen, hiddenSize, hiddenSize); + xft::Matrix inputBuffer(input, ctx->batchSize * inputSeqLen, hiddenSize, hiddenSize); + xft::Matrix imBuffer(imBuf, ctx->batchSize * inputSeqLen, hiddenSize, hiddenSize); + xft::Matrix outBuffer(output, ctx->batchSize * inputSeqLen, hiddenSize, hiddenSize); float epsilon = ctx->epsilon; int headSize = ctx->attHeadSize; @@ -234,7 +234,7 @@ class Attention { int qkvStride = qkvCols; auto &qkvMatMul = ctx->qkvMatMul; - hpj::Matrix qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride); + xft::Matrix qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride); #ifdef DEBUG dbg.debugPrint("---- DecoderLayer.forward (useSelfAttn=%d) ----\n", useSelfAttn); @@ -267,9 +267,9 @@ class Attention { } t2.release(); - hpj::Matrix query(qkvGroupMatMul, 0, inputBuffer.Rows(), 0, qCols); - hpj::Matrix key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols); - hpj::Matrix value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols); + xft::Matrix query(qkvGroupMatMul, 0, inputBuffer.Rows(), 0, qCols); + xft::Matrix key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols); + xft::Matrix value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols); #ifdef DEBUG dbg.debugPrint("Q[%d,%d](%d):\n", query.Rows(), query.Cols(), query.Stride()); @@ -320,7 +320,7 @@ class Attention { } // For multiple nodes inference, not the whole result buffer - hpj::Matrix attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); + xft::Matrix attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); if (pastSeqLen == 0) { if (ctx->inputSeqLen > getFlashThresh()) { @@ -397,8 +397,8 @@ class Attention { protected: template - void selfAttentionBF16(DecoderContext *ctx, hpj::Matrix &query, hpj::Matrix &key, - hpj::Matrix &value, hpj::Matrix &result, KVCacheTensor &presentKey, + void selfAttentionBF16(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, + xft::Matrix &value, xft::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue) { int responsibleQHeads = this->endQHead - this->startQHead; int responsibleKVHeads = this->endKVHead - this->startKVHead; @@ -447,7 +447,7 @@ class Attention { // Copy all keys and values to KV cache template - void copyKVCache(DecoderContext *ctx, hpj::Matrix &key, hpj::Matrix &value, + void copyKVCache(DecoderContext *ctx, xft::Matrix &key, xft::Matrix &value, KVCacheTensor &presentKey, KVCacheTensor &presentValue, int pastSeqLen) { int batchSize = ctx->batchSize; int headSize = ctx->attHeadSize; @@ -475,7 +475,7 @@ class Attention { // Copy one head from key or value to K cache or V cache // bdx: batch index; hdx: head index template - void copyKVCache(DecoderContext *ctx, hpj::Matrix &kv, KVCacheTensor &presentKV, int pastSeqLen, + void copyKVCache(DecoderContext *ctx, xft::Matrix &kv, KVCacheTensor &presentKV, int pastSeqLen, int bdx, int hdx) { for (int seq = 0; seq < ctx->inputSeqLen; ++seq) { auto src = kv.Row(bdx * ctx->inputSeqLen + seq) + hdx * ctx->attHeadSize; @@ -529,8 +529,8 @@ class Attention { // Note: the result here is still the intermediate result from the whole attention scope template - void fusedAttention(DecoderContext *ctx, hpj::Matrix &query, hpj::Matrix &key, hpj::Matrix &value, - hpj::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue, + void fusedAttention(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, xft::Matrix &value, + xft::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue, const float *attnMask, int pastSeqLen) { // How many heads this task should do int responsibleHeads = this->endQHead - this->startQHead; @@ -574,8 +574,8 @@ class Attention { } template - void slimAttention(DecoderContext *ctx, hpj::Matrix &query, hpj::Matrix &key, hpj::Matrix &value, - hpj::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue, + void slimAttention(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, xft::Matrix &value, + xft::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue, const float *attnMask, int pastSeqLen, int mBlockSize, bool kvCopied) { // How many heads this task should do int responsibleHeads = this->endQHead - this->startQHead; @@ -668,8 +668,8 @@ class Attention { // When #heads is very few, need to shard each head to use more resources template - void crossAttnShardHead(DecoderContext *ctx, hpj::Matrix &query, hpj::Matrix &key, - hpj::Matrix &value, hpj::Matrix &result, KVCacheTensor &presentKey, + void crossAttnShardHead(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, + xft::Matrix &value, xft::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue, const float *attnMask, int pastSeqLen) { const int responsibleHeads = this->endQHead - this->startQHead; const int batchSize = ctx->batchSize; @@ -684,8 +684,8 @@ class Attention { } template - void flashAttention(DecoderContext *ctx, hpj::Matrix &query, hpj::Matrix &key, hpj::Matrix &value, - hpj::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue, + void flashAttention(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, xft::Matrix &value, + xft::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue, const float *attnMask, int pastSeqLen) { #if defined(AVX512_BF16_WEIGHT_ONLY_BF16) using AttnT = bfloat16_t; @@ -886,18 +886,18 @@ class Attention { } // query, key, value weighs - hpj::Matrix qkvWeight; - hpj::Vector qkvWeightScale; // if weight is int8 - hpj::Vector qkvWeightZero; // if weight is int8 - hpj::Vector qkvWeightSum; // if weight is int8 + xft::Matrix qkvWeight; + xft::Vector qkvWeightScale; // if weight is int8 + xft::Vector qkvWeightZero; // if weight is int8 + xft::Vector qkvWeightSum; // if weight is int8 // query, key, value bias - hpj::Vector qkvBias; + xft::Vector qkvBias; - hpj::Matrix attnOutputWeight; - hpj::Vector attnOutputWeightScale; // if weight is int8 - hpj::Vector attnOutputWeightZero; // if weight is int8 - hpj::Vector attnOutputWeightSum; // if weight is int8 - hpj::Vector attnOutputBias; + xft::Matrix attnOutputWeight; + xft::Vector attnOutputWeightScale; // if weight is int8 + xft::Vector attnOutputWeightZero; // if weight is int8 + xft::Vector attnOutputWeightSum; // if weight is int8 + xft::Vector attnOutputBias; // Query/Key post op QKPO_CLS qkpo; diff --git a/src/layers/decoder_layer.h b/src/layers/decoder_layer.h index 3cbc9d55..02e99a1e 100644 --- a/src/layers/decoder_layer.h +++ b/src/layers/decoder_layer.h @@ -112,13 +112,13 @@ class Decoder { } private: - void copyWeights(hpj::Matrix &w, int start_col, int end_col, const float *data) { - hpj::Matrix subW(w, 0, w.Rows(), start_col, end_col - start_col); + void copyWeights(xft::Matrix &w, int start_col, int end_col, const float *data) { + xft::Matrix subW(w, 0, w.Rows(), start_col, end_col - start_col); copyWeights(subW, data); } // Copy the transposed weight into the non-transposed matrix - void copyWeights(hpj::Matrix &w, const float *data) { + void copyWeights(xft::Matrix &w, const float *data) { for (int j = 0; j < w.Cols(); ++j) { for (int i = 0; i < w.Rows(); ++i) { w(i, j) = *data++; @@ -126,7 +126,7 @@ class Decoder { } } - void copyTransposed(hpj::Matrix &dst, hpj::Matrix &src) { + void copyTransposed(xft::Matrix &dst, xft::Matrix &src) { dst.Resize(src.Cols(), src.Rows()); for (int i = 0; i < dst.Rows(); ++i) { for (int j = 0; j < dst.Cols(); ++j) { @@ -136,7 +136,7 @@ class Decoder { } // Add bias to matrix - void biasAdd(hpj::Matrix &m, hpj::Vector &bias) { + void biasAdd(xft::Matrix &m, xft::Vector &bias) { float *pbias = bias.Data(); #pragma omp parallel for for (int i = 0; i < m.Rows(); ++i) { diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index 6bd581a6..41c17bf0 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -63,7 +63,7 @@ class DistLinear { scaleWeight.Resize(N); zeroWeight.Resize(N); - hpj::Matrix quantizedWeight; + xft::Matrix quantizedWeight; ctx->mmHelper->convertWeight( true, K, N, w + splitOffset * K, nullptr, nullptr, quantizedWeight, scaleWeight, zeroWeight, sumWeight); ctx->mmHelper->packWeight(true, quantizedWeight, weight); @@ -108,9 +108,9 @@ class DistLinear { int splitSize; int splitOffset; - hpj::Matrix weight; - hpj::Vector scaleWeight; // if weight is int8 - hpj::Vector zeroWeight; // if weight is int8 - hpj::Vector sumWeight; // if weight is int8 + xft::Matrix weight; + xft::Vector scaleWeight; // if weight is int8 + xft::Vector zeroWeight; // if weight is int8 + xft::Vector sumWeight; // if weight is int8 float *bias = nullptr; }; diff --git a/src/layers/mlp_chatglm2.h b/src/layers/mlp_chatglm2.h index dbc83cd8..aed14763 100644 --- a/src/layers/mlp_chatglm2.h +++ b/src/layers/mlp_chatglm2.h @@ -33,7 +33,7 @@ class ChatGLM2MLP : public LlamaMLP { REQUIRES(ctx->actType == DecoderContext::SWIGLU, "unsupported activation."); // Vertically split the gate weight and up weight - hpj::Matrix convertedGateWeight, convertedUpWeight, convertedDownWeight; + xft::Matrix convertedGateWeight, convertedUpWeight, convertedDownWeight; auto range = SplitUtil::getTaskRange(intermediateSize, ctx->numSplit, ctx->splitIdx); int colSplit = range.second - range.first; @@ -83,7 +83,7 @@ class ChatGLM2MLP : public LlamaMLP { colSplit * sizeof(OriWeiT)); weightPTR += intermediateSize; } - hpj::Matrix quantizedCatWeights; + xft::Matrix quantizedCatWeights; ctx->mmHelper->convertWeight(trans, hiddenSize, colSplitStride, gateUpW, nullptr, nullptr, quantizedCatWeights, this->catWeightsScale, this->catWeightsZero, this->catWeightsSum); this->catWeights.Resize(quantizedCatWeights.Rows(), quantizedCatWeights.Cols()); diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index d3f102eb..06c7a1ef 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -58,7 +58,7 @@ class LlamaMLP : public SingletonBase> { "unsupported activation."); // Vertically split the gate weight and up weight - hpj::Matrix quantizedGateWeight, quantizedUpWeight, quantizedDownWeight; + xft::Matrix quantizedGateWeight, quantizedUpWeight, quantizedDownWeight; auto it = SplitUtil::getTaskRange(imSize, ctx->numSplit, ctx->splitIdx); downWeight.Resize(it.second - it.first, hiddenSize); @@ -74,7 +74,7 @@ class LlamaMLP : public SingletonBase> { ctx->mmHelper->packWeight(trans, quantizedGateWeight, gateWeight); ctx->mmHelper->packWeight(trans, quantizedUpWeight, upWeight); } else { - hpj::Matrix quantizedCatWeights; + xft::Matrix quantizedCatWeights; catGateUpWeights(quantizedGateWeight, quantizedUpWeight, gateWeightScale, gateWeightZero, gateWeightSum, upWeightScale, upWeightZero, upWeightSum, quantizedCatWeights, catWeightsScale, catWeightsZero, catWeightsSum); @@ -119,9 +119,9 @@ class LlamaMLP : public SingletonBase> { static_assert(sizeof(ctx->normBuf.Data()[0]) >= sizeof(ImT), "normBuff is not big enough!"); - hpj::Matrix inBuffer(input, M, hiddenSize, iStride); - hpj::Matrix outBuffer(output, M, hiddenSize, oStride); - hpj::Matrix normBuffer( + xft::Matrix inBuffer(input, M, hiddenSize, iStride); + xft::Matrix outBuffer(output, M, hiddenSize, oStride); + xft::Matrix normBuffer( (ImT *)ctx->normBuf.Data(), ctx->normBuf.Rows(), ctx->normBuf.Cols(), ctx->normBuf.Stride()); if (doLnBefore == true) { @@ -137,7 +137,7 @@ class LlamaMLP : public SingletonBase> { #endif if (!enableCATMLP()) { - hpj::Matrix imBuffer( + xft::Matrix imBuffer( (ImT *)ctx->imOut.Data(), ctx->imOut.Rows(), ctx->imOut.Cols(), ctx->imOut.Stride()); gateProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer); @@ -162,13 +162,13 @@ class LlamaMLP : public SingletonBase> { } else { auto M = normBuffer.Rows(); auto N = catWeights.Cols(); - hpj::Matrix imBuffer((ImT *)ctx->imOut.Data(), M, N, N); + xft::Matrix imBuffer((ImT *)ctx->imOut.Data(), M, N, N); // Need to allocate extra buffer as oneDNN does not support the case of stride > cols const int cols = N / 2; auto bufSize = sizeof(ImT) * M * cols; ImT *t = (ImT *)SimpleMemPool::instance().getBuffer("mlp_silu", bufSize); - hpj::Matrix siluBuf(t, M, cols, cols); + xft::Matrix siluBuf(t, M, cols, cols); #ifdef DEBUG dbg.debugPrint( ">>> enableCATMLP imBuffer: [%d, %d] (%d)\n", imBuffer.Rows(), imBuffer.Cols(), imBuffer.Stride()); @@ -199,7 +199,7 @@ class LlamaMLP : public SingletonBase> { } private: - void gateProj(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output) { + void gateProj(DecoderContext *ctx, xft::Matrix &input, xft::Matrix &output) { TimeLine t("GateProj"); assert(input.Rows() == output.Rows()); @@ -228,7 +228,7 @@ class LlamaMLP : public SingletonBase> { } } - void upProj(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output) { + void upProj(DecoderContext *ctx, xft::Matrix &input, xft::Matrix &output) { TimeLine t("UpProj"); assert(input.Rows() == output.Rows()); @@ -248,8 +248,8 @@ class LlamaMLP : public SingletonBase> { ctx->mmHelper->compute_resmul(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, C, ldc); } - void downProj(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output, - hpj::Matrix &residential, bool isMaster) { + void downProj(DecoderContext *ctx, xft::Matrix &input, xft::Matrix &output, + xft::Matrix &residential, bool isMaster) { TimeLine t("DownProj"); assert(input.Rows() == output.Rows()); @@ -276,7 +276,7 @@ class LlamaMLP : public SingletonBase> { } template - void catGateUpProj(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output, hpj::Matrix &siluBuf) { + void catGateUpProj(DecoderContext *ctx, xft::Matrix &input, xft::Matrix &output, xft::Matrix &siluBuf) { TimeLine t("catGateUpProj"); assert(input.Rows() == output.Rows()); @@ -308,11 +308,11 @@ class LlamaMLP : public SingletonBase> { } } - void catGateUpWeights(hpj::Matrix &gateWeight, hpj::Matrix &upWeight, - hpj::Vector &gateWeightScale, hpj::Vector &gateWeightZero, hpj::Vector &gateWeightSum, - hpj::Vector &upWeightScale, hpj::Vector &upWeightZero, hpj::Vector &upWeightSum, - hpj::Matrix &catWeights, hpj::Vector &catWeightsScale, hpj::Vector &catWeightsZero, - hpj::Vector &catWeightsSum) { + void catGateUpWeights(xft::Matrix &gateWeight, xft::Matrix &upWeight, + xft::Vector &gateWeightScale, xft::Vector &gateWeightZero, xft::Vector &gateWeightSum, + xft::Vector &upWeightScale, xft::Vector &upWeightZero, xft::Vector &upWeightSum, + xft::Matrix &catWeights, xft::Vector &catWeightsScale, xft::Vector &catWeightsZero, + xft::Vector &catWeightsSum) { catWeights.Resize(gateWeight.Rows(), gateWeight.Cols() + upWeight.Cols()); catWeightsScale.Resize(gateWeightScale.Size() + upWeightScale.Size()); catWeightsZero.Resize(gateWeightZero.Size() + upWeightZero.Size()); @@ -345,25 +345,25 @@ class LlamaMLP : public SingletonBase> { } protected: - hpj::Matrix gateWeight; - hpj::Vector gateWeightScale; // For int8_t weight - hpj::Vector gateWeightZero; // For int8_t weight - hpj::Vector gateWeightSum; // For int8_t weight - hpj::Matrix upWeight; - hpj::Vector upWeightScale; // For int8_t weight - hpj::Vector upWeightZero; // For int8_t weight - hpj::Vector upWeightSum; // For int8_t weight - hpj::Matrix catWeights; - hpj::Vector catWeightsScale; // For int8_t weight - hpj::Vector catWeightsZero; // For int8_t weight - hpj::Vector catWeightsSum; // For int8_t weight - hpj::Matrix downWeight; - hpj::Vector downWeightScale; // For int8_t weight - hpj::Vector downWeightZero; // For int8_t weight - hpj::Vector downWeightSum; // For int8_t weight + xft::Matrix gateWeight; + xft::Vector gateWeightScale; // For int8_t weight + xft::Vector gateWeightZero; // For int8_t weight + xft::Vector gateWeightSum; // For int8_t weight + xft::Matrix upWeight; + xft::Vector upWeightScale; // For int8_t weight + xft::Vector upWeightZero; // For int8_t weight + xft::Vector upWeightSum; // For int8_t weight + xft::Matrix catWeights; + xft::Vector catWeightsScale; // For int8_t weight + xft::Vector catWeightsZero; // For int8_t weight + xft::Vector catWeightsSum; // For int8_t weight + xft::Matrix downWeight; + xft::Vector downWeightScale; // For int8_t weight + xft::Vector downWeightZero; // For int8_t weight + xft::Vector downWeightSum; // For int8_t weight // LlamaRMSNorm param - hpj::Vector normWeight; + xft::Vector normWeight; #ifdef DEBUG Debugger dbg; diff --git a/src/layers/mlp_standard.h b/src/layers/mlp_standard.h index c0f54cbe..5f26b069 100644 --- a/src/layers/mlp_standard.h +++ b/src/layers/mlp_standard.h @@ -37,7 +37,7 @@ class MLP { int intermediateSize = ctx->intermediateSize; // Vertically split intermediate(FC1) weight - hpj::Matrix quantizedIntermediateWeight; + xft::Matrix quantizedIntermediateWeight; ctx->mmHelper->convertWeight(ctx, trans, hiddenSize, intermediateSize, _imWeight, nullptr, nullptr, true, quantizedIntermediateWeight, intermediateWeightScale, intermediateWeightZero, intermediateWeightSum); ctx->mmHelper->packWeight(trans, quantizedIntermediateWeight, intermediateWeight); @@ -49,7 +49,7 @@ class MLP { memcpy(intermediateBias.Data(), _imBias + colsPerSplit * ctx->splitIdx, sizeof(float) * colsPerSplit); // Horizontally split the output(FC2) weight - hpj::Matrix quantizedOutputWeight; + xft::Matrix quantizedOutputWeight; ctx->mmHelper->convertWeight(ctx, trans, intermediateSize, hiddenSize, _outputWeight, nullptr, nullptr, false, quantizedOutputWeight, outputWeightScale, outputWeightZero, outputWeightSum); ctx->mmHelper->packWeight(trans, quantizedOutputWeight, outputWeight); @@ -79,7 +79,7 @@ class MLP { void forward(DecoderContext *ctx, float *input, float *output, int iStride, int oStride, bool doLnBefore) { TimeLine t("StandardMLP"); int M = ctx->batchSize * ctx->inputSeqLen; - hpj::Matrix outBuffer(output, M, ctx->hiddenSize, ctx->hiddenSize); + xft::Matrix outBuffer(output, M, ctx->hiddenSize, ctx->hiddenSize); auto &resultBuffer1 = outBuffer; auto &resultBuffer2 = ctx->tmpBuf; @@ -163,14 +163,14 @@ class MLP { } protected: - void intermediate_relu(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output) { + void intermediate_relu(DecoderContext *ctx, xft::Matrix &input, xft::Matrix &output) { ctx->mmHelper->compute_biasadd_relu(false, input.Rows(), output.Cols(), input.Cols(), 1.0f, input.Data(), input.Stride(), intermediateWeight.Data(), intermediateWeightScale.Data(), intermediateWeightZero.Data(), intermediateWeightSum.Data(), 0.0f, output.Data(), output.Stride(), intermediateBias.Data()); } - void intermediate_gelu(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output) { + void intermediate_gelu(DecoderContext *ctx, xft::Matrix &input, xft::Matrix &output) { ctx->mmHelper->compute(false, input.Rows(), output.Cols(), input.Cols(), 1.0f, input.Data(), input.Stride(), intermediateWeight.Data(), intermediateWeightScale.Data(), intermediateWeightZero.Data(), intermediateWeightSum.Data(), 0.0f, output.Data(), output.Stride()); @@ -223,20 +223,20 @@ class MLP { } // private: - hpj::Matrix intermediateWeight; - hpj::Vector intermediateWeightScale; - hpj::Vector intermediateWeightZero; - hpj::Vector intermediateWeightSum; - hpj::Vector intermediateBias; - - hpj::Matrix outputWeight; - hpj::Vector outputWeightScale; - hpj::Vector outputWeightZero; - hpj::Vector outputWeightSum; - hpj::Vector outputBias; + xft::Matrix intermediateWeight; + xft::Vector intermediateWeightScale; + xft::Vector intermediateWeightZero; + xft::Vector intermediateWeightSum; + xft::Vector intermediateBias; + + xft::Matrix outputWeight; + xft::Vector outputWeightScale; + xft::Vector outputWeightZero; + xft::Vector outputWeightSum; + xft::Vector outputBias; // layerNorm param - hpj::Vector gamma2, beta2; + xft::Vector gamma2, beta2; #ifdef DEBUG Debugger dbg; diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 1b8f1179..0d5bd85a 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -220,7 +220,7 @@ class CommonDecoder : public AbstractDecoder { this->inputTokens = nullptr; this->maskSize = 0; this->attnMask = nullptr; - actBuffers.reset(new hpj::Matrix()); + actBuffers.reset(new xft::Matrix()); // Context DecoderContext *ctx = getDecoderContext(layers, hiddenSize, size_per_head, attHeadNum, kvHeadNum, imSize, act, @@ -620,13 +620,13 @@ class CommonDecoder : public AbstractDecoder { int getInitSeqLen() { return initSeqLen; } std::tuple, std::shared_ptr>, - std::shared_ptr>> + std::shared_ptr>> getSharedResources() { return std::make_tuple(context, kvCacheMgr, actBuffers); } void setSharedResources(const std::tuple, std::shared_ptr>, - std::shared_ptr>> &r) { + std::shared_ptr>> &r) { this->context = std::get<0>(r); this->kvCacheMgr = std::get<1>(r); this->actBuffers = std::get<2>(r); @@ -1000,7 +1000,7 @@ class CommonDecoder : public AbstractDecoder { using MlpOutT = typename MlpTypeExtractor::Tout; // Activation buffers (declared as float, but the actual data type may be different) - std::shared_ptr> actBuffers; + std::shared_ptr> actBuffers; protected: // Components most LLMs may use diff --git a/src/utils/debugger.h b/src/utils/debugger.h index c244a165..ca74d115 100644 --- a/src/utils/debugger.h +++ b/src/utils/debugger.h @@ -115,7 +115,7 @@ class Debugger { } template - void dumpMatrix(hpj::Matrix &m, bool print_all = false) { + void dumpMatrix(xft::Matrix &m, bool print_all = false) { std::ostringstream oss; uint64_t rows = m.Rows(); uint64_t cols = m.Cols(); diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 0989b4da..eb94ae35 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -35,7 +35,7 @@ extern bool enableSkipMsk(); class DecoderUtil { public: #if __AVX512F__ - static void rmsNorm(hpj::Matrix &x, hpj::Matrix &y, hpj::Vector &normWeight, float epsilon) { + static void rmsNorm(xft::Matrix &x, xft::Matrix &y, xft::Vector &normWeight, float epsilon) { TimeLine t("DecoderUtil::rmsNorm"); float *pweight = normWeight.Data(); int size = x.Cols(); @@ -86,7 +86,7 @@ class DecoderUtil { } static void layerNorm( - hpj::Matrix &x, hpj::Matrix &y, hpj::Vector &gamma, hpj::Vector &beta) { + xft::Matrix &x, xft::Matrix &y, xft::Vector &gamma, xft::Vector &beta) { TimeLine t("DecoderUtil::layerNorm"); float *pgamma = gamma.Data(); float *pbeta = beta.Data(); @@ -147,7 +147,7 @@ class DecoderUtil { } // Layer norm for small matrix with just a one Rrow - static void LayerNormOneRow(hpj::Matrix &x, hpj::Matrix &y, float *pgamma, float *pbeta, int size) { + static void LayerNormOneRow(xft::Matrix &x, xft::Matrix &y, float *pgamma, float *pbeta, int size) { TimeLine t("DecoderUtil::LayerNormOneRow"); constexpr int BLKSIZE = 128; const int splitSize = (size > BLKSIZE && size % BLKSIZE == 0) ? BLKSIZE : size; // size of each split @@ -209,8 +209,8 @@ class DecoderUtil { } } #else - static void layerNorm(DecoderContext *ctx, hpj::Matrix &x, hpj::Matrix &y, hpj::Vector &gamma, - hpj::Vector &beta) { + static void layerNorm(DecoderContext *ctx, xft::Matrix &x, xft::Matrix &y, xft::Vector &gamma, + xft::Vector &beta) { TimeLine t("DecoderUtil::layerNorm"); assert(x.Rows() == ctx->batchSize * ctx->inputSeqLen); assert(x.Cols() == ctx->hiddenSize); @@ -471,7 +471,7 @@ class DecoderUtil { // compute silu on the left half and then add it with the right half template - static void siluSum(hpj::Matrix &src, hpj::Matrix &dst) { + static void siluSum(xft::Matrix &src, xft::Matrix &dst) { __m512 one = _mm512_set1_ps(1.f); __m512 negOne = _mm512_set1_ps(-1.f); int M = src.Rows(); @@ -497,7 +497,7 @@ class DecoderUtil { // compute gelu on the left half and then add it with the right half template - static void geluSum(hpj::Matrix &src, hpj::Matrix &dst) { + static void geluSum(xft::Matrix &src, xft::Matrix &dst) { const __m512 c1 = _mm512_set1_ps(0.044715f); const __m512 c2 = _mm512_set1_ps(0.7978845608f); const __m512 vone = _mm512_set1_ps(1.0f); diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 5ad5fbc9..39839c35 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -95,8 +95,8 @@ class MMHelper { template void convertWeight(bool trans, int rows, int cols, const OriWeiT *weight, const float *scales, const float *zeros, - int splitOffset, int splitSize, bool verticalSplit, hpj::Matrix &convertedWeight, - hpj::Vector &scaleWeight, hpj::Vector &zeroWeight, hpj::Vector &sumWeight, + int splitOffset, int splitSize, bool verticalSplit, xft::Matrix &convertedWeight, + xft::Vector &scaleWeight, xft::Vector &zeroWeight, xft::Vector &sumWeight, bool unused) { // transform trans cases to no trans cases if (trans) { @@ -261,8 +261,8 @@ class MMHelper { template void convertWeight(bool trans, int rows, int cols, const OriWeiT *weight, const float *scales, const float *zeros, - int numSplit, int splitIdx, bool verticalSplit, hpj::Matrix &quantizedWeight, - hpj::Vector &scaleWeight, hpj::Vector &zeroWeight, hpj::Vector &sumWeight) { + int numSplit, int splitIdx, bool verticalSplit, xft::Matrix &quantizedWeight, + xft::Vector &scaleWeight, xft::Vector &zeroWeight, xft::Vector &sumWeight) { int totalSize = verticalSplit ? cols : rows; std::pair range = SplitUtil::getTaskRange(totalSize, numSplit, splitIdx); @@ -275,22 +275,22 @@ class MMHelper { template void convertWeight(bool trans, int rows, int cols, const OriWeiT *weight, const float *scales, const float *zeros, - hpj::Matrix &quantizedWeight, hpj::Vector &scaleWeight, hpj::Vector &zeroWeight, - hpj::Vector &sumWeight) { + xft::Matrix &quantizedWeight, xft::Vector &scaleWeight, xft::Vector &zeroWeight, + xft::Vector &sumWeight) { convertWeight(trans, rows, cols, weight, scales, zeros, 1, 0, true, quantizedWeight, scaleWeight, zeroWeight, sumWeight); } template void convertWeight(DecoderContext *ctx, bool trans, int rows, int cols, const OriWeiT *weight, const float *scales, - const float *zeros, bool verticalSplit, hpj::Matrix &quantizedWeight, hpj::Vector &scaleWeight, - hpj::Vector &zeroWeight, hpj::Vector &sumWeight) { + const float *zeros, bool verticalSplit, xft::Matrix &quantizedWeight, xft::Vector &scaleWeight, + xft::Vector &zeroWeight, xft::Vector &sumWeight) { convertWeight(trans, rows, cols, weight, scales, zeros, ctx->numSplit, ctx->splitIdx, verticalSplit, quantizedWeight, scaleWeight, zeroWeight, sumWeight); } template - void packWeight(bool trans, hpj::Matrix &src, hpj::Matrix &weight) { + void packWeight(bool trans, xft::Matrix &src, xft::Matrix &weight) { int K = trans ? src.Cols() : src.Rows(); int N = trans ? src.Rows() : src.Cols(); From a4f4b25c693a16021ed2a551cda59c604cab48d7 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 04/35] [Layer] Remove unused functions in Decoder layer (#353) --- src/layers/decoder_layer.h | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/src/layers/decoder_layer.h b/src/layers/decoder_layer.h index 02e99a1e..c86730ee 100644 --- a/src/layers/decoder_layer.h +++ b/src/layers/decoder_layer.h @@ -111,43 +111,6 @@ class Decoder { mlp.forward(ctx, input, output, iStride, oStride, doLnBefore); } -private: - void copyWeights(xft::Matrix &w, int start_col, int end_col, const float *data) { - xft::Matrix subW(w, 0, w.Rows(), start_col, end_col - start_col); - copyWeights(subW, data); - } - - // Copy the transposed weight into the non-transposed matrix - void copyWeights(xft::Matrix &w, const float *data) { - for (int j = 0; j < w.Cols(); ++j) { - for (int i = 0; i < w.Rows(); ++i) { - w(i, j) = *data++; - } - } - } - - void copyTransposed(xft::Matrix &dst, xft::Matrix &src) { - dst.Resize(src.Cols(), src.Rows()); - for (int i = 0; i < dst.Rows(); ++i) { - for (int j = 0; j < dst.Cols(); ++j) { - dst(i, j) = src(j, i); - } - } - } - - // Add bias to matrix - void biasAdd(xft::Matrix &m, xft::Vector &bias) { - float *pbias = bias.Data(); -#pragma omp parallel for - for (int i = 0; i < m.Rows(); ++i) { - float *p = m.Row(i); -#pragma omp simd - for (int j = 0; j < m.Cols(); ++j) { - p[j] += pbias[j]; - } - } - } - private: // For debug usage int layerIdx; From 45bcfa3aac3defb1337410d2259c92e70df1e502 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 05/35] [Model] Fix compile error of embeddingForward in YaRNLlama (#358) --- src/models/yarn_llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/yarn_llama.cpp b/src/models/yarn_llama.cpp index 813bf706..1d035951 100644 --- a/src/models/yarn_llama.cpp +++ b/src/models/yarn_llama.cpp @@ -94,8 +94,8 @@ void YaRNLlama::embeddingForward(int *ids, float *output, int to } template -void YaRNLlama::embeddingForward(int *ids, bfloat16_t *output, int batchSize, int seqLen) { - embedding->forward(ids, output, batchSize, seqLen); +void YaRNLlama::embeddingForward(int *ids, bfloat16_t *output, int tokenSize) { + embedding->forward(ids, output, tokenSize); } template From a704873369270d7d478f75736eab9997bf7e2105 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 06/35] [Common] Add sampling params into group seq. (#356) --- src/common/sequence.h | 128 +++++++++++++++++++------------- src/searchers/sampling_params.h | 36 +++++++++ 2 files changed, 112 insertions(+), 52 deletions(-) create mode 100644 src/searchers/sampling_params.h diff --git a/src/common/sequence.h b/src/common/sequence.h index d4d82db1..53cc84c7 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -16,6 +16,7 @@ #include #include +#include "sampling_params.h" #include /* @@ -42,16 +43,22 @@ namespace xft { // The SequenceMeta is one sequence of batch inputs and includes the generated tokens. class SequenceMeta { public: - SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen, std::vector &_inputTokens) - : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0), step(0) { - inputTokens.reserve(_inputSeqLen); - inputTokens.assign(_inputTokens.begin(), _inputTokens.end()); - nextTokens.reserve(_inputSeqLen); + SequenceMeta(std::vector &_inputTokens) + : sequenceID(SequencePool::getInstance().createSequenceID()) + , inputSeqLen(_inputTokens.size()) + , inputTokens(_inputTokens) + , pastSeqLen(0) + , step(0) { + nextTokens.reserve(inputSeqLen); setPastSeqLen(getPastSeqLen()); } - SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen) - : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), inputTokens(_inputSeqLen, 0), pastSeqLen(0), step(0) { + SequenceMeta(int32_t _inputSeqLen) + : sequenceID(SequencePool::getInstance().createSequenceID()) + , inputSeqLen(_inputSeqLen) + , inputTokens(_inputSeqLen, 0) + , pastSeqLen(0) + , step(0) { nextTokens.reserve(_inputSeqLen); } @@ -123,25 +130,47 @@ class SequenceMeta { // For beam searcher class SequenceGroupMeta { public: - SequenceGroupMeta(std::vector &seq) { - size_per_group = seq.size(); - sequences.reserve(size_per_group); + SequenceGroupMeta(std::vector &seq, SamplingMeta samplingMeta_) : samplingMeta(samplingMeta_) { + assert(samplingMeta.config.numBeams == seq.size()); + sequences.reserve(samplingMeta.config.numBeams); sequences.assign(seq.begin(), seq.end()); + groupID = sequences[0].getSequenceID(); + } + + SequenceGroupMeta(std::vector &_inputTokens, SamplingMeta samplingMeta_) : samplingMeta(samplingMeta_) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_inputTokens)); + } + groupID = sequences[0].getSequenceID(); + } + + SequenceGroupMeta(int32_t _inputSeqLen, SamplingMeta samplingMeta_) : samplingMeta(samplingMeta_) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_inputSeqLen)); + } + groupID = sequences[0].getSequenceID(); } - int32_t getGroupSize() { return size_per_group; } + int32_t getGroupID() { return groupID; } + + int32_t getGroupSize() { return samplingMeta.config.numBeams; } SequenceMeta *get() { return sequences.data(); } SequenceMeta *get(int index) { return &sequences[index]; } - SequenceMeta &operator[](int index) { - return sequences[index]; - } + SequenceMeta &operator[](int index) { return sequences[index]; } private: - int32_t size_per_group; + // using 1st sequence ID as group ID. + int32_t groupID; + + // The number of sequences in the group, equal to num beams + int32_t size; std::vector sequences; + SamplingMeta samplingMeta; }; // SequencePool @@ -166,61 +195,56 @@ class SequencePool { return id; } - SequenceGroupMeta *newMeta(int32_t sequenceID, int32_t inputSeqLen, std::vector &inputTokens) { - std::vector sequence; - sequence.emplace_back(SequenceMeta(sequenceID, inputSeqLen, inputTokens)); + SequenceGroupMeta *newMeta(std::vector &inputTokens, SamplingMeta samplingMeta_) { + std::vector sequences; + sequences.emplace_back(SequenceMeta(inputTokens)); - auto *group = new SequenceGroupMeta(sequence); + auto *group = new SequenceGroupMeta(sequences, samplingMeta_); return group; } - SequenceGroupMeta *newMeta(int32_t sequenceID, int32_t inputSeqLen) { - std::vector sequence; - sequence.emplace_back(SequenceMeta(sequenceID, inputSeqLen)); + SequenceGroupMeta *newMeta(int32_t inputSeqLen, SamplingMeta samplingMeta_) { + std::vector sequences; + sequences.emplace_back(SequenceMeta(inputSeqLen)); - auto *group = new SequenceGroupMeta(sequence); + auto *group = new SequenceGroupMeta(sequences, samplingMeta_); return group; } - SequenceGroupMeta *newGroupMeta(std::vector &sequenceIDs, std::vector &inputSeqLens, - std::vector> &inputTokens) { - assert(sequenceIDs.size() == inputSeqLens.size()); - assert(sequenceIDs.size() == inputTokens.size()); - + SequenceGroupMeta *newGroupMeta(std::vector> &inputTokens, SamplingMeta samplingMeta_) { std::vector sequences; - for (int i = 0; i < sequenceIDs.size(); ++i) { - sequences.emplace_back(SequenceMeta(sequenceIDs[i], inputSeqLens[i], inputTokens[i])); + for (int i = 0; i < inputTokens.size(); ++i) { + sequences.emplace_back(SequenceMeta(inputTokens[i])); } - auto *group = new SequenceGroupMeta(sequences); + auto *group = new SequenceGroupMeta(sequences, samplingMeta_); return group; } - SequenceGroupMeta *newGroupMeta(std::vector &sequenceIDs, std::vector &inputSeqLens) { - assert(sequenceIDs.size() == inputSeqLens.size()); + SequenceGroupMeta *newGroupMeta(std::vector &inputSeqLens, SamplingMeta samplingMeta_) { std::vector sequences; - for (int i = 0; i < sequenceIDs.size(); ++i) { - sequences.emplace_back(SequenceMeta(sequenceIDs[i], inputSeqLens[i])); + for (int i = 0; i < inputSeqLens.size(); ++i) { + sequences.emplace_back(SequenceMeta(inputSeqLens[i])); } - auto *group = new SequenceGroupMeta(sequences); + auto *group = new SequenceGroupMeta(sequences, samplingMeta_); return group; } - // Use first sequenceID if num_beam = 4 - bool add(int32_t sequenceID, SequenceGroupMeta *sequence, bool force = false) { + bool add(SequenceGroupMeta *sequenceGroup, bool force = false) { + int32_t groupID = sequenceGroup->getGroupID(); bool isSuccess = false; if (force) { - auto it = hub.find(sequenceID); + auto it = hub.find(groupID); if (it != hub.end()) { remove(it->first, true); } - hub[sequenceID] = sequence; + hub[groupID] = sequenceGroup; isSuccess = true; } else { - bool exist = has(sequenceID); + bool exist = has(groupID); if (!exist) { - hub[sequenceID] = sequence; + hub[groupID] = sequenceGroup; isSuccess = true; } } @@ -228,10 +252,10 @@ class SequencePool { return isSuccess; } - bool has(int32_t sequenceID) const { return hub.find(sequenceID) != hub.end(); } + bool has(int32_t groupID) const { return hub.find(groupID) != hub.end(); } - SequenceGroupMeta *get(int32_t sequenceID) const { - auto it = hub.find(sequenceID); + SequenceGroupMeta *get(int32_t groupID) const { + auto it = hub.find(groupID); if (it != hub.end()) { return it->second; } else { @@ -239,25 +263,25 @@ class SequencePool { } } - bool remove(int32_t sequenceID, bool deep = false) { + bool remove(int32_t groupID, bool deep = false) { bool isSuccess = false; - if (has(sequenceID)) { + if (has(groupID)) { if (deep == true) { - auto it = hub.find(sequenceID); + auto it = hub.find(groupID); if (it != hub.end()) { delete it->second; } } - isSuccess = hub.erase(sequenceID); + isSuccess = hub.erase(groupID); } return isSuccess; } - bool replace(int32_t sequenceID, SequenceGroupMeta *sequences) { + bool replace(int32_t groupID, SequenceGroupMeta *sequenceGroup) { bool isSuccess = false; - auto it = hub.find(sequenceID); + auto it = hub.find(groupID); if (it != hub.end()) { remove(it->first, true); - hub[sequenceID] = sequences; + hub[groupID] = sequenceGroup; isSuccess = true; } diff --git a/src/searchers/sampling_params.h b/src/searchers/sampling_params.h new file mode 100644 index 00000000..1c298f01 --- /dev/null +++ b/src/searchers/sampling_params.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once +#include +#include "abstract_searcher.h" + +namespace xft { +struct SamplingMeta { + bool done; + std::vector> stopWordsList; + std::vector stopWordsIndex; + std::vector cachedRepetVec; + SearcherConfig config; + + SamplingMeta() { done = false; } + + SamplingMeta(SearcherConfig config, std::vector> stopWordsList_) + : config(config), stopWordsList(stopWordsList_) { + done = false; + // TODO: stopWordsIndex is not initialized + } +}; + +}; // namespace xft \ No newline at end of file From 987a87442a67c3abeb066dad88c85d8b3bbfaf43 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 07/35] [Util] Remove DecoderContext in computeSoftmax (#362) --- src/utils/decoder_util.h | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index eb94ae35..835dba48 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -247,7 +247,7 @@ class DecoderUtil { #endif // General version - static void computeSoftmax(DecoderContext *ctx, float *data, const float *attnMask, int size) { + static void computeSoftmax(float *data, const float *attnMask, int size, float scale) { int vecs = (size + 15) / 16; // how many avx512 vectors __mmask16 tailMask = (size % 16 == 0 ? 0xffff : (1 << (size % 16)) - 1); // mask of last vector @@ -256,7 +256,7 @@ class DecoderUtil { // maxVal is used to avoid exp(x) = inf float maxVal = std::numeric_limits::lowest(); __m512 vmax = _mm512_set1_ps(maxVal); - __m512 vfactor = _mm512_set1_ps(ctx->attFactor); + __m512 vfactor = _mm512_set1_ps(scale); int i = 0; for (i = 0; i < vecs; ++i) { @@ -334,7 +334,7 @@ class DecoderUtil { } // Softmax: skip the calculation when attention mask is the lowest value - static void softmaxSkipMask(DecoderContext *ctx, float *data, const float *attnMask, int size) { + static void softmaxSkipMask(float *data, const float *attnMask, int size, float scale) { int vecs = (size + 15) / 16; // how many avx512 vectors __mmask16 tailMask = (size % 16 == 0 ? 0xffff : (1 << (size % 16)) - 1); // mask of last vector @@ -345,7 +345,7 @@ class DecoderUtil { float maxVal = std::numeric_limits::lowest(); __m512 vlowest = _mm512_set1_ps(maxVal); __m512 vmax = _mm512_set1_ps(maxVal); - __m512 vfactor = _mm512_set1_ps(ctx->attFactor); + __m512 vfactor = _mm512_set1_ps(scale); int i = 0; for (i = 0; i < vecs; ++i) { @@ -396,31 +396,6 @@ class DecoderUtil { } } - // input and output are both in qkScores - // attnMask: attention mask with the shape of (bs, 1, queryLen, keyLen) - // Note: the source has the shape of (bs, attHeadNum/num_spit, queryLen, keyLen) - static void computeSoftmax(DecoderContext *ctx, const float *attnMask, int queryLen, int keyLen, int stride = -1) { - TimeLine t("DecoderUtil::computeSoftmax"); - const int batchStride = queryLen * keyLen; - if (stride == -1) { stride = keyLen; } - - auto range = SplitUtil::getTaskRange(ctx->attHeadNum, ctx->numSplit, ctx->splitIdx); - int responsibleHeads = range.second - range.first; - -#pragma omp parallel for collapse(2) - for (int b = 0; b < ctx->batchSize; ++b) { - for (int i = 0; i < responsibleHeads; ++i) { - int idx = b * responsibleHeads + i; - float *result = ctx->qkScores + idx * queryLen * stride; - - for (int seq = 0; seq < queryLen; ++seq) { - computeSoftmax(ctx, result, attnMask + b * batchStride + seq * keyLen, keyLen); - result += stride; - } - } - } - } - // Same implementation with softmax, but: // Return max value, and the sum value of exp static std::pair softmaxWithStats(float *data, const float *attnMask, int size, float scale) { From 451ef21affc338088fbe6e8bac324fa038a47d8e Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 08/35] [Common] Refactor sequence.h. (#363) --- src/common/sequence.h | 96 +++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/src/common/sequence.h b/src/common/sequence.h index 53cc84c7..e6b1a686 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -14,11 +14,14 @@ // ============================================================================ #pragma once +#include #include #include -#include "sampling_params.h" #include +#include "environment.h" +#include "sampling_params.h" + /* SequencePool ┌──────┬──────┬──────┐ @@ -39,12 +42,33 @@ */ namespace xft { +// Global sequence ID manager +class SequenceIDManager { +public: + static SequenceIDManager &getInstance() { + static SequenceIDManager instance; + return instance; + } + + int32_t createSequenceID() { + int32_t id = globalSequenceID++; + if (id >= 10 * 1024) { + globalSequenceID = 0; + id = globalSequenceID++; + } + return id; + } + +private: + SequenceIDManager() {} + int32_t globalSequenceID = 0; +}; // The SequenceMeta is one sequence of batch inputs and includes the generated tokens. class SequenceMeta { public: SequenceMeta(std::vector &_inputTokens) - : sequenceID(SequencePool::getInstance().createSequenceID()) + : sequenceID(SequenceIDManager::getInstance().createSequenceID()) , inputSeqLen(_inputTokens.size()) , inputTokens(_inputTokens) , pastSeqLen(0) @@ -54,7 +78,7 @@ class SequenceMeta { } SequenceMeta(int32_t _inputSeqLen) - : sequenceID(SequencePool::getInstance().createSequenceID()) + : sequenceID(SequenceIDManager::getInstance().createSequenceID()) , inputSeqLen(_inputSeqLen) , inputTokens(_inputSeqLen, 0) , pastSeqLen(0) @@ -130,14 +154,14 @@ class SequenceMeta { // For beam searcher class SequenceGroupMeta { public: - SequenceGroupMeta(std::vector &seq, SamplingMeta samplingMeta_) : samplingMeta(samplingMeta_) { + SequenceGroupMeta(std::vector &seq, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { assert(samplingMeta.config.numBeams == seq.size()); sequences.reserve(samplingMeta.config.numBeams); sequences.assign(seq.begin(), seq.end()); groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(std::vector &_inputTokens, SamplingMeta samplingMeta_) : samplingMeta(samplingMeta_) { + SequenceGroupMeta(std::vector &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_inputTokens)); @@ -145,7 +169,7 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(int32_t _inputSeqLen, SamplingMeta samplingMeta_) : samplingMeta(samplingMeta_) { + SequenceGroupMeta(int32_t _inputSeqLen, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_inputSeqLen)); @@ -157,12 +181,17 @@ class SequenceGroupMeta { int32_t getGroupSize() { return samplingMeta.config.numBeams; } + // using 1st sequence'step as group step. + int32_t getStep() { return sequences[0].getStep(); } + SequenceMeta *get() { return sequences.data(); } SequenceMeta *get(int index) { return &sequences[index]; } SequenceMeta &operator[](int index) { return sequences[index]; } + bool isDone() { return samplingMeta.done; } + private: // using 1st sequence ID as group ID. int32_t groupID; @@ -186,49 +215,16 @@ class SequencePool { return instance; } - int32_t createSequenceID() { - int32_t id = globalSequenceID++; - if (id >= 10 * 1024) { - globalSequenceID = 0; - id = globalSequenceID++; - } - return id; - } - - SequenceGroupMeta *newMeta(std::vector &inputTokens, SamplingMeta samplingMeta_) { - std::vector sequences; - sequences.emplace_back(SequenceMeta(inputTokens)); - - auto *group = new SequenceGroupMeta(sequences, samplingMeta_); - return group; - } - - SequenceGroupMeta *newMeta(int32_t inputSeqLen, SamplingMeta samplingMeta_) { - std::vector sequences; - sequences.emplace_back(SequenceMeta(inputSeqLen)); - - auto *group = new SequenceGroupMeta(sequences, samplingMeta_); + // New sequenceGroupMeta will be added into pool. + SequenceGroupMeta *newGroupMeta(std::vector &inputTokens, SamplingMeta &samplingMeta_) { + auto *group = new SequenceGroupMeta(inputTokens, samplingMeta_); + this->add(group); return group; } - SequenceGroupMeta *newGroupMeta(std::vector> &inputTokens, SamplingMeta samplingMeta_) { - std::vector sequences; - for (int i = 0; i < inputTokens.size(); ++i) { - sequences.emplace_back(SequenceMeta(inputTokens[i])); - } - - auto *group = new SequenceGroupMeta(sequences, samplingMeta_); - return group; - } - - SequenceGroupMeta *newGroupMeta(std::vector &inputSeqLens, SamplingMeta samplingMeta_) { - - std::vector sequences; - for (int i = 0; i < inputSeqLens.size(); ++i) { - sequences.emplace_back(SequenceMeta(inputSeqLens[i])); - } - - auto *group = new SequenceGroupMeta(sequences, samplingMeta_); + SequenceGroupMeta *newGroupMeta(int32_t inputSeqLen, SamplingMeta &samplingMeta_) { + auto *group = new SequenceGroupMeta(inputSeqLen, samplingMeta_); + this->add(group); return group; } @@ -333,7 +329,7 @@ class TaskWaitingQueue { bool isFull() { bool full = false; - if (this->size() >= Env::getInstance().getMaxRequestNum()) { full = true; } + if (this->size() >= MaxRequestNum) { full = true; } return full; } @@ -346,9 +342,13 @@ class TaskWaitingQueue { void push(SequenceGroupMeta *seq) { queue.push(seq); } private: - TaskWaitingQueue() {} + TaskWaitingQueue() : MaxRequestNum(Env::getInstance().getMaxRequestNum()) {} std::queue queue; + + int32_t MaxRequestNum; }; +static std::vector workingGroup; + } // namespace xft \ No newline at end of file From 5e98e6ddf14995f8ce2f8283d5b16beabcb5cbb1 Mon Sep 17 00:00:00 2001 From: "Meng,Chen" Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 09/35] [kernels] refactor flash attention for continuous batching (#361) --- src/kernels/attention_kernels.h | 128 +++++++++++++++++++++++++- src/layers/attention.h | 145 +++++++----------------------- src/layers/attn_baichuan.h | 17 ++-- src/utils/decoder_util.h | 153 ++++++++++++++++++++++++++++++-- 4 files changed, 313 insertions(+), 130 deletions(-) diff --git a/src/kernels/attention_kernels.h b/src/kernels/attention_kernels.h index 7459cda5..d247ee89 100644 --- a/src/kernels/attention_kernels.h +++ b/src/kernels/attention_kernels.h @@ -778,4 +778,130 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in } // end for b } -} // namespace xft \ No newline at end of file +// scaled dot-product attention: bmm1 + softmax + bmm2 +// query key value are all in [*, seqLen, headnum, headsize] order +template +void selfScaledDpAttention(T *output, const T *query, const AttnT *key, const AttnT *value, int qHeadNum, int kvHeadNum, + int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *inputSeqLens, + const int *pastSeqLens, bool causal, const float *alibiSlopes, const float *attnMask, const float scale, + int threadNum) { + // output = softmax(query * trans(key)) * value + // causal = True: llama-family, chatglm2; extra alibiSlopes for baichuan + // causal = False: just chatglm (prefixLLM, 0:startid) need attnMask for now + + // get the max seqLen + int maxSrcLen = 0, maxTgtLen = 0; + for (int i = 0; i < batchSize; ++i) { + maxSrcLen = std::max(maxSrcLen, inputSeqLens[i]); + maxTgtLen = std::max(maxTgtLen, inputSeqLens[i] + pastSeqLens[i]); + } + // compute the seqStartLoc + int seqStartLoc[batchSize + 1]; + seqStartLoc[0] = 0; + for (int i = 0; i < batchSize; ++i) { + seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i]; + } + + // closest value of power of 2 + int minBlk = (int)std::pow(2, int(std::log2(maxSrcLen / 2))); + // Split sequence to make sure a moderate sync frequency and the intermediate + // result [srcSeq * tgtSeq] in cache. The current block size is derived from practical experience. + int srcBlk = std::min(256, minBlk); + int tgtBlk = std::min(512, maxTgtLen); + + int numGroup = qHeadNum / kvHeadNum; + + int numArr = 7; + int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk; + float *thrBuf + = (float *)SimpleMemPool::instance().getBuffer("threadBuffers", sizeof(float) * threadNum * arrStride); + float **thrPtrBuf + = (float **)SimpleMemPool::instance().getBuffer("threadPtrBuffers", sizeof(float *) * threadNum * numArr); + + float **preSum = thrPtrBuf; + float **sum = thrPtrBuf + threadNum; + float **preMax = thrPtrBuf + threadNum * 2; + float **max = thrPtrBuf + threadNum * 3; + float **qkArr = thrPtrBuf + threadNum * 4; + float **expQkvArr = thrPtrBuf + threadNum * 5; + float **qArr = thrPtrBuf + threadNum * 6; + + for (int i = 0; i < threadNum; ++i) { + preSum[i] = thrBuf + srcBlk * i; + sum[i] = thrBuf + srcBlk * threadNum + srcBlk * i; + preMax[i] = thrBuf + srcBlk * threadNum * 2 + srcBlk * i; + max[i] = thrBuf + srcBlk * threadNum * 3 + srcBlk * i; + qkArr[i] = thrBuf + srcBlk * threadNum * 4 + srcBlk * tgtBlk * i; + expQkvArr[i] = thrBuf + srcBlk * threadNum * (4 + tgtBlk) + srcBlk * headSize * i; + qArr[i] = thrBuf + srcBlk * threadNum * (4 + tgtBlk + headSize) + srcBlk * headSize * i; + } + +#pragma omp parallel for collapse(3) schedule(dynamic) + for (uint64_t b = 0; b < batchSize; ++b) { + for (int h = 0; h < qHeadNum; ++h) { + for (int m = 0; m < maxSrcLen; m += srcBlk) { + int srcLen = inputSeqLens[b]; + int tgtLen = inputSeqLens[b] + pastSeqLens[b]; + if (m >= srcLen) { continue; } + + int tid = omp_get_thread_num(); + int qRealBlk = std::min(srcBlk, srcLen - m); + uint64_t srcOff = seqStartLoc[b] * qStride + h * headSize; + uint64_t outOff = seqStartLoc[b] * oStride + h * headSize; + const T *qbuf = query + srcOff + m * qStride; + AttnT *q = (AttnT *)qArr[tid]; + T *out = output + outOff + m * oStride; + + // reset out + for (int ii = 0; ii < qRealBlk; ++ii) { +#pragma omp simd + for (int jj = 0; jj < headSize; ++jj) { + out[ii * oStride + jj] = 0; // reset output + q[ii * headSize + jj] = (AttnT)(qbuf[ii * qStride + jj]); // reset output + } + } + // reset sum +#pragma omp simd + for (int ii = 0; ii < qRealBlk; ++ii) { + preSum[tid][ii] = 0; + sum[tid][ii] = 0; + preMax[tid][ii] = std::numeric_limits::lowest(); + max[tid][ii] = std::numeric_limits::lowest(); + } + + uint64_t tgtOff = seqStartLoc[b] * kvStride + (h / numGroup) * headSize; + const AttnT *k = key + tgtOff; + const AttnT *v = value + tgtOff; + // split the target len dimension + for (int n = 0; n < tgtLen; n += tgtBlk) { + int kvRealBlk = std::min(tgtBlk, tgtLen - n); + // mask out. TODO: for prefixLLM + if (causal && m + qRealBlk - 1 < n) { + //printf("Skip bs %d head %d src %d tgt %d\n", b, h, m, n); + break; + } + + const AttnT *kBlk = k + n * kvStride; + const AttnT *vBlk = v + n * kvStride; + + if (causal) { + // causal=True, build-in mask + float headSlope = alibiSlopes != nullptr ? alibiSlopes[h] : 0.0f; + DecoderUtil::incrementalTileAttentionCausal(q, kBlk, vBlk, headSlope, m, n, qRealBlk, headSize, + kvRealBlk, preSum[tid], sum[tid], preMax[tid], max[tid], scale, qkArr[tid], + expQkvArr[tid], out, headSize, kvStride, kvStride, oStride); + } else { + // causal=False, need mask matrix for now + const float *attnMsk = attnMask + seqStartLoc[b] * tgtLen + m * tgtLen + n; + DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk, qRealBlk, headSize, kvRealBlk, + tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], scale, qkArr[tid], expQkvArr[tid], + out, headSize, kvStride, kvStride, oStride); + } + } + } + } + } + return; +} + +} // namespace xft diff --git a/src/layers/attention.h b/src/layers/attention.h index c69ddba5..abff5335 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -67,6 +67,8 @@ class Attention { printf("Not supported yet: QHeads=%d, KVHeads=%d\n", ctx->attHeadNum, ctx->kvHeadNum); exit(-1); } + + alibiSlopes = nullptr; } // The inerface is for PyTorch, thus the weights are already transposed @@ -701,8 +703,14 @@ class Attention { int kvCols = respKVHeads * headSize; int qkvCols = qCols + kvCols * 2; float scale = ctx->attFactor; - int srcLen = ctx->inputSeqLen; - int tgtLen = pastSeqLen + srcLen; + + int totalTokenSize = 0; + int inputSeqLens[batchSize], pastSeqLens[batchSize]; + for (int i = 0; i < batchSize; ++i) { + inputSeqLens[i] = ctx->inputSeqLen; + pastSeqLens[i] = pastSeqLen; + totalTokenSize += inputSeqLens[i]; + } // TODO: kv dtype conversion for prefixSharing AttnT *k, *v; @@ -712,22 +720,21 @@ class Attention { //Timer tmc(true, "convert KV matrix into bf16"); kvStride = kvCols * 2; AttnT *kvBuf = (AttnT *)SimpleMemPool::instance().getBuffer( - "flashKVBuf", batchSize * srcLen * kvStride * sizeof(AttnT)); -#pragma omp parallel for collapse(3) - for (uint64_t b = 0; b < batchSize; ++b) - for (uint64_t seq = 0; seq < srcLen; ++seq) - for (uint64_t i = 0; i < kvCols * 2; i += headSize) { - const ImT *srcPtr = key.Data() + b * srcLen * qkvCols + seq * qkvCols + i; - AttnT *dstPtr = kvBuf + b * srcLen * kvStride + seq * kvStride + i; - if constexpr (std::is_same_v && std::is_same_v) { - bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize); - } else if constexpr (std::is_same_v && std::is_same_v) { - bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize); - } else { - printf("Not supported Type in Flash Attention yet\n"); - exit(-1); - } + "flashKVBuf", totalTokenSize * kvStride * sizeof(AttnT)); +#pragma omp parallel for collapse(2) + for (uint64_t seq = 0; seq < totalTokenSize; ++seq) + for (uint64_t i = 0; i < kvCols * 2; i += headSize) { + const ImT *srcPtr = key.Data() + seq * qkvCols + i; + AttnT *dstPtr = kvBuf + seq * kvStride + i; + if constexpr (std::is_same_v && std::is_same_v) { + bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize); + } else if constexpr (std::is_same_v && std::is_same_v) { + bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize); + } else { + printf("Not supported Type in Flash Attention yet\n"); + exit(-1); } + } k = kvBuf; v = kvBuf + kvCols; @@ -738,109 +745,14 @@ class Attention { } // [batch, src, head, headsize] - scaledDpAttention(query.Data(), k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, - respKVHeads, headSize, result.Data(), query.Stride(), kvStride, result.Stride()); + xft::selfScaledDpAttention(result.Data(), query.Data(), k, v, respQHeads, respKVHeads, headSize, + result.Stride(), query.Stride(), kvStride, batchSize, inputSeqLens, pastSeqLens, true, alibiSlopes, + attnMask, scale, ctx->numThreads); // copy current key/values to cache copyKVCache(ctx, key, value, presentKey, presentValue, pastSeqLen); } - // scaled dot-product attention: bmm1 + softmax + bmm2 - template - void scaledDpAttention(const ImT *query, const AttnT *key, const AttnT *value, const float *attnMask, float scale, - int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, ImT *output, int qStride, - int kvStride, int stride) { - // output = trans(softmax(query * trans(key)) * value) - int nth = omp_get_max_threads(); - // closest value of power of 2 - int minBlk = (int)std::pow(2, int(std::log2(srcLen / 2))); - // Split sequence to make sure a moderate sync frequency and the intermediate - // result [srcSeq * tgtSeq] in cache. The current block size is derived from practical experience. - int srcBlk = std::min(256, minBlk); - int tgtBlk = std::min(512, tgtLen); - float refac = scale; - int numGroup = numQHead / numKVHead; - - int numArr = 7; - int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk; - float *thrBuf = (float *)SimpleMemPool::instance().getBuffer("threadBuffers", sizeof(float) * nth * arrStride); - float **thrPtrBuf - = (float **)SimpleMemPool::instance().getBuffer("threadPtrBuffers", sizeof(float *) * nth * numArr); - - float **preSum = thrPtrBuf; - float **sum = thrPtrBuf + nth; - float **preMax = thrPtrBuf + nth * 2; - float **max = thrPtrBuf + nth * 3; - float **qkArr = thrPtrBuf + nth * 4; - float **expQkvArr = thrPtrBuf + nth * 5; - float **qArr = thrPtrBuf + nth * 6; - - for (int i = 0; i < nth; ++i) { - preSum[i] = thrBuf + srcBlk * i; - sum[i] = thrBuf + srcBlk * nth + srcBlk * i; - preMax[i] = thrBuf + srcBlk * nth * 2 + srcBlk * i; - max[i] = thrBuf + srcBlk * nth * 3 + srcBlk * i; - qkArr[i] = thrBuf + srcBlk * nth * 4 + srcBlk * tgtBlk * i; - expQkvArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk) + srcBlk * headSize * i; - qArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk + headSize) + srcBlk * headSize * i; - } - -#pragma omp parallel for collapse(3) schedule(dynamic) - for (uint64_t i = 0; i < batchSize; ++i) { - for (int j = 0; j < numQHead; ++j) { - for (int m = 0; m < srcLen; m += srcBlk) { - int tid = omp_get_thread_num(); - - int qRealBlk = std::min(srcBlk, srcLen - m); - uint64_t srcOff = i * srcLen * qStride + j * headSize; - uint64_t outOff = i * srcLen * stride + j * headSize; - const ImT *qbuf = query + srcOff + m * qStride; - AttnT *q = (AttnT *)qArr[tid]; - ImT *out = output + outOff + m * stride; - - // reset out - for (int ii = 0; ii < qRealBlk; ++ii) { -#pragma omp simd - for (int jj = 0; jj < headSize; ++jj) { - out[ii * stride + jj] = 0; // reset output - q[ii * headSize + jj] = (AttnT)(qbuf[ii * qStride + jj]); // reset output - } - } - // reset sum -#pragma omp simd - for (int ii = 0; ii < qRealBlk; ++ii) { - preSum[tid][ii] = 0; - sum[tid][ii] = 0; - preMax[tid][ii] = std::numeric_limits::lowest(); - max[tid][ii] = std::numeric_limits::lowest(); - } - - uint64_t tgtOff = i * tgtLen * kvStride + (j / numGroup) * headSize; - const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen; - const AttnT *k = key + tgtOff; - const AttnT *v = value + tgtOff; - // split the target len dimension - for (int b = 0; b < tgtLen; b += tgtBlk) { - int kvRealBlk = std::min(tgtBlk, tgtLen - b); - // TODO: mask out - if (enableSkipMsk() && DecoderUtil::skipMskAttn(attnMsk + b, qRealBlk, kvRealBlk, tgtLen)) { - // printf("Skip bs %d head %d src %d tgt %d\n", i, j, m, b); - break; - } - - const AttnT *kBlk = k + b * kvStride; - const AttnT *vBlk = v + b * kvStride; - - DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk + b, qRealBlk, headSize, kvRealBlk, - tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid], expQkvArr[tid], - out, headSize, kvStride, kvStride, stride); - } - } - } - } - return; - } - private: std::pair getTaskRange(int N, int splits, int splitIdx) { int startId, endId; @@ -906,6 +818,9 @@ class Attention { NORM_CLS norm; int layerId; + // Alibi Slopes + float *alibiSlopes; + // The responsible head in the global view // If in single instance, startQHead=startKVHead=0, and endQHead-startQHead=qHeadNum int startQHead; diff --git a/src/layers/attn_baichuan.h b/src/layers/attn_baichuan.h index ad00a2b2..84cb6b87 100644 --- a/src/layers/attn_baichuan.h +++ b/src/layers/attn_baichuan.h @@ -28,9 +28,9 @@ template { public: BaichuanAttention(int layerId, DecoderContext *ctx) : Attention(layerId, ctx) { - if (ctx->maxPosEmbed <= 0 && alibiSlopes == nullptr) { + if (ctx->maxPosEmbed <= 0 && this->alibiSlopes == nullptr) { respBaichuanHeads = this->endQHead - this->startQHead; - alibiSlopes = new float[respBaichuanHeads]; + this->alibiSlopes = new float[respBaichuanHeads]; // alibi mask element float ratio = std::pow(2, 8); int closestPowerOf2 = std::pow(2, int(std::log2(ctx->attHeadNum))); @@ -38,10 +38,11 @@ class BaichuanAttention : public Attention { float x1 = std::pow(ratio, 1.0 / (closestPowerOf2 * 2)); for (int i = 0, h = this->startQHead; i < respBaichuanHeads; ++i, ++h) { if (h < closestPowerOf2) - alibiSlopes[i] = 1 / std::pow(x0, h + 1); + this->alibiSlopes[i] = 1 / std::pow(x0, h + 1); else - alibiSlopes[i] = 1 / std::pow(x1, 2 * (h - closestPowerOf2) + 1); + this->alibiSlopes[i] = 1 / std::pow(x1, 2 * (h - closestPowerOf2) + 1); } + alibiSlopes = this->alibiSlopes; } } @@ -50,15 +51,15 @@ class BaichuanAttention : public Attention { const static int getResponsibleHeads() { return respBaichuanHeads; } virtual ~BaichuanAttention() { - if (alibiSlopes != nullptr) { - delete[] alibiSlopes; - alibiSlopes = nullptr; + if (this->alibiSlopes != nullptr) { + delete[] this->alibiSlopes; + this->alibiSlopes = nullptr; } } protected: const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) override { - if (alibiSlopes != nullptr) + if (this->alibiSlopes != nullptr) return attnMask + hId * srcLen * tgtLen; else return attnMask + bId * srcLen * tgtLen; diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 835dba48..08289a16 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -532,10 +532,10 @@ class DecoderUtil { // need to do for res. template - static void softmaxTile(float *AB, ImT *ABout, float *sum, float *max, float *preSum, float *preMax, float refac, + static void softmaxTile(float *AB, ImT *ABout, float *sum, float *max, float *preSum, float *preMax, float scale, const float *attnMask, int m, int k, int attnMskStride) { float maxVal = std::numeric_limits::lowest(); - __m512 vrefac = _mm512_set1_ps(refac); + __m512 vscale = _mm512_set1_ps(scale); for (int i = 0; i < m; ++i) { float *buf = AB + i * k; ImT *obuf = ABout + i * k; @@ -548,7 +548,7 @@ class DecoderUtil { __m512 vx = xft::load_avx512(mask, buf + off); __m512 vmask = xft::load_avx512(mask, attnMsk + off); - vmax = _mm512_mask_max_ps(vmax, mask, vmax, vx * vrefac + vmask); + vmax = _mm512_mask_max_ps(vmax, mask, vmax, vx * vscale + vmask); } float _max = _mm512_reduce_max_ps(vmax); @@ -566,7 +566,7 @@ class DecoderUtil { __m512 vx = xft::load_avx512(mask, buf + off); __m512 vmask = xft::load_avx512(mask, attnMsk + off); - vx = BertUtil::vexp(vx * vrefac + vmask - vmax); + vx = BertUtil::vexp(vx * vscale + vmask - vmax); xft::store_avx512(obuf + off, mask, vx); @@ -591,6 +591,135 @@ class DecoderUtil { } } + template + static void alibiSoftmax(ImT *buf, float scale, float headSlope, int elements) { + float maxVal = std::numeric_limits::lowest(); + __m512 vpos = _mm512_set_ps(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + __m512 vmax = _mm512_set1_ps(maxVal); + __m512 vscale = _mm512_set1_ps(scale); + for (int off = 0; off < elements; off += 16) { + int remain = elements - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + __m512 vx = xft::load_avx512(mask, buf + off); + // compute avx512 var vmask that is pos * alibiSlopes[hidx] + __m512 vpositions = _mm512_add_ps(vpos, _mm512_set1_ps(off)); + __m512 vmask = _mm512_mul_ps(vpositions, _mm512_set1_ps(headSlope)); + vmax = _mm512_mask_max_ps(vmax, mask, vmax, vx * vscale + vmask); + } + float _max = _mm512_reduce_max_ps(vmax); + + // exp and get sum + __m512 vsum = _mm512_set1_ps(0); + vmax = _mm512_set1_ps(_max); + for (int off = 0; off < elements; off += 16) { + int remain = elements - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 vx = xft::load_avx512(mask, buf + off); + // compute avx512 var vmask that is pos * alibiSlopes[hidx] + __m512 vpositions = _mm512_add_ps(vpos, _mm512_set1_ps(off)); + __m512 vmask = _mm512_mul_ps(vpositions, _mm512_set1_ps(headSlope)); + vx = BertUtil::vexp(vx * vscale + vmask - vmax); + + xft::store_avx512(buf + off, mask, vx); + + vsum = _mm512_mask_add_ps(vsum, mask, vsum, vx); + } + float _sum = _mm512_reduce_add_ps(vsum); + + // Compute exp/sum(exp) and store + __m512 vrsum = _mm512_set1_ps(1.0f / _sum); + for (int off = 0; off < elements; off += 16) { + int remain = elements - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 vx = xft::load_avx512(mask, buf + off); + vx = vx * vrsum; + + xft::store_avx512(buf + off, mask, vx); + } + } + + template + static void softmaxTileCausal(float *AB, ImT *ABout, float *sum, float *max, float *preSum, float *preMax, + float scale, float headSlope, int qLoc, int kLoc, int tRows, int tCols) { + // build-in mask softmax computing + float maxVal = std::numeric_limits::lowest(); + __m512 vscale = _mm512_set1_ps(scale); + + __m512 vpos = _mm512_set_ps(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + vpos = _mm512_add_ps(vpos, _mm512_set1_ps(kLoc)); + + for (int i = 0; i < tRows; ++i) { + float *buf = AB + i * tCols; + ImT *obuf = ABout + i * tCols; + int k = qLoc + i + 1 - kLoc; + k = std::max(k, 0); + k = std::min(k, tCols); + // max val for avoiding inf and nan + __m512 vmax = _mm512_set1_ps(maxVal); + for (int off = 0; off < k; off += 16) { + int remain = k - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + __m512 vx = xft::load_avx512(mask, buf + off); + + if (headSlope != 0) { + // compute avx512 var vmask that is pos * alibiSlopes[hidx] + __m512 vpositions = _mm512_add_ps(vpos, _mm512_set1_ps(off)); + __m512 vmask = _mm512_mul_ps(vpositions, _mm512_set1_ps(headSlope)); + vmax = _mm512_mask_max_ps(vmax, mask, vmax, vx * vscale + vmask); + } else { + vmax = _mm512_mask_max_ps(vmax, mask, vmax, vx * vscale); + } + } + float _max = _mm512_reduce_max_ps(vmax); + + _max = _max > max[i] ? _max : max[i]; + __m512 merr = _mm512_set1_ps(max[i] - _max); + merr = BertUtil::vexp(merr); + max[i] = _max; + + // exp and get sum + __m512 vsum = _mm512_set1_ps(0); + vmax = _mm512_set1_ps(_max); + for (int off = 0; off < k; off += 16) { + int remain = k - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 vx = xft::load_avx512(mask, buf + off); + if (headSlope != 0) { + // compute avx512 var vmask that is pos * alibiSlopes[hidx] + __m512 vpositions = _mm512_add_ps(vpos, _mm512_set1_ps(off)); + __m512 vmask = _mm512_mul_ps(vpositions, _mm512_set1_ps(headSlope)); + vx = BertUtil::vexp(vx * vscale + vmask - vmax); + } else { + vx = BertUtil::vexp(vx * vscale - vmax); + } + + xft::store_avx512(obuf + off, mask, vx); + + vsum = _mm512_mask_add_ps(vsum, mask, vsum, vx); + } + float _sum = _mm512_reduce_add_ps(vsum); + float fac = _mm512_cvtss_f32(merr); + sum[i] = sum[i] * fac + _sum; + _sum = sum[i]; + + // Compute exp/sum(exp) and store + __m512 vrsum = _mm512_set1_ps(1.0f / _sum); + for (int off = 0; off < k; off += 16) { + int remain = k - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 vx = xft::load_avx512(mask, obuf + off); + vx = vx * vrsum; + + xft::store_avx512(obuf + off, mask, vx); + } + if (tCols > k) { memset(obuf + k, 0, (tCols - k) * sizeof(ImT)); } + } + } + template static void updateOutTile(T *output, const float *expABC, float *preSum, float *sum, float *preMax, float *max, int m, int n, int stride) { @@ -619,11 +748,23 @@ class DecoderUtil { // preSum = sum template static void incrementalTileAttention(const T *A, const T *B, const T *C, const float *attnMask, int m, int n, int k, - int attnMskStride, float *preSum, float *sum, float *preMax, float *max, float refac, float *AB, + int attnMskStride, float *preSum, float *sum, float *preMax, float *max, float scale, float *AB, float *expABC, ImT *output, int qStride, int kStride, int vStride, int stride) { sgemm(A, B, AB, m, k, n, qStride, kStride, k, false, true); // TODO:optimize - softmaxTile(AB, (T *)AB, sum, max, preSum, preMax, refac, attnMask, m, k, attnMskStride); + softmaxTile(AB, (T *)AB, sum, max, preSum, preMax, scale, attnMask, m, k, attnMskStride); + + sgemm((T *)AB, C, expABC, m, n, k, k, vStride, n, false, false); + updateOutTile(output, expABC, preSum, sum, preMax, max, m, n, stride); + } + + template + static void incrementalTileAttentionCausal(const T *A, const T *B, const T *C, float headSlope, int srcLoc, + int tgtLoc, int m, int n, int k, float *preSum, float *sum, float *preMax, float *max, float scale, + float *AB, float *expABC, ImT *output, int qStride, int kStride, int vStride, int stride) { + sgemm(A, B, AB, m, k, n, qStride, kStride, k, false, true); + // TODO:optimize + softmaxTileCausal(AB, (T *)AB, sum, max, preSum, preMax, scale, headSlope, srcLoc, tgtLoc, m, k); sgemm((T *)AB, C, expABC, m, n, k, k, vStride, n, false, false); updateOutTile(output, expABC, preSum, sum, preMax, max, m, n, stride); From 2b5e266256b529d2e0140c61a8f750062d56be03 Mon Sep 17 00:00:00 2001 From: "Meng,Chen" Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 10/35] [models] Add attnMeta for continuous batching (#364) --- src/common/attn_metadata.h | 70 +++++++++++++++++++++++++++++++++++++ src/layers/attention.h | 2 +- src/models/common_decoder.h | 2 ++ 3 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 src/common/attn_metadata.h diff --git a/src/common/attn_metadata.h b/src/common/attn_metadata.h new file mode 100644 index 00000000..bc2e6031 --- /dev/null +++ b/src/common/attn_metadata.h @@ -0,0 +1,70 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +class AttnMetaData { + +public: + AttnMetaData () : batchSize(0), attnMask(nullptr) {} + + AttnMetaData (int batchSize, int *inputTokenSizes, int *pastTokenSizes, bool isPrompt, bool isCausal, float *attnMask = nullptr) + : batchSize(batchSize), isPrompt(isPrompt), isCausal(isCausal), attnMask(attnMask) { + // causal=True, no need mask + assert(isCausal && attnMask == nullptr) + // causal=False, need mask + assert(!isCausal && attnMask) + + // fill inputSeqLens, pastSeqLens, seqStartLoc + inputSeqLens.resize(batchSize); + pastSeqLens.resize(batchSize); + seqStartLoc.resize(batchSize + 1); + + seqStartLoc[0] = 0; + for (int i = 0; i < batchSize; i++) { + inputSeqLens[i] = inputTokenSizes[i]; + pastSeqLens[i] = pastTokenSizes[i]; + seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i]; + } + + AttnMetaData (vector &inputTokens, vector &pastTokens, bool isPrompt, bool isCausal, float *attnMask = nullptr) + : batchSize(inputTokenSizes.size()), isPrompt(isPrompt), isCausal(isCausal), attnMask(attnMask), + inputSeqLens(inputTokenSizes), pastSeqLens(pastTokenSizes){ + // causal=True, no need mask + assert(isCausal && attnMask == nullptr) + // causal=False, need mask + assert(!isCausal && attnMask) + + // fill seqStartLoc + seqStartLoc.resize(batchSize + 1); + + seqStartLoc[0] = 0; + for (int i = 0; i < batchSize; i++) { + seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i]; + } + + } + +private: + bool isPrompt; + bool isCausal; + + int batchSize; + std::vector inputSeqLens; + std::vector pastSeqLens; + std::vector seqStartLoc; + + float *attnMask; + +}; diff --git a/src/layers/attention.h b/src/layers/attention.h index abff5335..635d649c 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -509,7 +509,7 @@ class Attention { void softmax(DecoderContext *ctx, T1 *score, const T2 *mask, int rows, int cols, int lds, int startSeq) { const int keyLen = cols; for (int seq = 0; seq < rows; ++seq) { - DecoderUtil::computeSoftmax(ctx, score + seq * lds, mask + (seq + startSeq) * keyLen, keyLen); + DecoderUtil::computeSoftmax(score + seq * lds, mask + (seq + startSeq) * keyLen, keyLen, ctx->attFactor); } } diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 0d5bd85a..e040a637 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -333,6 +333,7 @@ class CommonDecoder : public AbstractDecoder { // Prepare attention mask this->prepareAttnMask(ids, step + this->prefixSharing); + // prepareAttnMeta // Token position ids, note: different models may have different impl. int *positionIds = this->getPositionIds(ids, batchSize, inputSeqLen, step + this->prefixSharing); @@ -392,6 +393,7 @@ class CommonDecoder : public AbstractDecoder { // Pls be noted: in attention, 'outBuf' is used as imtermediate buffer, 'tmpBuf' is used as output AttnOutT *attnOut = (AttnOutT *)(this->getContext()->tmpBuf.Data()); + // attnMeta (inputSeqLens, pastSeqLens, seqStartLoc, is_prompt(useSelfAttn), causal, attnMask) this->decoders[i]->forwardAttention(getContext(), embBuf, outBuf, attnOut, attnMask, presentKey, // presentKey, presentValue, // presentValue, From e12ffa871609e787efe43fba9cc28f495f1048cf Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 11/35] [Model] add interface for seq meta. (#366) --- include/abstract_decoder.h | 3 + include/models.h | 10 +- src/models/common_decoder.h | 15 +++ src/models/hybrid_model.h | 6 ++ src/models/models.cpp | 184 ++++++++++++++++++++++++++++---- src/searchers/sampling_params.h | 16 ++- 6 files changed, 209 insertions(+), 25 deletions(-) diff --git a/include/abstract_decoder.h b/include/abstract_decoder.h index 4cf84148..ca8132db 100644 --- a/include/abstract_decoder.h +++ b/include/abstract_decoder.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include "sequence.h" class DecoderContext; class Messenger; @@ -37,6 +38,8 @@ class AbstractDecoder { // |<----------------------- vocabSize ----------------------------->| virtual std::tuple forward(int *ids, int64_t *dims, int step, bool logits_all = false) = 0; + virtual std::tuple forward(std::vector &seq, bool logits_all = false) = 0; + // Reorder cached keys and values, size=batchSize*beamSize virtual void reorderCache(int *idx, int size) = 0; diff --git a/include/models.h b/include/models.h index f850806b..5086fe1e 100644 --- a/include/models.h +++ b/include/models.h @@ -36,9 +36,17 @@ class Model { void config(SearcherConfig &config_, const std::vector> &stopWordsList_ = {}); + void set_input(std::vector &inputIds_, int batchSize_, int maxLen_ = -1, int numBeams_ = 1, + int numBeamHypsToKeep_ = 1, float lenPenalty_ = 1.0, bool doEarlyStopping_ = false, int eosTokenId_ = -1, + int padTokenId_ = -1, bool doSample_ = false, float temperature_ = 1.0, int topK_ = 50, float topP_ = 1.0, + float repetitionPenalty_ = 1.0, const std::vector> &stopWordsList_ = {}); + + void set_input(std::vector &inputIds_, int batchSize_, SearcherConfig &config_, + const std::vector> &stopWordsList_ = {}); + bool isDone(); - std::tuple forward(); + std::tuple forward(bool logits_all = true); std::vector generate(); diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index e040a637..ba84edc2 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -524,6 +524,21 @@ class CommonDecoder : public AbstractDecoder { finalOut, this->predictor->getSplitOffset(), this->predictor->getSplitSize()); } + std::tuple forward(std::vector &seqs, bool logitsAll = false) { + // Assume all sequences are all prompts(step==0) or all decodes(step>0) + // Assume input has been synced with master in higher level. + TimeLine t("Decoder.forward"); + TimeLine t1("Decoder.embedding"); + + int batchSize = seqs.size(); + int userSideBS = seqs.size(); + int step = seqs[0]->getStep(); + + // TODO + throw std::logic_error("Method not implemented"); + return std::tuple(nullptr, 0, 0); + } + void setPrefix(int *ids, int seqLen) { this->prefixSharing = true; this->prefixSeqLen = seqLen; diff --git a/src/models/hybrid_model.h b/src/models/hybrid_model.h index 5d1e0edc..03dfa7a9 100644 --- a/src/models/hybrid_model.h +++ b/src/models/hybrid_model.h @@ -72,6 +72,12 @@ class HybridModel : public AbstractDecoder { } } + // TODO + std::tuple forward(std::vector &seq, bool logits_all = false) { + throw std::logic_error("Method not implemented"); + return std::make_tuple(nullptr, 0, 0); + } + void reorderCache(int *idx, int size) { return firstModel->reorderCache(idx, size); } DecoderContext *getContext() { return firstModel->getContext(); } diff --git a/src/models/models.cpp b/src/models/models.cpp index dee1ab4e..b429ec72 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -31,9 +31,9 @@ #include "qwen.h" #include "qwen2.h" #include "searcher.h" +#include "sequence.h" #include "timeline.h" #include "yarn_llama.h" -#include "sequence.h" namespace xft { enum class GenerationMode { GREEDY_SEARCH, BEAM_SEARCH, SAMPLE }; @@ -71,6 +71,7 @@ void Model::exitSlaves() { } } +// TODO: deprecate the following function void Model::input(std::vector &inputIds_, int batchSize_) { isNewInput = true; Messenger &messenger = decoder->getMessenger(); @@ -88,6 +89,7 @@ void Model::input(std::vector &inputIds_, int batchSize_) { messenger.broadcast(inputIds.data(), dims[1]); } +// TODO: deprecate the following function void Model::config(int maxLen_, int numBeams_, int numBeamHypsToKeep_, float lenPenalty_, bool doEarlyStopping_, int eosTokenId_, int padTokenId_, bool doSample_, float temperature_, int topK_, float topP_, float repetitionPenalty_, const std::vector> &stopWordsList_) { @@ -107,6 +109,118 @@ void Model::config(int maxLen_, int numBeams_, int numBeamHypsToKeep_, float len this->config(configuration, stopWordsList_); } +void syncStopWordsList(std::vector> &stopWordsList) { + Messenger &messenger = Messenger::getInstance(); + + int listSize = stopWordsList.size(); + messenger.broadcast(&listSize, 1); + // If stopWordsList is empty, stop broadcasting and return. + if (listSize == 0) { return; } + + vector wordsSize(listSize); + if (messenger.getRank() == 0) { + for (int i = 0; i < listSize; i++) { + wordsSize[i] = stopWordsList[i].size(); + } + } + messenger.broadcast(wordsSize.data(), listSize); + + int wordsDataLen = 0; + for (auto x : wordsSize) { + wordsDataLen += x; + } + + // flatten to 1-D vector + vector wordsData(wordsDataLen); + if (messenger.getRank() == 0) { + int currentIndex = 0; + for (const auto &words : stopWordsList) { + std::copy(words.begin(), words.end(), wordsData.begin() + currentIndex); + currentIndex += words.size(); + } + } + messenger.broadcast(wordsData.data(), wordsDataLen); + + if (messenger.getRank() != 0) { + // restore stop words list to 2-D vector + std::vector> restoredList; + int currentIndex = 0; + for (int i = 0; i < wordsSize.size(); ++i) { + int size = wordsSize[i]; + std::vector subVector(wordsData.begin() + currentIndex, wordsData.begin() + currentIndex + size); + currentIndex += size; + restoredList.emplace_back(subVector); + } + } +} + +void Model::set_input(std::vector &inputIds_, int batchSize_, int maxLen_, int numBeams_, + int numBeamHypsToKeep_, float lenPenalty_, bool doEarlyStopping_, int eosTokenId_, int padTokenId_, + bool doSample_, float temperature_, int topK_, float topP_, float repetitionPenalty_, + const std::vector> &stopWordsList_) { + configuration.maxLen = maxLen_; + configuration.numBeams = numBeams_; + configuration.numBeamHypsToKeep = numBeamHypsToKeep_; + configuration.lenPenalty = lenPenalty_; + configuration.doEarlyStopping = doEarlyStopping_; + configuration.eosTokenId = eosTokenId_; + configuration.padTokenId = padTokenId_; + configuration.doSample = doSample_; + configuration.temperature = temperature_; + configuration.topK = topK_; + configuration.topP = topP_; + configuration.repetitionPenalty = repetitionPenalty_; + + this->set_input(inputIds_, batchSize_, configuration, stopWordsList_); +} + +void Model::set_input(std::vector &inputIds_, int batchSize_, SearcherConfig &config_, + const std::vector> &stopWordsList_) { + // TODO: remove new_input flag + if (config_.eosTokenId == -1) { config_.eosTokenId = decoder->getEndId(); } + if (config_.padTokenId == -1) { config_.padTokenId = config_.eosTokenId; } + SamplingMeta samplingMeta(config_, stopWordsList_); + + Messenger &messenger = Messenger::getInstance(); + if (isMaster()) { inputIds = inputIds_; } + + // Sync input and sampling param in distributed mode. + if (messenger.getSize() > 1) { + // [batch size, inputIds size] + int dims[2]; + if (isMaster()) { + dims[0] = batchSize_; + dims[1] = inputIds_.size(); + } + messenger.broadcast(dims, 2); + batchSize = dims[0]; + seqLen = dims[1] / batchSize; + + inputIds.resize(dims[1]); + messenger.broadcast(inputIds.data(), dims[1]); + + messenger.broadcast((int *)&samplingMeta.config, sizeof(SearcherConfig) / sizeof(int)); + + syncStopWordsList(samplingMeta.stopWordsList); + } else { + batchSize = batchSize_; + seqLen = inputIds_.size() / batchSize_; + } + + SequencePool &seqPool = SequencePool::getInstance(); + InputQueue &inputQueue = InputQueue::getInstance(); + for (int i = 0; i < batchSize; i++) { + auto group = seqPool.newGroupMeta(inputIds, samplingMeta); + inputQueue.push(group); + } + + xft::workingGroup.clear(); + while (!inputQueue.empty()) { + xft::workingGroup.push_back(inputQueue.pop()); + } +} + +// TODO: Deprecate the following function void Model::config(SearcherConfig &config_, const std::vector> &stopWordsList_) { isNewInput = true; if (decoder->getRank() == 0) { configuration = config_; } @@ -121,33 +235,63 @@ void Model::config(SearcherConfig &config_, const std::vector> } bool Model::isDone() { - if (searcher == nullptr || inputIds.empty()) { - printf("Please set input and config first.\n"); - exit(-1); + // TODO: Deprecate the following Path + if (searcher != nullptr) { + if (inputIds.empty()) { + printf("Please set input and config first.\n"); + exit(-1); + } + return !isNewInput && searcher->isDone(); } - return !isNewInput && searcher->isDone(); -} - -std::tuple Model::forward() { - int64_t dims[3] = {batchSize, 1, seqLen}; - return decoder->forward(inputIds.data(), dims, 0, true); + for (auto x : xft::workingGroup) { + if (!x->isDone()) { return false; } + } + return true; } -std::vector Model::generate() { - if (inputIds.empty()) { - printf("Please set input tokens by model.input().\n"); - exit(-1); +std::tuple Model::forward(bool logits_all) { + // TODO: Deprecate the following Path + if (searcher != nullptr) { + int64_t dims[3] = {batchSize, 1, seqLen}; + return decoder->forward(inputIds.data(), dims, 0, logits_all); } - if (searcher == nullptr) { - printf("Please set generation config by model.config().\n"); + // TODO: checking waiting queue + if (workingGroup.empty()) { + printf("Please input prompt first.\n"); exit(-1); } + // Assume that all sequences in the group are all prompts or all decodes. + // Prepare input data for the decoder. + std::vector workingSeqs; + for (auto x : workingGroup) { + workingSeqs.push_back(x->get(0)); + if (x->getGroupSize() > 1 && x->getStep() > 1) { + for (int32_t i = 1; i < x->getGroupSize(); i++) { + workingSeqs.push_back(x->get(i)); + } + } + } + + return decoder->forward(workingSeqs, logits_all); +} - if (isNewInput) { - isNewInput = false; - return searcher->getNextToken(inputIds.data(), batchSize, inputIds.size() / batchSize); +std::vector Model::generate() { + // TODO: Deprecate the following Path + if (searcher != nullptr) { + if (inputIds.empty()) { + printf("Please set input tokens by model.input().\n"); + exit(-1); + } + if (isNewInput) { + isNewInput = false; + return searcher->getNextToken(inputIds.data(), batchSize, inputIds.size() / batchSize); + } else { + return searcher->getNextToken(); + } } else { - return searcher->getNextToken(); + // TODO + throw std::logic_error("Method not implemented"); + return {}; } } diff --git a/src/searchers/sampling_params.h b/src/searchers/sampling_params.h index 1c298f01..68a47127 100644 --- a/src/searchers/sampling_params.h +++ b/src/searchers/sampling_params.h @@ -24,12 +24,20 @@ struct SamplingMeta { std::vector cachedRepetVec; SearcherConfig config; - SamplingMeta() { done = false; } + SamplingMeta() : done(false) {} SamplingMeta(SearcherConfig config, std::vector> stopWordsList_) - : config(config), stopWordsList(stopWordsList_) { - done = false; - // TODO: stopWordsIndex is not initialized + : done(false), config(config), stopWordsList(stopWordsList_) { + // Remove empty words, eos id, and words containing non-positive elements. + for (auto it = stopWordsList.rbegin(); it != stopWordsList.rend(); ++it) { + if ((*it).empty() || ((*it).size() == 1 && (*it)[0] == config.eosTokenId)) { + stopWordsList.erase(std::next(it).base()); + continue; + } + for (auto x : *it) { + if (x <= 0) { stopWordsList.erase(std::next(it).base()); } + } + } } }; From a4442f02ec11788f5e797243877c3df8cd1aace1 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 12/35] [Common] Modify resize() in DecoderContext to support (#367) --- src/common/transformer_ctx.h | 47 +++++++++++++++--------------------- src/layers/attention.h | 11 +++------ 2 files changed, 23 insertions(+), 35 deletions(-) diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 0c1d77cf..f99cd486 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -106,8 +106,6 @@ struct DecoderContext { // # of thread int numThreads; - float *qkScores; // attention score - // Please look into the comments in resize function to see how buffers are arranged xft::Matrix normBuf; // buf for the first layer norm xft::Matrix tmpBuf; // tmp buffer, same size as output @@ -262,43 +260,29 @@ struct DecoderContext { // Resize to make sure the buffer is big enough // |---------|---------|--------| - // | normBuf |qkvMatMul|qkScores| + // | normBuf |qkvMatMul| | // | | imOut | tmpBuf | - void resize(int batchSize, int inputSeqLen, bool preSeqLen) { - this->batchSize = batchSize; - this->inputSeqLen = inputSeqLen; - + void resize(int totalInSeqLen) { // Check total required size - const int pad = 0; // 4; - int hiddenStride = (hiddenSize % 512 == 0 ? hiddenSize + pad - : hiddenSize); // stride for matrix with columns of hiddenSize int responsibleHead = splitIdx < (attHeadNum % numSplit) ? (attHeadNum / numSplit + 1) : (attHeadNum / numSplit); int qCols = responsibleHead * attHeadSize; int kCols = qCols / (attHeadNum / kvHeadNum); int vCols = kCols; int qkvCols = qCols + kCols + vCols; - int qkvStride = (qkvCols % 512 == 0 ? qkvCols + pad : qkvCols); // stride for the concated QKV int mlpFactor = (this->actType == GELU || this->actType == SILU || this->actType == SWIGLU) ? 2 : 1; auto range = SplitUtil::getTaskRange(intermediateSize, numSplit, splitIdx); int imCols = range.second - range.first; - int imStride = (imCols % 512 == 0 ? imCols + pad : imCols); // stride for intermediate output - - uint64_t normSize = (uint64_t)batchSize * inputSeqLen * hiddenStride; - uint64_t qkvSize = (uint64_t)batchSize * inputSeqLen * qkvStride; - uint64_t imOutSize = (uint64_t)batchSize * inputSeqLen * imStride * mlpFactor; - int presentSeqLen = preSeqLen + 1; - int paddedSize = (presentSeqLen + 15) / 16 * 16; + uint64_t normSize = (uint64_t)batchSize * inputSeqLen * hiddenSize; + uint64_t qkvSize = (uint64_t)batchSize * inputSeqLen * qkvCols; + uint64_t imOutSize = (uint64_t)batchSize * inputSeqLen * imCols * mlpFactor; - // Note: the score buffer for first token generation is not padded - uint64_t scoreBufSize = preSeqLen > 0 ? (uint64_t)batchSize * responsibleHead * inputSeqLen * paddedSize - : (uint64_t)batchSize * responsibleHead * inputSeqLen * inputSeqLen; - uint64_t tmpBufSize = (uint64_t)batchSize * inputSeqLen * hiddenStride; + uint64_t tmpBufSize = (uint64_t)batchSize * inputSeqLen * hiddenSize; size1 = normSize; size2 = qkvSize < imOutSize ? imOutSize : qkvSize; - size3 = tmpBufSize < scoreBufSize ? scoreBufSize : tmpBufSize; + size3 = tmpBufSize; uint64_t total = size1 + size2 + size3; if (total > this->rawBufSize) { @@ -310,11 +294,18 @@ struct DecoderContext { } // Assign the buffer - this->qkScores = this->rawBuffer + size1 + size2; - normBuf.Assign(this->rawBuffer, batchSize * inputSeqLen, hiddenSize, hiddenStride); - tmpBuf.Assign(this->qkScores, batchSize * inputSeqLen, hiddenSize, hiddenStride); - imOut.Assign(this->rawBuffer + size1, batchSize * inputSeqLen, imCols, imStride); - qkvMatMul.Assign(this->rawBuffer + size1, batchSize * inputSeqLen, qkvCols, qkvStride); + normBuf.Assign(this->rawBuffer, batchSize * inputSeqLen, hiddenSize, hiddenSize); + tmpBuf.Assign(this->rawBuffer + size1 + size2, batchSize * inputSeqLen, hiddenSize, hiddenSize); + imOut.Assign(this->rawBuffer + size1, batchSize * inputSeqLen, imCols, imCols); + qkvMatMul.Assign(this->rawBuffer + size1, batchSize * inputSeqLen, qkvCols, qkvCols); + } + + // TODO: deprecate it + void resize(int batchSize, int inputSeqLen, bool preSeqLen) { + this->batchSize = batchSize; + this->inputSeqLen = inputSeqLen; + + this->resize(inputSeqLen * batchSize); } uint64_t getScoreCapacity() { diff --git a/src/layers/attention.h b/src/layers/attention.h index 635d649c..75ddeb5f 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -588,13 +588,10 @@ class Attention { // How many blocks in M dimension int mBlockNum = (ctx->inputSeqLen + mBlockSize - 1) / mBlockSize; - // To get score buffer according to openmp thread ID or not (see below) - float *scoreBuf = ctx->qkScores; + // To get score buffer according to openmp thread num int scoreStride = pastSeqLen > 0 ? (pastSeqLen + ctx->inputSeqLen + 15) / 16 * 16 : ctx->inputSeqLen; auto bufSizeRequired = ctx->numThreads * mBlockSize * scoreStride; - if (bufSizeRequired > ctx->getScoreCapacity()) { - scoreBuf = (float *)SimpleMemPool::instance().getBuffer("scoreBuf", sizeof(float) * bufSizeRequired); - } + float *scoreBuf = (float *)SimpleMemPool::instance().getBuffer("scoreBuf", sizeof(float) * bufSizeRequired); #pragma omp parallel for collapse(3) for (int b = 0; b < batchSize; ++b) { @@ -627,7 +624,7 @@ class Attention { #ifdef DEBUG if (b == 0 && i == 0) { dbg.debugPrint("Q * K, first head:\n"); - auto p = ctx->qkScores; + auto p = scoreBuf; dbg.debugPrint("%f, %f, %f ... %f %f %f\n", p[0] * ctx->attFactor, p[1] * ctx->attFactor, p[2] * ctx->attFactor, p[n - 3] * ctx->attFactor, p[n - 2] * ctx->attFactor, p[n - 1] * ctx->attFactor); @@ -640,7 +637,7 @@ class Attention { #ifdef DEBUG if (b == 0 && i == 0) { dbg.debugPrint("Softmax(Q * K), first head:\n"); - auto p = ctx->qkScores; + auto p = scoreBuf; dbg.debugPrint("%f, %f, %f ... %f %f %f\n", p[0], p[1], p[2], p[keyLen - 3], p[keyLen - 2], p[keyLen - 1]); } From fb525947f9b6cd0811c494429fc729a74dc069f4 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 13/35] [Model] New CommonDecoder::forward impl. skeleton (#369) --- src/common/sequence.h | 70 +++++++++++++++++++++---------------- src/models/common_decoder.h | 64 ++++++++++++++++++++++++++++++--- 2 files changed, 98 insertions(+), 36 deletions(-) diff --git a/src/common/sequence.h b/src/common/sequence.h index e6b1a686..24b54bd5 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -67,63 +67,71 @@ class SequenceIDManager { // The SequenceMeta is one sequence of batch inputs and includes the generated tokens. class SequenceMeta { public: - SequenceMeta(std::vector &_inputTokens) + SequenceMeta(std::vector &_promptTokens) : sequenceID(SequenceIDManager::getInstance().createSequenceID()) - , inputSeqLen(_inputTokens.size()) - , inputTokens(_inputTokens) + , inputSeqLen(_promptTokens.size()) , pastSeqLen(0) - , step(0) { - nextTokens.reserve(inputSeqLen); - setPastSeqLen(getPastSeqLen()); - } + , promptTokens(_promptTokens) + , step(0) {} SequenceMeta(int32_t _inputSeqLen) : sequenceID(SequenceIDManager::getInstance().createSequenceID()) , inputSeqLen(_inputSeqLen) - , inputTokens(_inputSeqLen, 0) , pastSeqLen(0) - , step(0) { - nextTokens.reserve(_inputSeqLen); - } + , promptTokens(_inputSeqLen, 0) + , step(0) {} ~SequenceMeta() {} int32_t getSequenceID() const { return sequenceID; } - // For first tokens - void stepForward() { + // Step forward given the generated token ID + void stepForward(int32_t genToken) { + inputSeqLen = 1; if (getStep() == 0) { - setPastSeqLen(inputTokens.size()); - setStep(getStep() + 1); + setPastSeqLen(promptTokens.size()); + } else { + setPastSeqLen(getPastSeqLen() + 1); } + addNextToken(genToken); + setStep(getStep() + 1); } - // For next token - void stepForward(int32_t token) { - addNextToken(token); - setPastSeqLen(getPastSeqLen() + 1); + // Step forward given the candidate token IDs (for verification) + void stepForward(const std::vector &candidateIDs) { + inputSeqLen = candidateIDs.size(); + if (getStep() == 0) { + setPastSeqLen(promptTokens.size()); + } else { + setPastSeqLen(getPastSeqLen() + 1); + } + generatedTokens.insert(generatedTokens.end(), candidateIDs.begin(), candidateIDs.end()); setStep(getStep() + 1); } - // Get the input tokens in sequence + // Get current input sequence length int32_t getInputSeqLen() const { return inputSeqLen; } - const int32_t *getInputTokens() const { return inputTokens.data(); } + const std::vector getInputTokens() const { + if (getStep() == 0) { + return promptTokens; + } else { + return std::vector(generatedTokens.end() - inputSeqLen, generatedTokens.end()); + } + } int32_t getPastSeqLen() const { return pastSeqLen; } void setPastSeqLen(int32_t _pastSeqLen) { pastSeqLen = _pastSeqLen; } // For next tokens - void addNextToken(int32_t token) { - nextTokens.clear(); - nextTokens.push_back(token); - inputTokens.push_back(token); - } - - int32_t getLatestToken() const { return nextTokens.back(); } + void addNextToken(int32_t token) { generatedTokens.push_back(token); } - const int32_t *getTotalTokens() const { return getInputTokens(); } + const std::vector getTotalTokens() const { + std::vector totalTokens = promptTokens; + totalTokens.insert(totalTokens.end(), generatedTokens.begin(), generatedTokens.end()); + return totalTokens; + } int32_t getStep() const { return step; } @@ -133,8 +141,8 @@ class SequenceMeta { int32_t sequenceID; int32_t inputSeqLen; int32_t pastSeqLen; - std::vector inputTokens; // input tokens + next tokens - std::vector nextTokens; // next tokens + std::vector promptTokens; // prompt tokens (user's input) + std::vector generatedTokens; // all generated tokens int32_t step; #ifdef PIPELINE_PARALLEL diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index ba84edc2..dbe703ac 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -530,13 +530,67 @@ class CommonDecoder : public AbstractDecoder { TimeLine t("Decoder.forward"); TimeLine t1("Decoder.embedding"); + // Prepare input + int totInputSeqLen = 0; + std::vector allInputIds; + for (auto seq : seqs) { + totInputSeqLen += seq->getInputSeqLen(); + auto ids = seq->getInputTokens(); + allInputIds.insert(allInputIds.end(), ids.begin(), ids.end()); + } + + // Prepare context + DecoderContext *ctx = this->getContext(); + ctx->resize(totInputSeqLen); + int batchSize = seqs.size(); - int userSideBS = seqs.size(); - int step = seqs[0]->getStep(); + int hiddenSize = ctx->hiddenSize; + + AttnInT *embBuf = (AttnInT *)actBuffers->Data(); + MlpOutT *outBuf = (MlpOutT *)(embBuf + totInputSeqLen * hiddenSize); + + // Embedding + this->embeddingForward(allInputIds.data(), embBuf, totInputSeqLen); + + // TODO: Decoder layers + + // Prepare input for final Layer Norm (only care about the last row of the result) + // Shape of embBuf: (bs, seqLen, hiddenSize) + MlpOutT *lnIn = embBuf; + auto logitRows = totInputSeqLen; + if (!logitsAll) { + // TODO: copy needed data + } + +#ifdef DEBUG + dbg.debugPrint(">>> DecoderLayer Output[%d, %d] (%d):\n", logitRows, hiddenSize, hiddenSize); + dbg.dumpMatrix(embBuf, logitRows, hiddenSize, hiddenSize); + dbg.debugPrint("LayerNorm In:\n"); + + dbg.dumpMatrix(lnIn, logitRows, hiddenSize, hiddenSize); +#endif + + // Last normalization layer + MlpOutT *lnOut = embBuf; + lastLayerNormForward(lnIn, lnOut, logitRows); + +#ifdef DEBUG + dbg.debugPrint("LayerNorm Out:\n"); + dbg.dumpMatrix(lnOut, logitRows, hiddenSize, hiddenSize); +#endif - // TODO - throw std::logic_error("Method not implemented"); - return std::tuple(nullptr, 0, 0); + // Predictor + float *finalOut = (float *)outBuf; + this->predictor->forward(ctx, lnOut, finalOut, logitRows); + +#ifdef DEBUG + auto splitSize = this->predictor->getSplitSize(); + dbg.debugPrint("finalOut:\n"); + dbg.dumpMatrix(finalOut, logitRows, splitSize, splitSize); +#endif + + return std::tuple( + finalOut, this->predictor->getSplitOffset(), this->predictor->getSplitSize()); } void setPrefix(int *ids, int seqLen) { From aa48f7ea3cac39d270abe30f7f52653fa3577483 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 14/35] [Common] New KVCacheMgr to support CB (#371) --- src/common/kvcache_mgr.h | 216 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 src/common/kvcache_mgr.h diff --git a/src/common/kvcache_mgr.h b/src/common/kvcache_mgr.h new file mode 100644 index 00000000..98f115a0 --- /dev/null +++ b/src/common/kvcache_mgr.h @@ -0,0 +1,216 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +#include +#include "kvcache_tensor.h" +#include + +namespace xft { + +class KVCacheMgrImplBase { +public: + virtual ~KVCacheMgrImplBase() = default; + virtual bool delSequence(int seqID) = 0; + virtual bool addSequence(int seqID, int prefixId = -1) = 0; + virtual bool reorderCache(const std::vector &seqIDs, const std::vector &prevSeqIDs) = 0; + virtual bool addPrefix(int prefixId, int seqID) = 0; + virtual bool prepareCache(const std::vector &seqIDs) = 0; + virtual std::vector getKey(int layerId) = 0; + virtual std::vector getValue(int layerId) = 0; +}; + +template +class KVCacheMgrImpl : public KVCacheMgrImplBase { +public: + KVCacheMgrImpl(int layers) { this->layers = layers; } + + ~KVCacheMgrImpl() { + // Free resource in cachePool (readyCaches are in cachePool too) + for (auto &it : sequenceCaches) { + delete it.second; + } + // Free resource in prefixCaches + for (auto &it : prefixCaches) { + delete it.second; + } + // Free resource in freeCaches + for (auto &it : freeCaches) { + delete it; + } + } + + // Free KVCache by sample ID. + bool delSequence(int seqID) override { + auto it = sequenceCaches.find(seqID); + + // Fail if not exist + if (it == sequenceCaches.end()) { return false; } + + // Move from sequenceCaches to freeCaches + freeCaches.push_back(it->second); + + sequenceCaches.erase(it); + + return true; + } + + bool addSequence(int seqID, int prefixId = -1) override { + // Fail if already exist + if (sequenceCaches.find(seqID) != sequenceCaches.end()) { return false; } + + // Get a free cache or create a new one + KVCacheTensor *cache = nullptr; + if (!freeCaches.empty()) { + cache = freeCaches.back(); + freeCaches.pop_back(); + } else { + cache = new KVCacheTensor[2 * layers]; + } + + sequenceCaches.insert({seqID, cache}); + + return true; + } + + // Reorder cache based on prevSeqIDs for beam search (caches reordered from prevSeqIDs to seqIDs) + // For example, if seqIDs = {1, 2, 3, 4} and prevSeqIDs = {1, 1, 1, 1}, then means to expand cache for sample 1 + bool reorderCache(const std::vector &seqIDs, const std::vector &prevSeqIDs) override { + // TODO: implement reorderCache + return false; + } + + // Create KVCache for prefix sharing + bool addPrefix(int prefixId, int seqID) override { + // Fail if already exist + if (prefixCaches.find(prefixId) != prefixCaches.end()) { return false; } + + // Cannot find the sample cache + if (sequenceCaches.find(seqID) == sequenceCaches.end()) { return false; } + + // Create a new one + KVCacheTensor *cache = new KVCacheTensor[2 * layers]; + + for (int i = 0; i < 2 * layers; i++) { + // TODO: add from method in KVCacheTensor + //cache[i].from(sequenceCaches[seqID][i]); + } + + prefixCaches.insert({prefixId, cache}); + + return true; + } + + // Set cache to be ready for this order of sampleIds + bool prepareCache(const std::vector &seqIDs) override { + std::vector *> readyList; + readyList.reserve(seqIDs.size()); + + for (auto seqID : seqIDs) { + auto it = sequenceCaches.find(seqID); + if (it == sequenceCaches.end()) { return false; } + readyList.push_back(it->second); + } + + readyCaches = std::move(readyList); + + return true; + } + + // Get key caches for a layer + std::vector getKey(int layerId) override { + std::vector keyCaches; + keyCaches.reserve(readyCaches.size()); + for (auto cache : readyCaches) { + keyCaches.push_back(&cache[2 * layerId]); + } + return keyCaches; + } + + // Get value caches for a layer + std::vector getValue(int layerId) override { + std::vector valueCaches; + valueCaches.reserve(readyCaches.size()); + for (auto cache : readyCaches) { + valueCaches.push_back(&cache[2 * layerId + 1]); + } + return valueCaches; + } + +private: + // seqID -> pointer to an array of caches (each element is a KVCacheTensor, size=2*layers) + // Layout of each array is: + // + // + // + // + // ... + std::unordered_map *> sequenceCaches; + + // prefixID -> pointer to an array of caches (each element is a KVCacheTensor, size=2*layers) + std::unordered_map *> prefixCaches; + + // List of ready caches, each element is for a sample; subset of sequenceCaches + std::vector *> readyCaches; + + // List of pending free caches, each element is for a sample + std::vector *> freeCaches; + + int layers; +}; + +class KVCacheMgr { +public: + static KVCacheMgr &instance() { + static KVCacheMgr inst; + return inst; + } + + void configure(int layers, DataType dataType) { + switch (dataType) { + case DataType::int8: cacheMgrImpl = new KVCacheMgrImpl(layers); break; + case DataType::fp16: cacheMgrImpl = new KVCacheMgrImpl(layers); break; + default: cacheMgrImpl = new KVCacheMgrImpl(layers); break; + } + } + + bool delSequence(int seqID) { return cacheMgrImpl->delSequence(seqID); } + + bool addSequence(int seqID, int prefixId = -1) { return cacheMgrImpl->addSequence(seqID, prefixId); } + + bool reorderCache(const std::vector &seqIDs, const std::vector &prevSeqIDs) { + return cacheMgrImpl->reorderCache(seqIDs, prevSeqIDs); + } + + bool addPrefix(int prefixId, int seqID) { return cacheMgrImpl->addPrefix(prefixId, seqID); } + + bool prepareCache(const std::vector &seqIDs) { return cacheMgrImpl->prepareCache(seqIDs); } + + std::vector getKey(int layerId) { return cacheMgrImpl->getKey(layerId); } + + std::vector getValue(int layerId) { return cacheMgrImpl->getValue(layerId); } + +private: + KVCacheMgrImplBase *cacheMgrImpl; + + KVCacheMgr() : cacheMgrImpl(nullptr) {} + + ~KVCacheMgr() { delete cacheMgrImpl; } + + KVCacheMgr(const KVCacheMgr &) = delete; + KVCacheMgr &operator=(const KVCacheMgr &) = delete; +}; + +} // namespace xft \ No newline at end of file From 3f15904506970ad52150ea0c77f9a6343b9fdf52 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 15/35] [Sampling] Add repetition penalty for new seq type. (#373) --- include/models.h | 1 + src/common/sequence.h | 20 ++++++-- src/models/models.cpp | 10 ++-- src/searchers/search_utils.cpp | 77 +++++++++++++++++++++++++++- src/searchers/search_utils.h | 6 +++ tests/ut/CMakeLists.txt | 2 + tests/ut/repetition_penalty_test.cpp | 54 +++++++++++++++++++ 7 files changed, 162 insertions(+), 8 deletions(-) create mode 100644 tests/ut/repetition_penalty_test.cpp diff --git a/include/models.h b/include/models.h index 5086fe1e..9ae19fca 100644 --- a/include/models.h +++ b/include/models.h @@ -87,6 +87,7 @@ class Model { int vocabSize; SearcherConfig configuration; bool isNewInput; + std::vector workingGroup; }; class AutoModel : public Model { diff --git a/src/common/sequence.h b/src/common/sequence.h index 24b54bd5..dfaf4924 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -85,6 +85,10 @@ class SequenceMeta { int32_t getSequenceID() const { return sequenceID; } + std::vector getPromptTokens() const { return promptTokens; } + + std::vector getGeneratedTokens() const { return generatedTokens; } + // Step forward given the generated token ID void stepForward(int32_t genToken) { inputSeqLen = 1; @@ -112,7 +116,7 @@ class SequenceMeta { // Get current input sequence length int32_t getInputSeqLen() const { return inputSeqLen; } - const std::vector getInputTokens() const { + std::vector getInputTokens() const { if (getStep() == 0) { return promptTokens; } else { @@ -127,7 +131,7 @@ class SequenceMeta { // For next tokens void addNextToken(int32_t token) { generatedTokens.push_back(token); } - const std::vector getTotalTokens() const { + std::vector getTotalTokens() const { std::vector totalTokens = promptTokens; totalTokens.insert(totalTokens.end(), generatedTokens.begin(), generatedTokens.end()); return totalTokens; @@ -185,6 +189,14 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } + SequenceGroupMeta(std::vector &_inputTokens) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_inputTokens)); + } + groupID = sequences[0].getSequenceID(); + } + int32_t getGroupID() { return groupID; } int32_t getGroupSize() { return samplingMeta.config.numBeams; } @@ -200,6 +212,8 @@ class SequenceGroupMeta { bool isDone() { return samplingMeta.done; } + SamplingMeta *getSamplingMeta() { return &samplingMeta; } + private: // using 1st sequence ID as group ID. int32_t groupID; @@ -357,6 +371,4 @@ class TaskWaitingQueue { int32_t MaxRequestNum; }; -static std::vector workingGroup; - } // namespace xft \ No newline at end of file diff --git a/src/models/models.cpp b/src/models/models.cpp index b429ec72..c6bbe088 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -214,9 +214,9 @@ void Model::set_input(std::vector &inputIds_, int batchSize_, SearcherC inputQueue.push(group); } - xft::workingGroup.clear(); + workingGroup.clear(); while (!inputQueue.empty()) { - xft::workingGroup.push_back(inputQueue.pop()); + workingGroup.push_back(inputQueue.pop()); } } @@ -243,7 +243,7 @@ bool Model::isDone() { } return !isNewInput && searcher->isDone(); } - for (auto x : xft::workingGroup) { + for (auto x : workingGroup) { if (!x->isDone()) { return false; } } return true; @@ -290,6 +290,10 @@ std::vector Model::generate() { } } else { // TODO + std::tuple result = forward(false); + float *outBuf = std::get<0>(result); + int sampleOffset = std::get<1>(result); + int sampleSize = std::get<2>(result); throw std::logic_error("Method not implemented"); return {}; } diff --git a/src/searchers/search_utils.cpp b/src/searchers/search_utils.cpp index eadba188..ea01bf1e 100644 --- a/src/searchers/search_utils.cpp +++ b/src/searchers/search_utils.cpp @@ -15,6 +15,7 @@ #include #include #include +#include "messenger.h" #include "search_utils.h" // Insert an element into a sorted vector while maintaining the order @@ -94,4 +95,78 @@ void stopWordsCheck(std::vector &nextTokenIds, std::vector } } } -} \ No newline at end of file +} + +namespace xft { +// Assume all sequences are all prompts or decodes. +// TODO: support num_beams > 1 (beam search) +void repetitionPenaltyLogitsProcess( + float *logits, int sampleOffset, int sampleSize, std::vector &seqGroups) { + bool multiRank = Messenger::getInstance().getSize() > 1; + + std::vector groupIndex; + // TODO: Num_beam > 1 (beam search) + int batchSize = seqGroups.size(); + + // Assume all seqences are all prompts or decodes. + int step = seqGroups[0]->getStep(); + + // For prompts + if (step == 0) { +#pragma omp parallel for + for (int b = 0; b < batchSize; b++) { + if (seqGroups[b]->getSamplingMeta()->config.repetitionPenalty == 1.0) { continue; } + SequenceMeta *seqMeta = seqGroups[b]->get(0); + std::vector &cachedVec = seqGroups[b]->getSamplingMeta()->cachedRepetVec; + cachedVec = seqMeta->getPromptTokens(); + std::sort(cachedVec.begin(), cachedVec.end()); + cachedVec.erase(std::unique(cachedVec.begin(), cachedVec.end()), cachedVec.end()); + + if (multiRank) { + // Get (sampleOffset, sampleOffset + sampleSize) + auto boundBegin = std::upper_bound(cachedVec.begin(), cachedVec.end(), sampleOffset); + auto boundEnd = std::lower_bound(cachedVec.begin(), cachedVec.end(), sampleOffset + sampleSize); + + cachedVec.erase(boundEnd, cachedVec.end()); + cachedVec.erase(cachedVec.begin(), boundBegin); + + std::transform(cachedVec.begin(), cachedVec.end(), cachedVec.begin(), + [sampleOffset](int num) { return num - sampleOffset; }); + } + } + } else { + if (multiRank) { +#pragma omp parallel for + for (int b = 0; b < batchSize; b++) { + if (seqGroups[b]->getSamplingMeta()->config.repetitionPenalty == 1.0) { continue; } + std::vector inputIds = seqGroups[b]->get(0)->getInputTokens(); + for (auto x : inputIds) { + if (x >= sampleOffset && x < sampleOffset + sampleSize) { + insertAndSort(seqGroups[b]->getSamplingMeta()->cachedRepetVec, x - sampleOffset); + } + } + } + } else { +#pragma omp parallel for + for (int b = 0; b < batchSize; b++) { + if (seqGroups[b]->getSamplingMeta()->config.repetitionPenalty == 1.0) { continue; } + std::vector inputIds = seqGroups[b]->get(0)->getInputTokens(); + for (auto x : inputIds) { + insertAndSort(seqGroups[b]->getSamplingMeta()->cachedRepetVec, x); + } + } + } + } + +#pragma omp parallel for + for (int b = 0; b < batchSize; b++) { + if (seqGroups[b]->getSamplingMeta()->config.repetitionPenalty == 1.0) { continue; } + int startLogits = b * sampleSize; + auto &penalty = seqGroups[b]->getSamplingMeta()->config.repetitionPenalty; + for (int index : seqGroups[b]->getSamplingMeta()->cachedRepetVec) { + float &logit = logits[startLogits + index]; + logit = logit < 0 ? logit * penalty : logit / penalty; + } + } +} +} // namespace xft \ No newline at end of file diff --git a/src/searchers/search_utils.h b/src/searchers/search_utils.h index ed8c6b45..9748f518 100644 --- a/src/searchers/search_utils.h +++ b/src/searchers/search_utils.h @@ -13,6 +13,7 @@ // limitations under the License. // ============================================================================ #pragma once +#include "sequence.h" // Insert an element into a sorted vector while maintaining the order void insertAndSort(std::vector &targetVector, int num); @@ -22,3 +23,8 @@ void repetitionPenaltyLogitsProcess(float penalty, float *logits, int sampleOffs void stopWordsCheck(std::vector &nextTokenIds, std::vector> &stopWordsList, std::vector> &stopWordsIndex, std::vector &doneBatch); + +namespace xft { +void repetitionPenaltyLogitsProcess( + float *logits, int sampleOffset, int sampleSize, std::vector &seqGroups); +} // namespace xft \ No newline at end of file diff --git a/tests/ut/CMakeLists.txt b/tests/ut/CMakeLists.txt index 940a45d2..2f7a77c5 100644 --- a/tests/ut/CMakeLists.txt +++ b/tests/ut/CMakeLists.txt @@ -75,6 +75,8 @@ foreach(src ${sources}) continue() endif() add_executable(timeline_test ${src}) + elseif(${executable} STREQUAL "repetition_penalty_test") + add_executable(repetition_penalty_test ${src} ${SRC_DIR}/searchers/search_utils.cpp) else() add_executable(${executable} ${src}) endif() diff --git a/tests/ut/repetition_penalty_test.cpp b/tests/ut/repetition_penalty_test.cpp new file mode 100644 index 00000000..40efede6 --- /dev/null +++ b/tests/ut/repetition_penalty_test.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include "search_utils.h" +#include "gtest/gtest.h" + +namespace xft { +TEST(RepetitionPenaltyTest, repetitionPenaltyLogitsProcessTest) { + // Test input + float logits[] = {0.2, 0.2, 0.2, 0.2, 0.2}; + int sampleOffset = 0; + int sampleSize = 5; + std::vector seqGroups; + std::vector promptTokens = {0, 2, 1}; + seqGroups.push_back(new SequenceGroupMeta(promptTokens)); + seqGroups[0]->getSamplingMeta()->config.repetitionPenalty = 2; + + // Call the function + repetitionPenaltyLogitsProcess(logits, sampleOffset, sampleSize, seqGroups); + + // Check logits + float expectedLogits_1[] = {0.1, 0.1, 0.1, 0.2, 0.2}; + for (int i = 0; i < sampleSize; i++) { + EXPECT_NEAR(logits[i], expectedLogits_1[i], 0.001); + } + + seqGroups[0]->get(0)->stepForward(3); + + repetitionPenaltyLogitsProcess(logits, sampleOffset, sampleSize, seqGroups); + + // Check logits + float expectedLogits_2[] = {0.05, 0.05, 0.05, 0.1, 0.2}; + for (int i = 0; i < sampleSize; i++) { + EXPECT_NEAR(logits[i], expectedLogits_2[i], 0.001); + } +} +} // namespace xft + +int main(int argc, char **argv) { + srand(time(NULL)); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 8c2e6b4a0b609ecdf7a36c7bb909f15cf2e0c600 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 16/35] [Sampling] Add greedy search for cb path. (#376) --- examples/cpp/example.cpp | 13 ++- include/models.h | 2 +- src/common/sequence.h | 2 + src/models/models.cpp | 55 ++++++++++++ src/searchers/sampling.cpp | 155 +++++++++++++++++++++++++++++++++ src/searchers/sampling.h | 25 ++++++ src/searchers/search_utils.cpp | 2 + 7 files changed, 251 insertions(+), 3 deletions(-) create mode 100644 src/searchers/sampling.cpp create mode 100644 src/searchers/sampling.h diff --git a/examples/cpp/example.cpp b/examples/cpp/example.cpp index c3d0cb70..ce8bd1f9 100644 --- a/examples/cpp/example.cpp +++ b/examples/cpp/example.cpp @@ -444,11 +444,20 @@ int main(int argc, char **argv) { for (int i = 0; i < loop; ++i) { secondIdCount = 0; - model.config(/*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, /*lenPenalty*/ 1.0, + + model.set_input(input, batchSize, /*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, + /*lenPenalty*/ 1.0, /*doEarlyStopping*/ false, /*eosTokenId*/ -1, /*padTokenId*/ -1, /*doSample*/ doSample, /*temperature*/ temperature, /*topK*/ topK, /*topP*/ topP, /*repetitionPenalty*/ repetitionPenalty); - model.input(input, batchSize); + + // TODO: Deprecated + // Old Path + // model.config(/*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, /*lenPenalty*/ 1.0, + // /*doEarlyStopping*/ false, /*eosTokenId*/ -1, /*padTokenId*/ -1, + // /*doSample*/ doSample, /*temperature*/ temperature, + // /*topK*/ topK, /*topP*/ topP, /*repetitionPenalty*/ repetitionPenalty); + // model.input(input, batchSize); std::vector firstIds; std::vector secondIds; diff --git a/include/models.h b/include/models.h index 9ae19fca..65683b6d 100644 --- a/include/models.h +++ b/include/models.h @@ -68,7 +68,7 @@ class Model { void setDecoder(AbstractDecoder *dec); - std::vector finalize() { return searcher->finalize(); } + std::vector finalize(); void exitSlaves(); diff --git a/src/common/sequence.h b/src/common/sequence.h index dfaf4924..f34bb08f 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -124,6 +124,8 @@ class SequenceMeta { } } + int getTotalLen() const { return promptTokens.size() + generatedTokens.size(); } + int32_t getPastSeqLen() const { return pastSeqLen; } void setPastSeqLen(int32_t _pastSeqLen) { pastSeqLen = _pastSeqLen; } diff --git a/src/models/models.cpp b/src/models/models.cpp index c6bbe088..513705fc 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -30,6 +30,8 @@ #include "opt_decoder.h" #include "qwen.h" #include "qwen2.h" +#include "sampling.h" +#include "search_utils.h" #include "searcher.h" #include "sequence.h" #include "timeline.h" @@ -249,6 +251,21 @@ bool Model::isDone() { return true; } +std::vector Model::finalize() { + // TODO: Deprecate the following Path + if (searcher != nullptr) { + return searcher->finalize(); + } else { + std::vector result; + // TODO: Unequal-length input & output + for (auto x : workingGroup) { + std::vector seq = x->get(0)->getTotalTokens(); + result.insert(result.end(), seq.begin(), seq.end()); + } + return result; + } +} + std::tuple Model::forward(bool logits_all) { // TODO: Deprecate the following Path if (searcher != nullptr) { @@ -275,6 +292,8 @@ std::tuple Model::forward(bool logits_all) { return decoder->forward(workingSeqs, logits_all); } +// We assume all gen kwargs in the batch are the same +// and all sequences are all prompts(step==0) or all decodes(step>0) std::vector Model::generate() { // TODO: Deprecate the following Path if (searcher != nullptr) { @@ -294,11 +313,47 @@ std::vector Model::generate() { float *outBuf = std::get<0>(result); int sampleOffset = std::get<1>(result); int sampleSize = std::get<2>(result); + + // Assume all gen kwargs in the batch are the same + auto &config = workingGroup[0]->getSamplingMeta()->config; + + if (config.numBeams != 1) { + // TODO: BeamSearch + throw std::logic_error("Beam Search Method not implemented"); + } else { + + // Logits processor + // Repetition penalty + if (config.repetitionPenalty != 1.0) { + repetitionPenaltyLogitsProcess(outBuf, sampleOffset, sampleSize, workingGroup); + } + + std::vector result; + + if (config.doSample) { + //TODO: samling + throw std::logic_error("Sampling Method not implemented"); + } else { + // Greedy search + result = greedySearch(outBuf, sampleOffset, sampleSize, batchSize); + } + + // Check stop status + stopCheck(result, workingGroup); + + // Step forward on all seqs + for (int i = 0; i < workingGroup.size(); i++) { + workingGroup[i]->get(0)->stepForward(result[i]); + } + + return result; + } throw std::logic_error("Method not implemented"); return {}; } } +// TODO: Deprecate the following function void Model::createSearcher(SearcherConfig &config_) { if (searcher != nullptr) { delete searcher; } diff --git a/src/searchers/sampling.cpp b/src/searchers/sampling.cpp new file mode 100644 index 00000000..bfc82714 --- /dev/null +++ b/src/searchers/sampling.cpp @@ -0,0 +1,155 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include + +#include "sampling.h" +#include "timeline.h" + +namespace xft { +// Assume all samples have the same sampling params. +std::vector greedySearch(float *logits, int sampleOffset, int sampleSize, int batchSize) { + TimeLine t("GreedySearch"); + + Messenger &messenger = Messenger::getInstance(); + int numThreads = 0; +#pragma omp parallel + { + int tid = omp_get_thread_num(); + if (tid == 0) { numThreads = omp_get_num_threads(); } + } + + auto msgerSize = messenger.getSize(); + + // Max ID and value for each sample + std::vector maxIds(batchSize); + float maxVals[batchSize]; + + // Small batch size (each sample can have at least 2 threads) + if (numThreads / batchSize >= 2) { + int thrPerSample = numThreads / batchSize; + int sizePerThr = (sampleSize + thrPerSample - 1) / thrPerSample; + int maxIndices[batchSize * thrPerSample]; + float maxValues[batchSize * thrPerSample]; + + // TODO: if size is small, possible to cause out of boundary +#pragma omp parallel for collapse(2) + for (int b = 0; b < batchSize; ++b) { + for (int t = 0; t < thrPerSample; ++t) { // thread index inside the sample + int start = t * sizePerThr; + int end = (start + sizePerThr) > sampleSize ? sampleSize : (start + sizePerThr); + float *p = logits + b * sampleSize; + + int maxIdx = start; + float maxVal = p[start]; + for (int off = start + 1; off < end; ++off) { + if (p[off] > maxVal) { + maxVal = p[off]; + maxIdx = off; + } + } + + // False sharing happens, but since only one time, not avoided + maxIndices[b * thrPerSample + t] = maxIdx; + maxValues[b * thrPerSample + t] = maxVal; + } + } + + // Local reduction + for (int i = 0; i < batchSize; ++i) { + int *pIndices = maxIndices + i * thrPerSample; + float *pValues = maxValues + i * thrPerSample; + int maxIdx = pIndices[0]; + float maxVal = pValues[0]; + for (int j = 1; j < thrPerSample; ++j) { + if (pValues[j] > maxVal) { + maxVal = pValues[j]; + maxIdx = pIndices[j]; + } + } + maxIds[i] = maxIdx; + maxVals[i] = maxVal; + } + } + + // Each thread handle one sample (one row) + else { +#pragma omp parallel for + for (int i = 0; i < batchSize; ++i) { + int maxId = 0; + float *p = logits + i * sampleSize; + float maxVal = p[0]; + for (int j = 1; j < sampleSize; ++j) { + if (p[j] > maxVal) { + maxVal = p[j]; + maxId = j; + } + } + maxIds[i] = maxId; + maxVals[i] = maxVal; + } + } + + // Reduce to get the max index (any better method??) + if (msgerSize > 1) { + float sendBuf[2 * batchSize]; + float recvBuf[2 * batchSize * msgerSize]; + + for (int i = 0; i < batchSize; ++i) { + sendBuf[2 * i] = (float)(maxIds[i] + sampleOffset); + sendBuf[2 * i + 1] = maxVals[i]; + } + + std::vector recvCount(msgerSize, static_cast(2 * batchSize)); + messenger.allgatherv(sendBuf, 2 * batchSize, recvBuf, recvCount); + + for (int i = 0; i < batchSize; ++i) { + int maxId = (int)(recvBuf[2 * i] + 0.5f); + float maxVal = recvBuf[2 * i + 1]; + for (int j = 1; j < msgerSize; ++j) { + if (recvBuf[2 * j * batchSize + 2 * i + 1] > maxVal) { + maxVal = recvBuf[2 * j * batchSize + 2 * i + 1]; + maxId = (int)(recvBuf[2 * j * batchSize + 2 * i] + 0.5f); + } + } + maxIds[i] = maxId; + } + } + + return maxIds; +} + +// For greedy search and samlping, not for beam search +void stopCheck(std::vector &generatedIds, std::vector &seqGroups) { + int batchSize = generatedIds.size(); +#pragma omp parallel for + for (int b = 0; b < batchSize; b++) { + // TODO: Deprecate this check, since no need for unequal-length output + if (seqGroups[b]->getSamplingMeta()->done) { + generatedIds[b] = seqGroups[b]->getSamplingMeta()->config.eosTokenId; + continue; + } + + // If the generated token is EOS, mark the sequence as done + if (seqGroups[b]->getSamplingMeta()->config.eosTokenId == generatedIds[b]) { + seqGroups[b]->getSamplingMeta()->done = true; + } + // If the sequence meets the max length, mark the sequence as done + else if (seqGroups[b]->get(0)->getTotalLen() + 1 >= seqGroups[b]->getSamplingMeta()->config.maxLen) { + seqGroups[b]->getSamplingMeta()->done = true; + } + // TODO: stop words check + } +} +} // namespace xft \ No newline at end of file diff --git a/src/searchers/sampling.h b/src/searchers/sampling.h new file mode 100644 index 00000000..21b5a438 --- /dev/null +++ b/src/searchers/sampling.h @@ -0,0 +1,25 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +#include "messenger.h" +#include "sampling_params.h" +#include "sequence.h" + +namespace xft { +std::vector greedySearch(float *logits, int sampleOffset, int sampleSize, int batchSize); + +void stopCheck(std::vector &generatedIds, std::vector &seqGroups); +} // namespace xft \ No newline at end of file diff --git a/src/searchers/search_utils.cpp b/src/searchers/search_utils.cpp index ea01bf1e..686103a2 100644 --- a/src/searchers/search_utils.cpp +++ b/src/searchers/search_utils.cpp @@ -17,6 +17,7 @@ #include #include "messenger.h" #include "search_utils.h" +#include "timeline.h" // Insert an element into a sorted vector while maintaining the order void insertAndSort(std::vector &targetVector, int num) { @@ -102,6 +103,7 @@ namespace xft { // TODO: support num_beams > 1 (beam search) void repetitionPenaltyLogitsProcess( float *logits, int sampleOffset, int sampleSize, std::vector &seqGroups) { + TimeLine t("RepetitionPenaltyLogitsProcess"); bool multiRank = Messenger::getInstance().getSize() > 1; std::vector groupIndex; From aac016764599ed493e35fe79d2414c209c70ed0c Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 17/35] [Model/Layer] New forward to support CB (CommonDecoder->DecoderBlock->DecoderLayer->Attention/MLP) (#375) --- src/layers/attention.h | 214 ++++++++++++++++++++++ src/layers/decoder_block.h | 349 ++++++++++++++++++++++++++++++++++++ src/layers/decoder_layer.h | 8 + src/models/common_decoder.h | 58 +++--- 4 files changed, 594 insertions(+), 35 deletions(-) create mode 100644 src/layers/decoder_block.h diff --git a/src/layers/attention.h b/src/layers/attention.h index 75ddeb5f..7484aa36 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -25,6 +25,7 @@ #include "gemm_kernel_ext.h" #include "kvcache_tensor.h" #include "matmul_helper.h" +#include "sequence.h" #include "simple_mem_pool.h" #include "transformer_ctx.h" #include "transformer_util.h" @@ -397,6 +398,191 @@ class Attention { } } + /** + * Forward computing for the whole Attention layer (QKV MatMul + MHA/GQA + Output MatMul) + */ + template + void forward(DecoderContext *ctx, std::vector &seqs, InT *input, OutT *output, + size_t totInSeqLen, std::vector *> &keyCaches, + std::vector *> &valueCaches, bool doLnBefore = true) { + + auto hiddenSize = ctx->hiddenSize; + xft::Matrix inputBuffer(input, totInSeqLen, hiddenSize, hiddenSize); + ImT *imBuf = (ImT *)ctx->getBuffer("tmp", totInSeqLen * hiddenSize); + xft::Matrix imBuffer(imBuf, totInSeqLen, hiddenSize, hiddenSize); + xft::Matrix outBuffer(output, totInSeqLen, hiddenSize, hiddenSize); + + float epsilon = ctx->epsilon; + int headSize = ctx->attHeadSize; + auto qkvRows = totInSeqLen; + int qCols = (this->endQHead - this->startQHead) * headSize; + int kvCols = (this->endKVHead - this->startKVHead) * headSize; + int qkCols = qCols + kvCols; + int qkvCols = qkCols + kvCols; + + int qkvStride = qkvCols; + auto &qkvMatMul = ctx->qkvMatMul; + xft::Matrix qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride); + +#ifdef DEBUG + dbg.debugPrint("---- DecoderLayer.forward (useSelfAttn=%d) ----\n", useSelfAttn); + dbg.debugPrint("input:\n"); + dbg.dumpMatrix(inputBuffer); +#endif + + if (doLnBefore) { + TimeLine t1("input.layer_norm"); + norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), + imBuffer.Stride(), epsilon); + } +#ifdef DEBUG + dbg.debugPrint("layer norm:\n"); + dbg.dumpMatrix(imBuffer); + dbg.debugPrint("qkvWeight [%d, %d]:\n", this->qkvWeight.Rows(), this->qkvWeight.Cols()); + dbg.dumpMatrix(this->qkvWeight); +#endif + + // Query, Key, Value computed together + TimeLine t2("QKV.linear"); + if (qkvBias.Size() == 0) { + ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(), + imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(), + qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride()); + } else { + ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, + imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(), + qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data()); + } + t2.release(); + + xft::Matrix query(qkvGroupMatMul, 0, inputBuffer.Rows(), 0, qCols); + xft::Matrix key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols); + xft::Matrix value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols); + +#ifdef DEBUG + dbg.debugPrint("Q[%d,%d](%d):\n", query.Rows(), query.Cols(), query.Stride()); + dbg.dumpMatrix(query); + dbg.debugPrint("K[%d,%d](%d):\n", key.Rows(), key.Cols(), key.Stride()); + dbg.dumpMatrix(key); + dbg.debugPrint("V[%d,%d](%d):\n", value.Rows(), value.Cols(), value.Stride()); + dbg.dumpMatrix(value); +#endif + + // Apply post operations on query and key + TimeLine t3("QKPO"); + // TODO: call into rotary embedding + // int qheads = this->endQHead - this->startQHead; + // int kheads = this->endKVHead - this->startKVHead; + // int qkShape[7] = {ctx->batchSize, ctx->inputSeqLen, qheads, headSize, kheads, ctx->maxSeqLength, pastSeqLen}; + // if (positionIds != nullptr) { + // qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, positionIds); + // } else if (ctx->maxPosEmbed > 0) { + // // Use the default position ids + // std::vector posIds(ctx->inputSeqLen); + // if (inputSeqLen == 1) { + // posIds[0] = pastSeqLen; + // } else { + // std::iota(posIds.begin(), posIds.end(), pastSeqLen); + // } + // qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, posIds.data()); + // } + t3.release(); + +#ifdef DEBUG + dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride()); + dbg.dumpMatrix(query); + dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride()); + dbg.dumpMatrix(key); +#endif + + // Revise attnFactor before softmax (for some models, attnFactor may be not the default value) + // We initially introduced the code for ChatGLM, but eventually found it has no difference and was unnecessary. + // However, we have chosen to keep it in the codebase in case it becomes useful for future models. + if (getScalingCoeff() != 0) { ctx->attFactor = getScalingCoeff(); } + + TimeLine t4("MHA"); + if constexpr (!INPUT_AS_RESID) { // Swap inputBuffer and imBuffer + auto tmp = imBuffer.Data(); + int rows = imBuffer.Rows(), cols = imBuffer.Cols(), stride = imBuffer.Stride(); + imBuffer.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride()); + inputBuffer.Assign(tmp, rows, cols, stride); + } + + // For multiple nodes inference, not the whole result buffer + xft::Matrix attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); + + if (seqs[0]->getStep() == 0) { // First token generation + // TODO: add flashAttention + if constexpr (std::is_same_v && std::is_same_v) { + selfAttentionBF16(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs); + } else { + fusedAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs); + } + } else { + fusedAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs); + } + t4.release(); + +#ifdef DEBUG + dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(), + attnSplit.Cols(), attnSplit.Stride()); + dbg.dumpMatrix(attnSplit); +#endif + + TimeLine t5("Output"); + // Output/projection in attention, only add the input in the first split + if (ctx->splitIdx == 0) { + float gamma = getResidentialScale(); + + // denseWithScaledSum should be enough, but as the performance of denseWithScaledSum is not verified, + // So here still use denseWithSum + if (gamma == 1) { + float *pbias = attnOutputBias.Data(); + if (attnOutputBias.Size() == 0) { pbias = nullptr; } + ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), + 1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), + attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, + outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride()); + } else { + float *pbias = attnOutputBias.Data(); + if (attnOutputBias.Size() == 0) { pbias = nullptr; } + ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, + attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), + attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), + outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride()); + } + } else { + if (attnOutputBias.Size() == 0) { + ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, + attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), + attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), + outBuffer.Stride()); + } else { + ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f, + attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(), + attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(), + outBuffer.Stride(), attnOutputBias.Data()); + } + } + t5.release(); + +#ifdef DEBUG + dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(), + outBuffer.Stride()); + dbg.dumpMatrix(outBuffer); +#endif + + if (!doLnBefore) { + TimeLine t6("result.layer_norm"); + norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); +#ifdef DEBUG + dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), + outBuffer.Stride()); + dbg.dumpMatrix(outBuffer); +#endif + } + } + protected: template void selfAttentionBF16(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, @@ -418,6 +604,34 @@ class Attention { [&](int b, int headIdx, int seqIdx) { return presentValue.getSequence(seqIdx, b, headIdx); }); } + template + void selfAttentionBF16(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, + xft::Matrix &value, xft::Matrix &result, + std::vector *> &keyCaches, std::vector *> &valueCaches, + std::vector &seqs) { + int responsibleQHeads = this->endQHead - this->startQHead; + int responsibleKVHeads = this->endKVHead - this->startKVHead; + + int tokenSizes[ctx->batchSize]; + for (int i = 0; i < ctx->batchSize; ++i) { + tokenSizes[i] = seqs[i]->getInputSeqLen(); + } + + xft::selfAttention( + result.Data(), query.Data(), key.Data(), value.Data(), responsibleQHeads, responsibleKVHeads, + ctx->attHeadSize, result.Stride(), query.Stride(), key.Stride(), ctx->batchSize, tokenSizes, + ctx->attFactor, ctx->numThreads, + [&](int b, int headIdx, int seqIdx) { return keyCaches[b]->getSequence(seqIdx, 0, headIdx); }, + [&](int b, int headIdx, int seqIdx) { return valueCaches[b]->getSequence(seqIdx, 0, headIdx); }); + } + + template + void fusedAttention(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, xft::Matrix &value, + xft::Matrix &result, std::vector *> &keyCaches, + std::vector *> &valueCaches, std::vector &seqs) { + // TODO: implement fusedAttention + } + int getMBlockSize(int inputSeqLen, int headSize, int minVal = 6) { // Special case if (inputSeqLen == 1) { return 1; } diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h new file mode 100644 index 00000000..0a454ce6 --- /dev/null +++ b/src/layers/decoder_block.h @@ -0,0 +1,349 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once +#include +#include "decoder_layer.h" +#include "dtype.h" +#include "kvcache_mgr.h" +#include "messenger.h" +#include "weight_util.h" + +template +class DecoderBlock { +public: + using DECODER = Decoder; + + DecoderBlock(DecoderContext *ctx, const std::string &modelPath, int layers, xft::DataType dt) { + if (layers % ctx->ppSize != 0) { + std::cerr << "Warning: layers cannot be evenly divided by pipeline parallel stage size(ppSize)." + << std::endl; + std::exit(-1); + } + + int layersOnDuty = layers / ctx->ppSize; + int startLayer = ctx->ppRank * layersOnDuty; + for (int i = startLayer; i < startLayer + layersOnDuty; ++i) { + auto pdec = new DECODER(ctx, i); + if (dt == xft::DataType::int8) { + this->setDecoderWeights(ctx, pdec, modelPath, i); + } else if (dt == xft::DataType::int4) { + this->setDecoderWeights(ctx, pdec, modelPath, i); + } else if (dt == xft::DataType::fp32) { + this->setDecoderWeights(ctx, pdec, modelPath, i); + } else { + std::cerr << "Error: The data type is NOT supported." << std::endl; + std::exit(-1); + } + this->decoders.push_back(pdec); + } + } + + virtual ~DecoderBlock() { + for (auto dec : this->decoders) { + delete dec; + } + } + + // To make it compatible with the old impl. + DECODER *get(int layerId) { return this->decoders[layerId]; } + + int size() const { return this->decoders.size(); } + + template + void forward(DecoderContext *ctx, std::vector &seqs, InT *input, OutT *output) { + using AttnOutT = typename AttnTypeExtractor::Tout; + + Messenger &messenger = Messenger::getInstance(); + xft::KVCacheMgr &kvCacheMgr = xft::KVCacheMgr::instance(); + + // Data preparation + std::vector seqIDs(seqs.size()); + size_t totInSeqLen = 0; + for (int i = 0; i < seqs.size(); ++i) { + seqIDs[i] = seqs[i]->getSequenceID(); + totInSeqLen += seqs[i]->getInputSeqLen(); + } + + // TODO: check and prepare KV cache only needed + // kvCacheMgr.prepareCache(seqIDs); + + // All layers forward + int layersOnDuty = this->decoders.size(); + for (int i = 0; i < layersOnDuty; ++i) { + int workers = messenger.getSize(); + + std::vector keyCaches = kvCacheMgr.getKey(i); + std::vector valueCaches = kvCacheMgr.getValue(i); + + std::vector *> keyCachesVec(keyCaches.size()); + std::vector *> valueCachesVec(valueCaches.size()); + + for (int j = 0; j < keyCaches.size(); ++j) { + keyCachesVec[j] = static_cast *>(keyCaches[j]); + } + + for (int j = 0; j < valueCaches.size(); ++j) { + valueCachesVec[j] = static_cast *>(valueCaches[j]); + } + + AttnOutT *attnOut = (AttnOutT *)(ctx->tmpBuf.Data()); + + this->decoders[i]->forwardAttention(ctx, seqs, input, attnOut, totInSeqLen, keyCachesVec, valueCachesVec); + + // Merge the result of attention + // When attention and FFN/MLP are in parallel, do not need to reduce after attention + if constexpr (!ATTN_MLP_PARALLEL) { + if (messenger.getSize() > 1) { messenger.reduceAdd(attnOut, attnOut, totInSeqLen * ctx->hiddenSize); } + } + + // When attention and FFN/MLP are in parallel, use the initial embedding as input + if constexpr (ATTN_MLP_PARALLEL) { + std::cerr << "Error: ATTN_MLP_PARALLEL=true is not supported." << std::endl; + std::exit(-1); + } else { + if (messenger.getSize() > 1) { + this->decoders[i]->forwardFFN(ctx, attnOut, output, ctx->hiddenSize, ctx->hiddenSize, true); + messenger.reduceAdd(output, output, totInSeqLen * ctx->hiddenSize); + } else { + this->decoders[i]->forwardFFN(ctx, attnOut, output, ctx->hiddenSize, ctx->hiddenSize, true); + } + } + } + } + +private: + static bool fileExists(const std::string &filename) { + std::ifstream file(filename); + return file.good(); + } + + // OriWeiT: float, int8_t or uint4x2_t + template + void setDecoderWeights(DecoderContext *ctx, DECODER *pdecoder, const std::string &modelPath, int layerIdx) { + using xft::DataType; + using xft::loadWeight; + + const int hiddenSize = ctx->hiddenSize; + const int imSize = ctx->intermediateSize; + const int kvHeadNum = ctx->kvHeadNum; + const int attHeadNum = ctx->attHeadNum; + const int attHeadSize = ctx->attHeadSize; + const int mlpFactor = (ctx->actType == DecoderContext::SWIGLU) ? 2 : 1; + int qSize = attHeadSize * attHeadNum; + int kvSize = attHeadSize * kvHeadNum; + int qkvSize = qSize + 2 * kvSize; + +#define ALLOC(size, alignment) xft::alloc((size), (alignment)) + OriWeiT *qkvWeight = (OriWeiT *)ALLOC(hiddenSize * qkvSize * sizeof(OriWeiT), 64); + float *qkvScales = nullptr; + float *qkvZeros = nullptr; + float *qkvBias = (float *)ALLOC(qkvSize * sizeof(float), 64); + + OriWeiT *attnOutWeight = (OriWeiT *)ALLOC(qSize * hiddenSize * sizeof(OriWeiT), 64); + float *attnOutScales = nullptr; + float *attnOutZeros = nullptr; + float *attnOutBias = (float *)ALLOC(hiddenSize * sizeof(float), 64); + + OriWeiT *fc1Weight = (OriWeiT *)ALLOC(hiddenSize * imSize * mlpFactor * sizeof(OriWeiT), 64); + float *fc1Scales = nullptr; + float *fc1Zeros = nullptr; + float *fc1Bias = (float *)ALLOC(imSize * sizeof(float), 64); + + OriWeiT *fc2Weight = (OriWeiT *)ALLOC(hiddenSize * imSize * sizeof(OriWeiT), 64); + float *fc2Scales = nullptr; + float *fc2Zeros = nullptr; + float *fc2Bias = (float *)ALLOC(hiddenSize * sizeof(float), 64); + + float *ln1Gamma = (float *)ALLOC(hiddenSize * sizeof(float), 64); + float *ln1Beta = (float *)ALLOC(hiddenSize * sizeof(float), 64); + float *ln2Gamma = (float *)ALLOC(hiddenSize * sizeof(float), 64); + float *ln2Beta = (float *)ALLOC(hiddenSize * sizeof(float), 64); + + OriWeiT *fc3Weight = nullptr; + float *fc3Scales = nullptr; + float *fc3Zeros = nullptr; + + // INT8/INT4 quant, wbits = 8/4, qweight dtype: int8_t/uint4x2_t + if constexpr (std::is_same_v || std::is_same_v) { + DataType dt = std::is_same_v ? DataType::int8 : DataType::int4; + + qkvZeros = (float *)ALLOC(qkvSize * sizeof(float), 64); + qkvScales = (float *)ALLOC(qkvSize * sizeof(float), 64); + attnOutZeros = (float *)ALLOC(hiddenSize * sizeof(float), 64); + attnOutScales = (float *)ALLOC(hiddenSize * sizeof(float), 64); + fc1Zeros = (float *)ALLOC(imSize * mlpFactor * sizeof(float), 64); + fc1Scales = (float *)ALLOC(imSize * mlpFactor * sizeof(float), 64); + fc2Zeros = (float *)ALLOC(imSize * sizeof(float), 64); + fc2Scales = (float *)ALLOC(imSize * sizeof(float), 64); + + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + + ".attention.query_key_value.qweight.0.bin", + qkvWeight, hiddenSize * qkvSize, dt); + loadWeight( + modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.zeros.0.bin", + qkvZeros, qkvSize, DataType::fp32); + loadWeight( + modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.scales.0.bin", + qkvScales, qkvSize, DataType::fp32); + + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.qweight.0.bin", + attnOutWeight, qSize * hiddenSize, dt); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.zeros.0.bin", + attnOutZeros, hiddenSize, DataType::fp32); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.scales.0.bin", + attnOutScales, hiddenSize, DataType::fp32); + + // Stardard 2 layer MLP + if (fileExists( + modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.qweight.0.bin")) { + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.qweight.0.bin", + fc1Weight, hiddenSize * imSize * mlpFactor, dt); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.zeros.0.bin", + fc1Zeros, imSize * mlpFactor, DataType::fp32); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.scales.0.bin", + fc1Scales, imSize * mlpFactor, DataType::fp32); + + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.qweight.0.bin", + fc2Weight, hiddenSize * imSize, dt); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.zeros.0.bin", + fc2Zeros, hiddenSize, DataType::fp32); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.scales.0.bin", + fc2Scales, hiddenSize, DataType::fp32); + } + // gate, up, down weights for Llama like model + else { + fc3Weight = (OriWeiT *)ALLOC(hiddenSize * imSize * sizeof(OriWeiT), 64); + fc3Zeros = (float *)ALLOC(hiddenSize * sizeof(float), 64); + fc3Scales = (float *)ALLOC(hiddenSize * sizeof(float), 64); + + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.qweight.0.bin", + fc1Weight, hiddenSize * imSize * mlpFactor, dt); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.zeros.0.bin", + fc1Zeros, imSize * mlpFactor, DataType::fp32); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.scales.0.bin", + fc1Scales, imSize * mlpFactor, DataType::fp32); + + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.qweight.0.bin", + fc2Weight, hiddenSize * imSize, dt); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.zeros.0.bin", + fc2Zeros, imSize, DataType::fp32); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.scales.0.bin", + fc2Scales, imSize, DataType::fp32); + + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.qweight.0.bin", + fc3Weight, hiddenSize * imSize, dt); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.zeros.0.bin", + fc3Zeros, hiddenSize, DataType::fp32); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.scales.0.bin", + fc3Scales, hiddenSize, DataType::fp32); + } + + } else if constexpr (std::is_same_v) { + loadWeight( + modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.weight.0.bin", + qkvWeight, hiddenSize * qkvSize); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.weight.0.bin", + attnOutWeight, qSize * hiddenSize); + + // Stardard 2 layer MLP + if (fileExists( + modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.weight.0.bin")) { + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.weight.0.bin", + fc1Weight, hiddenSize * imSize * mlpFactor); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.weight.0.bin", + fc2Weight, hiddenSize * imSize); + } + // gate, up, down weights for Llama like model + else { + fc3Weight = (OriWeiT *)ALLOC(hiddenSize * imSize * sizeof(OriWeiT), 64); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.weight.0.bin", + fc1Weight, hiddenSize * imSize * mlpFactor); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.weight.0.bin", + fc2Weight, hiddenSize * imSize); + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.weight.0.bin", + fc3Weight, hiddenSize * imSize); + } + } + + loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".input_layernorm.weight.bin", + ln1Gamma, hiddenSize); + loadWeight( + modelPath + "/model.layers." + std::to_string(layerIdx) + ".post_attention_layernorm.weight.bin", + ln2Gamma, hiddenSize); + +#define READ_OPTIONAL(filename, addr, size, errmsg) \ + { \ + int ret = loadWeight((filename), (addr), (size), DataType::unknown, false); \ + if (ret == 0) { \ + free(addr); \ + addr = nullptr; \ + } else { \ + if (ret != (size)) { \ + printf("%s\n", (errmsg)); \ + exit(-1); \ + } \ + } \ + } + + // The bias is optional + READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.bias.0.bin", + qkvBias, qkvSize, "read QKV bias error"); + READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.bias.bin", + attnOutBias, hiddenSize, "read attn dense bias error"); + READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".input_layernorm.bias.bin", ln1Beta, + hiddenSize, "read LN1 beta error"); + READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".post_attention_layernorm.bias.bin", + ln2Beta, hiddenSize, "read LN2 beta error"); + READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.bias.0.bin", + fc1Bias, imSize, "read FC1 bias error"); + READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.bias.bin", fc2Bias, + hiddenSize, "read FC2 bias error"); + + constexpr int sizeFactor = std::is_same_v ? 2 : 1; + pdecoder->setWeights(ctx, qkvWeight, qkvScales, qkvZeros, qkvBias, qkvWeight + qSize / sizeFactor, + qkvScales + qSize, qkvZeros + qSize, qkvBias + qSize, + qkvWeight + qSize / sizeFactor + kvSize / sizeFactor, qkvScales + qSize + kvSize, + qkvZeros + qSize + kvSize, qkvBias + qSize + kvSize, attnOutWeight, attnOutScales, attnOutZeros, + attnOutBias, ln1Gamma, ln1Beta, fc1Weight, fc1Scales, fc1Zeros, fc1Bias, fc2Weight, fc2Scales, fc2Zeros, + fc2Bias, ln2Gamma, ln2Beta, fc3Weight, fc3Scales, fc3Zeros, false); + + free(qkvWeight); + free(attnOutWeight); + free(fc1Weight); + free(fc2Weight); + free(fc3Weight); + free(qkvZeros); + free(attnOutZeros); + free(fc1Zeros); + free(fc2Zeros); + free(fc3Zeros); + free(qkvScales); + free(attnOutScales); + free(fc1Scales); + free(fc2Scales); + free(fc3Scales); + free(qkvBias); + free(attnOutBias); + free(fc1Bias); + free(fc2Bias); + free(ln1Gamma); + free(ln1Beta); + free(ln2Gamma); + free(ln2Beta); + } + +private: + std::vector decoders; +}; \ No newline at end of file diff --git a/src/layers/decoder_layer.h b/src/layers/decoder_layer.h index c86730ee..8134c594 100644 --- a/src/layers/decoder_layer.h +++ b/src/layers/decoder_layer.h @@ -105,6 +105,14 @@ class Decoder { useSelfAttn, doLnBefore, false, positionIds); } + template + void forwardAttention(DecoderContext *ctx, std::vector &seqs, InT *input, OutT *output, + size_t totInSeqLen, std::vector *> &keyCaches, + std::vector *> &valueCaches) { + TimeLine t("Decoder.forwardAttention"); + attn.forward(ctx, seqs, input, output, totInSeqLen, keyCaches, valueCaches); + } + template void forwardFFN(DecoderContext *ctx, InT *input, OutT *output, int iStride, int oStride, bool doLnBefore = true) { TimeLine t("Decoder.forwardFFN"); diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index dbe703ac..177a90cb 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -23,6 +23,7 @@ #include "abstract_decoder.h" #include "attention.h" #include "debugger.h" +#include "decoder_block.h" #include "decoder_layer.h" #include "dist_linear.h" #include "dtype.h" @@ -31,6 +32,7 @@ #include "mlp_chatglm2.h" #include "mlp_standard.h" #include "model_factory.h" +#include "sequence.h" #include "timeline.h" #include "transformer_ctx.h" #include "transpose_util.h" @@ -236,19 +238,7 @@ class CommonDecoder : public AbstractDecoder { std::exit(-1); } - int layers_per_pp_stage = layers / ctx->ppSize; - int start_layer = ctx->ppRank * layers_per_pp_stage; - for (int i = start_layer; i < start_layer + layers_per_pp_stage; ++i) { - auto pdec = new DECODER(ctx, i); - if (dt == DataType::int8) { - this->setDecoderWeights(pdec, modelPath, i); - } else if (dt == DataType::int4) { - this->setDecoderWeights(pdec, modelPath, i); - } else if (dt == DataType::fp32) { - this->setDecoderWeights(pdec, modelPath, i); - } - this->decoders.push_back(pdec); - } + decoderBlock = new DecoderBlock(ctx, modelPath, layers, dt); // Predictor int workers = messenger.getSize(); @@ -264,11 +254,8 @@ class CommonDecoder : public AbstractDecoder { if (this->inputTokens) free(this->inputTokens); if (this->attnMask) free(this->attnMask); + delete this->decoderBlock; delete this->predictor; - - for (auto dec : this->decoders) { - delete dec; - } } std::tuple forward(int *ids, int64_t *dims, int step, bool logitsAll = false) { @@ -381,7 +368,7 @@ class CommonDecoder : public AbstractDecoder { #endif // Decoder: forward - int layers_per_pp_stage = this->decoders.size(); + int layers_per_pp_stage = decoderBlock->size(); for (int i = 0; i < layers_per_pp_stage; ++i) { int workers = this->messenger.getSize(); if (step == 0 && this->prefixSharing) { @@ -394,7 +381,7 @@ class CommonDecoder : public AbstractDecoder { // Pls be noted: in attention, 'outBuf' is used as imtermediate buffer, 'tmpBuf' is used as output AttnOutT *attnOut = (AttnOutT *)(this->getContext()->tmpBuf.Data()); // attnMeta (inputSeqLens, pastSeqLens, seqStartLoc, is_prompt(useSelfAttn), causal, attnMask) - this->decoders[i]->forwardAttention(getContext(), embBuf, outBuf, attnOut, attnMask, + decoderBlock->get(i)->forwardAttention(getContext(), embBuf, outBuf, attnOut, attnMask, presentKey, // presentKey, presentValue, // presentValue, inputSeqLen, // inputSeqLen, @@ -417,18 +404,18 @@ class CommonDecoder : public AbstractDecoder { // When attention and FFN/MLP are in parallel, use the initial embedding as input if constexpr (ATTN_MLP_PARALLEL) { if (this->messenger.getSize() > 1) { - this->decoders[i]->forwardFFN(getContext(), embBuf, outBuf, hiddenSize, hiddenSize, true); + decoderBlock->get(i)->forwardFFN(getContext(), embBuf, outBuf, hiddenSize, hiddenSize, true); this->messenger.reduceAdd(outBuf, embBuf, batchSize * inputSeqLen * hiddenSize); } else { - this->decoders[i]->forwardFFN(getContext(), embBuf, embBuf, hiddenSize, hiddenSize, true); + decoderBlock->get(i)->forwardFFN(getContext(), embBuf, embBuf, hiddenSize, hiddenSize, true); } } else { // FFN (for multiple workers, output into outBuf and then reduce add to embBuf) if (this->messenger.getSize() > 1) { - this->decoders[i]->forwardFFN(getContext(), attnOut, outBuf, hiddenSize, hiddenSize, true); + decoderBlock->get(i)->forwardFFN(getContext(), attnOut, outBuf, hiddenSize, hiddenSize, true); this->messenger.reduceAdd(outBuf, embBuf, batchSize * inputSeqLen * hiddenSize); } else { - this->decoders[i]->forwardFFN(getContext(), attnOut, embBuf, hiddenSize, hiddenSize, true); + decoderBlock->get(i)->forwardFFN(getContext(), attnOut, embBuf, hiddenSize, hiddenSize, true); } } } @@ -552,7 +539,8 @@ class CommonDecoder : public AbstractDecoder { // Embedding this->embeddingForward(allInputIds.data(), embBuf, totInputSeqLen); - // TODO: Decoder layers + // Decoder block (all layers) + decoderBlock->forward(ctx, seqs, embBuf, outBuf); // Prepare input for final Layer Norm (only care about the last row of the result) // Shape of embBuf: (bs, seqLen, hiddenSize) @@ -629,14 +617,14 @@ class CommonDecoder : public AbstractDecoder { // Decoder: forward // TODO: Add PIPELINE_PARALLEL feature int hiddenSize = ctx->hiddenSize; - for (int i = 0; i < this->decoders.size(); ++i) { + for (int i = 0; i < this->decoderBlock->size(); ++i) { int workers = this->messenger.getSize(); KVCacheTensor &presentKey = this->kvCacheMgr->getPrefixKey(i); KVCacheTensor &presentValue = this->kvCacheMgr->getPrefixValue(i); // Pls be noted: in attention, 'outBuf' is used as imtermediate buffer, 'tmpBuf' is used as output AttnOutT *attnOut = (AttnOutT *)(this->getContext()->tmpBuf.Data()); - this->decoders[i]->forwardAttention(getContext(), embBuf, outBuf, attnOut, attnMask, + decoderBlock->get(i)->forwardAttention(getContext(), embBuf, outBuf, attnOut, attnMask, presentKey, // presentKey, presentValue, // presentValue, seqLen, // inputSeqLen, @@ -654,18 +642,18 @@ class CommonDecoder : public AbstractDecoder { // When attention and FFN/MLP are in parallel, use the initial embedding as input if constexpr (ATTN_MLP_PARALLEL) { if (this->messenger.getSize() > 1) { - this->decoders[i]->forwardFFN(getContext(), embBuf, outBuf, hiddenSize, hiddenSize, true); + decoderBlock->get(i)->forwardFFN(getContext(), embBuf, outBuf, hiddenSize, hiddenSize, true); this->messenger.reduceAdd(outBuf, embBuf, seqLen * hiddenSize); } else { - this->decoders[i]->forwardFFN(getContext(), embBuf, embBuf, hiddenSize, hiddenSize, true); + decoderBlock->get(i)->forwardFFN(getContext(), embBuf, embBuf, hiddenSize, hiddenSize, true); } } else { // FFN (for multiple workers, output into outBuf and then reduce add to embBuf) if (this->messenger.getSize() > 1) { - this->decoders[i]->forwardFFN(getContext(), attnOut, outBuf, hiddenSize, hiddenSize, true); + decoderBlock->get(i)->forwardFFN(getContext(), attnOut, outBuf, hiddenSize, hiddenSize, true); this->messenger.reduceAdd(outBuf, embBuf, seqLen * hiddenSize); } else { - this->decoders[i]->forwardFFN(getContext(), attnOut, embBuf, hiddenSize, hiddenSize, true); + decoderBlock->get(i)->forwardFFN(getContext(), attnOut, embBuf, hiddenSize, hiddenSize, true); } } } @@ -677,8 +665,8 @@ class CommonDecoder : public AbstractDecoder { // Get decoder context DecoderContext *getContext() { return context.get(); } - // How many layers - int getLayers() { return decoders.size(); } + // How many layers on Duty + int getLayers() { return decoderBlock->size(); } Messenger &getMessenger() { return messenger; } @@ -984,7 +972,7 @@ class CommonDecoder : public AbstractDecoder { int seqLen = ctx->inputSeqLen; int vocabSize = ctx->vocabSize; int maxPositions = ctx->maxPositions; - int layers = this->decoders.size(); + int layers = this->decoderBlock->size(); int workers = this->messenger.getSize(); // Prepare buffers @@ -1074,8 +1062,8 @@ class CommonDecoder : public AbstractDecoder { std::shared_ptr> actBuffers; protected: - // Components most LLMs may use - std::vector decoders; + // Decoder block (all decoder layers) + DecoderBlock *decoderBlock; using LinearWeiT = typename std::conditional, bfloat16_t, float16_t>::type; DistLinear *predictor; From 0e35c8fb45abeb89c99b4768a4d212221ed8d115 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 18/35] [Model] Return seqIDs when set input. (#377) --- include/models.h | 15 +++- src/models/model_factory.h | 22 +++--- src/models/models.cpp | 142 ++++++++++++++++++++++++++++++++++--- 3 files changed, 158 insertions(+), 21 deletions(-) diff --git a/include/models.h b/include/models.h index 65683b6d..f63512b4 100644 --- a/include/models.h +++ b/include/models.h @@ -36,18 +36,29 @@ class Model { void config(SearcherConfig &config_, const std::vector> &stopWordsList_ = {}); - void set_input(std::vector &inputIds_, int batchSize_, int maxLen_ = -1, int numBeams_ = 1, + // Return the sequences' IDs in the order of the input batch + std::vector set_input(std::vector &inputIds_, int batchSize_, int maxLen_ = -1, int numBeams_ = 1, int numBeamHypsToKeep_ = 1, float lenPenalty_ = 1.0, bool doEarlyStopping_ = false, int eosTokenId_ = -1, int padTokenId_ = -1, bool doSample_ = false, float temperature_ = 1.0, int topK_ = 50, float topP_ = 1.0, float repetitionPenalty_ = 1.0, const std::vector> &stopWordsList_ = {}); - void set_input(std::vector &inputIds_, int batchSize_, SearcherConfig &config_, + std::vector set_input(std::vector &inputIds_, int batchSize_, SearcherConfig &config_, const std::vector> &stopWordsList_ = {}); + std::vector set_input(std::vector> &inputIds_, SearcherConfig &config_, + const std::vector> &stopWordsList_ = {}); + + std::vector set_input(std::vector> &inputIds_, int maxLen_ = -1, int numBeams_ = 1, + int numBeamHypsToKeep_ = 1, float lenPenalty_ = 1.0, bool doEarlyStopping_ = false, int eosTokenId_ = -1, + int padTokenId_ = -1, bool doSample_ = false, float temperature_ = 1.0, int topK_ = 50, float topP_ = 1.0, + float repetitionPenalty_ = 1.0, const std::vector> &stopWordsList_ = {}); + bool isDone(); std::tuple forward(bool logits_all = true); + std::tuple forward(const std::vector &seqIDs, bool logits_all = true); + std::vector generate(); void createSearcher(SearcherConfig &config_); diff --git a/src/models/model_factory.h b/src/models/model_factory.h index 2730347d..5407d70a 100644 --- a/src/models/model_factory.h +++ b/src/models/model_factory.h @@ -60,31 +60,31 @@ class DecoderRegister { static DecoderRegister decoder_##CLASS##_##T##_##CacheT( \ #NAME "-" #T "-" #CacheT, [](const std::string &modelPath) { return new CLASS(modelPath); }); -#define REGISTER_HYBRID_MODEL(CLASS, NAME, T1, T2, CacheT) \ - static DecoderRegister hybridModel_##CLASS##_##T1##_##T2##_##CacheT(#NAME "-" #T1 "-" #T2 "-" #CacheT, \ +#define REGISTER_HYBRID_MODEL(CLASS, NAME, T1, T2, CacheT) \ + static DecoderRegister hybridModel_##CLASS##_##T1##_##T2##_##CacheT(#NAME "-" #T1 "-" #T2 "-" #CacheT, \ [](const std::string &modelPath) { return new HybridModel(modelPath); }); #define DECODER_ALL_CACHETYPE(KIND, CLASS, NAME, T) \ - KIND##_DECODER(CLASS, NAME, T, float) \ - KIND##_DECODER(CLASS, NAME, T, float16_t) \ + KIND##_DECODER(CLASS, NAME, T, float) \ + KIND##_DECODER(CLASS, NAME, T, float16_t) \ KIND##_DECODER(CLASS, NAME, T, int8_t) #define HYBRID_MODEL_ALL_CACHETYPE(KIND, CLASS, NAME, T1, T2) \ - KIND##_HYBRID_MODEL(CLASS, NAME, T1, T2, float) \ - KIND##_HYBRID_MODEL(CLASS, NAME, T1, T2, float16_t) \ + KIND##_HYBRID_MODEL(CLASS, NAME, T1, T2, float) \ + KIND##_HYBRID_MODEL(CLASS, NAME, T1, T2, float16_t) \ KIND##_HYBRID_MODEL(CLASS, NAME, T1, T2, int8_t) // Kernels in BF16 PATH not support FP32 KVCache #define DECODER_ALL_TYPE(KIND, CLASS, NAME) \ - KIND##_DECODER(CLASS, NAME, bfloat16_t, float16_t) \ - KIND##_DECODER(CLASS, NAME, bfloat16_t, int8_t) \ + KIND##_DECODER(CLASS, NAME, bfloat16_t, float16_t) \ + KIND##_DECODER(CLASS, NAME, bfloat16_t, int8_t) \ DECODER_ALL_CACHETYPE(KIND, CLASS, NAME, float16_t) \ DECODER_ALL_CACHETYPE(KIND, CLASS, NAME, int8_t) \ DECODER_ALL_CACHETYPE(KIND, CLASS, NAME, w8a8_t) \ DECODER_ALL_CACHETYPE(KIND, CLASS, NAME, uint4x2_t) \ DECODER_ALL_CACHETYPE(KIND, CLASS, NAME, nf4x2_t) -#define HYBRID_MODEL_ALL_TYPE(KIND, CLASS, NAME) \ +#define HYBRID_MODEL_ALL_TYPE(KIND, CLASS, NAME) \ KIND##_HYBRID_MODEL(CLASS, NAME, bfloat16_t, float16_t, float16_t) \ KIND##_HYBRID_MODEL(CLASS, NAME, bfloat16_t, int8_t, float16_t) \ KIND##_HYBRID_MODEL(CLASS, NAME, bfloat16_t, w8a8_t, float16_t) \ @@ -95,8 +95,8 @@ class DecoderRegister { KIND##_HYBRID_MODEL(CLASS, NAME, bfloat16_t, w8a8_t, int8_t) \ KIND##_HYBRID_MODEL(CLASS, NAME, bfloat16_t, uint4x2_t, int8_t) \ KIND##_HYBRID_MODEL(CLASS, NAME, bfloat16_t, nf4x2_t, int8_t) \ - HYBRID_MODEL_ALL_CACHETYPE(KIND, CLASS, NAME, w8a8_t, int8_t) \ - HYBRID_MODEL_ALL_CACHETYPE(KIND, CLASS, NAME, w8a8_t, uint4x2_t) \ + HYBRID_MODEL_ALL_CACHETYPE(KIND, CLASS, NAME, w8a8_t, int8_t) \ + HYBRID_MODEL_ALL_CACHETYPE(KIND, CLASS, NAME, w8a8_t, uint4x2_t) \ HYBRID_MODEL_ALL_CACHETYPE(KIND, CLASS, NAME, w8a8_t, nf4x2_t) // Please implement the model in your header file; diff --git a/src/models/models.cpp b/src/models/models.cpp index 513705fc..326b96c2 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -26,6 +26,7 @@ #include "datatypes.h" #include "gemma.h" #include "hybrid_model.h" +#include "kvcache_mgr.h" #include "llama.h" #include "opt_decoder.h" #include "qwen.h" @@ -75,6 +76,7 @@ void Model::exitSlaves() { // TODO: deprecate the following function void Model::input(std::vector &inputIds_, int batchSize_) { + // TODO: remove new_input flag isNewInput = true; Messenger &messenger = decoder->getMessenger(); int dims[2]; @@ -156,7 +158,7 @@ void syncStopWordsList(std::vector> &stopWordsList) { } } -void Model::set_input(std::vector &inputIds_, int batchSize_, int maxLen_, int numBeams_, +std::vector Model::set_input(std::vector &inputIds_, int batchSize_, int maxLen_, int numBeams_, int numBeamHypsToKeep_, float lenPenalty_, bool doEarlyStopping_, int eosTokenId_, int padTokenId_, bool doSample_, float temperature_, int topK_, float topP_, float repetitionPenalty_, const std::vector> &stopWordsList_) { @@ -173,12 +175,11 @@ void Model::set_input(std::vector &inputIds_, int batchSize_, int maxLe configuration.topP = topP_; configuration.repetitionPenalty = repetitionPenalty_; - this->set_input(inputIds_, batchSize_, configuration, stopWordsList_); + return this->set_input(inputIds_, batchSize_, configuration, stopWordsList_); } -void Model::set_input(std::vector &inputIds_, int batchSize_, SearcherConfig &config_, +std::vector Model::set_input(std::vector &inputIds_, int batchSize_, SearcherConfig &config_, const std::vector> &stopWordsList_) { - // TODO: remove new_input flag if (config_.eosTokenId == -1) { config_.eosTokenId = decoder->getEndId(); } if (config_.padTokenId == -1) { config_.padTokenId = config_.eosTokenId; } SamplingMeta samplingMeta(config_, stopWordsList_); @@ -209,17 +210,110 @@ void Model::set_input(std::vector &inputIds_, int batchSize_, SearcherC seqLen = inputIds_.size() / batchSize_; } + std::vector seqIDs; + SequencePool &seqPool = SequencePool::getInstance(); - InputQueue &inputQueue = InputQueue::getInstance(); + KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); + workingGroup.clear(); for (int i = 0; i < batchSize; i++) { auto group = seqPool.newGroupMeta(inputIds, samplingMeta); - inputQueue.push(group); + workingGroup.push_back(group); + seqIDs.push_back(group->getGroupID()); + // TODO: inin KVCache for beamsearch + kvCacheMgr.addSequence(group->getGroupID()); } + return seqIDs; +} + +std::vector Model::set_input(std::vector> &inputIds_, int maxLen_, int numBeams_, + int numBeamHypsToKeep_, float lenPenalty_, bool doEarlyStopping_, int eosTokenId_, int padTokenId_, + bool doSample_, float temperature_, int topK_, float topP_, float repetitionPenalty_, + const std::vector> &stopWordsList_) { + configuration.maxLen = maxLen_; + configuration.numBeams = numBeams_; + configuration.numBeamHypsToKeep = numBeamHypsToKeep_; + configuration.lenPenalty = lenPenalty_; + configuration.doEarlyStopping = doEarlyStopping_; + configuration.eosTokenId = eosTokenId_; + configuration.padTokenId = padTokenId_; + configuration.doSample = doSample_; + configuration.temperature = temperature_; + configuration.topK = topK_; + configuration.topP = topP_; + configuration.repetitionPenalty = repetitionPenalty_; + + return this->set_input(inputIds_, configuration, stopWordsList_); +} + +std::vector Model::set_input(std::vector> &inputIds_, SearcherConfig &config_, + const std::vector> &stopWordsList_) { + if (config_.eosTokenId == -1) { config_.eosTokenId = decoder->getEndId(); } + if (config_.padTokenId == -1) { config_.padTokenId = config_.eosTokenId; } + SamplingMeta samplingMeta(config_, stopWordsList_); + + Messenger &messenger = Messenger::getInstance(); + + batchSize = inputIds_.size(); + + std::vector seqIDs; + SequencePool &seqPool = SequencePool::getInstance(); + KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); workingGroup.clear(); - while (!inputQueue.empty()) { - workingGroup.push_back(inputQueue.pop()); + + // Sync input and sampling param in distributed mode. + if (messenger.getSize() > 1) { + // [batch size, inputIds size] + std::vector seqLens; + int dims[2]; + if (isMaster()) { + inputIds.clear(); + for (auto &ids : inputIds_) { + seqLens.push_back(ids.size()); + inputIds.insert(inputIds.end(), ids.begin(), ids.end()); + } + dims[0] = batchSize; + dims[1] = inputIds.size(); + } + + messenger.broadcast(dims, 2); + batchSize = dims[0]; + + inputIds.resize(dims[1]); + + messenger.broadcast(seqLens.data(), batchSize); + messenger.broadcast(inputIds.data(), dims[1]); + + messenger.broadcast((int *)&samplingMeta.config, sizeof(SearcherConfig) / sizeof(int)); + + syncStopWordsList(samplingMeta.stopWordsList); + + if (!isMaster()) { + auto it = inputIds.begin(); + for (int i = 0; i < batchSize; i++) { + std::vector input_(it, it + seqLens[i]); + auto group = seqPool.newGroupMeta(input_, samplingMeta); + workingGroup.push_back(group); + seqIDs.push_back(group->getGroupID()); + // TODO: inin KVCache for beamsearch + kvCacheMgr.addSequence(group->getGroupID()); + + it += seqLens[i]; + } + + return seqIDs; + } + } + + for (int i = 0; i < batchSize; i++) { + auto group = seqPool.newGroupMeta(inputIds, samplingMeta); + workingGroup.push_back(group); + seqIDs.push_back(group->getGroupID()); + // TODO: inin KVCache for beamsearch + kvCacheMgr.addSequence(group->getGroupID()); } + + return seqIDs; } // TODO: Deprecate the following function @@ -262,6 +356,13 @@ std::vector Model::finalize() { std::vector seq = x->get(0)->getTotalTokens(); result.insert(result.end(), seq.begin(), seq.end()); } + // Clear KVCache + KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); + for (auto x : workingGroup) { + kvCacheMgr.delSequence(x->getGroupID()); + } + workingGroup.clear(); + return result; } } @@ -292,6 +393,31 @@ std::tuple Model::forward(bool logits_all) { return decoder->forward(workingSeqs, logits_all); } +std::tuple Model::forward(const std::vector &seqIDs, bool logits_all) { + // TODO:Sync IDs in distributed mode. + // Assume that all sequences in the group are all prompts or all decodes. + // Prepare input data for the decoder. + SequencePool &seqPool = SequencePool::getInstance(); + std::vector workingSeqs; + for (auto &x : seqIDs) { + SequenceGroupMeta *group = seqPool.get(x); + if (group == nullptr) { + // TODO: Address error + printf("Sequence ID %d not found.\n", x); + continue; + } + + workingSeqs.push_back(group->get(0)); + if (group->getGroupSize() > 1 && group->getStep() > 1) { + for (int32_t i = 1; i < group->getGroupSize(); i++) { + workingSeqs.push_back(group->get(i)); + } + } + } + + return decoder->forward(workingSeqs, logits_all); +} + // We assume all gen kwargs in the batch are the same // and all sequences are all prompts(step==0) or all decodes(step>0) std::vector Model::generate() { From f441906a19557a70ecc55820a5384af35ec5b471 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 19/35] [Framework] Code fix to make new path for CB work (#379) --- src/common/datatypes.h | 29 +++++++++++++++++++++- src/common/kvcache_mgr.h | 48 +++++++++++++++++++++++++------------ src/layers/decoder_block.h | 2 +- src/models/common_decoder.h | 32 +++++++++++++++++++------ 4 files changed, 87 insertions(+), 24 deletions(-) diff --git a/src/common/datatypes.h b/src/common/datatypes.h index 12f8f7b9..e9df1ec6 100644 --- a/src/common/datatypes.h +++ b/src/common/datatypes.h @@ -22,7 +22,7 @@ #include "uint4x2.h" namespace xft { -std::string getTypeIdName(xft::DataType dtype) { +inline std::string getTypeIdName(xft::DataType dtype) { switch (dtype) { case xft::DataType::fp32: return "float"; case xft::DataType::bf16: return "bfloat16_t"; @@ -43,4 +43,31 @@ std::string getTypeIdName(xft::DataType dtype) { } return std::string("unknown"); } + +// Get DataType according to c++ types +template +inline DataType getDataType() { + static_assert(sizeof(T) == 0, "Unsupported type"); + return DataType::unknown; +} + +template <> +inline DataType getDataType() { + return DataType::fp32; +} + +template <> +inline DataType getDataType() { + return DataType::bf16; +} + +template <> +inline DataType getDataType() { + return DataType::fp16; +} + +template <> +inline DataType getDataType() { + return int8; +} } // namespace xft diff --git a/src/common/kvcache_mgr.h b/src/common/kvcache_mgr.h index 98f115a0..15af5b4e 100644 --- a/src/common/kvcache_mgr.h +++ b/src/common/kvcache_mgr.h @@ -24,7 +24,7 @@ class KVCacheMgrImplBase { public: virtual ~KVCacheMgrImplBase() = default; virtual bool delSequence(int seqID) = 0; - virtual bool addSequence(int seqID, int prefixId = -1) = 0; + virtual bool addSequence(int seqID, int maxSeqLen = -1, int prefixId = -1) = 0; virtual bool reorderCache(const std::vector &seqIDs, const std::vector &prevSeqIDs) = 0; virtual bool addPrefix(int prefixId, int seqID) = 0; virtual bool prepareCache(const std::vector &seqIDs) = 0; @@ -35,20 +35,25 @@ class KVCacheMgrImplBase { template class KVCacheMgrImpl : public KVCacheMgrImplBase { public: - KVCacheMgrImpl(int layers) { this->layers = layers; } + KVCacheMgrImpl(int maxSeqLen, int headNum, int headSize, int layers) { + this->maxSeqLen_ = maxSeqLen; + this->headNum_ = headNum; + this->headSize_ = headSize; + this->layers_ = layers; + } ~KVCacheMgrImpl() { // Free resource in cachePool (readyCaches are in cachePool too) for (auto &it : sequenceCaches) { - delete it.second; + delete[] it.second; } // Free resource in prefixCaches for (auto &it : prefixCaches) { - delete it.second; + delete[] it.second; } // Free resource in freeCaches for (auto &it : freeCaches) { - delete it; + delete[] it; } } @@ -67,7 +72,7 @@ class KVCacheMgrImpl : public KVCacheMgrImplBase { return true; } - bool addSequence(int seqID, int prefixId = -1) override { + bool addSequence(int seqID, int maxSeqLen = -1, int prefixId = -1) override { // Fail if already exist if (sequenceCaches.find(seqID) != sequenceCaches.end()) { return false; } @@ -77,7 +82,13 @@ class KVCacheMgrImpl : public KVCacheMgrImplBase { cache = freeCaches.back(); freeCaches.pop_back(); } else { - cache = new KVCacheTensor[2 * layers]; + cache = new KVCacheTensor[2 * layers_]; + } + + // User specified maxSeqLen needs to be <= model's configured maxSeqLen + auto maxLen = maxSeqLen > 0 ? std::min(maxSeqLen, maxSeqLen_) : maxSeqLen_; + for (int i = 0; i < 2 * layers_; ++i) { + cache[i].resize(maxLen, 1, headNum_, headSize_); } sequenceCaches.insert({seqID, cache}); @@ -101,9 +112,9 @@ class KVCacheMgrImpl : public KVCacheMgrImplBase { if (sequenceCaches.find(seqID) == sequenceCaches.end()) { return false; } // Create a new one - KVCacheTensor *cache = new KVCacheTensor[2 * layers]; + KVCacheTensor *cache = new KVCacheTensor[2 * layers_]; - for (int i = 0; i < 2 * layers; i++) { + for (int i = 0; i < 2 * layers_; i++) { // TODO: add from method in KVCacheTensor //cache[i].from(sequenceCaches[seqID][i]); } @@ -168,7 +179,10 @@ class KVCacheMgrImpl : public KVCacheMgrImplBase { // List of pending free caches, each element is for a sample std::vector *> freeCaches; - int layers; + int maxSeqLen_; + int headNum_; + int headSize_; + int layers_; }; class KVCacheMgr { @@ -178,17 +192,21 @@ class KVCacheMgr { return inst; } - void configure(int layers, DataType dataType) { + void configure(int maxSeqLen, int headNum, int headSize, int layers, DataType dataType) { switch (dataType) { - case DataType::int8: cacheMgrImpl = new KVCacheMgrImpl(layers); break; - case DataType::fp16: cacheMgrImpl = new KVCacheMgrImpl(layers); break; - default: cacheMgrImpl = new KVCacheMgrImpl(layers); break; + case DataType::int8: cacheMgrImpl = new KVCacheMgrImpl(maxSeqLen, headNum, headSize, layers); break; + case DataType::fp16: + cacheMgrImpl = new KVCacheMgrImpl(maxSeqLen, headNum, headSize, layers); + break; + default: cacheMgrImpl = new KVCacheMgrImpl(maxSeqLen, headNum, headSize, layers); break; } } bool delSequence(int seqID) { return cacheMgrImpl->delSequence(seqID); } - bool addSequence(int seqID, int prefixId = -1) { return cacheMgrImpl->addSequence(seqID, prefixId); } + bool addSequence(int seqID, int maxSeqLen = -1, int prefixId = -1) { + return cacheMgrImpl->addSequence(seqID, maxSeqLen, prefixId); + } bool reorderCache(const std::vector &seqIDs, const std::vector &prevSeqIDs) { return cacheMgrImpl->reorderCache(seqIDs, prevSeqIDs); diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h index 0a454ce6..a3463e5c 100644 --- a/src/layers/decoder_block.h +++ b/src/layers/decoder_block.h @@ -77,7 +77,7 @@ class DecoderBlock { } // TODO: check and prepare KV cache only needed - // kvCacheMgr.prepareCache(seqIDs); + kvCacheMgr.prepareCache(seqIDs); // All layers forward int layersOnDuty = this->decoders.size(); diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 177a90cb..4f1fb59a 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -22,6 +22,7 @@ #include "INIReader.h" #include "abstract_decoder.h" #include "attention.h" +#include "datatypes.h" #include "debugger.h" #include "decoder_block.h" #include "decoder_layer.h" @@ -168,7 +169,7 @@ class CommonDecoder : public AbstractDecoder { const int attHeadNum = reader.GetInteger(modelType, "head_num"); // Use the same head number for the default multi-head attention const int kvHeadNum = reader.GetInteger(modelType, "kv_head_num", attHeadNum); - const int size_per_head = reader.GetInteger(modelType, "size_per_head"); + const int headSize = reader.GetInteger(modelType, "size_per_head"); const int imSize = reader.GetInteger(modelType, "inter_size"); const int layers = reader.GetInteger(modelType, "num_layer"); const int vocabSize = reader.GetInteger(modelType, "vocab_size"); @@ -180,7 +181,7 @@ class CommonDecoder : public AbstractDecoder { const int maxSeqLength = reader.GetInteger(modelType, "seq_length", -1); const bool useLogN = reader.GetInteger(modelType, "use_logn_attn", true); const bool useNTK = reader.GetInteger(modelType, "use_dynamic_ntk", true); - const int hiddenSize = reader.GetInteger(modelType, "hidden_size", attHeadNum * size_per_head); + const int hiddenSize = reader.GetInteger(modelType, "hidden_size", attHeadNum * headSize); const int embeddingSize = hiddenSize; const int multi_query_group_num = reader.GetInteger(modelType, "multi_query_group_num", attHeadNum); const float epsilon = reader.GetFloat(modelType, "layernorm_eps", 1e-6); @@ -225,7 +226,7 @@ class CommonDecoder : public AbstractDecoder { actBuffers.reset(new xft::Matrix()); // Context - DecoderContext *ctx = getDecoderContext(layers, hiddenSize, size_per_head, attHeadNum, kvHeadNum, imSize, act, + DecoderContext *ctx = getDecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, useLogN, useNTK, ropeParamsPtr); @@ -239,6 +240,8 @@ class CommonDecoder : public AbstractDecoder { } decoderBlock = new DecoderBlock(ctx, modelPath, layers, dt); + auto maxSeqLen = maxSeqLength > 0 ? maxSeqLength : maxPositions; + KVCacheMgr::instance().configure(maxSeqLen, kvHeadNum, headSize, layers, getDataType()); // Predictor int workers = messenger.getSize(); @@ -517,6 +520,12 @@ class CommonDecoder : public AbstractDecoder { TimeLine t("Decoder.forward"); TimeLine t1("Decoder.embedding"); + if (unlikely(seqs.empty())) { return std::tuple(nullptr, 0, 0); } + + DecoderContext *ctx = this->getContext(); + int batchSize = seqs.size(); + int hiddenSize = ctx->hiddenSize; + // Prepare input int totInputSeqLen = 0; std::vector allInputIds; @@ -527,11 +536,11 @@ class CommonDecoder : public AbstractDecoder { } // Prepare context - DecoderContext *ctx = this->getContext(); ctx->resize(totInputSeqLen); - int batchSize = seqs.size(); - int hiddenSize = ctx->hiddenSize; + // Prepare buffers + int logitRows = (!logitsAll && seqs[0]->getStep() == 0) ? seqs.size() : totInputSeqLen; + prepareBuffer(ctx, totInputSeqLen, logitRows); AttnInT *embBuf = (AttnInT *)actBuffers->Data(); MlpOutT *outBuf = (MlpOutT *)(embBuf + totInputSeqLen * hiddenSize); @@ -545,7 +554,6 @@ class CommonDecoder : public AbstractDecoder { // Prepare input for final Layer Norm (only care about the last row of the result) // Shape of embBuf: (bs, seqLen, hiddenSize) MlpOutT *lnIn = embBuf; - auto logitRows = totInputSeqLen; if (!logitsAll) { // TODO: copy needed data } @@ -997,6 +1005,16 @@ class CommonDecoder : public AbstractDecoder { ctx->attHeadSize, prefix); } + void prepareBuffer(DecoderContext *ctx, int totInputSeqLen, int logitRows) { + int hiddenSize = ctx->hiddenSize; + int vocabSize = ctx->vocabSize; + + // Convert final output buffer size into units of hiddenSize + int outRows = std::ceil(logitRows * vocabSize / hiddenSize); + + this->actBuffers->Resize(totInputSeqLen + outRows, hiddenSize); + } + float *getAttnMask(int sizeRequired) { if (this->maskSize < sizeRequired) { if (this->attnMask) free(this->attnMask); From f9bfb4958d3dc7bed42e543dc37aa842ccfa30cd Mon Sep 17 00:00:00 2001 From: marvinYu Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 20/35] [Layer] update mlp for CB. (#384) --- src/layers/decoder_block.h | 4 ++-- src/layers/decoder_layer.h | 4 ++-- src/layers/mlp_llama.h | 5 +++-- src/layers/mlp_standard.h | 5 +++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h index a3463e5c..c310a585 100644 --- a/src/layers/decoder_block.h +++ b/src/layers/decoder_block.h @@ -114,10 +114,10 @@ class DecoderBlock { std::exit(-1); } else { if (messenger.getSize() > 1) { - this->decoders[i]->forwardFFN(ctx, attnOut, output, ctx->hiddenSize, ctx->hiddenSize, true); + this->decoders[i]->forwardFFN(ctx, attnOut, output, ctx->hiddenSize, ctx->hiddenSize, true, totInSeqLen); messenger.reduceAdd(output, output, totInSeqLen * ctx->hiddenSize); } else { - this->decoders[i]->forwardFFN(ctx, attnOut, output, ctx->hiddenSize, ctx->hiddenSize, true); + this->decoders[i]->forwardFFN(ctx, attnOut, output, ctx->hiddenSize, ctx->hiddenSize, true, totInSeqLen); } } } diff --git a/src/layers/decoder_layer.h b/src/layers/decoder_layer.h index 8134c594..9a44b13f 100644 --- a/src/layers/decoder_layer.h +++ b/src/layers/decoder_layer.h @@ -114,9 +114,9 @@ class Decoder { } template - void forwardFFN(DecoderContext *ctx, InT *input, OutT *output, int iStride, int oStride, bool doLnBefore = true) { + void forwardFFN(DecoderContext *ctx, InT *input, OutT *output, int iStride, int oStride, bool doLnBefore = true, int totInSeqLen = 0) { TimeLine t("Decoder.forwardFFN"); - mlp.forward(ctx, input, output, iStride, oStride, doLnBefore); + mlp.forward(ctx, input, output, iStride, oStride, doLnBefore, totInSeqLen); } private: diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 06c7a1ef..8a3bda34 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -112,9 +112,10 @@ class LlamaMLP : public SingletonBase> { // Forward for FFN (Feed Forward Network) void forward(DecoderContext *ctx, InT *input, OutT *output, int iStride, int oStride, - bool doLnBefore = true /*not used*/) { + bool doLnBefore = true /*not used*/, int totInSeqLen = 0) { TimeLine t("LlamaMLP"); - const int M = ctx->batchSize * ctx->inputSeqLen; + + const int M = totInSeqLen == 0 ? ctx->batchSize * ctx->inputSeqLen : totInSeqLen; const int hiddenSize = ctx->hiddenSize; static_assert(sizeof(ctx->normBuf.Data()[0]) >= sizeof(ImT), "normBuff is not big enough!"); diff --git a/src/layers/mlp_standard.h b/src/layers/mlp_standard.h index 5f26b069..0e9c6123 100644 --- a/src/layers/mlp_standard.h +++ b/src/layers/mlp_standard.h @@ -76,9 +76,10 @@ class MLP { #endif // Forward for FFN (Feed Forward Network) - void forward(DecoderContext *ctx, float *input, float *output, int iStride, int oStride, bool doLnBefore) { + void forward(DecoderContext *ctx, float *input, float *output, int iStride, int oStride, bool doLnBefore, + int totInSeqLen = 0) { TimeLine t("StandardMLP"); - int M = ctx->batchSize * ctx->inputSeqLen; + int M = totInSeqLen == 0 ? ctx->batchSize * ctx->inputSeqLen : totInSeqLen; xft::Matrix outBuffer(output, M, ctx->hiddenSize, ctx->hiddenSize); auto &resultBuffer1 = outBuffer; From 3f232c59a4b7304bc1dd178248f5ab2ccc6d9900 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:45:07 +0800 Subject: [PATCH 21/35] [Framework] Update set_input for cb. (#381) --- include/models.h | 14 ++++- src/common/kvcache_mgr.h | 5 ++ src/common/sequence.h | 20 ++++++ src/models/models.cpp | 133 +++++++++++++++++++++++++++++++-------- 4 files changed, 145 insertions(+), 27 deletions(-) diff --git a/include/models.h b/include/models.h index f63512b4..56cf2121 100644 --- a/include/models.h +++ b/include/models.h @@ -53,12 +53,17 @@ class Model { int padTokenId_ = -1, bool doSample_ = false, float temperature_ = 1.0, int topK_ = 50, float topP_ = 1.0, float repetitionPenalty_ = 1.0, const std::vector> &stopWordsList_ = {}); + std::vector set_input(std::vector> &inputIds_, std::vector seqIDs, + SearcherConfig &config_, const std::vector> &stopWordsList_ = {}); + + // Only used for model.forward() + std::vector set_input( + std::vector> &inputIds_, std::vector seqIDs = {}, int maxLen = -1); + bool isDone(); std::tuple forward(bool logits_all = true); - std::tuple forward(const std::vector &seqIDs, bool logits_all = true); - std::vector generate(); void createSearcher(SearcherConfig &config_); @@ -75,6 +80,10 @@ class Model { int getVocabSize() { return this->vocabSize; } + void initMaxSeqLen(); + + int getMaxSeqLen() { return maxSeqLen; } + SearcherConfig getConfig() { return configuration; } void setDecoder(AbstractDecoder *dec); @@ -96,6 +105,7 @@ class Model { int batchSize; int seqLen; int vocabSize; + int maxSeqLen; SearcherConfig configuration; bool isNewInput; std::vector workingGroup; diff --git a/src/common/kvcache_mgr.h b/src/common/kvcache_mgr.h index 15af5b4e..c2fa271b 100644 --- a/src/common/kvcache_mgr.h +++ b/src/common/kvcache_mgr.h @@ -28,6 +28,7 @@ class KVCacheMgrImplBase { virtual bool reorderCache(const std::vector &seqIDs, const std::vector &prevSeqIDs) = 0; virtual bool addPrefix(int prefixId, int seqID) = 0; virtual bool prepareCache(const std::vector &seqIDs) = 0; + virtual bool exist(int seqID) const = 0; virtual std::vector getKey(int layerId) = 0; virtual std::vector getValue(int layerId) = 0; }; @@ -160,6 +161,8 @@ class KVCacheMgrImpl : public KVCacheMgrImplBase { return valueCaches; } + bool exist(int seqID) const override { return sequenceCaches.find(seqID) != sequenceCaches.end(); } + private: // seqID -> pointer to an array of caches (each element is a KVCacheTensor, size=2*layers) // Layout of each array is: @@ -220,6 +223,8 @@ class KVCacheMgr { std::vector getValue(int layerId) { return cacheMgrImpl->getValue(layerId); } + bool exist(int seqID) const { return cacheMgrImpl->exist(seqID); } + private: KVCacheMgrImplBase *cacheMgrImpl; diff --git a/src/common/sequence.h b/src/common/sequence.h index f34bb08f..211b69ed 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -199,6 +199,14 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } + SequenceGroupMeta(int32_t _inputSeqLen) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_inputSeqLen)); + } + groupID = sequences[0].getSequenceID(); + } + int32_t getGroupID() { return groupID; } int32_t getGroupSize() { return samplingMeta.config.numBeams; } @@ -252,6 +260,18 @@ class SequencePool { return group; } + SequenceGroupMeta *newGroupMeta(std::vector &inputTokens) { + auto *group = new SequenceGroupMeta(inputTokens); + this->add(group); + return group; + } + + SequenceGroupMeta *newGroupMeta(int32_t inputSeqLen) { + auto *group = new SequenceGroupMeta(inputSeqLen); + this->add(group); + return group; + } + bool add(SequenceGroupMeta *sequenceGroup, bool force = false) { int32_t groupID = sequenceGroup->getGroupID(); bool isSuccess = false; diff --git a/src/models/models.cpp b/src/models/models.cpp index 326b96c2..40366993 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -66,6 +66,11 @@ Model::~Model() { if (searcher != nullptr) { delete searcher; } } +void Model::initMaxSeqLen() { + DecoderContext *ctx = decoder->getContext(); + this->maxSeqLen = ctx->maxSeqLength > 0 ? ctx->maxSeqLength : ctx->maxPositions; +} + void Model::exitSlaves() { if (decoder->getRank() == 0) { configuration.numBeams = 0; @@ -316,6 +321,107 @@ std::vector Model::set_input(std::vector> &inputIds_, return seqIDs; } +std::vector Model::set_input(std::vector> &inputIds_, std::vector seqIDs, + SearcherConfig &config_, const std::vector> &stopWordsList_) { + if (config_.eosTokenId == -1) { config_.eosTokenId = decoder->getEndId(); } + if (config_.padTokenId == -1) { config_.padTokenId = config_.eosTokenId; } + config_.maxLen = std::min(config_.maxLen, this->maxSeqLen); + + SamplingMeta samplingMeta(config_, stopWordsList_); + + Messenger &messenger = Messenger::getInstance(); + + batchSize = inputIds_.size(); + + SequencePool &seqPool = SequencePool::getInstance(); + KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); + workingGroup.clear(); + + // Sync input and sampling param in distributed mode. + if (messenger.getSize() > 1) { + // TODO: Sync + } + + if (seqIDs.empty()) { + // Prompt(1st token) + // Create seq meta for inputs and return seq IDs + for (int i = 0; i < batchSize; i++) { + auto group = seqPool.newGroupMeta(inputIds_[i], samplingMeta); + workingGroup.push_back(group); + seqIDs.push_back(group->getGroupID()); + kvCacheMgr.addSequence(group->getGroupID(), config_.maxLen); + } + } else { + // Decode(next token) + // Update seq meta with inputs and return seq IDs + if (inputIds_.size() != seqIDs.size()) { + printf("[ERROR] Input size and seqIDs size mismatch.\n"); + exit(-1); + } + for (int i = 0; i < batchSize; i++) { + auto group = seqPool.get(seqIDs[i]); + if (group == nullptr) { + // TODO: Address beam search case. + printf("[ERROR] Sequence ID %d not found.\n", seqIDs[i]); + exit(-1); + } + workingGroup.push_back(group); + if (!kvCacheMgr.exist(seqIDs[i])) { + printf("[ERROR] Sequence ID %d not found in KVCache.\n", seqIDs[i]); + exit(-1); + } + } + } + return seqIDs; +} + +std::vector Model::set_input(std::vector> &inputIds_, std::vector seqIDs, int maxLen) { + Messenger &messenger = Messenger::getInstance(); + SequencePool &seqPool = SequencePool::getInstance(); + KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); + workingGroup.clear(); + batchSize = inputIds_.size(); + + maxLen = std::min(maxLen, this->maxSeqLen); + + if (messenger.getSize() > 1) { + // TODO: Sync input and sampling param in distributed mode. + // [batch_size, total_length, seqID_size, maxLen] + } + if (seqIDs.empty()) { + // Prompt(1st token) + // Create seq meta for inputs and return seq IDs + for (int i = 0; i < batchSize; i++) { + auto group = seqPool.newGroupMeta(inputIds_[i]); + workingGroup.push_back(group); + seqIDs.push_back(group->getGroupID()); + kvCacheMgr.addSequence(group->getGroupID(), maxLen); + } + } else { + // Decode(next token) + // Update seq meta with inputs and return seq IDs + if (inputIds_.size() != seqIDs.size()) { + printf("[ERROR] Input size and seqIDs size mismatch.\n"); + exit(-1); + } + for (int i = 0; i < batchSize; i++) { + auto group = seqPool.get(seqIDs[i]); + if (group == nullptr) { + // TODO: Address beam search case. + printf("[ERROR] Sequence ID %d not found.\n", seqIDs[i]); + exit(-1); + } + group->get(0)->stepForward(inputIds_[i][0]); + workingGroup.push_back(group); + if (!kvCacheMgr.exist(seqIDs[i])) { + printf("[ERROR] Sequence ID %d not found in KVCache.\n", seqIDs[i]); + exit(-1); + } + } + } + return seqIDs; +} + // TODO: Deprecate the following function void Model::config(SearcherConfig &config_, const std::vector> &stopWordsList_) { isNewInput = true; @@ -393,31 +499,6 @@ std::tuple Model::forward(bool logits_all) { return decoder->forward(workingSeqs, logits_all); } -std::tuple Model::forward(const std::vector &seqIDs, bool logits_all) { - // TODO:Sync IDs in distributed mode. - // Assume that all sequences in the group are all prompts or all decodes. - // Prepare input data for the decoder. - SequencePool &seqPool = SequencePool::getInstance(); - std::vector workingSeqs; - for (auto &x : seqIDs) { - SequenceGroupMeta *group = seqPool.get(x); - if (group == nullptr) { - // TODO: Address error - printf("Sequence ID %d not found.\n", x); - continue; - } - - workingSeqs.push_back(group->get(0)); - if (group->getGroupSize() > 1 && group->getStep() > 1) { - for (int32_t i = 1; i < group->getGroupSize(); i++) { - workingSeqs.push_back(group->get(i)); - } - } - } - - return decoder->forward(workingSeqs, logits_all); -} - // We assume all gen kwargs in the batch are the same // and all sequences are all prompts(step==0) or all decodes(step>0) std::vector Model::generate() { @@ -607,5 +688,7 @@ AutoModel::AutoModel(std::string modelPath, xft::DataType dataType, xft::DataTyp printf("Unsupported data type or KV cache data type.\n"); exit(-1); } + + initMaxSeqLen(); } } // namespace xft From 6625b015349b30afffa969608678f75c6c386058 Mon Sep 17 00:00:00 2001 From: "Meng,Chen" Date: Sat, 11 May 2024 11:09:35 +0800 Subject: [PATCH 22/35] [Layers] Added RotaryEmbedding forward for cb mode & Fixed rope uts (#383) --- ci_build | 7 - src/kernels/rotary_embedding_kernels.cpp | 135 ++++++++++++++++++++ src/kernels/rotary_embedding_kernels.h | 30 +++++ src/layers/attention.h | 33 ++--- src/layers/rope_2d.cpp | 6 + src/layers/rope_2d.h | 2 + src/layers/rotary_embedding.cpp | 61 ++++----- src/layers/rotary_embedding.h | 8 +- src/layers/rotary_embedding_chatglm2.cpp | 12 ++ src/layers/rotary_embedding_chatglm2.h | 5 + src/layers/rotary_embedding_qwen.cpp | 12 ++ src/layers/rotary_embedding_qwen.h | 5 + src/layers/yarn_scaled_rotary_embedding.cpp | 87 ++++--------- src/layers/yarn_scaled_rotary_embedding.h | 12 ++ src/models/common_decoder.h | 10 +- tests/ut/rotary_embedding_test.cpp | 61 ++++++--- 16 files changed, 342 insertions(+), 144 deletions(-) create mode 100644 src/kernels/rotary_embedding_kernels.cpp create mode 100644 src/kernels/rotary_embedding_kernels.h diff --git a/ci_build b/ci_build index 921feb33..073f9fbb 100755 --- a/ci_build +++ b/ci_build @@ -84,13 +84,6 @@ ut() { for file in ./*; do if [ -x "$file" ]; then - #Todo(marvin): delete me when the case is ready. - if [[ "$file" == "./rotary_embedding_test" ]]; then - Warning "Bypass the fail case of $file." - continue - fi - ################################################## - if [[ "$file" != *_test ]]; then Warning "$file is not ending with '_test', skip current loop." continue diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp new file mode 100644 index 00000000..08c516c5 --- /dev/null +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include "rotary_embedding_kernels.h" +#include "intrinsics_util.h" + +namespace xft { + +void llamaSetCosSinCache( + const float *invFreq, float *embCos, float *embSin, int invFreqSize, int maxPositionEmbeddings, float scale) { + +#pragma omp parallel for + for (size_t i = 0; i < maxPositionEmbeddings; i++) { + float *pcos = embCos + i * invFreqSize; + float *psin = embSin + i * invFreqSize; + + for (size_t j = 0; j < invFreqSize; j++) { + float tmp = i * invFreq[j]; + float cosTmp = std::cos(tmp) * scale; + float sinTmp = std::sin(tmp) * scale; + + pcos[j] = cosTmp; + psin[j] = sinTmp; + } + } +} + +// def rotate_half(x): +// """Rotates half the hidden dims of the input.""" +// x1 = x[..., : x.shape[-1] // 2] +// x2 = x[..., x.shape[-1] // 2 :] +// return torch.cat((-x2, x1), dim=-1) +// def apply_rotary_pos_emb(q, k, cos, sin, position_ids): +// # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. +// cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] +// sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] +// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// q_embed = (q * cos) + (rotate_half(q) * sin) +// k_embed = (k * cos) + (rotate_half(k) * sin) +// return q_embed, k_embed +// + +void llamaApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + const int half = (dim + 1) / 2; + const int heads = std::max(qHeads, kHeads); + +#pragma omp parallel for collapse(2) + for (int head = 0; head < heads; ++head) { + for (int seq = 0; seq < totSeqLen; ++seq) { + int pos = positionIds[seq]; + float *pcos = emb_cos + pos * half; + float *psin = emb_sin + pos * half; + + float *q = query + seq * qStride + head * dim; + float *k = key + seq * kStride + head * dim; + +#pragma omp simd + for (int i = 0; i < half; ++i) { + if (head < qHeads) { + auto q1 = q[i]; + q[i] = q1 * pcos[i] - q[i + half] * psin[i]; + q[i + half] = q[i + half] * pcos[i] + q1 * psin[i]; + } + if (head < kHeads) { + auto k1 = k[i]; + k[i] = k1 * pcos[i] - k[i + half] * psin[i]; + k[i + half] = k[i + half] * pcos[i] + k1 * psin[i]; + } + } + } + } +} + +void llamaApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + const int half = (dim + 1) / 2; + const int heads = std::max(qHeads, kHeads); + +#pragma omp parallel for collapse(2) + for (int head = 0; head < heads; ++head) { + for (int seq = 0; seq < totSeqLen; ++seq) { + int pos = positionIds[seq]; + float *pcos = emb_cos + pos * half; + float *psin = emb_sin + pos * half; + + bfloat16_t *q = query + seq * qStride + head * dim; + bfloat16_t *k = key + seq * kStride + head * dim; + + // Process chunks of 16 elements at a time + for (int i = 0; i < half; i += 16) { + int remain = half - i; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 pCosVec = _mm512_maskz_loadu_ps(mask, &pcos[i]); + __m512 pSinVec = _mm512_maskz_loadu_ps(mask, &psin[i]); + + // Compute something like: + // q[i] = q[i] * pcos[i] - q[i + half] * psin[i]; + // q[i + half] = q[i + half] * pcos[i] + q[i] * psin[i]; + if (head < qHeads) { + __m512 qVec = xft::load_avx512(mask, &q[i]); + __m512 qHalfVec = xft::load_avx512(mask, &q[i + half]); + __m512 qNew = _mm512_fmsub_ps(qVec, pCosVec, _mm512_mul_ps(qHalfVec, pSinVec)); + __m512 qHalfNew = _mm512_fmadd_ps(qHalfVec, pCosVec, _mm512_mul_ps(qVec, pSinVec)); + xft::store_avx512(&q[i], mask, qNew); + xft::store_avx512(&q[i + half], mask, qHalfNew); + } + + if (head < kHeads) { + __m512 kVec = xft::load_avx512(mask, &k[i]); + __m512 kHalfVec = xft::load_avx512(mask, &k[i + half]); + __m512 kNew = _mm512_fmsub_ps(kVec, pCosVec, _mm512_mul_ps(kHalfVec, pSinVec)); + __m512 kHalfNew = _mm512_fmadd_ps(kHalfVec, pCosVec, _mm512_mul_ps(kVec, pSinVec)); + xft::store_avx512(&k[i], mask, kNew); + xft::store_avx512(&k[i + half], mask, kHalfNew); + } + } + } + } +} + +} // namespace xft diff --git a/src/kernels/rotary_embedding_kernels.h b/src/kernels/rotary_embedding_kernels.h new file mode 100644 index 00000000..70872ebb --- /dev/null +++ b/src/kernels/rotary_embedding_kernels.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +#include "bfloat16.h" + +namespace xft { + +void llamaSetCosSinCache(const float *invFreq, float *embCos, float *embSin, int invFreqSize, + int max_position_embeddings = 2048, float scale = 1.0); + +void llamaApplyRotaryPosEmbed(float *query, float *key, float *embCos, float *embSin, int qStride, int kStride, int dim, + int totSeqLen, int qHeads, int kHeads, const int *positionIds); + +void llamaApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); + +} // namespace xft diff --git a/src/layers/attention.h b/src/layers/attention.h index 7484aa36..d0c690d4 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -470,22 +470,23 @@ class Attention { // Apply post operations on query and key TimeLine t3("QKPO"); - // TODO: call into rotary embedding - // int qheads = this->endQHead - this->startQHead; - // int kheads = this->endKVHead - this->startKVHead; - // int qkShape[7] = {ctx->batchSize, ctx->inputSeqLen, qheads, headSize, kheads, ctx->maxSeqLength, pastSeqLen}; - // if (positionIds != nullptr) { - // qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, positionIds); - // } else if (ctx->maxPosEmbed > 0) { - // // Use the default position ids - // std::vector posIds(ctx->inputSeqLen); - // if (inputSeqLen == 1) { - // posIds[0] = pastSeqLen; - // } else { - // std::iota(posIds.begin(), posIds.end(), pastSeqLen); - // } - // qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, posIds.data()); - // } + if (ctx->maxPosEmbed > 0) { + int qheads = this->endQHead - this->startQHead; + int kheads = this->endKVHead - this->startKVHead; + int totInputSeqLen = 0; + for (auto seq : seqs) { + totInputSeqLen += seq->getInputSeqLen(); + } + // Use the default position ids + std::vector posIds(totInputSeqLen); + int loc = 0; + for (auto seq : seqs) { + std::iota(posIds.begin() + loc, posIds.begin() + loc + seq->getInputSeqLen(), seq->getPastSeqLen()); + loc += seq->getInputSeqLen(); + } + qkpo.forward(query.Data(), key.Data(), totInputSeqLen, query.Stride(), key.Stride(), qheads, kheads, + posIds.data()); + } t3.release(); #ifdef DEBUG diff --git a/src/layers/rope_2d.cpp b/src/layers/rope_2d.cpp index 92bada1e..9dbe241c 100644 --- a/src/layers/rope_2d.cpp +++ b/src/layers/rope_2d.cpp @@ -185,3 +185,9 @@ void RotaryEmbedding2D::forward( } // end bs } // end head } + +void RotaryEmbedding2D::forward( + float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { + printf("Unsupported RotaryEmbedding2D in cb mode!\n"); + exit(1); +} diff --git a/src/layers/rope_2d.h b/src/layers/rope_2d.h index 0b8c63c6..c12c70bc 100644 --- a/src/layers/rope_2d.h +++ b/src/layers/rope_2d.h @@ -26,6 +26,8 @@ class RotaryEmbedding2D { ~RotaryEmbedding2D() {} void forward(float *query, float *key, int qStride, int kStride, const int *qk_shape, const int *positions); + void forward(float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); private: void prepareEmbedding(); diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index 7241ccd3..63a35bbe 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -27,8 +27,7 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { ctx->GetAttr("rope_theta", &this->base, 10000); ctx->GetAttr("rope_type", &this->rope_type, std::to_string(-1)); - if (this->rope_type == "linear") - ctx->GetAttr("scaling_factor", &this->scaling_factor, 1.0f); + if (this->rope_type == "linear") ctx->GetAttr("scaling_factor", &this->scaling_factor, 1.0f); inv_freq_size = (dim + 1) / 2; @@ -37,10 +36,12 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { if (!ctx->cached(inv_freq_str)) { inv_freq = ctx->getBuffer(inv_freq_str, inv_freq_size); + for (size_t i = 0; i < inv_freq_size; i++) { inv_freq[i] = 1.0 / pow(base, float(i * 2) / dim); + inv_freq[i] /= this->scaling_factor; } - llamaCalEmb(inv_freq, max_position_embeddings); + xft::llamaSetCosSinCache(inv_freq, emb_cos, emb_sin, inv_freq_size, max_position_embeddings); } else if (dim != inv_freq_size * 2) { printf("Incorrect dim=%d, inv_freq_size=%d\n", dim, inv_freq_size); exit(-1); @@ -48,42 +49,20 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { } // This API is deprecated, will delete after all rotary embed code refactor. -LlamaRotaryEmbedding::LlamaRotaryEmbedding(const int dim, const int max_position_embeddings, const float base) {} - -void LlamaRotaryEmbedding::llamaCalEmb(const float *inv_freq, const int max_position_embeddings) { -#pragma omp parallel for - for (size_t i = 0; i < max_position_embeddings; i++) { - float *pcos = emb_cos + i * inv_freq_size; - float *psin = emb_sin + i * inv_freq_size; - - for (size_t j = 0; j < inv_freq_size; j++) { - float tmp = i * inv_freq[j] / this->scaling_factor; - float cos_tmp = std::cos(tmp); - float sin_tmp = std::sin(tmp); +LlamaRotaryEmbedding::LlamaRotaryEmbedding(const int dim, const int max_position_embeddings, const float base) { + this->dim = dim; + inv_freq_size = (dim + 1) / 2; - pcos[j] = cos_tmp; - psin[j] = sin_tmp; - } + inv_freq = (float *)malloc(inv_freq_size * sizeof(float)); + emb_cos = (float *)xft::alloc(max_position_embeddings * inv_freq_size * sizeof(float)); + emb_sin = (float *)xft::alloc(max_position_embeddings * inv_freq_size * sizeof(float)); + for (size_t i = 0; i < inv_freq_size; i++) { + inv_freq[i] = 1.0 / pow(base, float(i * 2) / dim); } + + xft::llamaSetCosSinCache(inv_freq, emb_cos, emb_sin, inv_freq_size, max_position_embeddings); } -// def rotate_half(x): -// """Rotates half the hidden dims of the input.""" -// x1 = x[..., : x.shape[-1] // 2] -// x2 = x[..., x.shape[-1] // 2 :] -// return torch.cat((-x2, x1), dim=-1) -// def apply_rotary_pos_emb(q, k, cos, sin, position_ids): -// # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. -// cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] -// sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] -// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] -// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] -// q_embed = (q * cos) + (rotate_half(q) * sin) -// k_embed = (k * cos) + (rotate_half(k) * sin) -// return q_embed, k_embed -// -// qk_shape: 4 values of [batch_size, seq_len, head_num, head_size] -// position_ids: an array in the size of seq_len // query and key is the matrix like below: // // |<------------------------------ head_num * head_size --------------------------------->| @@ -213,3 +192,15 @@ void LlamaRotaryEmbedding::forward( } } } + +void LlamaRotaryEmbedding::forward( + float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { + xft::llamaApplyRotaryPosEmbed( + query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); +} + +void LlamaRotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, + int qHeads, int kHeads, int *positionIds) { + xft::llamaApplyRotaryPosEmbed( + query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); +} diff --git a/src/layers/rotary_embedding.h b/src/layers/rotary_embedding.h index b488a5f6..eac5910b 100644 --- a/src/layers/rotary_embedding.h +++ b/src/layers/rotary_embedding.h @@ -18,6 +18,7 @@ #include #include "bfloat16.h" +#include "rotary_embedding_kernels.h" #include "transformer_ctx.h" /* Sample: @@ -40,12 +41,13 @@ class LlamaRotaryEmbedding { ~LlamaRotaryEmbedding() {} void forward(float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds); - void forward( bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds); -private: - void llamaCalEmb(const float *inv_freq, const int max_position_embeddings); + void forward(float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); + void forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); private: bool initialized = false; diff --git a/src/layers/rotary_embedding_chatglm2.cpp b/src/layers/rotary_embedding_chatglm2.cpp index 900b72bb..c360b941 100644 --- a/src/layers/rotary_embedding_chatglm2.cpp +++ b/src/layers/rotary_embedding_chatglm2.cpp @@ -198,3 +198,15 @@ void ChatGLM2RotaryEmbedding::forward( } } } + +void ChatGLM2RotaryEmbedding::forward( + float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { + printf("Unsupported ChatGLM2RotaryEmbedding in cb mode !\n"); + exit(1); +} + +void ChatGLM2RotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, + int qHeads, int kHeads, int *positionIds) { + printf("Unsupported ChatGLM2RotaryEmbedding in cb mode !\n"); + exit(1); +} diff --git a/src/layers/rotary_embedding_chatglm2.h b/src/layers/rotary_embedding_chatglm2.h index 5d85dded..3eca94a3 100644 --- a/src/layers/rotary_embedding_chatglm2.h +++ b/src/layers/rotary_embedding_chatglm2.h @@ -40,6 +40,11 @@ class ChatGLM2RotaryEmbedding { void forward( bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qk_shape, const int *position_ids); + void forward(float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); + void forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); + private: void glm2CalEmb(); void interleave_qk(__m512 a, __m512 b, __m512 *result0, __m512 *result1); diff --git a/src/layers/rotary_embedding_qwen.cpp b/src/layers/rotary_embedding_qwen.cpp index 9716b949..5b8f8e72 100644 --- a/src/layers/rotary_embedding_qwen.cpp +++ b/src/layers/rotary_embedding_qwen.cpp @@ -326,3 +326,15 @@ void QwenRotaryEmbedding::forward( } } } + +void QwenRotaryEmbedding::forward( + float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { + printf("Unsupported QwenRotaryEmbedding in cb mode !\n"); + exit(1); +} + +void QwenRotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, + int qHeads, int kHeads, int *positionIds) { + printf("Unsupported QwenRotaryEmbedding in cb mode !\n"); + exit(1); +} diff --git a/src/layers/rotary_embedding_qwen.h b/src/layers/rotary_embedding_qwen.h index 9e8e4850..5c3d2c44 100644 --- a/src/layers/rotary_embedding_qwen.h +++ b/src/layers/rotary_embedding_qwen.h @@ -44,6 +44,11 @@ class QwenRotaryEmbedding { void forward( bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds); + void forward(float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); + void forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); + void init_logn(int max_length = 2048, bool use_logn = true, bool use_ntk = true); private: diff --git a/src/layers/yarn_scaled_rotary_embedding.cpp b/src/layers/yarn_scaled_rotary_embedding.cpp index 44dbad8e..ed48e341 100644 --- a/src/layers/yarn_scaled_rotary_embedding.cpp +++ b/src/layers/yarn_scaled_rotary_embedding.cpp @@ -17,12 +17,6 @@ #include "allocator.h" #include "compile_util.h" -static int maxSeqLenCached = -1; -static int invFreqSize = -1; -static float *invFreq; -static float *embCos = nullptr; -static float *embSin = nullptr; - bool LlamaYaRNScaledRotaryEmbedding::initialized = false; // dim: equals to head size @@ -35,6 +29,7 @@ LlamaYaRNScaledRotaryEmbedding::LlamaYaRNScaledRotaryEmbedding( maxSeqLenCached = maxPosEmbed; invFreqSize = (dim + 1) / 2; + this->dim = dim; // assert ropeParam in Context assert(ropeParamsPtr->type == "yarn"); @@ -46,13 +41,17 @@ LlamaYaRNScaledRotaryEmbedding::LlamaYaRNScaledRotaryEmbedding( yarnLinearRampMask(invFreqMask, low, high, invFreqSize, ropeParamsPtr->extraPolFactor); invFreq = (float *)malloc(invFreqSize * sizeof(float)); + embCos = (float *)xft::alloc(maxSeqLenCached * invFreqSize * sizeof(float)); + embSin = (float *)xft::alloc(maxSeqLenCached * invFreqSize * sizeof(float)); for (size_t i = 0; i < invFreqSize; i++) { invFreq[i] = 1.0 / pow(ropeParamsPtr->base, float(i * 2) / dim); invFreq[i] = invFreq[i] / ropeParamsPtr->scale * (1 - invFreqMask[i]) + invFreq[i] * invFreqMask[i]; } free(invFreqMask); - yarnLlamaCalEmb(ropeParamsPtr->scale, ropeParamsPtr->attnFactor); + float scale = ropeParamsPtr->scale <= 1 ? 1.0 : (0.1 * std::log(ropeParamsPtr->scale) + 1.0); + scale *= ropeParamsPtr->attnFactor; + xft::llamaSetCosSinCache(invFreq, embCos, embSin, invFreqSize, maxSeqLenCached, scale); } else if (dim != invFreqSize * 2) { printf("Incorrect dim=%d, inv_freq_size=%d\n", dim, invFreqSize); exit(-1); @@ -81,52 +80,6 @@ void LlamaYaRNScaledRotaryEmbedding::yarnLinearRampMask( } } -void LlamaYaRNScaledRotaryEmbedding::yarnLlamaCalEmb(float scale, float attnFactor) { - float mscale; - if (scale <= 1) - mscale = 1.0; - else - mscale = 0.1 * std::log(scale) + 1.0; - mscale *= attnFactor; - - embCos = (float *)xft::alloc(maxSeqLenCached * (invFreqSize * 2) * sizeof(float)); - embSin = (float *)xft::alloc(maxSeqLenCached * (invFreqSize * 2) * sizeof(float)); - -#pragma omp parallel for - for (size_t i = 0; i < maxSeqLenCached; i++) { - float *pcos = embCos + i * invFreqSize * 2; - float *psin = embSin + i * invFreqSize * 2; - - for (size_t j = 0; j < invFreqSize; j++) { - float tmp = i * invFreq[j]; - float cosTmp = std::cos(tmp) * mscale; - float sinTmp = std::sin(tmp) * mscale; - - pcos[j] = cosTmp; - pcos[j + invFreqSize] = cosTmp; - psin[j] = sinTmp; - psin[j + invFreqSize] = sinTmp; - } - } -} - -// def rotate_half(x): -// """Rotates half the hidden dims of the input.""" -// x1 = x[..., : x.shape[-1] // 2] -// x2 = x[..., x.shape[-1] // 2 :] -// return torch.cat((-x2, x1), dim=-1) -// def apply_rotary_pos_emb(q, k, cos, sin, position_ids): -// # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. -// cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] -// sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] -// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] -// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] -// q_embed = (q * cos) + (rotate_half(q) * sin) -// k_embed = (k * cos) + (rotate_half(k) * sin) -// return q_embed, k_embed -// -// qk_shape: 4 values of [batch_size, seq_len, head_num, head_size] -// position_ids: an array in the size of seq_len // query and key is the matrix like below: // // |<------------------------------ head_size * head_num --------------------------------->| @@ -160,8 +113,8 @@ void LlamaYaRNScaledRotaryEmbedding::forward( for (int bs = 0; bs < batchSize; ++bs) { for (int seq = 0; seq < seqLen; ++seq) { int pos = positionIds[seq]; - float *pcos = embCos + pos * dim; - float *psin = embSin + pos * dim; + float *pcos = embCos + pos * half; + float *psin = embSin + pos * half; float *q = query + bs * seqLen * qStride + seq * qStride + head * dim; float *k = key + bs * seqLen * kStride + seq * kStride + head * dim; @@ -169,13 +122,13 @@ void LlamaYaRNScaledRotaryEmbedding::forward( for (int i = 0; i < half; ++i) { if (head < qHeads) { auto q1 = q[i]; - q[i] = q[i] * pcos[i] - q[i + half] * psin[i]; - q[i + half] = q[i + half] * pcos[i + half] + q1 * psin[i + half]; + q[i] = q1 * pcos[i] - q[i + half] * psin[i]; + q[i + half] = q[i + half] * pcos[i] + q1 * psin[i]; } if (head < kHeads) { auto k1 = k[i]; - k[i] = k[i] * pcos[i] - k[i + half] * psin[i]; - k[i + half] = k[i + half] * pcos[i + half] + k1 * psin[i + half]; + k[i] = k1 * pcos[i] - k[i + half] * psin[i]; + k[i + half] = k[i + half] * pcos[i] + k1 * psin[i]; } } } @@ -200,8 +153,8 @@ void LlamaYaRNScaledRotaryEmbedding::forward( for (int bs = 0; bs < batchSize; ++bs) { for (int seq = 0; seq < seqLen; ++seq) { int pos = positionIds[seq]; - float *pcos = embCos + pos * dim; - float *psin = embSin + pos * dim; + float *pcos = embCos + pos * half; + float *psin = embSin + pos * half; bfloat16_t *q = query + bs * seqLen * qStride + seq * qStride + head * dim; bfloat16_t *k = key + bs * seqLen * kStride + seq * kStride + head * dim; @@ -239,3 +192,15 @@ void LlamaYaRNScaledRotaryEmbedding::forward( } } } + +void LlamaYaRNScaledRotaryEmbedding::forward( + float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { + xft::llamaApplyRotaryPosEmbed( + query, key, embCos, embSin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); +} + +void LlamaYaRNScaledRotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, + int kStride, int qHeads, int kHeads, int *positionIds) { + xft::llamaApplyRotaryPosEmbed( + query, key, embCos, embSin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); +} diff --git a/src/layers/yarn_scaled_rotary_embedding.h b/src/layers/yarn_scaled_rotary_embedding.h index c00a8635..1c9e2662 100644 --- a/src/layers/yarn_scaled_rotary_embedding.h +++ b/src/layers/yarn_scaled_rotary_embedding.h @@ -18,6 +18,7 @@ #include #include #include "bfloat16.h" +#include "rotary_embedding_kernels.h" #include "transformer_ctx.h" /* Sample: @@ -43,6 +44,11 @@ class LlamaYaRNScaledRotaryEmbedding { void forward( bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds); + void forward(float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); + void forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds); + private: void yarnFindRange(int &low, int &high, int betaFast, int betaSlow, int dim, float base, int orgMaxPosEmbed); void yarnLinearRampMask(float *invFreqMask, int low, int high, int dim, float extraFactor); @@ -50,4 +56,10 @@ class LlamaYaRNScaledRotaryEmbedding { private: static bool initialized; + int dim = -1; + int maxSeqLenCached = -1; + int invFreqSize = -1; + float *invFreq; + float *embCos = nullptr; + float *embSin = nullptr; }; diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 4f1fb59a..02edcc0c 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -38,13 +38,14 @@ #include "transformer_ctx.h" #include "transpose_util.h" #include "weight_util.h" -#include "sequence.h" using namespace xft; struct QKPO_Dummy { QKPO_Dummy(int dim, int maxPos) {} void forward(float *query, float *key, int qStride, int kStride, const int *qk_shape, const int *position_ids) {} + void forward(float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, + int *positionIds) {}; }; // To get data types in MLP class @@ -316,8 +317,7 @@ class CommonDecoder : public AbstractDecoder { dbg.debugPrint("---- embedding.forward ----\n"); dbg.debugPrint("ids:\n"); dbg.dumpMatrix(ids, batchSize, inputSeqLen, inputSeqLen); - dbg.debugPrint( - "embBuf(rows: %d, cols: %d, stride: %d):\n", batchSize * inputSeqLen, hiddenSize, hiddenSize); + dbg.debugPrint("embBuf(rows: %d, cols: %d, stride: %d):\n", batchSize * inputSeqLen, hiddenSize, hiddenSize); dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, hiddenSize, hiddenSize); #endif @@ -360,7 +360,7 @@ class CommonDecoder : public AbstractDecoder { } } - while(TaskWaitingQueue::getInstance().empty()); + while (TaskWaitingQueue::getInstance().empty()); SequenceGroupMeta *runningTask = nullptr; int32_t sequenceID = -1; @@ -581,7 +581,7 @@ class CommonDecoder : public AbstractDecoder { #ifdef DEBUG auto splitSize = this->predictor->getSplitSize(); - dbg.debugPrint("finalOut:\n"); + dbg.debugPrint("finalOut:\n"); dbg.dumpMatrix(finalOut, logitRows, splitSize, splitSize); #endif diff --git a/tests/ut/rotary_embedding_test.cpp b/tests/ut/rotary_embedding_test.cpp index e3fe8f32..c339e6ec 100644 --- a/tests/ut/rotary_embedding_test.cpp +++ b/tests/ut/rotary_embedding_test.cpp @@ -28,51 +28,78 @@ static bool compare(const float *result, const float *ground_truth, const int si TEST(RotrayEmbedding, RotrayEmbeddingTest) { int bs = 2, seq = 2, headnum = 2, dim = 2; int max_len = 10; - int qkshape[5] = {bs, seq, headnum, dim, headnum}; - int pos_ids[2] = {1, 0}; int stride = bs * seq, size = bs * seq * headnum * dim; - LlamaRotaryEmbedding RotrayEmbeddingTest(dim, max_len); - float q[16] = {4, 4, 4, 4, 3, 2, 1, 1, 4, 4, 2, 1, 4, 1, 3, 0}; - float k[16] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + float q_input[16] = {4, 4, 4, 4, 3, 2, 1, 1, 4, 4, 2, 1, 4, 1, 3, 0}; + float k_input[16] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; float q_groundtruth[16] = {-1.20467, 5.52709, -1.20467, 5.52709, 3, 2, 1, 1, -1.20467, 5.52709, 0.239134, 2.22324, 4, 1, 3, 0}; float k_groundtruth[16] = {-0.301169, 1.38177, -0.301169, 1.38177, 1, 1, 1, 1, -0.301169, 1.38177, -0.301169, 1.38177, 1, 1, 1, 1}; + float q[16], k[16]; + + LlamaRotaryEmbedding RotrayEmbeddingTest(dim, max_len); + + memcpy(q, q_input, sizeof(float) * 16); + memcpy(k, k_input, sizeof(float) * 16); + int qkshape[5] = {bs, seq, headnum, dim, headnum}; + int pos_ids[2] = {1, 0}; RotrayEmbeddingTest.forward(q, k, stride, stride, qkshape, pos_ids); EXPECT_TRUE(compare(q, q_groundtruth, size)); EXPECT_TRUE(compare(k, k_groundtruth, size)); + + memcpy(q, q_input, sizeof(float) * 16); + memcpy(k, k_input, sizeof(float) * 16); + int posIds[bs * seq] = {1, 0, 1, 0}; + RotrayEmbeddingTest.forward(q, k, bs * seq, stride, stride, headnum, headnum, posIds); + EXPECT_TRUE(compare(q, q_groundtruth, size)); + EXPECT_TRUE(compare(k, k_groundtruth, size)); } TEST(RotrayEmbedding, BF16Test) { int bs = 2, seq = 2, headnum = 2, dim = 2; int max_len = 10; - int qkshape[5] = {bs, seq, headnum, dim, headnum}; - int pos_ids[2] = {1, 0}; int stride = bs * seq, size = bs * seq * headnum * dim; - float q_fp32[16] = {4, 4, 4, 4, 3, 2, 1, 1, 4, 4, 2, 1, 4, 1, 3, 0}; - float k_fp32[16] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float q_input[16] = {4, 4, 4, 4, 3, 2, 1, 1, 4, 4, 2, 1, 4, 1, 3, 0}; + float k_input[16] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; float q_groundtruth[16] = {-1.20467, 5.52709, -1.20467, 5.52709, 3, 2, 1, 1, -1.20467, 5.52709, 0.239134, 2.22324, 4, 1, 3, 0}; float k_groundtruth[16] = {-0.301169, 1.38177, -0.301169, 1.38177, 1, 1, 1, 1, -0.301169, 1.38177, -0.301169, 1.38177, 1, 1, 1, 1}; + float q_output[16], k_output[16]; + LlamaRotaryEmbedding RotrayEmbeddingTest(dim, max_len); bfloat16_t q[16]; bfloat16_t k[16]; - bfloat16_t::cvt_float_to_bfloat16(q_fp32, q, 16); - bfloat16_t::cvt_float_to_bfloat16(k_fp32, k, 16); - LlamaRotaryEmbedding RotrayEmbeddingTest(dim, max_len); + // forward 1 + bfloat16_t::cvt_float_to_bfloat16(q_input, q, 16); + bfloat16_t::cvt_float_to_bfloat16(k_input, k, 16); + int qkshape[5] = {bs, seq, headnum, dim, headnum}; + int pos_ids[2] = {1, 0}; RotrayEmbeddingTest.forward(q, k, stride, stride, qkshape, pos_ids); - bfloat16_t::cvt_bfloat16_to_float(q, q_fp32, 16); - bfloat16_t::cvt_bfloat16_to_float(k, k_fp32, 16); + bfloat16_t::cvt_bfloat16_to_float(q, q_output, 16); + bfloat16_t::cvt_bfloat16_to_float(k, k_output, 16); + + EXPECT_TRUE(compare(q_output, q_groundtruth, size, 0.01)); + EXPECT_TRUE(compare(k_output, k_groundtruth, size, 0.01)); + + // forward 2 + bfloat16_t::cvt_float_to_bfloat16(q_input, q, 16); + bfloat16_t::cvt_float_to_bfloat16(k_input, k, 16); + int posIds[bs * seq] = {1, 0, 1, 0}; + RotrayEmbeddingTest.forward(q, k, bs * seq, stride, stride, headnum, headnum, posIds); - EXPECT_TRUE(compare(q_fp32, q_groundtruth, size, 0.01)); - EXPECT_TRUE(compare(k_fp32, k_groundtruth, size, 0.01)); + bfloat16_t::cvt_bfloat16_to_float(q, q_output, 16); + bfloat16_t::cvt_bfloat16_to_float(k, k_output, 16); + + EXPECT_TRUE(compare(q_output, q_groundtruth, size, 0.01)); + EXPECT_TRUE(compare(k_output, k_groundtruth, size, 0.01)); } int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +} From f220fe0600bca18d518a2c3ba6b503f9eaa97dea Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Sat, 11 May 2024 11:11:36 +0800 Subject: [PATCH 23/35] [Layer] Cross attention impl. for CB (#382) --- src/kernels/attention_kernels.h | 2 +- src/layers/attention.h | 28 ++++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/kernels/attention_kernels.h b/src/kernels/attention_kernels.h index d247ee89..a8136f79 100644 --- a/src/kernels/attention_kernels.h +++ b/src/kernels/attention_kernels.h @@ -93,7 +93,7 @@ void gemmSV(T1 *score, const std::tuple &valueMat, T4 *output, int auto [B, ldv, scale] = valueMat; auto C = output; const int N = headSize; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { xft::small_gemm(A, B, scale, C, M, N, K, lds, ldv, ldo); } else { xft::small_gemm(A, B, C, M, N, K, lds, ldv, ldo); diff --git a/src/layers/attention.h b/src/layers/attention.h index d0c690d4..cc999370 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -613,14 +613,15 @@ class Attention { int responsibleQHeads = this->endQHead - this->startQHead; int responsibleKVHeads = this->endKVHead - this->startKVHead; - int tokenSizes[ctx->batchSize]; - for (int i = 0; i < ctx->batchSize; ++i) { + int batchSize = seqs.size(); + int tokenSizes[batchSize]; + for (int i = 0; i < batchSize; ++i) { tokenSizes[i] = seqs[i]->getInputSeqLen(); } xft::selfAttention( result.Data(), query.Data(), key.Data(), value.Data(), responsibleQHeads, responsibleKVHeads, - ctx->attHeadSize, result.Stride(), query.Stride(), key.Stride(), ctx->batchSize, tokenSizes, + ctx->attHeadSize, result.Stride(), query.Stride(), key.Stride(), batchSize, tokenSizes, ctx->attFactor, ctx->numThreads, [&](int b, int headIdx, int seqIdx) { return keyCaches[b]->getSequence(seqIdx, 0, headIdx); }, [&](int b, int headIdx, int seqIdx) { return valueCaches[b]->getSequence(seqIdx, 0, headIdx); }); @@ -630,7 +631,26 @@ class Attention { void fusedAttention(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, xft::Matrix &value, xft::Matrix &result, std::vector *> &keyCaches, std::vector *> &valueCaches, std::vector &seqs) { - // TODO: implement fusedAttention + int responsibleQHeads = this->endQHead - this->startQHead; + int responsibleKVHeads = this->endKVHead - this->startKVHead; + + int batchSize = seqs.size(); + + // TODO: move to AttentionBlock + int inputSeqLens[batchSize]; + int pastSeqLens[batchSize]; + for (int i = 0; i < batchSize; ++i) { + inputSeqLens[i] = seqs[i]->getInputSeqLen(); + pastSeqLens[i] = seqs[i]->getPastSeqLen(); + } + + // TODO: non-causal case handle + xft::crossAttnByHead( + result.Data(), query.Data(), key.Data(), value.Data(), responsibleQHeads, responsibleKVHeads, + ctx->attHeadSize, result.Stride(), query.Stride(), key.Stride(), batchSize, inputSeqLens, pastSeqLens, + true, nullptr, ctx->attFactor, ctx->numThreads, + [&](int b, int headIdx) { return keyCaches[b]->getHead(0, headIdx); }, + [&](int b, int headIdx) { return valueCaches[b]->getHead(0, headIdx); }); } int getMBlockSize(int inputSeqLen, int headSize, int minVal = 6) { From eb417af3ef41ad74d1b1cb51ef074da57e120a34 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Sat, 11 May 2024 13:34:34 +0800 Subject: [PATCH 24/35] [Build] Fix namespace build issue. (#388) --- src/layers/attention.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/attention.cpp b/src/layers/attention.cpp index d6ac2470..0e9d0669 100644 --- a/src/layers/attention.cpp +++ b/src/layers/attention.cpp @@ -110,7 +110,7 @@ void AttentionLLaMAImpl(DataType dt, int batchSize, int inputSeqLen, int attHead } ctx->resize(batchSize, inputSeqLen, pastSeqLen); - hpj::Matrix actBuffers; + xft::Matrix actBuffers; actBuffers.Resize(batchSize * inputSeqLen, hiddenSize); float *attnMask = prepareAttnMask(ctx, step); From b5bda0ce69af814132b2145dc52e3c201311892d Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Sat, 11 May 2024 13:34:53 +0800 Subject: [PATCH 25/35] [Common] DecoderContext::resize bug fix (#387) --- src/common/transformer_ctx.h | 17 ++++++++--------- src/layers/attention.h | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index f99cd486..0faa075d 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -274,11 +274,10 @@ struct DecoderContext { auto range = SplitUtil::getTaskRange(intermediateSize, numSplit, splitIdx); int imCols = range.second - range.first; - uint64_t normSize = (uint64_t)batchSize * inputSeqLen * hiddenSize; - uint64_t qkvSize = (uint64_t)batchSize * inputSeqLen * qkvCols; - uint64_t imOutSize = (uint64_t)batchSize * inputSeqLen * imCols * mlpFactor; - - uint64_t tmpBufSize = (uint64_t)batchSize * inputSeqLen * hiddenSize; + uint64_t normSize = (uint64_t)totalInSeqLen * hiddenSize; + uint64_t qkvSize = (uint64_t)totalInSeqLen * qkvCols; + uint64_t imOutSize = (uint64_t)totalInSeqLen * imCols * mlpFactor; + uint64_t tmpBufSize = (uint64_t)totalInSeqLen * hiddenSize; size1 = normSize; size2 = qkvSize < imOutSize ? imOutSize : qkvSize; @@ -294,10 +293,10 @@ struct DecoderContext { } // Assign the buffer - normBuf.Assign(this->rawBuffer, batchSize * inputSeqLen, hiddenSize, hiddenSize); - tmpBuf.Assign(this->rawBuffer + size1 + size2, batchSize * inputSeqLen, hiddenSize, hiddenSize); - imOut.Assign(this->rawBuffer + size1, batchSize * inputSeqLen, imCols, imCols); - qkvMatMul.Assign(this->rawBuffer + size1, batchSize * inputSeqLen, qkvCols, qkvCols); + normBuf.Assign(this->rawBuffer, totalInSeqLen, hiddenSize, hiddenSize); + tmpBuf.Assign(this->rawBuffer + size1 + size2, totalInSeqLen, hiddenSize, hiddenSize); + imOut.Assign(this->rawBuffer + size1, totalInSeqLen, imCols, imCols); + qkvMatMul.Assign(this->rawBuffer + size1, totalInSeqLen, qkvCols, qkvCols); } // TODO: deprecate it diff --git a/src/layers/attention.h b/src/layers/attention.h index cc999370..69c64a64 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -425,7 +425,7 @@ class Attention { xft::Matrix qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride); #ifdef DEBUG - dbg.debugPrint("---- DecoderLayer.forward (useSelfAttn=%d) ----\n", useSelfAttn); + dbg.debugPrint("---- DecoderLayer.forward ----\n"); dbg.debugPrint("input:\n"); dbg.dumpMatrix(inputBuffer); #endif From 35562c02ac90d9bd35527b7fa3853ca4df6c04aa Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Sat, 11 May 2024 15:13:13 +0800 Subject: [PATCH 26/35] [Model][Layer] Correct output of the new forward (#389) --- src/layers/decoder_block.h | 19 +++++++++++++++---- src/models/common_decoder.h | 12 ++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h index c310a585..0cfff71c 100644 --- a/src/layers/decoder_block.h +++ b/src/layers/decoder_block.h @@ -61,8 +61,8 @@ class DecoderBlock { int size() const { return this->decoders.size(); } - template - void forward(DecoderContext *ctx, std::vector &seqs, InT *input, OutT *output) { + template + void forward(DecoderContext *ctx, std::vector &seqs, T *inputBuf, T *outputBuf) { using AttnOutT = typename AttnTypeExtractor::Tout; Messenger &messenger = Messenger::getInstance(); @@ -81,6 +81,10 @@ class DecoderBlock { // All layers forward int layersOnDuty = this->decoders.size(); + auto input = inputBuf; + auto output = outputBuf; + AttnOutT *attnOut = (AttnOutT *)(ctx->tmpBuf.Data()); + for (int i = 0; i < layersOnDuty; ++i) { int workers = messenger.getSize(); @@ -90,6 +94,7 @@ class DecoderBlock { std::vector *> keyCachesVec(keyCaches.size()); std::vector *> valueCachesVec(valueCaches.size()); + // TODO: better method? for (int j = 0; j < keyCaches.size(); ++j) { keyCachesVec[j] = static_cast *>(keyCaches[j]); } @@ -98,8 +103,6 @@ class DecoderBlock { valueCachesVec[j] = static_cast *>(valueCaches[j]); } - AttnOutT *attnOut = (AttnOutT *)(ctx->tmpBuf.Data()); - this->decoders[i]->forwardAttention(ctx, seqs, input, attnOut, totInSeqLen, keyCachesVec, valueCachesVec); // Merge the result of attention @@ -120,6 +123,14 @@ class DecoderBlock { this->decoders[i]->forwardFFN(ctx, attnOut, output, ctx->hiddenSize, ctx->hiddenSize, true, totInSeqLen); } } + + // Update the input/output for the next layer + std::swap(input, output); + } + + // Copy final result to the output buffer + if (inputBuf != outputBuf && layersOnDuty % 2 == 0) { + std::memcpy(outputBuf, inputBuf, totInSeqLen * ctx->hiddenSize * sizeof(T)); } } diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 02edcc0c..2e973397 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -549,13 +549,17 @@ class CommonDecoder : public AbstractDecoder { this->embeddingForward(allInputIds.data(), embBuf, totInputSeqLen); // Decoder block (all layers) - decoderBlock->forward(ctx, seqs, embBuf, outBuf); + decoderBlock->forward(ctx, seqs, embBuf, embBuf); // Prepare input for final Layer Norm (only care about the last row of the result) - // Shape of embBuf: (bs, seqLen, hiddenSize) + // Shape of embBuf: (total_input_seqlen, hiddenSize) MlpOutT *lnIn = embBuf; - if (!logitsAll) { - // TODO: copy needed data + if (logitRows != totInputSeqLen) { + int offset = -1; + for (int b = 0; b < batchSize; ++b) { + offset += seqs[b]->getInputSeqLen(); + memcpy(lnIn + b * hiddenSize, embBuf + offset * hiddenSize, hiddenSize * sizeof(MlpOutT)); + } } #ifdef DEBUG From e67e4552f9b09750d90de098a0d381aed9fe6fe7 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Mon, 13 May 2024 09:57:10 +0800 Subject: [PATCH 27/35] [Example] add cb_check example (#390) --- examples/cpp/CMakeLists.txt | 16 +- examples/cpp/cb_check.cpp | 434 ++++++++++++++++++++++++++++++++++++ 2 files changed, 447 insertions(+), 3 deletions(-) create mode 100644 examples/cpp/cb_check.cpp diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 86f778fc..9332179e 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -14,26 +14,36 @@ # ============================================================================ cmake_minimum_required(VERSION 3.15.1) -aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} EXAMPLE_SCR) - include(${CMAKE_SOURCE_DIR}/cmake/cmdline.cmake) include(${CMAKE_SOURCE_DIR}/cmake/sentencepiece.cmake) -add_executable(example ${EXAMPLE_SCR}) +set(EXAMPLE_SRCS "example.cpp" "vocab_opt.cpp" "vocab_qwen.cpp") +set(CB_CHECK_SRCS "cb_check.cpp" "vocab_opt.cpp" "vocab_qwen.cpp") + +add_executable(example ${EXAMPLE_SRCS}) +add_executable(cb_check ${CB_CHECK_SRCS}) target_include_directories(example PRIVATE ${CMAKE_SOURCE_DIR}/3rdparty/cmdline) target_include_directories(example PRIVATE ${CMAKE_SOURCE_DIR}/3rdparty/sentencepiece/include) +target_include_directories(cb_check PRIVATE ${CMAKE_SOURCE_DIR}/3rdparty/cmdline) +target_include_directories(cb_check PRIVATE ${CMAKE_SOURCE_DIR}/3rdparty/sentencepiece/include) target_link_directories(example PRIVATE ${CMAKE_SOURCE_DIR}/3rdparty/sentencepiece/${CMAKE_INSTALL_LIBDIR}) +target_link_directories(cb_check PRIVATE ${CMAKE_SOURCE_DIR}/3rdparty/sentencepiece/${CMAKE_INSTALL_LIBDIR}) if(BUILD_WITH_SHARED_LIBS) target_link_libraries(example PRIVATE xfastertransformer) + target_link_libraries(cb_check PRIVATE xfastertransformer) else() target_link_libraries(example PRIVATE xfastertransformer_static) + target_link_libraries(cb_check PRIVATE xfastertransformer_static) endif() target_link_libraries(example PRIVATE sentencepiece -lstdc++fs) +target_link_libraries(cb_check PRIVATE sentencepiece -lstdc++fs) if(WITH_GPU) target_link_libraries(example PRIVATE -fsycl -fsycl-device-code-split=per_kernel -lOpenCL) + target_link_libraries(cb_check PRIVATE -fsycl -fsycl-device-code-split=per_kernel -lOpenCL) endif() add_dependencies(example cmdline sentencepiece_lib) +add_dependencies(cb_check cmdline sentencepiece_lib) diff --git a/examples/cpp/cb_check.cpp b/examples/cpp/cb_check.cpp new file mode 100644 index 00000000..5ecfc093 --- /dev/null +++ b/examples/cpp/cb_check.cpp @@ -0,0 +1,434 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include +#include +#include +#include +#include +#include + +#include "INIReader.h" +#include "cmdline.h" +#include "sentencepiece_processor.h" +#include "timer.h" +#include "xfastertransformer.h" + +extern const char *vocab_opt[]; +extern const char *vocab_qwen[]; + +class TokenizerBase { +public: + TokenizerBase() {} + TokenizerBase(std::string &tokenPath) { + std::filesystem::path filePath(tokenPath); + + if (!(std::filesystem::exists(filePath) && std::filesystem::is_regular_file(filePath))) { + std::cout << "[ERROR] " << filePath << " isn't a file or not existed." << std::endl; + exit(-1); + } + + const auto status = processor.Load(tokenPath); + if (!status.ok()) { + std::cout << status.ToString() << std::endl; + std::cout << "[ERROR] Fail to load tokenizer file." << std::endl; + exit(-1); + } + vocabSize = processor.GetPieceSize(); + }; + + virtual std::vector encode(std::string &input) { + std::vector output; + processor.Encode(input, &output); + addSpecialTokenIds(output); + return output; + } + + void addSpecialTokenIds(std::vector &input) { + input.insert(input.begin(), prefixTokenIds.begin(), prefixTokenIds.end()); + input.insert(input.end(), suffixTokenIds.begin(), suffixTokenIds.end()); + } + + virtual std::string decode(std::vector &ids) { + std::string text; + processor.Decode(ids, &text); + return text; + } + virtual std::string decode(int id) { return processor.IdToPiece(id); } + + void printResult(std::vector &ids, int batchSize, int numBeams) { + if (batchSize * numBeams > 2) { + printf("[%d]%s [%d]%s ... [%d]%s\n", ids[0], decode(ids[0]).c_str(), ids[1], decode(ids[1]).c_str(), + ids[batchSize * numBeams - 1], decode(ids[batchSize * numBeams - 1]).c_str()); + } else if (batchSize * numBeams > 1) { + printf("[%d]%s [%d]%s\n", ids[0], decode(ids[0]).c_str(), ids[batchSize * numBeams - 1], + decode(ids[batchSize * numBeams - 1]).c_str()); + } else { + printf("[%d]%s ", ids[0], decode(ids[0]).c_str()); + } + } + + std::vector batchDecode(std::vector &input, int batchSize) { + int seqLen = input.size() / batchSize; + std::vector ret; + for (int i = 0; i < batchSize; ++i) { + std::vector tokens(input.begin() + i * seqLen, input.begin() + (i + 1) * seqLen); + ret.emplace_back(decode(tokens)); + } + return ret; + } + +protected: + std::vector prefixTokens; + std::vector suffixTokens; + std::vector prefixTokenIds; + std::vector suffixTokenIds; + + sentencepiece::SentencePieceProcessor processor; + int vocabSize; +}; + +class ChatGLMTokenizer : public TokenizerBase { +public: + ChatGLMTokenizer(std::string &tokenPath) : TokenizerBase(tokenPath) { + suffixTokens = {"[gMASK]", ""}; + suffixTokenIds = {processor.PieceToId("[gMASK]"), processor.PieceToId("")}; + } +}; + +class ChatGLM2Tokenizer : public TokenizerBase { +public: + ChatGLM2Tokenizer(std::string &tokenPath) : TokenizerBase(tokenPath) { + // ChatGLM2's special tokens is not included in sentencepiece. ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + prefixTokens = {"[gMASK]", "sop"}; + prefixTokenIds = {vocabSize + 1, vocabSize + 3}; + } + std::string decode(std::vector &ids) override { + ids.erase(std::remove_if(ids.begin(), ids.end(), [this](int value) { return value >= vocabSize; }), ids.end()); + std::string text; + processor.Decode(ids, &text); + return text; + } + + std::string decode(int id) override { + if (id > vocabSize) { + return ""; + } else { + return processor.IdToPiece(id); + } + } +}; + +class LlamaTokenizer : public TokenizerBase { +public: + LlamaTokenizer(std::string &tokenPath) : TokenizerBase(tokenPath) { processor.SetEncodeExtraOptions("bos"); } +}; + +class BaichuanTokenizer : public TokenizerBase { +public: + BaichuanTokenizer(std::string &tokenPath) : TokenizerBase(tokenPath) { + // 195: user_id 196: assistant_id + prefixTokenIds = {195}; + suffixTokenIds = {196}; + } +}; + +class YaRNLlamaTokenizer : public TokenizerBase { +public: + YaRNLlamaTokenizer(std::string &tokenPath) { vocabSize = 106963; } + + // TODO: Need to achieve actual encode function + std::vector encode(std::string &input) override { + // only for Test + return std::vector( + {7454, 2402, 257, 640, 11, 612, 11196, 257, 1310, 2576, 508, 8288, 284, 423, 17545, 13}); + } + + std::string decode(std::vector &ids) override { + if (ids.size() == 1) { return decode(ids[0]); } + + std::string text(""); + text.reserve(ids.size()); + + for (int id : ids) { + if (id < vocabSize) { + if (vocab_list == nullptr) + text += "[" + std::to_string(id) + "] "; + else + text += vocab_list[id]; + } else { + text += "(null) "; + } + } + + return text; + } + std::string decode(int id) override { + if (id < vocabSize) { + return vocab_list[id]; + } else { + return "(null)"; + } + } + +private: + const char **vocab_list = nullptr; +}; + +class OptTokenizer : public TokenizerBase { +public: + OptTokenizer(std::string &tokenPath) { vocabSize = 50265; } + + std::vector encode(std::string &input) override { + return std::vector({2, 11475, 2115, 10, 86, 6, 89, 13412, 10, 410, 1816, 54, 6640, 7, 33, 18848, 4}); + } + + std::string decode(std::vector &ids) override { + if (ids.size() == 1) { return decode(ids[0]); } + std::string text(""); + for (int id : ids) { + if (id < vocabSize) { + text += vocab_list[id]; + text += " "; + } else { + text += "(null) "; + } + } + return text; + } + std::string decode(int id) override { + if (id < vocabSize) { + return vocab_list[id]; + } else { + return "(null)"; + } + } + +private: + const char **vocab_list = vocab_opt; +}; + +class QwenTokenizer : public TokenizerBase { +public: + QwenTokenizer(std::string &tokenPath) { vocabSize = 151851; } + + // TODO: Need to achieve actual encode function + std::vector encode(std::string &input) override { + // only for Test + return std::vector( + {12522, 5193, 264, 882, 11, 1052, 24295, 264, 2632, 3743, 879, 14915, 311, 614, 30978, 13}); + } + + std::string decode(std::vector &ids) override { + if (ids.size() == 1) { return decode(ids[0]); } + + std::string text(""); + text.reserve(ids.size()); + + for (int id : ids) { + if (id < vocabSize) { + text += vocab_list[id]; + } else { + text += "(null) "; + } + } + + return text; + } + std::string decode(int id) override { + if (id < vocabSize) { + return vocab_list[id]; + } else { + return "(null)"; + } + } + +private: + const char **vocab_list = vocab_qwen; +}; + +class GemmaTokenizer : public TokenizerBase { +public: + GemmaTokenizer(std::string &tokenPath) : TokenizerBase(tokenPath) { + vocabSize = 256000; + prefixTokenIds = {2, 2, 106, 1645, 108}; + suffixTokenIds = {107, 108, 106, 2516, 108}; + } +}; + +TokenizerBase *getTokenizer(std::string &modeltype, std::string &tokenPath) { + if (modeltype == "gpt") { + return new OptTokenizer(tokenPath); + } else if (modeltype == "llama") { + return new LlamaTokenizer(tokenPath); + } else if (modeltype == "yarn_llama") { + return new YaRNLlamaTokenizer(tokenPath); + } else if (modeltype == "baichuan") { + return new BaichuanTokenizer(tokenPath); + } else if (modeltype == "chatglm") { + return new ChatGLMTokenizer(tokenPath); + } else if (modeltype == "chatglm2" or modeltype == "chatglm3") { + return new ChatGLM2Tokenizer(tokenPath); + } else if (modeltype == "qwen") { + return new QwenTokenizer(tokenPath); + } else if (modeltype == "gemma") { + return new GemmaTokenizer(tokenPath); + } else { + std::cout << "[Error] Token list of loaded model is unsupported yet.\n" << std::endl; + exit(-1); + } +} + +std::map dataTypeMap = {{"fp16", xft::DataType::fp16}, {"bf16", xft::DataType::bf16}, + {"int8", xft::DataType::int8}, {"w8a8", xft::DataType::w8a8}, {"int4", xft::DataType::int4}, + {"nf4", xft::DataType::nf4}, {"bf16_fp16", xft::DataType::bf16_fp16}, {"bf16_int8", xft::DataType::bf16_int8}, + {"bf16_w8a8", xft::DataType::bf16_w8a8}, {"bf16_int4", xft::DataType::bf16_int4}, + {"bf16_nf4", xft::DataType::bf16_nf4}, {"w8a8_int8", xft::DataType::w8a8_int8}, + {"w8a8_int4", xft::DataType::w8a8_int4}, {"w8a8_nf4", xft::DataType::w8a8_nf4}}; + +std::map kvCacheDataTypeMap + = {{"fp32", xft::DataType::fp32}, {"fp16", xft::DataType::fp16}, {"int8", xft::DataType::int8}}; + +std::string getModelType(std::string &modelPath) { + std::string configPath = modelPath + "/config.ini"; + INIReader reader = INIReader(configPath); + if (reader.ParseError() < 0) { + printf("[Error] Could not load model config.ini.\n"); + exit(-1); + } + std::string modeltype = *reader.Sections().begin(); + return modeltype; +} + +int main(int argc, char **argv) { + cmdline::parser args; + + args.add("model", 'm', "path of xft format model", true); + args.add("token", 't', "path of tokenizer", true); + args.add("input", 'i', "input prompt.", false, + "Once upon a time, there existed a little girl who liked to have adventures."); + args.add("dtype", 'd', "weight data type", false, "fp16"); + args.add("kv_cache_dtype", '\0', "kv cache data type", false, "fp16"); + + args.parse_check(argc, argv); + + std::string modelPath = args.get("model"); + std::string tokenPath = args.get("token"); + + std::string dtype_name = args.get("dtype"); + xft::DataType dtype = xft::DataType::fp16; + std::string kv_cache_dtype_name = args.get("kv_cache_dtype"); + xft::DataType kvCacheDataType = xft::DataType::fp16; + + // Check data type + auto it = dataTypeMap.find(dtype_name); + if (it != dataTypeMap.end()) { + dtype = it->second; + } else { + std::cout << "[Error] Unsupport dtype index: " << dtype_name << std::endl; + return 0; + } + + it = kvCacheDataTypeMap.find(kv_cache_dtype_name); + if (it != kvCacheDataTypeMap.end()) { + kvCacheDataType = it->second; + } else { + std::cout << "[Error] Unsupport KV cache dtype index: " << kv_cache_dtype_name << std::endl; + return 0; + } + + std::string modeltype = getModelType(modelPath); + + auto *tokenizer = getTokenizer(modeltype, tokenPath); + std::string inputPrompt = args.get("input"); + std::vector input = tokenizer->encode(inputPrompt); + + xft::AutoModel model(modelPath, dtype, kvCacheDataType); + bool isMaster = model.isMaster(); + + if (isMaster) { + std::cout << "[INFO] Model path is " << modelPath << std::endl; + std::cout << "[INFO] Token path is " << tokenPath << std::endl; + std::cout << "[INFO] Data type is " << dtype_name << std::endl; + std::cout << "[INFO] KV cache data type is " << kv_cache_dtype_name << std::endl; + std::cout << "[INFO] Input prompt: " << inputPrompt << std::endl; + std::cout << "[INFO] Input Token Ids: "; + for (auto x : input) { + std::cout << x << " "; + } + std::cout << std::endl; + } + + SearcherConfig config; + std::vector> generatedTokens(3); + int seqIDs[3]; + + // 1st sequence: generate some tokens + auto ret = model.set_input(input, 1, config); + seqIDs[0] = ret[0]; + ret = model.generate(); // 1st token + for (auto id : ret) { + generatedTokens[0].emplace_back(id); + } + + for (int i = 0; i < 2; ++i) { // some next tokens + std::vector> inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}}; + std::vector seqs = {seqIDs[0]}; + model.set_input(inputIDs, seqs, config); + auto ret = model.generate(); + for (auto id : ret) { + generatedTokens[0].emplace_back(id); + } + } + + // 2nd sequence: first token generation + ret = model.set_input(input, 1, config); + seqIDs[1] = ret[0]; + ret = model.generate(); + for (auto id : ret) { + generatedTokens[1].emplace_back(id); + } + + // Batching together to generate some tokens for both sequences + for (int i = 0; i < 2; ++i) { + std::vector> inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}, + {generatedTokens[1].at(generatedTokens[1].size() - 1)}}; + std::vector seqs = {seqIDs[0], seqIDs[1]}; + + model.set_input(inputIDs, seqs, config); + auto ret = model.generate(); + assert(ret.size() == 2); + for (int j = 0; j < 2; ++j) { + generatedTokens[j].emplace_back(ret[j]); + } + } + + // Print out values inside generatedTokens[0] + std::cout << "Generated Tokens [0]: "; + for (auto id : generatedTokens[0]) { + std::cout << id << " "; + } + std::cout << std::endl; + std::vector strs = tokenizer->batchDecode(generatedTokens[0], 1); + std::cout << strs[0] << std::endl; + + // Print out values inside generatedTokens[1] + std::cout << "Generated Tokens [1]: "; + for (auto id : generatedTokens[1]) { + std::cout << id << " "; + } + std::cout << std::endl; + + return 0; +} From c576aff179122d338a2dff44eef158068b86a1ce Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Mon, 13 May 2024 11:01:12 +0800 Subject: [PATCH 28/35] [Bug] Fix incorrect buffer size calculation (#391) --- src/models/common_decoder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 2e973397..e1b75f99 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -1014,7 +1014,7 @@ class CommonDecoder : public AbstractDecoder { int vocabSize = ctx->vocabSize; // Convert final output buffer size into units of hiddenSize - int outRows = std::ceil(logitRows * vocabSize / hiddenSize); + int outRows = std::ceil(1.0f * logitRows * vocabSize / hiddenSize); this->actBuffers->Resize(totInputSeqLen + outRows, hiddenSize); } From eff6a75dafdc2ed6d6ae45df24414c56ab27edf6 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Mon, 13 May 2024 14:09:33 +0800 Subject: [PATCH 29/35] [Example] Fix continuous batching C++ example. (#392) --- examples/cpp/cb_check.cpp | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/cpp/cb_check.cpp b/examples/cpp/cb_check.cpp index 5ecfc093..decb4632 100644 --- a/examples/cpp/cb_check.cpp +++ b/examples/cpp/cb_check.cpp @@ -371,11 +371,15 @@ int main(int argc, char **argv) { } SearcherConfig config; + config.maxLen = 128; std::vector> generatedTokens(3); int seqIDs[3]; + std::vector> inputIDs; + std::vector seqs; // 1st sequence: generate some tokens - auto ret = model.set_input(input, 1, config); + inputIDs = {input}; + auto ret = model.set_input(inputIDs, seqs, config); seqIDs[0] = ret[0]; ret = model.generate(); // 1st token for (auto id : ret) { @@ -383,8 +387,8 @@ int main(int argc, char **argv) { } for (int i = 0; i < 2; ++i) { // some next tokens - std::vector> inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}}; - std::vector seqs = {seqIDs[0]}; + inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}}; + seqs = {seqIDs[0]}; model.set_input(inputIDs, seqs, config); auto ret = model.generate(); for (auto id : ret) { @@ -393,7 +397,9 @@ int main(int argc, char **argv) { } // 2nd sequence: first token generation - ret = model.set_input(input, 1, config); + inputIDs = {input}; + seqs.clear(); + ret = model.set_input(inputIDs, seqs, config); seqIDs[1] = ret[0]; ret = model.generate(); for (auto id : ret) { @@ -402,9 +408,9 @@ int main(int argc, char **argv) { // Batching together to generate some tokens for both sequences for (int i = 0; i < 2; ++i) { - std::vector> inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}, + inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}, {generatedTokens[1].at(generatedTokens[1].size() - 1)}}; - std::vector seqs = {seqIDs[0], seqIDs[1]}; + seqs = {seqIDs[0], seqIDs[1]}; model.set_input(inputIDs, seqs, config); auto ret = model.generate(); @@ -429,6 +435,8 @@ int main(int argc, char **argv) { std::cout << id << " "; } std::cout << std::endl; + strs = tokenizer->batchDecode(generatedTokens[1], 1); + std::cout << strs[0] << std::endl; return 0; } From 2b374ffec0831d2af6d2547c22f95a15e50cfd56 Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Mon, 13 May 2024 14:40:29 +0800 Subject: [PATCH 30/35] [Example] More check in C++ continuous batching example (#393) --- examples/cpp/cb_check.cpp | 63 +++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/examples/cpp/cb_check.cpp b/examples/cpp/cb_check.cpp index decb4632..7daba370 100644 --- a/examples/cpp/cb_check.cpp +++ b/examples/cpp/cb_check.cpp @@ -388,8 +388,7 @@ int main(int argc, char **argv) { for (int i = 0; i < 2; ++i) { // some next tokens inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}}; - seqs = {seqIDs[0]}; - model.set_input(inputIDs, seqs, config); + model.set_input(inputIDs, {seqIDs[0]}, config); auto ret = model.generate(); for (auto id : ret) { generatedTokens[0].emplace_back(id); @@ -410,9 +409,7 @@ int main(int argc, char **argv) { for (int i = 0; i < 2; ++i) { inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}, {generatedTokens[1].at(generatedTokens[1].size() - 1)}}; - seqs = {seqIDs[0], seqIDs[1]}; - - model.set_input(inputIDs, seqs, config); + model.set_input(inputIDs, {seqIDs[0], seqIDs[1]}, config); auto ret = model.generate(); assert(ret.size() == 2); for (int j = 0; j < 2; ++j) { @@ -420,23 +417,51 @@ int main(int argc, char **argv) { } } - // Print out values inside generatedTokens[0] - std::cout << "Generated Tokens [0]: "; - for (auto id : generatedTokens[0]) { - std::cout << id << " "; + // 3rd sequence: first token generation + inputIDs = {input}; + seqs.clear(); + ret = model.set_input(inputIDs, seqs, config); + seqIDs[2] = ret[0]; + ret = model.generate(); + for (auto id : ret) { + generatedTokens[2].emplace_back(id); + } + + // Batching together to generate some tokens for 3 sequences + for (int i = 0; i < 2; ++i) { + inputIDs = {{generatedTokens[0].at(generatedTokens[0].size() - 1)}, + {generatedTokens[1].at(generatedTokens[1].size() - 1)}, + {generatedTokens[2].at(generatedTokens[2].size() - 1)}}; + model.set_input(inputIDs, {seqIDs[0], seqIDs[1], seqIDs[2]}, config); + auto ret = model.generate(); + assert(ret.size() == 3); + for (int j = 0; j < 3; ++j) { + generatedTokens[j].emplace_back(ret[j]); + } + } + + // Suppose sequence 0 finished + for (int i = 0; i < 2; ++i) { + inputIDs = {{generatedTokens[1].at(generatedTokens[1].size() - 1)}, + {generatedTokens[2].at(generatedTokens[2].size() - 1)}}; + model.set_input(inputIDs, {seqIDs[1], seqIDs[2]}, config); + auto ret = model.generate(); + assert(ret.size() == 2); + for (int j = 0; j < 2; ++j) { + generatedTokens[j + 1].emplace_back(ret[j]); + } } - std::cout << std::endl; - std::vector strs = tokenizer->batchDecode(generatedTokens[0], 1); - std::cout << strs[0] << std::endl; - // Print out values inside generatedTokens[1] - std::cout << "Generated Tokens [1]: "; - for (auto id : generatedTokens[1]) { - std::cout << id << " "; + // Print out values inside generatedTokens + for (int i = 0; i < 3; ++i) { + std::cout << "Generated Tokens [" << i << "]: "; + for (auto id : generatedTokens[i]) { + std::cout << id << " "; + } + std::cout << std::endl; + std::vector strs = tokenizer->batchDecode(generatedTokens[i], 1); + std::cout << strs[0] << std::endl; } - std::cout << std::endl; - strs = tokenizer->batchDecode(generatedTokens[1], 1); - std::cout << strs[0] << std::endl; return 0; } From 7e9d7316d2f6f38d23527d4035afa6fe027511d7 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Mon, 13 May 2024 15:45:20 +0800 Subject: [PATCH 31/35] [Model] Check maxLen should be [input len, model max len]. (#394) --- src/models/models.cpp | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/models/models.cpp b/src/models/models.cpp index 40366993..708c030c 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -187,6 +187,7 @@ std::vector Model::set_input(std::vector &inputIds_, int batchSize const std::vector> &stopWordsList_) { if (config_.eosTokenId == -1) { config_.eosTokenId = decoder->getEndId(); } if (config_.padTokenId == -1) { config_.padTokenId = config_.eosTokenId; } + if (config_.maxLen < 0) { config_.maxLen = this->maxSeqLen; } SamplingMeta samplingMeta(config_, stopWordsList_); Messenger &messenger = Messenger::getInstance(); @@ -215,6 +216,7 @@ std::vector Model::set_input(std::vector &inputIds_, int batchSize seqLen = inputIds_.size() / batchSize_; } + samplingMeta.config.maxLen = std::max(samplingMeta.config.maxLen, seqLen); std::vector seqIDs; SequencePool &seqPool = SequencePool::getInstance(); @@ -225,7 +227,7 @@ std::vector Model::set_input(std::vector &inputIds_, int batchSize workingGroup.push_back(group); seqIDs.push_back(group->getGroupID()); // TODO: inin KVCache for beamsearch - kvCacheMgr.addSequence(group->getGroupID()); + kvCacheMgr.addSequence(group->getGroupID(), samplingMeta.config.maxLen); } return seqIDs; @@ -255,6 +257,7 @@ std::vector Model::set_input(std::vector> &inputIds_, const std::vector> &stopWordsList_) { if (config_.eosTokenId == -1) { config_.eosTokenId = decoder->getEndId(); } if (config_.padTokenId == -1) { config_.padTokenId = config_.eosTokenId; } + if (config_.maxLen < 0) { config_.maxLen = this->maxSeqLen; } SamplingMeta samplingMeta(config_, stopWordsList_); Messenger &messenger = Messenger::getInstance(); @@ -265,16 +268,21 @@ std::vector Model::set_input(std::vector> &inputIds_, SequencePool &seqPool = SequencePool::getInstance(); KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); workingGroup.clear(); + std::vector seqLens; + if (isMaster()) { + for (auto &ids : inputIds_) { + seqLens.push_back(ids.size()); + samplingMeta.config.maxLen = std::max(samplingMeta.config.maxLen, (int)ids.size()); + } + } // Sync input and sampling param in distributed mode. if (messenger.getSize() > 1) { // [batch size, inputIds size] - std::vector seqLens; int dims[2]; if (isMaster()) { inputIds.clear(); for (auto &ids : inputIds_) { - seqLens.push_back(ids.size()); inputIds.insert(inputIds.end(), ids.begin(), ids.end()); } dims[0] = batchSize; @@ -301,7 +309,7 @@ std::vector Model::set_input(std::vector> &inputIds_, workingGroup.push_back(group); seqIDs.push_back(group->getGroupID()); // TODO: inin KVCache for beamsearch - kvCacheMgr.addSequence(group->getGroupID()); + kvCacheMgr.addSequence(group->getGroupID(), samplingMeta.config.maxLen); it += seqLens[i]; } @@ -315,7 +323,7 @@ std::vector Model::set_input(std::vector> &inputIds_, workingGroup.push_back(group); seqIDs.push_back(group->getGroupID()); // TODO: inin KVCache for beamsearch - kvCacheMgr.addSequence(group->getGroupID()); + kvCacheMgr.addSequence(group->getGroupID(), samplingMeta.config.maxLen); } return seqIDs; @@ -325,6 +333,7 @@ std::vector Model::set_input(std::vector> &inputIds_, SearcherConfig &config_, const std::vector> &stopWordsList_) { if (config_.eosTokenId == -1) { config_.eosTokenId = decoder->getEndId(); } if (config_.padTokenId == -1) { config_.padTokenId = config_.eosTokenId; } + if (config_.maxLen < 0) { config_.maxLen = this->maxSeqLen; } config_.maxLen = std::min(config_.maxLen, this->maxSeqLen); SamplingMeta samplingMeta(config_, stopWordsList_); @@ -345,11 +354,15 @@ std::vector Model::set_input(std::vector> &inputIds_, if (seqIDs.empty()) { // Prompt(1st token) // Create seq meta for inputs and return seq IDs + for (auto &ids : inputIds_) { + samplingMeta.config.maxLen = std::max(samplingMeta.config.maxLen, (int)ids.size()); + } + for (int i = 0; i < batchSize; i++) { auto group = seqPool.newGroupMeta(inputIds_[i], samplingMeta); workingGroup.push_back(group); seqIDs.push_back(group->getGroupID()); - kvCacheMgr.addSequence(group->getGroupID(), config_.maxLen); + kvCacheMgr.addSequence(group->getGroupID(), samplingMeta.config.maxLen); } } else { // Decode(next token) @@ -382,7 +395,7 @@ std::vector Model::set_input(std::vector> &inputIds_, workingGroup.clear(); batchSize = inputIds_.size(); - maxLen = std::min(maxLen, this->maxSeqLen); + maxLen = maxLen < 0 ? this->maxSeqLen : std::min(maxLen, this->maxSeqLen); if (messenger.getSize() > 1) { // TODO: Sync input and sampling param in distributed mode. @@ -391,6 +404,10 @@ std::vector Model::set_input(std::vector> &inputIds_, if (seqIDs.empty()) { // Prompt(1st token) // Create seq meta for inputs and return seq IDs + for (auto &ids : inputIds_) { + maxLen = std::max(maxLen, (int)ids.size()); + } + for (int i = 0; i < batchSize; i++) { auto group = seqPool.newGroupMeta(inputIds_[i]); workingGroup.push_back(group); From 524bf3286728f654dadab89f39b334b6bc42ac1b Mon Sep 17 00:00:00 2001 From: pujiang2018 Date: Mon, 13 May 2024 16:13:33 +0800 Subject: [PATCH 32/35] [Layer] Better method to reinterpret KV cache (#397) * [Common] Add sequenceMeta, sequenceGroup and sequenecePool. (#343) * merge batchSize and seqLen into one in TokenEembedding * merge batchSize and seqLen into one in TokenEembedding (#350) * [Common] Move Martix into xft namespace. (#351) * remove unsed function in DecoderLayer * [Layer] Remove unused functions in Decoder layer (#353) * fix compile error of embeddingForward * [Model] Fix compile error of embeddingForward in YaRNLlama (#358) * [Common] Add sampling params into group seq. (#356) * remove DecoderContext in computeSoftmax * [Util] Remove DecoderContext in computeSoftmax (#362) * [Common] Refactor sequence.h. (#363) * [kernels] refactor flash attention for continuous batching (#361) * [models] Add attnMeta for continuous batching (#364) * [Layers] fix build error (#365) * [Model] add interface for seq meta. (#366) * refactor resize function in DecoderContext to support CB, and qkScores member removed * [Common] Modify resize() in DecoderContext to support (#367) * add some code to CommonDecoder::forward() * SequenceMeta refactor * [Model] New CommonDecoder::forward impl. skeleton (#369) * new KVCacheMgr supporting CB * fix typo & set default prefixId to -1 in addSequence() * [Common] New KVCacheMgr to support CB (#371) * [Sampling] Add repetition penalty for new seq type. (#373) * New foward to support CB (CommonDecoder->DecoderBlock->DecoderLayer->Attention/MLP) * add todo * [Sampling] Add greedy search for cb path. (#376) * logic issue fix * code fix to make new forward work * add maxSeqLen limitation * cross attention impl. for CB * DecoderContext::resize fix * correct the output of the new forward * add cb_check * fix incorrect buffer size calculation * 2 sequences -> 3 sequences * better method to prepare KV cache --------- Co-authored-by: Changqing Li Co-authored-by: Duyi-Wang Co-authored-by: Meng,Chen --- src/layers/decoder_block.h | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h index 0cfff71c..4a18c13d 100644 --- a/src/layers/decoder_block.h +++ b/src/layers/decoder_block.h @@ -91,19 +91,10 @@ class DecoderBlock { std::vector keyCaches = kvCacheMgr.getKey(i); std::vector valueCaches = kvCacheMgr.getValue(i); - std::vector *> keyCachesVec(keyCaches.size()); - std::vector *> valueCachesVec(valueCaches.size()); - - // TODO: better method? - for (int j = 0; j < keyCaches.size(); ++j) { - keyCachesVec[j] = static_cast *>(keyCaches[j]); - } - - for (int j = 0; j < valueCaches.size(); ++j) { - valueCachesVec[j] = static_cast *>(valueCaches[j]); - } - - this->decoders[i]->forwardAttention(ctx, seqs, input, attnOut, totInSeqLen, keyCachesVec, valueCachesVec); + // Reinterpret the keyCaches and valueCaches to the correct type + this->decoders[i]->forwardAttention(ctx, seqs, input, attnOut, totInSeqLen, + *reinterpret_cast *> *>(&keyCaches), + *reinterpret_cast *> *>(&valueCaches)); // Merge the result of attention // When attention and FFN/MLP are in parallel, do not need to reduce after attention From 3865654756ea4679ed3b80feec99db46e7af7d40 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 08:53:04 +0800 Subject: [PATCH 33/35] [Interface] Add python api for continuous batching. (#398) --- include/models.h | 6 +++ src/models/models.cpp | 70 +++++++++++++++++++++++++++++ src/pytorch/auto_model.h | 55 +++++++++++++++++++++++ src/pytorch/pytorch_warpper.cpp | 3 ++ src/xfastertransformer/automodel.py | 9 ++++ 5 files changed, 143 insertions(+) diff --git a/include/models.h b/include/models.h index 56cf2121..0bb64efb 100644 --- a/include/models.h +++ b/include/models.h @@ -60,6 +60,10 @@ class Model { std::vector set_input( std::vector> &inputIds_, std::vector seqIDs = {}, int maxLen = -1); + // Only used for model.forward() + std::vector set_input( + std::vector &inputIds_, int batchSize_, std::vector seqIDs = {}, int maxLen = -1); + bool isDone(); std::tuple forward(bool logits_all = true); @@ -98,6 +102,8 @@ class Model { bool setStopWords(std::vector> stopWordsList); + bool freeSeqs(std::vector &seqIDs); + private: AbstractDecoder *decoder; AbstractSearcher *searcher; diff --git a/src/models/models.cpp b/src/models/models.cpp index 708c030c..cf182798 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -439,6 +439,74 @@ std::vector Model::set_input(std::vector> &inputIds_, return seqIDs; } +std::vector Model::set_input( + std::vector &inputIds_, int batchSize_, std::vector seqIDs, int maxLen) { + Messenger &messenger = Messenger::getInstance(); + SequencePool &seqPool = SequencePool::getInstance(); + KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); + workingGroup.clear(); + batchSize = batchSize_; + seqLen = inputIds_.size() / batchSize; + + maxLen = maxLen < 0 ? this->maxSeqLen : std::min(maxLen, this->maxSeqLen); + maxLen = std::max(maxLen, seqLen); + + if (messenger.getSize() > 1) { + // TODO: Sync input and sampling param in distributed mode. + // [batch_size, total_length, seqID_size, maxLen] + } + if (seqIDs.empty()) { + // Prompt(1st token) + // Create seq meta for inputs and return seq IDs + for (int i = 0; i < batchSize; i++) { + std::vector inputTokens(inputIds_.begin() + i * seqLen, inputIds_.begin() + (i + 1) * seqLen); + auto group = seqPool.newGroupMeta(inputTokens); + workingGroup.push_back(group); + seqIDs.push_back(group->getGroupID()); + kvCacheMgr.addSequence(group->getGroupID(), maxLen); + } + } else { + // Decode(next token) + // Update seq meta with inputs and return seq IDs + if (inputIds_.size() != seqIDs.size()) { + printf("[ERROR] Input size and seqIDs size mismatch.\n"); + exit(-1); + } + if (inputIds_.size() != batchSize_) { + printf("[ERROR] Input size and batch size mismatch.\n"); + exit(-1); + } + for (int i = 0; i < batchSize; i++) { + auto group = seqPool.get(seqIDs[i]); + if (group == nullptr) { + // TODO: Address beam search case. + printf("[ERROR] Sequence ID %d not found.\n", seqIDs[i]); + exit(-1); + } + group->get(0)->stepForward(inputIds_[i]); + workingGroup.push_back(group); + if (!kvCacheMgr.exist(seqIDs[i])) { + printf("[ERROR] Sequence ID %d not found in KVCache.\n", seqIDs[i]); + exit(-1); + } + } + } + + return seqIDs; +} + +bool Model::freeSeqs(std::vector &seqIDs) { + // TODO: Sync + KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); + SequencePool &seqPool = SequencePool::getInstance(); + bool ret = true; + for (auto &id : seqIDs) { + ret = ret && kvCacheMgr.delSequence(id); + ret = ret && seqPool.remove(id); + } + return ret; +} + // TODO: Deprecate the following function void Model::config(SearcherConfig &config_, const std::vector> &stopWordsList_) { isNewInput = true; @@ -481,8 +549,10 @@ std::vector Model::finalize() { } // Clear KVCache KVCacheMgr &kvCacheMgr = KVCacheMgr::instance(); + SequencePool &seqPool = SequencePool::getInstance(); for (auto x : workingGroup) { kvCacheMgr.delSequence(x->getGroupID()); + seqPool.remove(x->getGroupID()); } workingGroup.clear(); diff --git a/src/pytorch/auto_model.h b/src/pytorch/auto_model.h index 8e213d87..0c467e33 100644 --- a/src/pytorch/auto_model.h +++ b/src/pytorch/auto_model.h @@ -212,6 +212,61 @@ struct TorchAutoModel : torch::CustomClassHolder { void unsetPrefix() { model->unsetPrefix(); }; + torch::Tensor forwardCB() { + // Forward for continuous batching + int batchSize = model->getBatchSize(); + int vocabSize = model->getVocabSize(); + + std::tuple result = model->forward(false); + float *outBuf = std::get<0>(result); + int sampleOffset = std::get<1>(result); + int sampleSize = std::get<2>(result); + + // Create a torch::Tensor from the C array + int64_t tdims[3] = {batchSize, 1, vocabSize}; + torch::Tensor ret = torch::from_blob(outBuf, tdims, torch::kFloat32); + return ret; + } + + torch::Tensor setInputCB(torch::optional inputIds_, torch::optional seqIDs_, + torch::optional maxLength) { + int batchSize = 0; + std::vector seqIDs; + if (model->getRank() == 0) { + TORCH_CHECK(inputIds_.has_value(), "Make sure master's input is not None.") + + batchSize = inputIds_.value().size(0); + int seqLen = inputIds_.value().size(1); + + torch::Tensor inputTensor = inputIds_.value().to(torch::kInt32); + + tokenIds.resize(batchSize * seqLen); + memcpy(tokenIds.data(), inputTensor.data_ptr(), batchSize * seqLen * sizeof(int)); + + if (seqIDs_.has_value()) { + torch::Tensor seqIDsTensor = seqIDs_.value().to(torch::kInt32); + TORCH_CHECK(batchSize == seqIDsTensor.size(0), "seqIDs size[0] must equal to inputIds size[0].") + seqIDs.resize(batchSize); + memcpy(seqIDs.data(), seqIDsTensor.data_ptr(), batchSize * sizeof(int)); + } + } + + seqIDs = model->set_input(tokenIds, batchSize, seqIDs, maxLength.value()); + torch::Tensor ret = torch::from_blob(seqIDs.data(), {batchSize}, torch::kInt32).to(torch::kInt64); + return ret; + } + + bool freeSeqs(torch::optional seqIDs_) { + std::vector seqIDs; + if (model->getRank() == 0) { + TORCH_CHECK(seqIDs_.has_value(), "Make sure master's input is not None.") + torch::Tensor seqIDsTensor = seqIDs_.value().to(torch::kInt32); + seqIDs.resize(seqIDsTensor.size(0)); + memcpy(seqIDs.data(), seqIDsTensor.data_ptr(), seqIDsTensor.size(0) * sizeof(int)); + } + return model->freeSeqs(seqIDs); + } + private: xft::Model *model; std::vector tokenIds; diff --git a/src/pytorch/pytorch_warpper.cpp b/src/pytorch/pytorch_warpper.cpp index 6803bd9d..a3358c9f 100644 --- a/src/pytorch/pytorch_warpper.cpp +++ b/src/pytorch/pytorch_warpper.cpp @@ -21,10 +21,13 @@ TORCH_LIBRARY(xfastertransformer, m) { .def("get_rank", &TorchAutoModel::getRank) .def("input", &TorchAutoModel::input) .def("config", &TorchAutoModel::config) + .def("set_input_cb", &TorchAutoModel::setInputCB) .def("is_done", &TorchAutoModel::isDone) .def("forward", &TorchAutoModel::forward) + .def("forward_cb", &TorchAutoModel::forwardCB) .def("generate", &TorchAutoModel::generate) .def("finalize", &TorchAutoModel::finalize) + .def("free_seqs", &TorchAutoModel::freeSeqs) .def("set_prefix", &TorchAutoModel::setPrefix) .def("unset_prefix", &TorchAutoModel::unsetPrefix); } diff --git a/src/xfastertransformer/automodel.py b/src/xfastertransformer/automodel.py index 1df9912c..73c90af2 100644 --- a/src/xfastertransformer/automodel.py +++ b/src/xfastertransformer/automodel.py @@ -48,6 +48,15 @@ def __init__(self, path, dtype: str = "fp16", kv_cache_dtype: str = "fp16"): def __call__(self, inputs, **kwargs): return self.model.forward(inputs) + def set_input_cb(self, input_ids, seq_ids, max_length): + return self.model.set_input_cb(input_ids, seq_ids, max_length) + + def forward_cb(self): + return self.model.forward_cb() + + def free_seqs(self, seq_ids): + return self.model.free_seqs(seq_ids) + @classmethod def from_pretrained(cls, path, dtype: str = "fp16", kv_cache_dtype: str = "fp16"): return cls(path, dtype, kv_cache_dtype) From af0aae879cb1f13fe440b0db71e2bf8ddcbb0386 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 09:08:44 +0800 Subject: [PATCH 34/35] [Example] Reactivate the old path. --- examples/cpp/example.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/cpp/example.cpp b/examples/cpp/example.cpp index ce8bd1f9..613b15f0 100644 --- a/examples/cpp/example.cpp +++ b/examples/cpp/example.cpp @@ -445,19 +445,19 @@ int main(int argc, char **argv) { for (int i = 0; i < loop; ++i) { secondIdCount = 0; - model.set_input(input, batchSize, /*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, - /*lenPenalty*/ 1.0, + // TODO: Deprecated this old path + model.config(/*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, /*lenPenalty*/ 1.0, /*doEarlyStopping*/ false, /*eosTokenId*/ -1, /*padTokenId*/ -1, /*doSample*/ doSample, /*temperature*/ temperature, /*topK*/ topK, /*topP*/ topP, /*repetitionPenalty*/ repetitionPenalty); - - // TODO: Deprecated - // Old Path - // model.config(/*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, /*lenPenalty*/ 1.0, + model.input(input, batchSize); + + // New path + // model.set_input(input, batchSize, /*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, + // /*lenPenalty*/ 1.0, // /*doEarlyStopping*/ false, /*eosTokenId*/ -1, /*padTokenId*/ -1, // /*doSample*/ doSample, /*temperature*/ temperature, // /*topK*/ topK, /*topP*/ topP, /*repetitionPenalty*/ repetitionPenalty); - // model.input(input, batchSize); std::vector firstIds; std::vector secondIds; From cef27bc44887d9695231ec3941d327d4aabc2f45 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 May 2024 09:20:24 +0800 Subject: [PATCH 35/35] [Build] Fix build issue. --- src/layers/decoder_layer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/decoder_layer.cpp b/src/layers/decoder_layer.cpp index b69ab3f0..d1017648 100644 --- a/src/layers/decoder_layer.cpp +++ b/src/layers/decoder_layer.cpp @@ -134,7 +134,7 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, } ctx->resize(batchSize, inputSeqLen, pastSeqLen); - hpj::Matrix actBuffers; + xft::Matrix actBuffers; actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize); float *attnMask = prepareAttnMask(ctx, step);