Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Achieve whole pipeline parallel. #355

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3c509ab
[Build] Fix build issue.
changqi1 Apr 15, 2024
24de368
[Model] Init pipeline parallel.
changqi1 Apr 18, 2024
7f23d07
could run input 3 and 4 in mpi 2 and 4.
changqi1 Apr 19, 2024
4f67f28
use threadpool singleton
changqi1 Apr 19, 2024
d7a4304
format code
changqi1 Apr 19, 2024
70c17cd
format code
changqi1 Apr 19, 2024
c1417ee
Merge branch 'main' into changqing/bug/fix_build_context
changqi1 Apr 19, 2024
3eca8ef
remove non-Master thread code
changqi1 Apr 19, 2024
2216a40
format code
changqi1 Apr 19, 2024
143294b
format code
changqi1 Apr 19, 2024
700d0c8
move prompt.h
changqi1 Apr 19, 2024
8295d0b
Add comments
changqi1 Apr 19, 2024
93cbade
input 2 request
changqi1 Apr 22, 2024
dadc692
format code
changqi1 Apr 22, 2024
0d0c74b
modify promptID to sampleID.
changqi1 Apr 22, 2024
291bd28
rename filename
changqi1 Apr 22, 2024
6ffe206
format code
changqi1 Apr 22, 2024
6051c84
format code
changqi1 Apr 22, 2024
d4bd7dc
format code
changqi1 Apr 22, 2024
f7708af
update samplePool
changqi1 Apr 22, 2024
cf4ff41
update samplePool
changqi1 Apr 22, 2024
e96c1a8
update code
changqi1 Apr 22, 2024
dc1a9e5
sampleID to seqenceID
changqi1 Apr 22, 2024
171f451
update
changqi1 Apr 22, 2024
1df416e
udpate
changqi1 Apr 22, 2024
3b287df
udpate
changqi1 Apr 22, 2024
8d5dfe1
update
changqi1 Apr 22, 2024
d77e7c1
update
changqi1 Apr 22, 2024
fffd7b6
add PIPELINE_PARALLEL macro
changqi1 Apr 23, 2024
dcff843
Update pp inputs
changqi1 Apr 25, 2024
90cabf8
run good
changqi1 Apr 28, 2024
e9d5437
run good
changqi1 Apr 28, 2024
9e1b770
could run
changqi1 Apr 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cpp/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ int main(int argc, char **argv) {
}
auto result = model.finalize();

if (isMaster) {
if (true) {
std::cout << "\n[INFO] Final output is: " << std::endl;
std::vector<std::string> sent = tokenizer->batchDecode(result, batchSize);
for (auto str : sent) {
Expand Down
313 changes: 313 additions & 0 deletions src/common/sequence.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#pragma once

#include <cstdint>
#include <queue>
#include <unordered_map>

/*
SequencePool
┌──────┬──────┬──────┐
│ │ │ ◄───┼──┬─ SequenceMeta
├──────┼──────┼──────┤ │
BatchInputs │ │ │ ◄───┼──┘
│ └▲─┬─▲─┴──────┴──────┘
│ │ │ └───────────────────────────────────┐
▼ ┌──┬──┬──┬──┐ │ │ ┌──┬──┬──┬──┬──┬──┬──┬──┬──┐ │
Input ─►│ │ │ │ ├──┘ └─────►│ │ │ │ │ │ │ │ │ ├─┐ │
└──┴──┴──┴──┘ └──┴──┴──┴──┴──┴──┴──┴──┴──┘ │ │
InputQueue TaskWaitingQueue0 │ │
┌───────────────────────────────┘ │
│ ┌──┬──┬──┬──┬──┬──┬──┬──┬──┐ │
└─►│ │ │ │ │ │ │ │ │ ├───┘
└──┴──┴──┴──┴──┴──┴──┴──┴──┘
TaskWaitingQueue1
*/

namespace xft {

// The SequenceMeta is one sequence of batch inputs and includes the generated tokens.
class SequenceMeta {
public:
SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen, std::vector<int32_t> &_inputTokens)
: sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0), step(0) {
inputTokens.resize(_inputSeqLen);
inputTokens.assign(_inputTokens.begin(), _inputTokens.end());
nextTokens.resize(_inputSeqLen);
setPastSeqLen(getPastSeqLen());
}

SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen)
: sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), inputTokens(_inputSeqLen, 0), pastSeqLen(0), step(0) {
nextTokens.resize(_inputSeqLen);
}

~SequenceMeta() {}

int32_t getSequenceID() const { return sequenceID; }

// For first tokens
void stepForward() {
if (getStep() == 0) {
setPastSeqLen(inputTokens.size());
setStep(getStep() + 1);
}
}

// For next token
void stepForward(int32_t token) {
// addNextToken(token);
setPastSeqLen(getPastSeqLen() + 1);
setStep(getStep() + 1);
}

// Get the input tokens in sequence
int32_t getInputSeqLen() const { return inputSeqLen; }

const int32_t *getInputTokens() const { return inputTokens.data(); }

int32_t getPastSeqLen() const { return pastSeqLen; }

void setPastSeqLen(int32_t _pastSeqLen) { pastSeqLen = _pastSeqLen; }

// For next tokens
void addNextToken(int32_t token) {
nextTokens.clear();
nextTokens.push_back(token);
inputTokens.push_back(token);
}

int32_t getLatestToken() const { return nextTokens.back(); }

const int32_t *getTotalTokens() const { return getInputTokens(); }

int32_t getStep() const { return step; }

void setStep(int32_t _step) { step = _step; }

private:
int32_t sequenceID;
int32_t inputSeqLen;
int32_t pastSeqLen;
std::vector<int32_t> inputTokens; // input tokens + next tokens
std::vector<int32_t> nextTokens; // next tokens
int32_t step;

#ifdef PIPELINE_PARALLEL
public:
template <typename T>
void allocBuffer(int32_t hiddenSize, void *_hiddenStates) {
hiddenStates = xft::alloc(sizeof(T) * getInputSeqLen() * hiddenSize);
memcpy(hiddenStates, _hiddenStates, sizeof(T) * getInputSeqLen() * hiddenSize);
}

int32_t getHiddenStatesSize() const { return hiddenStatesSize; }

void setHiddenStatesSize(int32_t _hiddenStatesSize) { hiddenStatesSize = _hiddenStatesSize; }

private:
int32_t hiddenSize;
int64_t hiddenStatesSize;
void *hiddenStates;
#endif
};

