Skip to content

Commit

Permalink
Sanitize Request constructor and document properly
Browse files Browse the repository at this point in the history
Towards #77.

Reduces Request to require only (Id, Segments, ResponseBuilder). Request is
an internal structure for Service and *will not be* exposed. However, the
constructor was in fact dreadful and can now be cleaned up to simplify
that ResponseBuilder nicely abstracts away the transition process from
Request -> Response.
  • Loading branch information
Jerin Philip committed Apr 2, 2021
1 parent cb48c79 commit 62d022c
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 50 deletions.
43 changes: 26 additions & 17 deletions src/translator/batch_translator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,27 @@ namespace bergamot {
BatchTranslator::BatchTranslator(DeviceId const device,
std::vector<Ptr<Vocab const>> &vocabs,
Ptr<Options> options,
const AlignedMemory* modelMemory,
const AlignedMemory* shortlistMemory)
const AlignedMemory *modelMemory,
const AlignedMemory *shortlistMemory)
: device_(device), options_(options), vocabs_(&vocabs),
modelMemory_(modelMemory), shortlistMemory_(shortlistMemory) {}
modelMemory_(modelMemory), shortlistMemory_(shortlistMemory) {}

void BatchTranslator::initialize() {
// Initializes the graph.
if (options_->hasAndNotEmpty("shortlist")) {
int srcIdx = 0, trgIdx = 1;
bool shared_vcb = vocabs_->front() == vocabs_->back();
if (shortlistMemory_->size() > 0 && shortlistMemory_->begin() != nullptr) {
bool check = options_->get<bool>("check-bytearray",true);
slgen_ = New<data::BinaryShortlistGenerator>(shortlistMemory_->begin(), shortlistMemory_->size(),
vocabs_->front(), vocabs_->back(),
srcIdx, trgIdx, shared_vcb, check);
}
else {
// Changed to BinaryShortlistGenerator to enable loading binary shortlist file
// This class also supports text shortlist file
bool check = options_->get<bool>("check-bytearray", true);
slgen_ = New<data::BinaryShortlistGenerator>(
shortlistMemory_->begin(), shortlistMemory_->size(), vocabs_->front(),
vocabs_->back(), srcIdx, trgIdx, shared_vcb, check);
} else {
// Changed to BinaryShortlistGenerator to enable loading binary shortlist
// file This class also supports text shortlist file
slgen_ = New<data::BinaryShortlistGenerator>(options_, vocabs_->front(),
vocabs_->back(), srcIdx,
trgIdx, shared_vcb);
vocabs_->back(), srcIdx,
trgIdx, shared_vcb);
}
}

