Skip to content

Commit

Permalink
[Sampling] Add greedy search for cb path. (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
Duyi-Wang committed May 8, 2024
1 parent e79c9c2 commit 9377f83
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 3 deletions.
13 changes: 11 additions & 2 deletions examples/cpp/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,20 @@ int main(int argc, char **argv) {

for (int i = 0; i < loop; ++i) {
secondIdCount = 0;
model.config(/*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, /*lenPenalty*/ 1.0,

model.set_input(input, batchSize, /*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1,
/*lenPenalty*/ 1.0,
/*doEarlyStopping*/ false, /*eosTokenId*/ -1, /*padTokenId*/ -1,
/*doSample*/ doSample, /*temperature*/ temperature,
/*topK*/ topK, /*topP*/ topP, /*repetitionPenalty*/ repetitionPenalty);
model.input(input, batchSize);

// TODO: Deprecated
// Old Path
// model.config(/*maxLen*/ maxLen, /*numBeams*/ numBeams, /*numBeamHypsToKeep*/ 1, /*lenPenalty*/ 1.0,
// /*doEarlyStopping*/ false, /*eosTokenId*/ -1, /*padTokenId*/ -1,
// /*doSample*/ doSample, /*temperature*/ temperature,
// /*topK*/ topK, /*topP*/ topP, /*repetitionPenalty*/ repetitionPenalty);
// model.input(input, batchSize);

std::vector<int> firstIds;
std::vector<int> secondIds;
Expand Down
2 changes: 1 addition & 1 deletion include/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Model {

void setDecoder(AbstractDecoder *dec);

std::vector<int32_t> finalize() { return searcher->finalize(); }
std::vector<int32_t> finalize();

void exitSlaves();

Expand Down
2 changes: 2 additions & 0 deletions src/common/sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class SequenceMeta {
}
}

int getTotalLen() const { return promptTokens.size() + generatedTokens.size(); }

int32_t getPastSeqLen() const { return pastSeqLen; }

void setPastSeqLen(int32_t _pastSeqLen) { pastSeqLen = _pastSeqLen; }
Expand Down
55 changes: 55 additions & 0 deletions src/models/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "opt_decoder.h"
#include "qwen.h"
#include "qwen2.h"
#include "sampling.h"
#include "search_utils.h"
#include "searcher.h"
#include "sequence.h"
#include "timeline.h"
Expand Down Expand Up @@ -249,6 +251,21 @@ bool Model::isDone() {
return true;
}

std::vector<int32_t> Model::finalize() {
// TODO: Deprecate the following Path
if (searcher != nullptr) {
return searcher->finalize();
} else {
std::vector<int32_t> result;
// TODO: Unequal-length input & output
for (auto x : workingGroup) {
std::vector<int32_t> seq = x->get(0)->getTotalTokens();
result.insert(result.end(), seq.begin(), seq.end());
}
return result;
}
}

std::tuple<float *, int, int> Model::forward(bool logits_all) {
// TODO: Deprecate the following Path
if (searcher != nullptr) {
Expand All @@ -275,6 +292,8 @@ std::tuple<float *, int, int> Model::forward(bool logits_all) {
return decoder->forward(workingSeqs, logits_all);
}

// We assume all gen kwargs in the batch are the same
// and all sequences are all prompts(step==0) or all decodes(step>0)
std::vector<int32_t> Model::generate() {
// TODO: Deprecate the following Path
if (searcher != nullptr) {
Expand All @@ -294,11 +313,47 @@ std::vector<int32_t> Model::generate() {
float *outBuf = std::get<0>(result);
int sampleOffset = std::get<1>(result);
int sampleSize = std::get<2>(result);

// Assume all gen kwargs in the batch are the same
auto &config = workingGroup[0]->getSamplingMeta()->config;

if (config.numBeams != 1) {
// TODO: BeamSearch
throw std::logic_error("Beam Search Method not implemented");
} else {

// Logits processor
// Repetition penalty
if (config.repetitionPenalty != 1.0) {
repetitionPenaltyLogitsProcess(outBuf, sampleOffset, sampleSize, workingGroup);
}

std::vector<int> result;

if (config.doSample) {
//TODO: samling
throw std::logic_error("Sampling Method not implemented");
} else {
// Greedy search
result = greedySearch(outBuf, sampleOffset, sampleSize, batchSize);
}

// Check stop status
stopCheck(result, workingGroup);

// Step forward on all seqs
for (int i = 0; i < workingGroup.size(); i++) {
workingGroup[i]->get(0)->stepForward(result[i]);
}

return result;
}
throw std::logic_error("Method not implemented");
return {};
}
}

// TODO: Deprecate the following function
void Model::createSearcher(SearcherConfig &config_) {
if (searcher != nullptr) { delete searcher; }

Expand Down
155 changes: 155 additions & 0 deletions src/searchers/sampling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#include <omp.h>

#include "sampling.h"
#include "timeline.h"

namespace xft {
// Assume all samples have the same sampling params.
std::vector<int> greedySearch(float *logits, int sampleOffset, int sampleSize, int batchSize) {
TimeLine t("GreedySearch");

Messenger &messenger = Messenger::getInstance();
int numThreads = 0;
#pragma omp parallel
{
int tid = omp_get_thread_num();
if (tid == 0) { numThreads = omp_get_num_threads(); }
}

auto msgerSize = messenger.getSize();

// Max ID and value for each sample
std::vector<int> maxIds(batchSize);
float maxVals[batchSize];

// Small batch size (each sample can have at least 2 threads)
if (numThreads / batchSize >= 2) {
int thrPerSample = numThreads / batchSize;
int sizePerThr = (sampleSize + thrPerSample - 1) / thrPerSample;
int maxIndices[batchSize * thrPerSample];
float maxValues[batchSize * thrPerSample];

// TODO: if size is small, possible to cause out of boundary
#pragma omp parallel for collapse(2)
for (int b = 0; b < batchSize; ++b) {
for (int t = 0; t < thrPerSample; ++t) { // thread index inside the sample
int start = t * sizePerThr;
int end = (start + sizePerThr) > sampleSize ? sampleSize : (start + sizePerThr);
float *p = logits + b * sampleSize;

int maxIdx = start;
float maxVal = p[start];
for (int off = start + 1; off < end; ++off) {
if (p[off] > maxVal) {
maxVal = p[off];
maxIdx = off;
}
}

// False sharing happens, but since only one time, not avoided
maxIndices[b * thrPerSample + t] = maxIdx;
maxValues[b * thrPerSample + t] = maxVal;
}
}

// Local reduction
for (int i = 0; i < batchSize; ++i) {
int *pIndices = maxIndices + i * thrPerSample;
float *pValues = maxValues + i * thrPerSample;
int maxIdx = pIndices[0];
float maxVal = pValues[0];
for (int j = 1; j < thrPerSample; ++j) {
if (pValues[j] > maxVal) {
maxVal = pValues[j];
maxIdx = pIndices[j];
}
}
maxIds[i] = maxIdx;
maxVals[i] = maxVal;
}
}

// Each thread handle one sample (one row)
else {
#pragma omp parallel for
for (int i = 0; i < batchSize; ++i) {
int maxId = 0;
float *p = logits + i * sampleSize;
float maxVal = p[0];
for (int j = 1; j < sampleSize; ++j) {
if (p[j] > maxVal) {
maxVal = p[j];
maxId = j;
}
}
maxIds[i] = maxId;
maxVals[i] = maxVal;
}
}

// Reduce to get the max index (any better method??)
if (msgerSize > 1) {
float sendBuf[2 * batchSize];
float recvBuf[2 * batchSize * msgerSize];

for (int i = 0; i < batchSize; ++i) {
sendBuf[2 * i] = (float)(maxIds[i] + sampleOffset);
sendBuf[2 * i + 1] = maxVals[i];
}

std::vector<long unsigned int> recvCount(msgerSize, static_cast<long unsigned int>(2 * batchSize));
messenger.allgatherv(sendBuf, 2 * batchSize, recvBuf, recvCount);

for (int i = 0; i < batchSize; ++i) {
int maxId = (int)(recvBuf[2 * i] + 0.5f);
float maxVal = recvBuf[2 * i + 1];
for (int j = 1; j < msgerSize; ++j) {
if (recvBuf[2 * j * batchSize + 2 * i + 1] > maxVal) {
maxVal = recvBuf[2 * j * batchSize + 2 * i + 1];
maxId = (int)(recvBuf[2 * j * batchSize + 2 * i] + 0.5f);
}
}
maxIds[i] = maxId;
}
}

return maxIds;
}

// For greedy search and samlping, not for beam search
void stopCheck(std::vector<int> &generatedIds, std::vector<SequenceGroupMeta *> &seqGroups) {
int batchSize = generatedIds.size();
#pragma omp parallel for
for (int b = 0; b < batchSize; b++) {
// TODO: Deprecate this check, since no need for unequal-length output
if (seqGroups[b]->getSamplingMeta()->done) {
generatedIds[b] = seqGroups[b]->getSamplingMeta()->config.eosTokenId;
continue;
}

// If the generated token is EOS, mark the sequence as done
if (seqGroups[b]->getSamplingMeta()->config.eosTokenId == generatedIds[b]) {
seqGroups[b]->getSamplingMeta()->done = true;
}
// If the sequence meets the max length, mark the sequence as done
else if (seqGroups[b]->get(0)->getTotalLen() + 1 >= seqGroups[b]->getSamplingMeta()->config.maxLen) {
seqGroups[b]->getSamplingMeta()->done = true;
}
// TODO: stop words check
}
}
} // namespace xft
25 changes: 25 additions & 0 deletions src/searchers/sampling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#pragma once

#include "messenger.h"
#include "sampling_params.h"
#include "sequence.h"

namespace xft {
std::vector<int> greedySearch(float *logits, int sampleOffset, int sampleSize, int batchSize);

void stopCheck(std::vector<int> &generatedIds, std::vector<SequenceGroupMeta *> &seqGroups);
} // namespace xft
2 changes: 2 additions & 0 deletions src/searchers/search_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <vector>
#include "messenger.h"
#include "search_utils.h"
#include "timeline.h"

// Insert an element into a sorted vector while maintaining the order
void insertAndSort(std::vector<int> &targetVector, int num) {
Expand Down Expand Up @@ -102,6 +103,7 @@ namespace xft {
// TODO: support num_beams > 1 (beam search)
void repetitionPenaltyLogitsProcess(
float *logits, int sampleOffset, int sampleSize, std::vector<SequenceGroupMeta *> &seqGroups) {
TimeLine t("RepetitionPenaltyLogitsProcess");
bool multiRank = Messenger::getInstance().getSize() > 1;

std::vector<int> groupIndex;
Expand Down

0 comments on commit 9377f83

Please sign in to comment.