Skip to content

Commit

Permalink
Merge pull request #161 from sashafrey/master
Browse files Browse the repository at this point in the history
Vowpal Wabbit parser
  • Loading branch information
bigartm committed Mar 15, 2015
2 parents c7e9e04 + cd05d7d commit 002727c
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 67 deletions.
170 changes: 170 additions & 0 deletions src/artm/core/collection_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,174 @@ CollectionParser::TokenMap CollectionParser::ParseVocabMatrixMarket() {
return token_info; // empty if no input file had been provided
}

// ToDo: Collect token cooccurrence in BatchCollector, and export it in ParseVowpalWabbit().
class CollectionParser::BatchCollector {
private:
Item *item_;
Batch batch_;
std::map<Token, int> local_map_;
std::map<Token, CollectionParserTokenInfo> global_map_;
int64_t total_token_count_;
int64_t total_items_count_;

void StartNewItem() {
item_ = batch_.add_item();
item_->add_field();
total_items_count_++;
}

public:
BatchCollector() : item_(nullptr), total_token_count_(0), total_items_count_(0) {}

void Record(Token token, int token_count) {
if (global_map_.find(token) == global_map_.end())
global_map_.insert(std::make_pair(token, CollectionParserTokenInfo(token.keyword, token.class_id)));
if (local_map_.find(token) == local_map_.end()) {
local_map_.insert(std::make_pair(token, batch_.token_size()));
batch_.add_token(token.keyword);
batch_.add_class_id(token.class_id);
}

CollectionParserTokenInfo& token_info = global_map_[token];
int local_token_id = local_map_[token];

if (item_ == nullptr) StartNewItem();

Field* field = item_->mutable_field(0);
field->add_token_id(local_token_id);
field->add_token_count(token_count);

token_info.items_count++;
token_info.token_count += token_count;
total_token_count_ += token_count;
}

void FinishItem(int item_id, std::string item_title) {
if (item_ == nullptr) StartNewItem(); // this item fill be empty;

item_->set_id(item_id);
item_->set_title(item_title);

LOG_IF(INFO, total_items_count_ % 100000 == 0) << total_items_count_ << " documents parsed.";


// Item is already included in the batch;
// Set item_ to nullptr to finish it; then next Record() will create a new item;
item_ = nullptr;
}

Batch FinishBatch() {
Batch batch;
batch.Swap(&batch_);
local_map_.clear();
return batch;
}

const Batch& batch() { return batch_; }

std::shared_ptr<DictionaryConfig> ExportDictionaryConfig() {
// Craft the dictionary
auto retval = std::make_shared<DictionaryConfig>();
retval->set_total_items_count(total_items_count_);
retval->set_total_token_count(total_token_count_);

for (auto& key_value : global_map_) {
artm::DictionaryEntry* entry = retval->add_entry();
entry->set_key_token(key_value.second.keyword);
entry->set_class_id(key_value.second.class_id);
entry->set_token_count(key_value.second.token_count);
entry->set_items_count(key_value.second.items_count);
entry->set_value(static_cast<double>(key_value.second.token_count) /
static_cast<double>(total_token_count_));
}

return retval;
}
};

std::shared_ptr<DictionaryConfig> CollectionParser::ParseVowpalWabbit() {
BatchCollector batch_collector;

if (!boost::filesystem::exists(config_.docword_file_path()))
BOOST_THROW_EXCEPTION(DiskReadException(
"File " + config_.docword_file_path() + " does not exist."));

boost::iostreams::stream<mapped_file_source> docword(config_.docword_file_path());
std::string str;
int line_no = 0;
while (!docword.eof()) {
std::getline(docword, str);
line_no++;
if (docword.eof())
break;

std::vector<std::string> strs;
boost::split(strs, str, boost::is_any_of(" \t\r"));

if (strs.size() <= 1) {
std::stringstream ss;
ss << "Error in " << config_.docword_file_path() << ":" << line_no << ", too few entries: " << str;
BOOST_THROW_EXCEPTION(InvalidOperation(ss.str()));
}

std::string item_title = strs[0];

ClassId class_id = DefaultClass;
for (int elem_index = 1; elem_index < strs.size(); ++elem_index) {
std::string elem = strs[elem_index];
if (elem.size() == 0)
continue;
if (elem[0] == '|') {
class_id = elem.substr(1);
continue;
}

int token_count = 1;
std::string token = elem;
size_t split_index = elem.find(':');
if (split_index != std::string::npos) {
if (split_index == 0 || split_index == (elem.size() - 1)) {
std::stringstream ss;
ss << "Error in " << config_.docword_file_path() << ":" << line_no
<< ", entries can not start or end with colon: " << elem;
BOOST_THROW_EXCEPTION(InvalidOperation(ss.str()));
}
token = elem.substr(0, split_index);
std::string token_occurences_string = elem.substr(split_index + 1);
try {
token_count = boost::lexical_cast<int>(token_occurences_string);
}
catch (boost::bad_lexical_cast &) {
std::stringstream ss;
ss << "Error in " << config_.docword_file_path() << ":" << line_no
<< ", can not parse integer number of occurences: " << elem;
BOOST_THROW_EXCEPTION(InvalidOperation(ss.str()));
}
}

batch_collector.Record(artm::core::Token(class_id, token), token_count);
}

batch_collector.FinishItem(line_no, item_title);
if (batch_collector.batch().item_size() >= config_.num_items_per_batch()) {
::artm::core::BatchHelpers::SaveBatch(batch_collector.FinishBatch(), config_.target_folder());
}
}

if (batch_collector.batch().item_size() > 0) {
::artm::core::BatchHelpers::SaveBatch(batch_collector.FinishBatch(), config_.target_folder());
}

std::shared_ptr<DictionaryConfig> retval = batch_collector.ExportDictionaryConfig();

if (config_.has_dictionary_file_name()) {
::artm::core::BatchHelpers::SaveMessage(config_.dictionary_file_name(),
config_.target_folder(), *retval);
}

return retval;
}