// For beam searcher
class SequenceGroupMeta {
public:
SequenceGroupMeta(int32_t _num_beams, std::vector<SequenceMeta *> &seq) {
num_beams = _num_beams;
sequences = seq;
}

private:
int32_t num_beams;
std::vector<SequenceMeta *> sequences;
};

// SequencePool
// ┌──────┬──────┬──────┐
// │ │ │ ◄───┼──┬─ SequenceMeta
// ├──────┼──────┼──────┤ │
// │ │ │ ◄───┼──┘
// └──────┴──────┴──────┘
class SequencePool {
public:
static SequencePool &getInstance() {
static SequencePool instance;
return instance;
}

int32_t createSequenceID() {
int32_t id = globalSequenceID++;
if (id >= 10 * 1024) {
globalSequenceID = 0;
id = globalSequenceID++;
}
return id;
}

SequenceMeta *createMeta(int32_t sequenceID, int32_t inputSeqLen, std::vector<int32_t> &inputTokens) {
auto *sequenceMeta = new SequenceMeta(sequenceID, inputSeqLen, inputTokens);
return sequenceMeta;
}

SequenceMeta *createMeta(int32_t sequenceID, int32_t inputSeqLen) {
auto *sequenceMeta = new SequenceMeta(sequenceID, inputSeqLen);
return sequenceMeta;
}

bool add(int32_t sequenceID, SequenceMeta *sequence, bool force = false) {
bool isSuccess = false;
if (force) {
auto it = hub.find(sequenceID);
if (it != hub.end()) { remove(it->first, true); }

hub[sequenceID] = sequence;
isSuccess = true;
} else {
bool exist = has(sequenceID);
if (!exist) {
hub[sequenceID] = sequence;
isSuccess = true;
}
}

return isSuccess;
}

bool has(int32_t sequenceID) const { return hub.find(sequenceID) != hub.end(); }

SequenceMeta *get(int32_t sequenceID) const {
auto it = hub.find(sequenceID);
if (it != hub.end()) {
return it->second;
} else {
return nullptr;
}
}

bool remove(int32_t sequenceID, bool deep = false) {
bool isSuccess = false;
if (has(sequenceID)) {
if (deep == true) {
auto it = hub.find(sequenceID);
if (it != hub.end()) { delete it->second; }
}
isSuccess = hub.erase(sequenceID);
}

return isSuccess;
}

bool replace(int32_t sequenceID, SequenceMeta *newSequenceMeta) {
bool isSuccess = false;
auto it = hub.find(sequenceID);
if (it != hub.end()) {
remove(it->first, true);
hub[sequenceID] = newSequenceMeta;
isSuccess = true;
}

return isSuccess;
}

void clear() {
for (auto &it : hub) {
delete it.second;
}
hub.clear();
globalSequenceID = 0;
}

private:
SequencePool() {}

int32_t globalSequenceID = 0;
std::unordered_map<int32_t, SequenceMeta *> hub;
};

// Manage input sequenceMeta
class InputQueue {
public:
static InputQueue &getInstance() {
static InputQueue instance;
return instance;
}

bool empty() { return queue.empty(); }

SequenceMeta *pop() {
auto seq = queue.front();
queue.pop();
return seq;
}

void push(SequenceMeta *seq) { queue.push(seq); }

void clear() {
while (!queue.empty()) {
queue.pop();
}
}

private:
InputQueue() {}

std::queue<SequenceMeta *> queue;
};

// Manage executive sequenceMeta
class TaskWaitingQueue {
public:
static TaskWaitingQueue &getInstance() {
static TaskWaitingQueue instance;
return instance;
}

bool empty() { return queue.empty(); }

int32_t size() { return queue.size(); }

bool isFull() {
bool full = false;
if (this->size() >= Env::getInstance().getMaxRequestNum()) { full = true; }
return full;
}

SequenceMeta *front() { return queue.front(); }

SequenceMeta *pop() {
auto seq = queue.front();
queue.pop();
return seq;
}

void push(SequenceMeta *seq) { queue.push(seq); }

void clear() {
while (!queue.empty()) {
queue.pop();
}
}

private:
TaskWaitingQueue() {}

std::queue<SequenceMeta *> queue;
};

} // namespace xft
6 changes: 5 additions & 1 deletion src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ struct DecoderContext {
// For custom usage
int reserved1;

#ifdef PIPELINE_PARALLEL
int32_t sequenceID;
#endif

// Model structure configuration
int vocabSize;
int embeddingSize;
Expand Down Expand Up @@ -319,4 +323,4 @@ struct DecoderContext {
}

~DecoderContext() { free(this->rawBuffer); }
};
};
Loading