Expand All @@ -42,10 +41,18 @@ void BatchTranslator::initialize() {
graph_->setDevice(device_);
graph_->getBackend()->configureDevice(options_);
graph_->reserveWorkspaceMB(options_->get<size_t>("workspace"));
if (modelMemory_->size() > 0 && modelMemory_->begin() != nullptr) { // If we have provided a byte array that contains the model memory, we can initialise the model from there, as opposed to from reading in the config file
if (modelMemory_->size() > 0 &&
modelMemory_->begin() !=
nullptr) { // If we have provided a byte array that contains the model
// memory, we can initialise the model from there, as
// opposed to from reading in the config file
ABORT_IF((uintptr_t)modelMemory_->begin() % 256 != 0,
"The provided memory is not aligned to 256 bytes and will crash when vector instructions are used on it.");
const std::vector<const void *> container = {modelMemory_->begin()}; // Marian supports multiple models initialised in this manner hence std::vector. However we will only ever use 1 during decoding.
"The provided memory is not aligned to 256 bytes and will crash "
"when vector instructions are used on it.");
const std::vector<const void *> container = {
modelMemory_->begin()}; // Marian supports multiple models initialised
// in this manner hence std::vector. However we
// will only ever use 1 during decoding.
scorers_ = createScorers(options_, container);
} else {
scorers_ = createScorers(options_);
Expand All @@ -63,11 +70,13 @@ void BatchTranslator::translate(Batch &batch) {
std::vector<data::SentenceTuple> batchVector;

auto &sentences = batch.sentences();
size_t batchSequenceNumber{0};
for (auto &sentence : sentences) {
data::SentenceTuple sentence_tuple(sentence.lineNumber());
data::SentenceTuple sentence_tuple(batchSequenceNumber);
Segment segment = sentence.getUnderlyingSegment();
sentence_tuple.push_back(segment);
batchVector.push_back(sentence_tuple);
++batchSequenceNumber;
}

size_t batchSize = batchVector.size();
Expand Down
10 changes: 2 additions & 8 deletions src/translator/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ namespace marian {
namespace bergamot {

// -----------------------------------------------------------------
Request::Request(size_t Id, size_t lineNumberBegin, Segments &&segments,
Request::Request(size_t Id, Segments &&segments,
ResponseBuilder &&responseBuilder)
: Id_(Id), lineNumberBegin_(lineNumberBegin),
segments_(std::move(segments)),
: Id_(Id), segments_(std::move(segments)),
responseBuilder_(std::move(responseBuilder))

{
Expand All @@ -23,7 +22,6 @@ Request::Request(size_t Id, size_t lineNumberBegin, Segments &&segments,
histories_.resize(segments_.size(), nullptr);
}

size_t Request::lineNumberBegin() const { return lineNumberBegin_; }
size_t Request::numSegments() const { return segments_.size(); }

size_t Request::segmentTokens(size_t index) const {
Expand Down Expand Up @@ -58,10 +56,6 @@ size_t RequestSentence::numTokens() const {
return (request_->segmentTokens(index_));
}

size_t RequestSentence::lineNumber() const {
return (request_->lineNumberBegin() + index_);
}

void RequestSentence::completeSentence(Ptr<History> history) {
// Relays completeSentence into request's processHistory, using index
// information.
Expand Down
46 changes: 27 additions & 19 deletions src/translator/request.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
//
// Defines:
//
// Request: holds the input text of a text, Segments (vector<Words>) which are
// to go to the batching mechanism and alignments between the processed
// segments and the input text (sourceTokenRanges). In addition, Request takes
// care of the barrier which fires when all the Segments in a request are done
// translating by the workers (BatchTranslator).
// TODO(jerinphilip): Extend Request with notions of Priority (sequence,
// user-given).
//
// RequestSentence: is a tuple of (index, Ptr<Request>). This provides the
// batching mechanism access to the segment within the request. The backref to
// Request allows event triggering the barrier upon completion of the last
// sentence by a worker.

#ifndef SRC_BERGAMOT_REQUEST_H_
#define SRC_BERGAMOT_REQUEST_H_

Expand All @@ -34,18 +18,42 @@
namespace marian {
namespace bergamot {

/// A Request is an internal representation used to represent a Request after
/// processed by TextProcessor into sentences consituted by marian::Words.
///
/// The batching mechanism (Batcher) draws from multiple Requests and compiles
/// sentences into a batch. When a batch completes translation (at
/// BatchTranslator, intended in a different thread), backward propogation
/// happens through:
///
/// Batch::completeBatch(...)
// -> RequestSentence::completeSentence(..)
/// -> Request::processHistory(...)
///
/// When all sentences in a Request are completed, responseBuilder is
/// triggered with the compiled Histories, to construct the Response
/// corresponding to the Request and set value of the promise which triggers the
/// future at Client.
class Request {
public:
Request(size_t Id, size_t lineNumberBegin, Segments &&segments,
ResponseBuilder &&responseBuilder);
/// Constructs an internal representation of the Request identified by Id,
/// processed Segments and accepts a callback (ResponseBuilder) which builds
/// the Response upon completion of the Request.
///
///
/// @param [in] Id: Identifier assigned to Request by Service.
/// @param [in] segments: Each segment is a unit to be translated.
/// @param [in] responseBuilder: Callback function (of ResponseBuilder type)
/// to be triggered upon the completion of translation of all units in a
/// Request.
Request(size_t Id, Segments &&segments, ResponseBuilder &&responseBuilder);

// Obtain the count of tokens in the segment correponding to index. Used to
// insert sentence from multiple requests into the corresponding size bucket.
size_t segmentTokens(size_t index) const;

// Obtain number of segments in a request.
size_t numSegments() const;
size_t lineNumberBegin() const;

// Obtains segment corresponding to index to create a batch of segments among
// several requests.
Expand Down
14 changes: 8 additions & 6 deletions src/translator/service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ loadVocabularies(marian::Ptr<marian::Options> options) {
namespace marian {
namespace bergamot {

Service::Service(Ptr<Options> options, AlignedMemory modelMemory, AlignedMemory shortlistMemory)
Service::Service(Ptr<Options> options, AlignedMemory modelMemory,
AlignedMemory shortlistMemory)
: requestId_(0), vocabs_(std::move(loadVocabularies(options))),
text_processor_(vocabs_, options), batcher_(options),
numWorkers_(options->get<int>("cpu-threads")),
modelMemory_(std::move(modelMemory)), shortlistMemory_(std::move(shortlistMemory))
modelMemory_(std::move(modelMemory)),
shortlistMemory_(std::move(shortlistMemory))
#ifndef WASM_COMPATIBLE_SOURCE
// 0 elements in PCQueue is illegal and can lead to failures. Adding a
// guard to have at least one entry allocated. In the single-threaded
Expand All @@ -55,7 +57,8 @@ void Service::build_translators(Ptr<Options> options, size_t numTranslators) {
translators_.reserve(numTranslators);
for (size_t cpuId = 0; cpuId < numTranslators; cpuId++) {
marian::DeviceId deviceId(cpuId, DeviceType::cpu);
translators_.emplace_back(deviceId, vocabs_, options, &modelMemory_, &shortlistMemory_);
translators_.emplace_back(deviceId, vocabs_, options, &modelMemory_,
&shortlistMemory_);
}
}

Expand Down Expand Up @@ -122,9 +125,8 @@ std::future<Response> Service::translate(std::string &&input) {
RequestParams requestParams; // TODO(jerinphilip): Take this in as argument
ResponseBuilder responseBuilder(requestParams, std::move(source), vocabs_,
std::move(responsePromise));
Ptr<Request> request =
New<Request>(requestId_++, /* lineNumberBegin = */ 0, std::move(segments),
std::move(responseBuilder));
Ptr<Request> request = New<Request>(requestId_++, std::move(segments),
std::move(responseBuilder));

batcher_.addWholeRequest(request);
if (numWorkers_ == 0) {
Expand Down

0 comments on commit 62d022c

Please sign in to comment.