std::shared_ptr<DictionaryConfig> CollectionParser::Parse() {
TokenMap token_map;
switch (config_.format()) {
Expand All @@ -347,6 +515,8 @@ std::shared_ptr<DictionaryConfig> CollectionParser::Parse() {
token_map = ParseVocabMatrixMarket();
return ParseDocwordBagOfWordsUci(&token_map);

case CollectionParserConfig_Format_VowpalWabbit:
return ParseVowpalWabbit();
default:
BOOST_THROW_EXCEPTION(ArgumentOutOfRangeException(
"CollectionParserConfig.format", config_.format()));
Expand Down
2 changes: 2 additions & 0 deletions src/artm/core/collection_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ class CollectionParser : boost::noncopyable {
std::map<std::pair<int, int>, int> token_coocurrence_;
std::vector<int> item_tokens_;
};
class BatchCollector;

// ParseDocwordBagOfWordsUci is also used to parse MatrixMarket format, because
// the format of docword file is the same for both.
std::shared_ptr<DictionaryConfig> ParseDocwordBagOfWordsUci(TokenMap* token_map);
std::shared_ptr<DictionaryConfig> ParseVowpalWabbit();

TokenMap ParseVocabBagOfWordsUci();
TokenMap ParseVocabMatrixMarket();
Expand Down
49 changes: 26 additions & 23 deletions src/artm/messages.pb.cc

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions src/artm/messages.pb.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/artm/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ message CollectionParserConfig {
enum Format {
BagOfWordsUci = 0;
MatrixMarket = 1;
VowpalWabbit = 2;
}

optional Format format = 1 [default = BagOfWordsUci];
Expand Down
31 changes: 31 additions & 0 deletions src/artm_tests/collection_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,34 @@ TEST(CollectionParser, Multiclass) {
try { boost::filesystem::remove_all(target_folder); }
catch (...) {}
}

// To run this particular test:
// artm_tests.exe --gtest_filter=CollectionParser.VowpalWabbit
TEST(CollectionParser, VowpalWabbit) {
std::string target_folder = artm::test::Helpers::getUniqueString();

::artm::CollectionParserConfig config;
config.set_format(::artm::CollectionParserConfig_Format_VowpalWabbit);
config.set_target_folder(target_folder);
config.set_dictionary_file_name("test_parser.dictionary");
config.set_docword_file_path("../../../test_data/vw_data.txt");
config.set_num_items_per_batch(1);

std::shared_ptr< ::artm::DictionaryConfig> dictionary_parsed = ::artm::ParseCollection(config);
ASSERT_EQ(dictionary_parsed->entry_size(), 4);
EXPECT_EQ(dictionary_parsed->entry(0).key_token(), "alex");
EXPECT_EQ(dictionary_parsed->entry(0).class_id(), "author");
EXPECT_EQ(dictionary_parsed->entry(0).token_count(), 3);
EXPECT_EQ(dictionary_parsed->entry(1).key_token(), "hello");
EXPECT_EQ(dictionary_parsed->entry(1).class_id(), "@default_class");
EXPECT_EQ(dictionary_parsed->entry(1).token_count(), 6);
EXPECT_EQ(dictionary_parsed->entry(2).key_token(), "noname");
EXPECT_EQ(dictionary_parsed->entry(2).class_id(), "author");
EXPECT_EQ(dictionary_parsed->entry(2).token_count(), 1);
EXPECT_EQ(dictionary_parsed->entry(3).key_token(), "world");
EXPECT_EQ(dictionary_parsed->entry(3).class_id(), "@default_class");
EXPECT_EQ(dictionary_parsed->entry(3).token_count(), 2);

try { boost::filesystem::remove_all(target_folder); }
catch (...) {}
}

0 comments on commit 002727c

Please sign in to comment.