Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model/Layer] New forward to support CB (CommonDecoder->DecoderBlock->DecoderLayer->Attention/MLP) #375

Merged
merged 40 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8f351ba
[Common] Add sequenceMeta, sequenceGroup and sequenecePool. (#343)
changqi1 Apr 25, 2024
5ab63df
Merge commit 'd2b8df0c85ba57b62169d74c88192de3bf6e4820' into cb_dev
pujiang2018 Apr 26, 2024
dbcb267
merge batchSize and seqLen into one in TokenEembedding
pujiang2018 Apr 26, 2024
3949abd
merge batchSize and seqLen into one in TokenEembedding (#350)
pujiang2018 Apr 26, 2024
9a53fb2
[Common] Move Martix into xft namespace. (#351)
Duyi-Wang Apr 26, 2024
25ee312
Merge commit '9a53fb2ea6b9141ba7c045bc0d135c1809e8f22c' into pujiang/…
pujiang2018 Apr 26, 2024
376b2bc
remove unsed function in DecoderLayer
pujiang2018 Apr 26, 2024
7112b33
[Layer] Remove unused functions in Decoder layer (#353)
pujiang2018 Apr 26, 2024
4ff4707
Merge commit '819ecccfa06662bddb16e9e1cd1ec8775d1c0180' into cb_dev
pujiang2018 Apr 29, 2024
d281a54
Merge commit '4ff47074fc85a27e13251c3fb618f36e338c456f' into pujiang/…
pujiang2018 Apr 29, 2024
b5b225a
fix compile error of embeddingForward
pujiang2018 Apr 29, 2024
f8f8571
[Model] Fix compile error of embeddingForward in YaRNLlama (#358)
pujiang2018 Apr 29, 2024
b95dac1
[Common] Add sampling params into group seq. (#356)
Duyi-Wang Apr 29, 2024
d5c9407
remove DecoderContext in computeSoftmax
pujiang2018 Apr 29, 2024
be615b2
Merge commit 'f8f85714331c0df2ce4a8344e06972316770ec11' into pujiang/…
pujiang2018 Apr 29, 2024
e48ea1f
[Util] Remove DecoderContext in computeSoftmax (#362)
pujiang2018 Apr 29, 2024
7ad311e
[Common] Refactor sequence.h. (#363)
Duyi-Wang Apr 30, 2024
9e7bdca
[kernels] refactor flash attention for continuous batching (#361)
abenmao Apr 30, 2024
cfab63a
[models] Add attnMeta for continuous batching (#364)
abenmao Apr 30, 2024
deabd33
[Layers] fix build error (#365)
abenmao Apr 30, 2024
2499f60
[Model] add interface for seq meta. (#366)
Duyi-Wang Apr 30, 2024
5833d41
Merge commit '2499f602c22184ca5afaa2f013ae0ff4e3bd4263' into pujiang/…
pujiang2018 May 3, 2024
0514833
refactor resize function in DecoderContext to support CB, and qkScore…
pujiang2018 May 6, 2024
c792aff
[Common] Modify resize() in DecoderContext to support (#367)
pujiang2018 May 6, 2024
5315a5a
Merge commit 'c792aff5f7cc8554afa0399f4b6d241333b0b56c' into pujiang/…
pujiang2018 May 6, 2024
63e895a
add some code to CommonDecoder::forward()
pujiang2018 May 6, 2024
9cfc7c6
SequenceMeta refactor
pujiang2018 May 7, 2024
2e05716
[Model] New CommonDecoder::forward impl. skeleton (#369)
pujiang2018 May 7, 2024
41e692c
new KVCacheMgr supporting CB
pujiang2018 May 7, 2024
32c845c
Merge commit '2e057165b456ef9b88591a880403ffe47e7500c3' into pujiang/…
pujiang2018 May 7, 2024
c2ac8d2
fix typo & set default prefixId to -1 in addSequence()
pujiang2018 May 7, 2024
dfa1d0e
[Common] New KVCacheMgr to support CB (#371)
pujiang2018 May 7, 2024
e79c9c2
[Sampling] Add repetition penalty for new seq type. (#373)
Duyi-Wang May 8, 2024
1e93ef0
Merge commit 'e79c9c21825ad51116f49a290baa7eea80e25b07' into pujiang/…
pujiang2018 May 8, 2024
689b41a
New foward to support CB (CommonDecoder->DecoderBlock->DecoderLayer->…
pujiang2018 May 8, 2024
3ce8d06
add todo
pujiang2018 May 8, 2024
9377f83
[Sampling] Add greedy search for cb path. (#376)
Duyi-Wang May 8, 2024
b71ec05
Merge commit '9377f8371ebfe4999c2171eceab106a52ff2618a' into pujiang/…
pujiang2018 May 8, 2024
3144675
logic issue fix
pujiang2018 May 8, 2024
fcee295
Merge branch 'cb_dev' into pujiang/feature/cb_dev
pujiang2018 May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/abstract_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <cstdint>
#include <tuple>
#include "sequence.h"

class DecoderContext;
class Messenger;
Expand All @@ -35,6 +36,8 @@ class AbstractDecoder {
// |<----------------------- vocabSize ----------------------------->|
virtual std::tuple<float *, int, int> forward(int *ids, int64_t *dims, int step, bool logits_all = false) = 0;

virtual std::tuple<float *, int, int> forward(std::vector<xft::SequenceMeta *> &seq, bool logits_all = false) = 0;

// Reorder cached keys and values, size=batchSize*beamSize
virtual void reorderCache(int *idx, int size) = 0;

Expand Down
11 changes: 10 additions & 1 deletion include/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,17 @@ class Model {

void config(SearcherConfig &config_, const std::vector<std::vector<int>> &stopWordsList_ = {});

void set_input(std::vector<int32_t> &inputIds_, int batchSize_, int maxLen_ = -1, int numBeams_ = 1,
int numBeamHypsToKeep_ = 1, float lenPenalty_ = 1.0, bool doEarlyStopping_ = false, int eosTokenId_ = -1,
int padTokenId_ = -1, bool doSample_ = false, float temperature_ = 1.0, int topK_ = 50, float topP_ = 1.0,
float repetitionPenalty_ = 1.0, const std::vector<std::vector<int>> &stopWordsList_ = {});

void set_input(std::vector<int32_t> &inputIds_, int batchSize_, SearcherConfig &config_,
const std::vector<std::vector<int>> &stopWordsList_ = {});

bool isDone();

std::tuple<float *, int, int> forward();
std::tuple<float *, int, int> forward(bool logits_all = true);

std::vector<int32_t> generate();

Expand Down Expand Up @@ -79,6 +87,7 @@ class Model {
int vocabSize;
SearcherConfig configuration;
bool isNewInput;
std::vector<SequenceGroupMeta *> workingGroup;
};

class AutoModel : public Model {
Expand Down
70 changes: 70 additions & 0 deletions src/common/attn_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#pragma once

class AttnMetaData {

public:
AttnMetaData () : batchSize(0), attnMask(nullptr) {}

AttnMetaData (int batchSize, int *inputTokenSizes, int *pastTokenSizes, bool isPrompt, bool isCausal, float *attnMask = nullptr)
: batchSize(batchSize), isPrompt(isPrompt), isCausal(isCausal), attnMask(attnMask) {
// causal=True, no need mask
assert(isCausal && attnMask == nullptr)
// causal=False, need mask
assert(!isCausal && attnMask)

// fill inputSeqLens, pastSeqLens, seqStartLoc
inputSeqLens.resize(batchSize);
pastSeqLens.resize(batchSize);
seqStartLoc.resize(batchSize + 1);

seqStartLoc[0] = 0;
for (int i = 0; i < batchSize; i++) {
inputSeqLens[i] = inputTokenSizes[i];
pastSeqLens[i] = pastTokenSizes[i];
seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i];
}

AttnMetaData (vector<int> &inputTokens, vector<int> &pastTokens, bool isPrompt, bool isCausal, float *attnMask = nullptr)
: batchSize(inputTokenSizes.size()), isPrompt(isPrompt), isCausal(isCausal), attnMask(attnMask),
inputSeqLens(inputTokenSizes), pastSeqLens(pastTokenSizes){
// causal=True, no need mask
assert(isCausal && attnMask == nullptr)
// causal=False, need mask
assert(!isCausal && attnMask)

// fill seqStartLoc
seqStartLoc.resize(batchSize + 1);

seqStartLoc[0] = 0;
for (int i = 0; i < batchSize; i++) {
seqStartLoc[i + 1] = seqStartLoc[i] + inputSeqLens[i];
}

}

private:
bool isPrompt;
bool isCausal;

int batchSize;
std::vector<int> inputSeqLens;
std::vector<int> pastSeqLens;
std::vector<int> seqStartLoc;

float *attnMask;

};
216 changes: 216 additions & 0 deletions src/common/kvcache_mgr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#pragma once

#include <vector>
#include "kvcache_tensor.h"
#include <unordered_map>

namespace xft {

class KVCacheMgrImplBase {
public:
virtual ~KVCacheMgrImplBase() = default;
virtual bool delSequence(int seqID) = 0;
virtual bool addSequence(int seqID, int prefixId = -1) = 0;
virtual bool reorderCache(const std::vector<int> &seqIDs, const std::vector<int> &prevSeqIDs) = 0;
virtual bool addPrefix(int prefixId, int seqID) = 0;
virtual bool prepareCache(const std::vector<int> &seqIDs) = 0;
virtual std::vector<void *> getKey(int layerId) = 0;
virtual std::vector<void *> getValue(int layerId) = 0;
};

template <typename T>
class KVCacheMgrImpl : public KVCacheMgrImplBase {
public:
KVCacheMgrImpl(int layers) { this->layers = layers; }

~KVCacheMgrImpl() {
// Free resource in cachePool (readyCaches are in cachePool too)
for (auto &it : sequenceCaches) {
delete it.second;
}
// Free resource in prefixCaches
for (auto &it : prefixCaches) {
delete it.second;
}
// Free resource in freeCaches
for (auto &it : freeCaches) {
delete it;
}
}

// Free KVCache by sample ID.
bool delSequence(int seqID) override {
auto it = sequenceCaches.find(seqID);

// Fail if not exist
if (it == sequenceCaches.end()) { return false; }

// Move from sequenceCaches to freeCaches
freeCaches.push_back(it->second);

sequenceCaches.erase(it);

return true;
}

bool addSequence(int seqID, int prefixId = -1) override {
// Fail if already exist
if (sequenceCaches.find(seqID) != sequenceCaches.end()) { return false; }

// Get a free cache or create a new one
KVCacheTensor<T> *cache = nullptr;
if (!freeCaches.empty()) {
cache = freeCaches.back();
freeCaches.pop_back();
} else {
cache = new KVCacheTensor<T>[2 * layers];
}

sequenceCaches.insert({seqID, cache});

return true;
}

// Reorder cache based on prevSeqIDs for beam search (caches reordered from prevSeqIDs to seqIDs)
// For example, if seqIDs = {1, 2, 3, 4} and prevSeqIDs = {1, 1, 1, 1}, then means to expand cache for sample 1
bool reorderCache(const std::vector<int> &seqIDs, const std::vector<int> &prevSeqIDs) override {
// TODO: implement reorderCache
return false;
}

// Create KVCache for prefix sharing
bool addPrefix(int prefixId, int seqID) override {
// Fail if already exist
if (prefixCaches.find(prefixId) != prefixCaches.end()) { return false; }

// Cannot find the sample cache
if (sequenceCaches.find(seqID) == sequenceCaches.end()) { return false; }

// Create a new one
KVCacheTensor<T> *cache = new KVCacheTensor<T>[2 * layers];

for (int i = 0; i < 2 * layers; i++) {
// TODO: add from method in KVCacheTensor
//cache[i].from(sequenceCaches[seqID][i]);
}

prefixCaches.insert({prefixId, cache});

return true;
}

// Set cache to be ready for this order of sampleIds
bool prepareCache(const std::vector<int> &seqIDs) override {
std::vector<KVCacheTensor<T> *> readyList;
readyList.reserve(seqIDs.size());

for (auto seqID : seqIDs) {
auto it = sequenceCaches.find(seqID);
if (it == sequenceCaches.end()) { return false; }
readyList.push_back(it->second);
}

readyCaches = std::move(readyList);

return true;
}

// Get key caches for a layer
std::vector<void *> getKey(int layerId) override {
std::vector<void *> keyCaches;
keyCaches.reserve(readyCaches.size());
for (auto cache : readyCaches) {
keyCaches.push_back(&cache[2 * layerId]);
}
return keyCaches;
}

// Get value caches for a layer
std::vector<void *> getValue(int layerId) override {
std::vector<void *> valueCaches;
valueCaches.reserve(readyCaches.size());
for (auto cache : readyCaches) {
valueCaches.push_back(&cache[2 * layerId + 1]);
}
return valueCaches;
}

private:
// seqID -> pointer to an array of caches (each element is a KVCacheTensor, size=2*layers)
// Layout of each array is:
// <Key cache for layer 0>
// <Value cache for layer 0>
// <Key cache for layer 1>
// <Value cache for layer 1>
// ...
std::unordered_map<int, KVCacheTensor<T> *> sequenceCaches;

// prefixID -> pointer to an array of caches (each element is a KVCacheTensor, size=2*layers)
std::unordered_map<int, KVCacheTensor<T> *> prefixCaches;

// List of ready caches, each element is for a sample; subset of sequenceCaches
std::vector<KVCacheTensor<T> *> readyCaches;

// List of pending free caches, each element is for a sample
std::vector<KVCacheTensor<T> *> freeCaches;

int layers;
};

class KVCacheMgr {
public:
static KVCacheMgr &instance() {
static KVCacheMgr inst;
return inst;
}

void configure(int layers, DataType dataType) {
switch (dataType) {
case DataType::int8: cacheMgrImpl = new KVCacheMgrImpl<int8_t>(layers); break;
case DataType::fp16: cacheMgrImpl = new KVCacheMgrImpl<float16_t>(layers); break;
default: cacheMgrImpl = new KVCacheMgrImpl<float16_t>(layers); break;
}
}

bool delSequence(int seqID) { return cacheMgrImpl->delSequence(seqID); }

bool addSequence(int seqID, int prefixId = -1) { return cacheMgrImpl->addSequence(seqID, prefixId); }

bool reorderCache(const std::vector<int> &seqIDs, const std::vector<int> &prevSeqIDs) {
return cacheMgrImpl->reorderCache(seqIDs, prevSeqIDs);
}

bool addPrefix(int prefixId, int seqID) { return cacheMgrImpl->addPrefix(prefixId, seqID); }

bool prepareCache(const std::vector<int> &seqIDs) { return cacheMgrImpl->prepareCache(seqIDs); }

std::vector<void *> getKey(int layerId) { return cacheMgrImpl->getKey(layerId); }

std::vector<void *> getValue(int layerId) { return cacheMgrImpl->getValue(layerId); }

private:
KVCacheMgrImplBase *cacheMgrImpl;

KVCacheMgr() : cacheMgrImpl(nullptr) {}

~KVCacheMgr() { delete cacheMgrImpl; }

KVCacheMgr(const KVCacheMgr &) = delete;
KVCacheMgr &operator=(const KVCacheMgr &) = delete;
};

} // namespace xft
4 changes: 2 additions & 2 deletions src/common/my_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void *xft_numa_alloc(size_t size);
void xft_numa_free(void *start, size_t size);
}

namespace hpj {
namespace xft {

template <typename T>
struct is_quantization_type {
Expand Down Expand Up @@ -366,4 +366,4 @@ class Vector {
}
uint64_t Size() { return size; }
};
} // namespace hpj
} // namespace xft
Loading