Skip to content

Commit

Permalink
Implement hypergraph output for chart moses
Browse files Browse the repository at this point in the history
  • Loading branch information
bhaddow committed Aug 7, 2014
1 parent fbe73dd commit b5a1f02
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 39 deletions.
2 changes: 1 addition & 1 deletion contrib/server/Jamfile
Expand Up @@ -35,7 +35,7 @@ if $(build-moses-server) = true
xmlrpc-linkflags = [ shell_or_die "$(xmlrpc-command) c++2 abyss-server --libs" ] ;
xmlrpc-cxxflags = [ shell_or_die "$(xmlrpc-command) c++2 abyss-server --cflags" ] ;

exe mosesserver : mosesserver.cpp ../../moses//moses ../../OnDiskPt//OnDiskPt ../../moses-cmd/IOWrapper.cpp : <linkflags>$(xmlrpc-linkflags) <cxxflags>$(xmlrpc-cxxflags) ;
exe mosesserver : mosesserver.cpp ../../moses//moses ../../OnDiskPt//OnDiskPt ../../moses-cmd/IOWrapper.cpp ../..//boost_filesystem : <linkflags>$(xmlrpc-linkflags) <cxxflags>$(xmlrpc-cxxflags) ;
} else {
alias mosesserver ;
}
2 changes: 1 addition & 1 deletion contrib/server/mosesserver.cpp
Expand Up @@ -283,7 +283,7 @@ class Translator : public xmlrpc_c::method
if (addGraphInfo) {
const size_t translationId = tinput.GetTranslationId();
std::ostringstream sgstream;
manager.GetSearchGraph(translationId,sgstream);
manager.OutputSearchGraphMoses(sgstream);
retData.insert(pair<string, xmlrpc_c::value>("sg", xmlrpc_c::value_string(sgstream.str())));
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion mira/Jamfile
Expand Up @@ -2,7 +2,7 @@ lib mira_lib :
[ glob *.cpp : *Test.cpp Main.cpp ]
../mert//mert_lib ../moses//moses ../OnDiskPt//OnDiskPt ..//boost_program_options ;

exe mira : Main.cpp mira_lib ../mert//mert_lib ../moses//moses ../OnDiskPt//OnDiskPt ..//boost_program_options ;
exe mira : Main.cpp mira_lib ../mert//mert_lib ../moses//moses ../OnDiskPt//OnDiskPt ..//boost_program_options ..//boost_filesystem ;

alias programs : mira ;

Expand Down
2 changes: 1 addition & 1 deletion moses-chart-cmd/Main.cpp
Expand Up @@ -177,7 +177,7 @@ class TranslationTask : public Task

if (staticData.GetOutputSearchGraph()) {
std::ostringstream out;
manager.GetSearchGraph(translationId, out);
manager.OutputSearchGraphMoses( out);
OutputCollector *oc = m_ioWrapper.GetSearchGraphOutputCollector();
UTIL_THROW_IF2(oc == NULL, "File for search graph output not specified");
oc->Write(translationId, out.str());
Expand Down
7 changes: 4 additions & 3 deletions moses/ChartCell.cpp
Expand Up @@ -22,6 +22,7 @@
#include <algorithm>
#include "ChartCell.h"
#include "ChartCellCollection.h"
#include "HypergraphOutput.h"
#include "RuleCubeQueue.h"
#include "RuleCube.h"
#include "WordsRange.h"
Expand Down Expand Up @@ -195,13 +196,13 @@ const HypoList *ChartCell::GetAllSortedHypotheses() const
return ret;
}

//! call GetSearchGraph() for each hypo collection
void ChartCell::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const std::map<unsigned, bool> &reachable) const
//! call WriteSearchGraph() for each hypo collection
void ChartCell::WriteSearchGraph(const ChartSearchGraphWriter& writer, const std::map<unsigned, bool> &reachable) const
{
MapType::const_iterator iterOutside;
for (iterOutside = m_hypoColl.begin(); iterOutside != m_hypoColl.end(); ++iterOutside) {
const ChartHypothesisCollection &coll = iterOutside->second;
coll.GetSearchGraph(translationId, outputSearchGraphStream, reachable);
coll.WriteSearchGraph(writer, reachable);
}
}

