Skip to content

Commit

Permalink
Moving co-occurrence collector in core (#888)
Browse files Browse the repository at this point in the history
* moved co-occurrence collector in core and changed cooc output file format
  • Loading branch information
MichaelSolotky committed Mar 7, 2018
1 parent 519a047 commit 0262269
Show file tree
Hide file tree
Showing 14 changed files with 1,477 additions and 1,002 deletions.
9 changes: 3 additions & 6 deletions python/tests/artm/test_regularizer_biterms.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,9 @@ def test_func():
fout.write('{0}\n'.format(e))

with open(cooc_file_name, 'w') as fout:
fout.write('0 3 5.0\n')
fout.write('0 1 4.0\n')
fout.write('0 2 5.0\n')
fout.write('1 3 2.0\n')
fout.write('1 2 2.0\n')
fout.write('2 3 2.0\n')
fout.write('A D:5.0 B:4.0 C:5.0\n')
fout.write('B D:2.0 C:2.0\n')
fout.write('C D 2.0\n')

dictionary = artm.Dictionary()
dictionary.gather(data_path=batches_folder, vocab_file_path=vocab_file_name, cooc_file_path=cooc_file_name)
Expand Down
2 changes: 2 additions & 0 deletions src/artm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ set(SRC_LIST
core/check_messages.h
core/collection_parser.cc
core/collection_parser.h
core/cooccurrence_collector.cc
core/cooccurrence_collector.h
core/common.h
core/cuckoo_watch.cc
core/cuckoo_watch.h
Expand Down
139 changes: 128 additions & 11 deletions src/artm/core/collection_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <unordered_map>
#include <iostream> // NOLINT
#include <future> // NOLINT
#include <atomic>

#include "boost/algorithm/string.hpp"
#include "boost/algorithm/string/predicate.hpp"
Expand All @@ -29,6 +30,7 @@
#include "artm/core/exceptions.h"
#include "artm/core/helpers.h"
#include "artm/core/protobuf_helpers.h"
#include "artm/core/cooccurrence_collector.h"

using ::artm::utility::ifstream_or_cin;

Expand Down Expand Up @@ -351,7 +353,6 @@ CollectionParser::TokenMap CollectionParser::ParseVocabMatrixMarket() {
return token_info; // empty if no input file had been provided
}

// ToDo: Collect token cooccurrence in BatchCollector, and export it in ParseVowpalWabbit().
class CollectionParser::BatchCollector {
private:
Item *item_;
Expand Down Expand Up @@ -432,6 +433,7 @@ class CollectionParser::BatchCollector {
const Batch& batch() { return batch_; }
};

// ToDo (MichaelSolotky): split this func into several
CollectionParserInfo CollectionParser::ParseVowpalWabbit() {
BatchNameGenerator batch_name_generator(kBatchNameLength,
config_.name_type() == CollectionParserConfig_BatchNameType_Guid);
Expand All @@ -441,21 +443,32 @@ CollectionParserInfo CollectionParser::ParseVowpalWabbit() {

auto config = config_;

std::mutex lock;
std::mutex read_access;
std::mutex cooc_config_access;
std::mutex token_map_access;
std::mutex token_statistics_access;

int global_line_no = 0;

std::unordered_map<Token, bool, TokenHasher> token_map;
CollectionParserInfo parser_info;

::artm::core::CooccurrenceCollector cooc_collector(config);
int64_t total_num_of_pairs = 0;
std::atomic_bool gather_transaction_cooc(false);

// The function defined below works as follows:
// 1. Acquire lock for reading from docword file
// 2. Read num_items_per_batch lines from docword file, and store them in a local buffer (vector<string>)
// 3. Release the lock
// 4. Parse strings, form a batch, and save it to disk
// During parsing it gathers co-occurrence counters for pairs of tokens (if the correspondent flag == true)
// Steps 1-4 are repeated in a while loop until there is no content left in docword file.
// Multiple copies of the function can work in parallel.
auto func = [&docword, &global_line_no, &progress, &batch_name_generator, &lock,
&parser_info, &token_map, config]() {
auto func = [&docword, &global_line_no, &progress, &batch_name_generator, &read_access, &cooc_config_access,
&token_map_access, &token_statistics_access, &parser_info, &token_map, &total_num_of_pairs,
&cooc_collector, &gather_transaction_cooc, config]() {
int64_t local_num_of_pairs = 0; // statistics for future ppmi calculation
while (true) {
// The following variable remembers at which line the batch has started.
// It helps to create informative error message (including line number)
Expand All @@ -467,11 +480,11 @@ CollectionParserInfo CollectionParser::ParseVowpalWabbit() {
BatchCollector batch_collector;
std::unordered_set<TransactionType, TransactionHasher> transaction_types;

{
std::lock_guard<std::mutex> guard(lock);
{ // Read portion of documents
std::lock_guard<std::mutex> guard(read_access);
first_line_no_for_batch = global_line_no;
if (docword.eof()) {
return;
break;
}

while ((int64_t) all_strs_for_batch.size() < config.num_items_per_batch()) {
Expand All @@ -491,6 +504,16 @@ CollectionParserInfo CollectionParser::ParseVowpalWabbit() {
}
}

// It will hold tf and df of pairs of tokens
// Every pair of valid tokens (both exist in vocab) is saved in this storage
// After walking through portion of documents all the statistics is dumped on disk
// and then this storage is destroyed
CooccurrenceStatisticsHolder cooc_stat_holder;
// For every token from vocab keep the information about the last document this token occured in
// std::unordered_map<int> num_of_last_document_token_occured; // ToDo: think how to add elements here
// ToDo (MichaelSolotky): consider the case if there is no vocab
std::vector<int> num_of_last_document_token_occured(cooc_collector.vocab_.token_map_.size(), -1);

for (int str_index = 0; str_index < (int64_t) all_strs_for_batch.size(); ++str_index) {
std::string str = all_strs_for_batch[str_index];
const int line_no = first_line_no_for_batch + str_index;
Expand Down Expand Up @@ -534,7 +557,7 @@ CollectionParserInfo CollectionParser::ParseVowpalWabbit() {
std::string temp = split_index != std::string::npos ? elem.substr(0, split_index) : elem;
boost::split(tokens, temp, boost::is_any_of(TransactionSeparator));

if (class_ids.size() != tokens.size()) {
if (class_ids.size() != tokens.size()) {
std::stringstream ss;
ss << "Error in " << config.docword_file_path() << ":" << line_no
<< ", transaction type size is " << class_ids.size() << " and transaction size is "
Expand Down Expand Up @@ -563,15 +586,92 @@ CollectionParserInfo CollectionParser::ParseVowpalWabbit() {

transaction_types.emplace(TransactionType(class_ids));
batch_collector.Record(class_ids, tokens, transaction_weight);
}

if (config.gather_cooc()) {
if (class_ids.size() > 1) {
gather_transaction_cooc = true;
return;
}
const ClassId first_token_class_id = class_ids[0];

int first_token_id = -1;
if (config.has_vocab_file_path()) {
first_token_id = cooc_collector.vocab_.FindTokenId(elem, first_token_class_id);
if (first_token_id == TOKEN_NOT_FOUND) {
continue;
}
} else { // ToDo (MichaelSolotky): continue the case if there is no vocab
BOOST_THROW_EXCEPTION(InvalidOperation("No vocab file specified. Can't gather co-occurrences"));
}

if (num_of_last_document_token_occured[first_token_id] != str_index) {
num_of_last_document_token_occured[first_token_id] = str_index;
std::unique_lock<std::mutex> lock(token_statistics_access);
++cooc_collector.num_of_documents_token_occurred_in_[first_token_id];
}
// Take window_width tokens (parameter) to the right of the current one
// If there are some words beginnig on '|' in the text the window should be extended
// and it's extended using not_a_word_counter
ClassId second_token_class_id = first_token_class_id;
unsigned not_a_word_counter = 0;
// Loop through tokens in the window
for (unsigned neigh_index = 1; neigh_index <= cooc_collector.config_.cooc_window_width() +
not_a_word_counter &&
elem_index + neigh_index < strs.size();
++neigh_index) {
if (strs[elem_index + neigh_index].empty()) {
continue;
}
if (strs[elem_index + neigh_index][0] == '|') {
second_token_class_id = strs[elem_index + neigh_index].substr(1);
++not_a_word_counter;
continue;
}
// Take into consideration only tokens from the same modality
if (second_token_class_id != first_token_class_id) {
continue;
}
int second_token_id = -1;
const std::string neigh = strs[elem_index + neigh_index];
if (config.has_vocab_file_path()) {
second_token_id = cooc_collector.vocab_.FindTokenId(neigh, second_token_class_id);
if (second_token_id == TOKEN_NOT_FOUND) {
continue;
}
} else { // ToDo (MichaelSolotky): continue the case if there is no vocab
BOOST_THROW_EXCEPTION(InvalidOperation("No vocab file specified. Can't gather co-occurrences"));
}

if (cooc_collector.config_.use_symetric_cooc()) {
if (first_token_id < second_token_id) {
cooc_stat_holder.SavePairOfTokens(first_token_id, second_token_id, str_index);
} else if (first_token_id > second_token_id) {
cooc_stat_holder.SavePairOfTokens(second_token_id, first_token_id, str_index);
} else {
cooc_stat_holder.SavePairOfTokens(first_token_id, first_token_id, str_index, 2);
}
} else {
cooc_stat_holder.SavePairOfTokens(first_token_id, second_token_id, str_index);
cooc_stat_holder.SavePairOfTokens(second_token_id, first_token_id, str_index);
}
local_num_of_pairs += 2;
} // End of token's neghbors parsing
} // End of token parsing
} // End of item parsing
batch_collector.FinishItem(line_no, item_title);
} // End of items of 1 batch parsing
if (config.gather_cooc() && !cooc_stat_holder.Empty()) {
// This function saves gathered statistics on disk
// After saving on disk statistics from all the batches needs to be merged
// This is implemented in ReadAndMergeCooccurrenceBatches(), so the next step is to call this method
// Sorting is needed before storing all pairs of tokens on disk (it's for future agregation)
cooc_collector.UploadOnDisk(cooc_stat_holder);
}

if (all_strs_for_batch.size() > 0) {
artm::Batch batch;
{
std::lock_guard<std::mutex> guard(lock);
std::lock_guard<std::mutex> guard(token_map_access);
batch = batch_collector.FinishBatch(&parser_info);
for (int token_id = 0; token_id < batch.token_size(); ++token_id) {
token_map[artm::core::Token(batch.class_id(token_id), batch.token(token_id))] = true;
Expand All @@ -583,6 +683,10 @@ CollectionParserInfo CollectionParser::ParseVowpalWabbit() {
}
::artm::core::Helpers::SaveBatch(batch, config.target_folder(), batch_name);
}
} // End of collection parsing
{ // Save number of pairs (needed for ppmi)
std::unique_lock<std::mutex> lock(cooc_config_access);
total_num_of_pairs += local_num_of_pairs;
}
};

Expand All @@ -607,11 +711,24 @@ CollectionParserInfo CollectionParser::ParseVowpalWabbit() {
for (int i = 0; i < num_threads; i++) {
tasks.push_back(std::move(std::async(std::launch::async, func)));
}

for (int i = 0; i < num_threads; i++) {
tasks[i].get();
}

if (gather_transaction_cooc) {
BOOST_THROW_EXCEPTION(InvalidOperation("Parser can't gather co-occurrences on transaction data yet"));
}

cooc_collector.config_.set_total_num_of_pairs(total_num_of_pairs);
cooc_collector.config_.set_total_num_of_documents(parser_info.num_items());

// Launch merging of co-occurrence bathces and ppmi calculation
if (config.gather_cooc() && cooc_collector.VocabSize() >= 2) {
if (cooc_collector.CooccurrenceBatchesQuantity() != 0) {
cooc_collector.ReadAndMergeCooccurrenceBatches();
}
}

parser_info.set_dictionary_size(token_map.size());
return parser_info;
}
Expand Down
1 change: 1 addition & 0 deletions src/artm/core/collection_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "artm/core/common.h"
#include "artm/core/token.h"
#include "artm/core/cooccurrence_collector.h"

namespace artm {
namespace core {
Expand Down
4 changes: 4 additions & 0 deletions src/artm/core/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ const int64_t kProtobufCodedStreamTotalBytesLimit = 2147483647ULL;

static const std::string TransactionSeparator = "^";

const std::string TokenCoocFrequency = "tf";
const std::string DocumentCoocFrequency = "df";

const std::string kParentPhiMatrixBatch = "__parent_phi_matrix_batch__";


template <typename T>
std::string to_string(T value) {
return boost::lexical_cast<std::string>(value);
Expand Down

0 comments on commit 0262269

Please sign in to comment.