Skip to content

Commit

Permalink
max-length per sentence in batch
Browse files Browse the repository at this point in the history
  • Loading branch information
hieuhoang committed Jun 8, 2018
1 parent 7c41702 commit f429d4a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/amun/common/history.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class History {
unsigned GetLineNum() const
{ return lineNo_; }

unsigned GetMaxLength() const
{ return maxLength_; }

void SetActive(bool active);
bool GetActive() const;

Expand Down
12 changes: 9 additions & 3 deletions src/amun/common/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ std::shared_ptr<Histories> Search::Translate(const Sentences& sentences) {
}
//cerr << "beamSizes=" << Debug(beamSizes, 1) << endl;

bool hasSurvivors = CalcBeam(histories, beamSizes, prevHyps, states, nextStates);
bool hasSurvivors = CalcBeam(histories, beamSizes, prevHyps, states, nextStates, decoderStep);
if (!hasSurvivors) {
break;
}
Expand Down Expand Up @@ -134,18 +134,24 @@ bool Search::CalcBeam(
std::vector<unsigned>& beamSizes,
Beam& prevHyps,
States& states,
States& nextStates)
States& nextStates,
unsigned decoderStep)
{
unsigned batchSize = beamSizes.size();
Beams beams(batchSize);
bestHyps_->CalcBeam(prevHyps, scorers_, filterIndices_, beams, beamSizes);
histories->Add(beams);

//cerr << "batchSize=" << batchSize << endl;
histories->SetActive(false);
Beam survivors;
for (unsigned batchId = 0; batchId < batchSize; ++batchId) {
const History &hist = *histories->at(batchId);
unsigned maxLength = hist.GetMaxLength();

//cerr << "beamSizes[batchId]=" << batchId << " " << beamSizes[batchId] << " " << maxLength << endl;
for (auto& h : beams[batchId]) {
if (h->GetWord() != EOS_ID) {
if (decoderStep < maxLength && h->GetWord() != EOS_ID) {
survivors.push_back(h);

histories->SetActive(batchId, true);
Expand Down
3 changes: 2 additions & 1 deletion src/amun/common/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class Search {
std::vector<unsigned>& beamSizes,
Beam& prevHyps,
States& states,
States& nextStates);
States& nextStates,
unsigned decoderStep);

Search(const Search&) = delete;

Expand Down

0 comments on commit f429d4a

Please sign in to comment.