Expand Down
3 changes: 2 additions & 1 deletion moses/ChartCell.h
Expand Up @@ -40,6 +40,7 @@

namespace Moses
{
class ChartSearchGraphWriter;
class ChartTranslationOptionList;
class ChartCellCollection;
class ChartManager;
Expand Down Expand Up @@ -124,7 +125,7 @@ class ChartCell : public ChartCellBase
return m_coverage < compare.m_coverage;
}

void GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const std::map<unsigned,bool> &reachable) const;
void WriteSearchGraph(const ChartSearchGraphWriter& writer, const std::map<unsigned,bool> &reachable) const;

};

Expand Down
23 changes: 3 additions & 20 deletions moses/ChartHypothesisCollection.cpp
Expand Up @@ -24,6 +24,7 @@
#include "ChartHypothesisCollection.h"
#include "ChartHypothesis.h"
#include "ChartManager.h"
#include "HypergraphOutput.h"
#include "util/exception.hh"

using namespace std;
Expand Down Expand Up @@ -293,27 +294,9 @@ void ChartHypothesisCollection::CleanupArcList()
* \param outputSearchGraphStream stream to output the info to
* \param reachable @todo don't know
*/
void ChartHypothesisCollection::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const std::map<unsigned, bool> &reachable) const
void ChartHypothesisCollection::WriteSearchGraph(const ChartSearchGraphWriter& writer, const std::map<unsigned, bool> &reachable) const
{
HCType::const_iterator iter;
for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
ChartHypothesis &mainHypo = **iter;
if (StaticData::Instance().GetUnprunedSearchGraph() ||
reachable.find(mainHypo.GetId()) != reachable.end()) {
outputSearchGraphStream << translationId << " " << mainHypo << endl;
}

const ChartArcList *arcList = mainHypo.GetArcList();
if (arcList) {
ChartArcList::const_iterator iterArc;
for (iterArc = arcList->begin(); iterArc != arcList->end(); ++iterArc) {
const ChartHypothesis &arc = **iterArc;
if (reachable.find(arc.GetId()) != reachable.end()) {
outputSearchGraphStream << translationId << " " << arc << endl;
}
}
}
}
writer.WriteHypos(*this,reachable);
}

std::ostream& operator<<(std::ostream &out, const ChartHypothesisCollection &coll)
Expand Down
4 changes: 3 additions & 1 deletion moses/ChartHypothesisCollection.h
Expand Up @@ -28,6 +28,8 @@
namespace Moses
{

class ChartSearchGraphWriter;

//! functor to compare (chart) hypotheses by (descending) score
class ChartHypothesisScoreOrderer
{
Expand Down Expand Up @@ -117,7 +119,7 @@ class ChartHypothesisCollection
return m_bestScore;
}

void GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const std::map<unsigned,bool> &reachable) const;
void WriteSearchGraph(const ChartSearchGraphWriter& writer, const std::map<unsigned,bool> &reachable) const;

};

Expand Down
31 changes: 24 additions & 7 deletions moses/ChartManager.cpp
Expand Up @@ -25,6 +25,7 @@
#include "ChartHypothesis.h"
#include "ChartKBestExtractor.h"
#include "ChartTranslationOptions.h"
#include "HypergraphOutput.h"
#include "StaticData.h"
#include "DecodeStep.h"
#include "TreeInput.h"
Expand Down Expand Up @@ -223,8 +224,9 @@ void ChartManager::CalcNBest(
}
}

