Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge branch 'master' of github.com:kpu/lazy

Conflicts:
	alone/threading.cc
  • Loading branch information...
commit 8a4a19df6fcd8679df88e2a6d3e7d88b2fb6e362 2 parents fbb0aee + 90ea4d9
@kpu authored
View
2  Jamroot
@@ -1,6 +1,6 @@
path-constant TOP : . ;
include $(TOP)/sanity.jam ;
-boost 104000 ;
+boost 104200 ;
external-lib z ;
project : requirements $(requirements) <threading>multi:<define>WITH_THREADS ;
View
21 alone/assemble.cc
@@ -3,19 +3,32 @@
#include "search/final.hh"
#include "search/rule.hh"
+#include <iostream>
+
namespace alone {
std::ostream &operator<<(std::ostream &o, const search::Final &final) {
const search::Rule::ItemsRet &words = final.From().Items();
+ if (words.empty()) return o;
const search::Final *const *child = final.Children().data();
- for (search::Rule::ItemsRet::const_iterator i(words.begin()); i != words.end(); ++i) {
+ search::Rule::ItemsRet::const_iterator i(words.begin());
+ for (; i != words.end() - 1; ++i) {
if (i->Terminal()) {
o << i->String() << ' ';
} else {
- o << **child;
+ o << **child << ' ';
++child;
}
}
+
+ if (i->Terminal()) {
+ if (i->String() != "</s>") {
+ o << i->String();
+ }
+ } else {
+ o << **child;
+ }
+
return o;
}
@@ -56,4 +69,8 @@ void DetailedFinal(std::ostream &o, const search::Final &final, const char *inde
DetailedFinalInternal(o, final, indent_str, 0);
}
+void PrintFinal(const search::Final &final) {
+ std::cout << final << std::endl;
+}
+
} // namespace alone
View
3  alone/assemble.hh
@@ -13,6 +13,9 @@ std::ostream &operator<<(std::ostream &o, const search::Final &final);
void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " ");
+// This isn't called anywhere but makes it easy to print from gdb.
+void PrintFinal(const search::Final &final);
+
} // namespace alone
#endif // ALONE_ASSEMBLE__
View
1  alone/graph.hh
@@ -5,6 +5,7 @@
#include "search/rule.hh"
#include "search/types.hh"
#include "search/vertex.hh"
+#include "util/exception.hh"
#include <boost/noncopyable.hpp>
#include <boost/pool/object_pool.hpp>
View
1  alone/threading.cc
@@ -9,6 +9,7 @@
#include <boost/ref.hpp>
#include <boost/scoped_ptr.hpp>
+#include <boost/utility/in_place_factory.hpp>
#include <sstream>
View
1  lm/Jamfile
@@ -4,6 +4,7 @@ import testing ;
run left_test.cc lm ..//boost_unit_test_framework : : test.arpa ;
run model_test.cc lm ..//boost_unit_test_framework : : test.arpa test_nounk.arpa ;
+run partial_test.cc lm ..//boost_unit_test_framework : : test.arpa ;
exe query : ngram_query.cc lm ;
exe build_binary : build_binary.cc lm ;
View
10 lm/left.hh
@@ -100,10 +100,6 @@ template <class M> class RuleScore {
out_.right = in.right;
if (left_done_) {
prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
- return;
- }
- if (out_.left.length) {
- left_done_ = true;
} else {
out_.left = in.left;
left_done_ = in.left.full;
@@ -131,12 +127,6 @@ template <class M> class RuleScore {
return;
}
- // Right state was minimized, so it's already independent of the new words to the left.
- if (in.right.length < in.left.length) {
- out_.right = in.right;
- return;
- }
-
// Shift exisiting words down.
for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) {
*(i + in.right.length) = *i;
View
167 lm/partial.hh
@@ -0,0 +1,167 @@
+#ifndef LM_PARTIAL__
+#define LM_PARTIAL__
+
+#include "lm/return.hh"
+#include "lm/state.hh"
+
+#include <algorithm>
+
+#include <assert.h>
+
+namespace lm {
+namespace ngram {
+
+struct ExtendReturn {
+ float adjust;
+ bool make_full;
+ unsigned char next_use;
+};
+
+template <class Model> ExtendReturn ExtendLoop(
+ const Model &model,
+ unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start,
+ const uint64_t *pointers, const uint64_t *pointers_end,
+ uint64_t *&pointers_write,
+ float *backoff_write) {
+ unsigned char add_length = add_rend - add_rbegin;
+
+ float backoff_buf[2][kMaxOrder - 1];
+ float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1];
+ std::copy(backoff_start, backoff_start + add_length, backoff_in);
+
+ ExtendReturn value;
+ value.make_full = false;
+ value.adjust = 0.0;
+ value.next_use = add_length;
+
+ unsigned char i = 0;
+ unsigned char length = pointers_end - pointers;
+ // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities.
+ if (pointers_write) {
+ // Using full context, writing to new left state.
+ for (; i < length; ++i) {
+ FullScoreReturn ret(model.ExtendLeft(
+ add_rbegin, add_rbegin + value.next_use,
+ backoff_in,
+ pointers[i], i + seen + 1,
+ backoff_out,
+ value.next_use));
+ std::swap(backoff_in, backoff_out);
+ if (ret.independent_left) {
+ value.adjust += ret.prob;
+ value.make_full = true;
+ ++i;
+ break;
+ }
+ value.adjust += ret.rest;
+ *pointers_write++ = ret.extend_left;
+ if (value.next_use != add_length) {
+ value.make_full = true;
+ ++i;
+ break;
+ }
+ }
+ }
+ // Using some of the new context.
+ for (; i < length && value.next_use; ++i) {
+ FullScoreReturn ret(model.ExtendLeft(
+ add_rbegin, add_rbegin + value.next_use,
+ backoff_in,
+ pointers[i], i + seen + 1,
+ backoff_out,
+ value.next_use));
+ std::swap(backoff_in, backoff_out);
+ value.adjust += ret.prob;
+ }
+ float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1);
+ // Using none of the new context.
+ value.adjust += unrest;
+
+ std::copy(backoff_in, backoff_in + value.next_use, backoff_write);
+ return value;
+}
+
+template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) {
+ assert(seen < reveal.length || reveal_full);
+ uint64_t *pointers_write = reveal_full ? NULL : left.pointers;
+ float backoff_buffer[kMaxOrder - 1];
+ ExtendReturn value(ExtendLoop(
+ model,
+ seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen,
+ left.pointers, left.pointers + left.length,
+ pointers_write,
+ left.full ? backoff_buffer : (right.backoff + right.length)));
+ if (reveal_full) {
+ left.length = 0;
+ value.make_full = true;
+ } else {
+ left.length = pointers_write - left.pointers;
+ value.make_full |= (left.length == model.Order() - 1);
+ }
+ if (left.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
+ } else {
+ // If left wasn't full when it came in, put words into right state.
+ std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length);
+ right.length += value.next_use;
+ left.full = value.make_full || (right.length == model.Order() - 1);
+ }
+ return value.adjust;
+}
+
+template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) {
+ assert(seen < reveal.length || reveal.full);
+ uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length);
+ ExtendReturn value(ExtendLoop(
+ model,
+ seen, right.words, right.words + right.length, right.backoff,
+ reveal.pointers + seen, reveal.pointers + reveal.length,
+ pointers_write,
+ right.backoff));
+ if (reveal.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i];
+ right.length = 0;
+ value.make_full = true;
+ } else {
+ right.length = value.next_use;
+ value.make_full |= (right.length == model.Order() - 1);
+ }
+ if (!left.full) {
+ left.length = pointers_write - left.pointers;
+ left.full = value.make_full || (left.length == model.Order() - 1);
+ }
+ return value.adjust;
+}
+
+template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) {
+ assert(first_right.length < kMaxOrder);
+ assert(second_left.length < kMaxOrder);
+ assert(between_length < kMaxOrder - 1);
+ uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length);
+ float backoff_buffer[kMaxOrder - 1];
+ ExtendReturn value(ExtendLoop(
+ model,
+ between_length, first_right.words, first_right.words + first_right.length, first_right.backoff,
+ second_left.pointers, second_left.pointers + second_left.length,
+ pointers_write,
+ second_left.full ? backoff_buffer : (second_right.backoff + second_right.length)));
+ if (second_left.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
+ } else {
+ std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length);
+ second_right.length += value.next_use;
+ value.make_full |= (second_right.length == model.Order() - 1);
+ }
+ if (!first_left.full) {
+ first_left.length = pointers_write - first_left.pointers;
+ first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1);
+ }
+ assert(first_left.length < kMaxOrder);
+ assert(second_right.length < kMaxOrder);
+ return value.adjust;
+}
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_PARTIAL__
View
199 lm/partial_test.cc
@@ -0,0 +1,199 @@
+#include "lm/partial.hh"
+
+#include "lm/left.hh"
+#include "lm/model.hh"
+#include "util/tokenize_piece.hh"
+
+#define BOOST_TEST_MODULE PartialTest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+namespace lm {
+namespace ngram {
+namespace {
+
+const char *TestLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+
+Config SilentConfig() {
+ Config config;
+ config.arpa_complain = Config::NONE;
+ config.messages = NULL;
+ return config;
+}
+
+struct ModelFixture {
+ ModelFixture() : m(TestLocation(), SilentConfig()) {}
+
+ RestProbingModel m;
+};
+
+BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture)
+
+BOOST_AUTO_TEST_CASE(SimpleBefore) {
+ Left left;
+ left.full = false;
+ left.length = 0;
+ Right right;
+ right.length = 0;
+
+ Right reveal;
+ reveal.length = 1;
+ WordIndex period = m.GetVocabulary().Index(".");
+ reveal.words[0] = period;
+ reveal.backoff[0] = -0.845098;
+
+ BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001);
+ BOOST_CHECK_EQUAL(0, left.length);
+ BOOST_CHECK(!left.full);
+ BOOST_CHECK_EQUAL(1, right.length);
+ BOOST_CHECK_EQUAL(period, right.words[0]);
+ BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
+
+ WordIndex more = m.GetVocabulary().Index("more");
+ reveal.words[1] = more;
+ reveal.backoff[1] = -0.4771212;
+ reveal.length = 2;
+ BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001);
+ BOOST_CHECK_EQUAL(0, left.length);
+ BOOST_CHECK(!left.full);
+ BOOST_CHECK_EQUAL(2, right.length);
+ BOOST_CHECK_EQUAL(period, right.words[0]);
+ BOOST_CHECK_EQUAL(more, right.words[1]);
+ BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
+ BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001);
+}
+
+BOOST_AUTO_TEST_CASE(AlsoWouldConsider) {
+ WordIndex would = m.GetVocabulary().Index("would");
+ WordIndex consider = m.GetVocabulary().Index("consider");
+
+ ChartState current;
+ current.left.length = 1;
+ current.left.pointers[0] = would;
+ current.left.full = false;
+ current.right.length = 1;
+ current.right.words[0] = would;
+ current.right.backoff[0] = -0.30103;
+
+ Left after;
+ after.full = false;
+ after.length = 1;
+ after.pointers[0] = consider;
+
+ // adjustment for would consider
+ BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001);
+
+ BOOST_CHECK_EQUAL(2, current.left.length);
+ BOOST_CHECK_EQUAL(would, current.left.pointers[0]);
+ BOOST_CHECK_EQUAL(false, current.left.full);
+
+ WordIndex also = m.GetVocabulary().Index("also");
+ Right before;
+ before.length = 1;
+ before.words[0] = also;
+ before.backoff[0] = -0.30103;
+ // r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)]
+ // p(also -> would) = -2, p(also would -> consider) = -3
+ BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001);
+ BOOST_CHECK_EQUAL(0, current.left.length);
+ BOOST_CHECK(current.left.full);
+ BOOST_CHECK_EQUAL(2, current.right.length);
+ BOOST_CHECK_EQUAL(would, current.right.words[0]);
+ BOOST_CHECK_EQUAL(also, current.right.words[1]);
+}
+
+BOOST_AUTO_TEST_CASE(EndSentence) {
+ WordIndex loin = m.GetVocabulary().Index("loin");
+ WordIndex period = m.GetVocabulary().Index(".");
+ WordIndex eos = m.GetVocabulary().EndSentence();
+
+ ChartState between;
+ between.left.length = 1;
+ between.left.pointers[0] = eos;
+ between.left.full = true;
+ between.right.length = 0;
+
+ Right before;
+ before.words[0] = period;
+ before.words[1] = loin;
+ before.backoff[0] = -0.845098;
+ before.backoff[1] = 0.0;
+
+ before.length = 1;
+ BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001);
+ BOOST_CHECK_EQUAL(0, between.left.length);
+}
+
+float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) {
+ RuleScore<RestProbingModel> scorer(model, out);
+ for (unsigned int *i = begin; i < end; ++i) {
+ scorer.Terminal(*i);
+ }
+ return scorer.Finish();
+}
+
+void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) {
+ Right before(before_in);
+ Left after(after_in);
+ after.full = false;
+ float got = 0.0;
+ for (unsigned int i = 1; i < 5; ++i) {
+ if (before_in.length >= i) {
+ before.length = i;
+ got += RevealBefore(model, before, i - 1, false, between.left, between.right);
+ }
+ if (after_in.length >= i) {
+ after.length = i;
+ got += RevealAfter(model, between.left, between.right, after, i - 1);
+ }
+ }
+ if (after_in.full) {
+ after.full = true;
+ got += RevealAfter(model, between.left, between.right, after, after.length);
+ }
+ if (before_full) {
+ got += RevealBefore(model, before, before.length, true, between.left, between.right);
+ }
+ // Sometimes they're zero and BOOST_CHECK_CLOSE fails for this.
+ BOOST_CHECK(fabs(expect - got) < 0.001);
+}
+
+void FullDivide(const RestProbingModel &model, StringPiece str) {
+ std::vector<WordIndex> indices;
+ for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
+ indices.push_back(model.GetVocabulary().Index(*i));
+ }
+ ChartState full_state;
+ float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state);
+
+ ChartState before_state;
+ before_state.left.full = false;
+ RuleScore<RestProbingModel> before_scorer(model, before_state);
+ float before_score = 0.0;
+ for (unsigned int before = 0; before < indices.size(); ++before) {
+ for (unsigned int after = before; after <= indices.size(); ++after) {
+ ChartState after_state, between_state;
+ float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state);
+ float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state);
+ CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left);
+ }
+ before_scorer.Terminal(indices[before]);
+ before_score = before_scorer.Finish();
+ }
+}
+
+BOOST_AUTO_TEST_CASE(Strings) {
+ FullDivide(m, "also would consider");
+ FullDivide(m, "looking on a little more loin . </s>");
+ FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+} // namespace
+} // namespace ngram
+} // namespace lm
View
2  lm/state.hh
@@ -51,6 +51,8 @@ inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);
}
+typedef State Right;
+
struct Left {
bool operator==(const Left &other) const {
return
View
4 search/edge.hh
@@ -1,6 +1,7 @@
#ifndef SEARCH_EDGE__
#define SEARCH_EDGE__
+#include "lm/state.hh"
#include "search/arity.hh"
#include "search/rule.hh"
#include "search/types.hh"
@@ -39,6 +40,9 @@ class Edge {
struct PartialEdge {
Score score;
+ // Terminals
+ lm::ngram::ChartState between[kMaxArity + 1];
+ // Non-terminals
PartialVertex nt[kMaxArity];
bool operator<(const PartialEdge &other) const {
View
131 search/edge_generator.cc
@@ -1,10 +1,13 @@
#include "search/edge_generator.hh"
#include "lm/left.hh"
+#include "lm/partial.hh"
#include "search/context.hh"
#include "search/vertex.hh"
#include "search/vertex_generator.hh"
+#include <numeric>
+
namespace search {
bool EdgeGenerator::Init(Edge &edge) {
@@ -19,6 +22,9 @@ bool EdgeGenerator::Init(Edge &edge) {
for (unsigned int i = GetRule().Arity(); i < 2; ++i) {
root.nt[i] = kBlankPartialVertex;
}
+ for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) {
+ root.between[i] = GetRule().Lexical(i);
+ }
// wtf no clear method?
generate_ = Generate();
generate_.push(root);
@@ -26,11 +32,50 @@ bool EdgeGenerator::Init(Edge &edge) {
return true;
}
-unsigned int EdgeGenerator::PickVictim(const PartialEdge &in) const {
- // TODO: better decision rule.
- return in.nt[0].Length() >= in.nt[1].Length();
+namespace {
+
+template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) {
+ memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1));
+
+ float ret = 0.0;
+ lm::ngram::ChartState *before, *after;
+ if (victim == 0) {
+ before = &update.between[0];
+ after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1];
+ } else {
+ assert(victim == 1);
+ assert(arity == 2);
+ before = &update.between[previous.nt[0].Complete() ? 0 : 1];
+ after = &update.between[2];
+ }
+ const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State();
+ const PartialVertex &update_nt = update.nt[victim];
+ const lm::ngram::ChartState &update_reveal = update_nt.State();
+ float just_after = 0.0;
+ if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
+ just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
+ }
+ if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) {
+ ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
+ }
+ if (update_nt.Complete()) {
+ if (update_reveal.left.full) {
+ before->left.full = true;
+ } else {
+ assert(update_reveal.left.length == update_reveal.right.length);
+ ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
+ }
+ if (victim == 0) {
+ update.between[0].right = after->right;
+ } else {
+ update.between[2].left = before->left;
+ }
+ }
+ return previous.score + (ret + just_after) * context.GetWeights().LM();
}
+} // namespace
+
template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) {
assert(!generate_.empty());
const PartialEdge &top = generate_.top();
@@ -45,7 +90,8 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe
if (lowest_length == 255) {
// All states report complete.
lm::ngram::ChartState state;
- RecomputeFinal(context, top, state);
+ state.left = top.between[0].left;
+ state.right = top.between[GetRule().Arity()].right;
parent.NewHypothesis(state, *from_, top);
generate_.pop();
top_ = generate_.empty() ? -kScoreInf : generate_.top().score;
@@ -54,83 +100,36 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe
unsigned int stay = !victim;
PartialEdge continuation, alternate;
- continuation.nt[stay] = top.nt[stay];
- alternate.nt[stay] = top.nt[stay];
// The alternate's score will change because alternate.nt[victim] changes.
- alternate.score = top.score - top.nt[victim].Bound();
bool split = top.nt[victim].Split(continuation.nt[victim], alternate.nt[victim]);
- generate_.pop();
- // top is now a dangling reference.
+
+ continuation.nt[stay] = top.nt[stay];
+ continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation);
if (split) {
// We have an alternate.
- alternate.score += alternate.nt[victim].Bound();
+ alternate.score = top.score - top.nt[victim].Bound() + alternate.nt[victim].Bound();
+ memcpy(alternate.between, top.between, sizeof(lm::ngram::ChartState) * (GetRule().Arity() + 1));
+ alternate.nt[stay] = top.nt[stay];
+
+ generate_.pop();
+ // top is now a dangling reference.
+
// TODO: dedupe?
generate_.push(alternate);
+ } else {
+ generate_.pop();
+ // top is now a dangling reference.
}
- continuation.score = GetRule().Bound() + Adjustment(context, continuation) * context.GetWeights().LM();
- for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
- continuation.score += continuation.nt[i].Bound();
- }
+
// TODO: dedupe?
generate_.push(continuation);
+
top_ = generate_.top().score;
return true;
}
-template <class Model> void EdgeGenerator::RecomputeFinal(Context<Model> &context, const PartialEdge &to, lm::ngram::ChartState &state) {
- if (GetRule().Arity() == 0) {
- state = GetRule().Lexical(0);
- return;
- }
- lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), state);
- if (GetRule().BeginSentence()) {
- scorer.BeginSentence();
- scorer.NonTerminal(GetRule().Lexical(0));
- } else {
- scorer.BeginNonTerminal(GetRule().Lexical(0));
- }
- scorer.NonTerminal(to.nt[0].State());
- scorer.NonTerminal(GetRule().Lexical(1));
- if (GetRule().Arity() == 2) {
- scorer.NonTerminal(to.nt[1].State());
- scorer.NonTerminal(GetRule().Lexical(2));
- }
- scorer.Finish();
- return;
-}
-
-// TODO: this can be done WAY more efficiently.
-template <class Model> Score EdgeGenerator::Adjustment(Context<Model> &context, const PartialEdge &to) const {
- if (GetRule().Arity() == 0) return 0.0;
- lm::ngram::ChartState state;
- lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), state);
- scorer.BeginNonTerminal(GetRule().Lexical(0));
- scorer.NonTerminal(to.nt[0].State());
- float total = 0.0;
- if (!to.nt[0].Complete()) {
- total += scorer.Finish();
- scorer.Reset();
- scorer.BeginNonTerminal(to.nt[0].State());
- }
- scorer.NonTerminal(GetRule().Lexical(1));
- if (GetRule().Arity() == 1) return total + scorer.Finish();
- assert(GetRule().Arity() == 2);
- scorer.NonTerminal(to.nt[1].State());
- if (!to.nt[1].Complete()) {
- total += scorer.Finish();
- scorer.Reset();
- scorer.BeginNonTerminal(to.nt[1].State());
- }
- scorer.NonTerminal(GetRule().Lexical(2));
- return total + scorer.Finish();
-}
-
template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent);
template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent);
-template void EdgeGenerator::RecomputeFinal(Context<lm::ngram::RestProbingModel> &context, const PartialEdge &to, lm::ngram::ChartState &state);
-template void EdgeGenerator::RecomputeFinal(Context<lm::ngram::ProbingModel> &context, const PartialEdge &to, lm::ngram::ChartState &state);
-template Score EdgeGenerator::Adjustment(Context<lm::ngram::RestProbingModel> &context, const PartialEdge &to) const;
-template Score EdgeGenerator::Adjustment(Context<lm::ngram::ProbingModel> &context, const PartialEdge &to) const;
} // namespace search
View
6 search/edge_generator.hh
@@ -30,12 +30,6 @@ class EdgeGenerator {
template <class Model> bool Pop(Context<Model> &context, VertexGenerator &parent);
private:
- unsigned int PickVictim(const PartialEdge &in) const;
-
- template <class Model> void RecomputeFinal(Context<Model> &context, const PartialEdge &to, lm::ngram::ChartState &state);
-
- template <class Model> Score Adjustment(Context<Model> &context, const PartialEdge &to) const;
-
const Rule &GetRule() const {
return from_->GetRule();
}
View
11 search/vertex.hh
@@ -27,13 +27,13 @@ class VertexNode {
state_.left.full = false;
state_.left.length = 0;
state_.right.length = 0;
+ right_full_ = false;
bound_ = -kScoreInf;
end_ = NULL;
}
- lm::ngram::ChartState &MutableState() {
- return state_;
- }
+ lm::ngram::ChartState &MutableState() { return state_; }
+ bool &MutableRightFull() { return right_full_; }
void AddExtend(VertexNode *next) {
extend_.push_back(next);
@@ -55,6 +55,7 @@ class VertexNode {
}
const lm::ngram::ChartState &State() const { return state_; }
+ bool RightFull() const { return right_full_; }
Score Bound() const {
return bound_;
@@ -77,7 +78,10 @@ class VertexNode {
private:
std::vector<VertexNode*> extend_;
+
lm::ngram::ChartState state_;
+ bool right_full_;
+
Score bound_;
Final *end_;
};
@@ -93,6 +97,7 @@ class PartialVertex {
bool Complete() const { return back_->Complete(); }
const lm::ngram::ChartState &State() const { return back_->State(); }
+ bool RightFull() const { return back_->RightFull(); }
Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); }
View
21 search/vertex_generator.cc
@@ -47,37 +47,38 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed
Trie *node = &root_;
while (true) {
if (left == state.left.length) {
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, right);
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false);
for (; right < state.right.length; ++right) {
- node = &FindOrInsert(*node, state.right.words[right], state, left, right + 1);
+ node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false);
}
- node = &FindOrInsert(*node, kCompleteAdd, state, state.left.length, state.right.length);
break;
}
- node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, right);
+ node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false);
left++;
if (right == state.right.length) {
- node = &FindOrInsert(*node, kCompleteAdd, state, left, right);
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true);
for (; left < state.left.length; ++left) {
- node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, right);
+ node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true);
}
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, state.right.length);
break;
}
- node = &FindOrInsert(*node, state.right.words[right], state, left, right + 1);
+ node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false);
right++;
}
+
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
got.first->second = CompleteTransition(*node, state, from, partial);
--to_pop_;
}
-VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, unsigned char right) {
+VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
VertexGenerator::Trie &next = node.extend[added];
if (!next.under) {
next.under = context_.NewVertexNode();
lm::ngram::ChartState &writing = next.under->MutableState();
writing = state;
- writing.left.full &= (left == state.left.length);
+ writing.left.full &= left_full && state.left.full;
+ next.under->MutableRightFull() = right_full && state.left.full;
writing.left.length = left;
writing.right.length = right;
node.under->AddExtend(next.under);
View
2  search/vertex_generator.hh
@@ -34,7 +34,7 @@ class VertexGenerator {
boost::unordered_map<uint64_t, Trie> extend;
};
- Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, unsigned char right);
+ Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full);
Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial);
View
2  util/file.cc
@@ -14,6 +14,8 @@
#if defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#include <io.h>
+#else
+#include <unistd.h>
#endif
namespace util {
View
13 util/file_piece.cc
@@ -18,22 +18,18 @@
#include <sys/types.h>
#include <sys/stat.h>
-#ifdef HAVE_ZLIB
-#include <zlib.h>
-#endif
-
namespace util {
ParseNumberException::ParseNumberException(StringPiece value) throw() {
*this << "Could not parse \"" << value << "\" into a number";
}
-GZException::GZException(void *file) {
#ifdef HAVE_ZLIB
+GZException::GZException(gzFile file) {
int num;
- *this << gzerror(file, &num) << " from zlib";
-#endif // HAVE_ZLIB
+ *this << gzerror( file, &num) << " from zlib";
}
+#endif // HAVE_ZLIB
// Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale).
const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
@@ -153,9 +149,8 @@ template <class T> T FilePiece::ReadNumber() {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
- if (position_ >= position_end_) throw EndOfFileException();
// Hallucinate a null off the end of the file.
- std::string buffer(position_, position_end_ - position_);
+ std::string buffer(position_, position_end_);
char *end;
T ret;
ParseNumber(buffer.c_str(), end, ret);
View
10 util/file_piece.hh
@@ -13,6 +13,10 @@
#include <stdint.h>
+#ifdef HAVE_ZLIB
+#include <zlib.h>
+#endif
+
namespace util {
class ParseNumberException : public Exception {
@@ -23,7 +27,9 @@ class ParseNumberException : public Exception {
class GZException : public Exception {
public:
- explicit GZException(void *file);
+#ifdef HAVE_ZLIB
+ explicit GZException(gzFile file);
+#endif
GZException() throw() {}
~GZException() throw() {}
};
@@ -117,7 +123,7 @@ class FilePiece {
std::string file_name_;
#ifdef HAVE_ZLIB
- void *gz_file_;
+ gzFile gz_file_;
#endif // HAVE_ZLIB
};
View
1  util/mmap.cc
@@ -19,6 +19,7 @@
#include <windows.h>
#include <io.h>
#else
+#include <unistd.h>
#include <sys/mman.h>
#endif
Please sign in to comment.
Something went wrong with that request. Please try again.