void ChartManager::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const
void ChartManager::WriteSearchGraph(const ChartSearchGraphWriter& writer) const
{

size_t size = m_source.GetSize();

// which hypotheses are reachable?
Expand All @@ -237,7 +239,11 @@ void ChartManager::GetSearchGraph(long translationId, std::ostream &outputSearch
// no hypothesis
return;
}
FindReachableHypotheses( hypo, reachable);
size_t winners = 0;
size_t losers = 0;

FindReachableHypotheses( hypo, reachable, &winners, &losers);
writer.WriteHeader(winners, losers);

for (size_t width = 1; width <= size; ++width) {
for (size_t startPos = 0; startPos <= size-width; ++startPos) {
Expand All @@ -246,12 +252,13 @@ void ChartManager::GetSearchGraph(long translationId, std::ostream &outputSearch
TRACE_ERR(" " << range << "=");

const ChartCell &cell = m_hypoStackColl.Get(range);
cell.GetSearchGraph(translationId, outputSearchGraphStream, reachable);
cell.WriteSearchGraph(writer, reachable);
}
}
}

void ChartManager::FindReachableHypotheses( const ChartHypothesis *hypo, std::map<unsigned,bool> &reachable ) const
void ChartManager::FindReachableHypotheses(
const ChartHypothesis *hypo, std::map<unsigned,bool> &reachable, size_t* winners, size_t* losers) const
{
// do not recurse, if already visited
if (reachable.find(hypo->GetId()) != reachable.end()) {
Expand All @@ -260,9 +267,14 @@ void ChartManager::FindReachableHypotheses( const ChartHypothesis *hypo, std::ma

// recurse
reachable[ hypo->GetId() ] = true;
if (hypo->GetWinningHypothesis() == hypo) {
(*winners)++;
} else {
(*losers)++;
}
const std::vector<const ChartHypothesis*> &previous = hypo->GetPrevHypos();
for(std::vector<const ChartHypothesis*>::const_iterator i = previous.begin(); i != previous.end(); ++i) {
FindReachableHypotheses( *i, reachable );
FindReachableHypotheses( *i, reachable, winners, losers );
}

// also loop over recombined hypotheses (arcs)
Expand All @@ -271,14 +283,19 @@ void ChartManager::FindReachableHypotheses( const ChartHypothesis *hypo, std::ma
ChartArcList::const_iterator iterArc;
for (iterArc = arcList->begin(); iterArc != arcList->end(); ++iterArc) {
const ChartHypothesis &arc = **iterArc;
FindReachableHypotheses( &arc, reachable );
FindReachableHypotheses( &arc, reachable, winners, losers );
}
}
}

void ChartManager::OutputSearchGraphAsHypergraph(std::ostream &outputSearchGraphStream) const {
//TODO
ChartSearchGraphWriterHypergraph writer(&outputSearchGraphStream);
WriteSearchGraph(writer);
}

void ChartManager::OutputSearchGraphMoses(std::ostream &outputSearchGraphStream) const {
ChartSearchGraphWriterMoses writer(&outputSearchGraphStream, m_lineNumber);
WriteSearchGraph(writer);
}

} // namespace Moses
12 changes: 10 additions & 2 deletions moses/ChartManager.h
Expand Up @@ -38,6 +38,7 @@ namespace Moses
{

class ChartHypothesis;
class ChartSearchGraphWriter;

/** Holds everything you need to decode 1 sentence with the hierachical/syntax decoder
*/
Expand All @@ -55,6 +56,11 @@ class ChartManager

ChartTranslationOptionList m_translationOptionList; /**< pre-computed list of translation options for the phrases in this sentence */

/* auxilliary functions for SearchGraphs */
void FindReachableHypotheses(
const ChartHypothesis *hypo, std::map<unsigned,bool> &reachable , size_t* winners, size_t* losers) const;
void WriteSearchGraph(const ChartSearchGraphWriter& writer) const;

public:
ChartManager(size_t lineNumber, InputType const& source);
~ChartManager();
Expand All @@ -63,11 +69,13 @@ class ChartManager
const ChartHypothesis *GetBestHypothesis() const;
void CalcNBest(size_t n, std::vector<boost::shared_ptr<ChartKBestExtractor::Derivation> > &nBestList, bool onlyDistinct=false) const;

void GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const;
void FindReachableHypotheses( const ChartHypothesis *hypo, std::map<unsigned,bool> &reachable ) const; /* auxilliary function for GetSearchGraph */
/** "Moses" (osg) type format */
void OutputSearchGraphMoses(std::ostream &outputSearchGraphStream) const;

/** Output in (modified) Kenneth hypergraph format */
void OutputSearchGraphAsHypergraph(std::ostream &outputSearchGraphStream) const;


//! the input sentence being decoded
const InputType& GetSource() const {
return m_source;
Expand Down
90 changes: 90 additions & 0 deletions moses/HypergraphOutput.cpp
Expand Up @@ -34,6 +34,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

#include <util/exception.hh>

#include "ChartHypothesisCollection.h"
#include "ChartManager.h"
#include "HypergraphOutput.h"
#include "Manager.h"
Expand Down Expand Up @@ -154,5 +155,94 @@ void HypergraphOutput<M>::Write(const M& manager) const {
template class HypergraphOutput<Manager>;
template class HypergraphOutput<ChartManager>;


void ChartSearchGraphWriterMoses::WriteHypos
(const ChartHypothesisCollection& hypos, const map<unsigned, bool> &reachable) const {

ChartHypothesisCollection::const_iterator iter;
for (iter = hypos.begin() ; iter != hypos.end() ; ++iter) {
ChartHypothesis &mainHypo = **iter;
if (StaticData::Instance().GetUnprunedSearchGraph() ||
reachable.find(mainHypo.GetId()) != reachable.end()) {
(*m_out) << m_lineNumber << " " << mainHypo << endl;
}

const ChartArcList *arcList = mainHypo.GetArcList();
if (arcList) {
ChartArcList::const_iterator iterArc;
for (iterArc = arcList->begin(); iterArc != arcList->end(); ++iterArc) {
const ChartHypothesis &arc = **iterArc;
if (reachable.find(arc.GetId()) != reachable.end()) {
(*m_out) << m_lineNumber << " " << arc << endl;
}
}
}
}

}
void ChartSearchGraphWriterHypergraph::WriteHeader(size_t winners, size_t losers) const {

(*m_out) << "# target ||| features ||| source-covered" << endl;
(*m_out) << winners << " " << (winners+losers) << endl;

}

void ChartSearchGraphWriterHypergraph::WriteHypos(const ChartHypothesisCollection& hypos,
const map<unsigned, bool> &reachable) const {

ChartHypothesisCollection::const_iterator iter;
for (iter = hypos.begin() ; iter != hypos.end() ; ++iter) {
const ChartHypothesis* mainHypo = *iter;
if (!StaticData::Instance().GetUnprunedSearchGraph() &&
reachable.find(mainHypo->GetId()) == reachable.end()) {
//Ignore non reachable nodes
continue;
}
(*m_out) << "# node " << m_nodeId << endl;
m_hypoIdToNodeId[mainHypo->GetId()] = m_nodeId;
++m_nodeId;
vector<const ChartHypothesis*> edges;
edges.push_back(mainHypo);
const ChartArcList *arcList = (*iter)->GetArcList();
if (arcList) {
ChartArcList::const_iterator iterArc;
for (iterArc = arcList->begin(); iterArc != arcList->end(); ++iterArc) {
const ChartHypothesis* arc = *iterArc;
if (reachable.find(arc->GetId()) != reachable.end()) {
edges.push_back(arc);
}
}
}
(*m_out) << edges.size() << endl;
for (vector<const ChartHypothesis*>::const_iterator ei = edges.begin(); ei != edges.end(); ++ei) {
const ChartHypothesis* hypo = *ei;
const TargetPhrase& target = hypo->GetCurrTargetPhrase();
size_t ntIndex = 0;
for (size_t i = 0; i < target.GetSize(); ++i) {
const Word& word = target.GetWord(i);
if (word.IsNonTerminal()) {
size_t hypoId = hypo->GetPrevHypos()[ntIndex++]->GetId();
(*m_out) << "[" << m_hypoIdToNodeId[hypoId] << "]";
} else {
(*m_out) << word.GetFactor(0)->GetString();
}
(*m_out) << " ";
}
(*m_out) << " ||| ";
ScoreComponentCollection scores = hypo->GetScoreBreakdown();
HypoList::const_iterator hi;
for (hi = hypo->GetPrevHypos().begin(); hi != hypo->GetPrevHypos().end(); ++hi) {
scores.MinusEquals((*hi)->GetScoreBreakdown());
}
scores.Save(*m_out, false);
(*m_out) << " ||| ";
(*m_out) << hypo->GetCurrSourceRange().GetNumWordsCovered();
(*m_out) << endl;

}
}
}


} //namespace Moses

0 comments on commit b5a1f02

Please sign in to comment.