Skip to content

Commit

Permalink
Merge pull request #90 from sashafrey/master
Browse files Browse the repository at this point in the history
Issue #51 and #46
  • Loading branch information
bigartm committed Jan 2, 2015
2 parents fba66f6 + a7da200 commit d51ffff
Show file tree
Hide file tree
Showing 15 changed files with 440 additions and 471 deletions.
9 changes: 3 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
language: cpp
language:
- objective-c

compiler:
- gcc
Expand All @@ -11,10 +12,6 @@ before_install:
- cd ..
- cp datasets/* src/python/examples/

install:
- sudo apt-get update
- sudo apt-get install build-essential g++ python-dev autotools-dev libicu-dev build-essential libbz2-dev libboost-all-dev

before_script:
- mkdir build && cd build && cmake .. && cd ..

Expand All @@ -24,6 +21,6 @@ script:
- cd 3rdparty/protobuf/python && python setup.py build && sudo python setup.py install && cd ../../..
- cd build/src/artm_tests && ./artm_tests && cd ../../..
- export PYTHONPATH=`pwd`/src/python:$PYTHONPATH
- export ARTM_SHARED_LIBRARY=`pwd`/build/src/artm/libartm.so
- export ARTM_SHARED_LIBRARY=`pwd`/build/src/artm/libartm.dylib
- cd src/python/tests && python tests.py && cd ../../..
- cd src/python/examples && for f in *.py ; do echo "==== $f ====" && python "$f" ; done && cd ../../..
33 changes: 31 additions & 2 deletions docs/ref/messages.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ items in one batch are always processed sequentially.
repeated string class_id = 3;
}

.. attribute:: Batch.token

A set value that defines all tokens than may appear in the batch.

.. attribute:: Batch.item

A set of items of the batch.

.. attribute:: Batch.class_id

A set of values that define for classes (modalities) of tokens.
This repeated field must have the same length as :attr:`token`.
This value is optional, use an empty list indicate that all tokens belong to the default class.

.. _Stream:

Expand Down Expand Up @@ -474,6 +487,7 @@ Represents a configuration of a topic model.

A set of values that define the weights of the corresponding classes (modalities).
This repeated field must have the same length as :attr:`class_id`.
This value is optional, use an empty list to set equal weights for all classes.

.. attribute:: ModelConfig.use_sparse_bow

Expand All @@ -483,6 +497,8 @@ Represents a configuration of a topic model.
Dense representation (*use_sparse_bow = false*) better fits for non-textual collections
(for example for matrix factorization).

Note that :attr:`class_weight` and :attr:`class_id` must not be used together with *use_sparse_bow=false*.

.. attribute:: ModelConfig.use_random_theta

A flag indicating whether to initialize ``p(t|d)`` distribution with random uniform distribution.
Expand Down Expand Up @@ -1525,8 +1541,21 @@ Represents a configuration of a collection parser.

| The file must be sorted on docID.
| Values of wordID must be unity-based (not zero-based).
| The format of the vocab.*.txt file is line
| contains wordID=n.
| The format of the vocab.*.txt file is line containing wordID=n.
| Note that words must not have spaces or tabs.
| In vocab.*.txt file it is also possible to specify
| :attr:`Batch.class_id` for tokens, as it is shown in this example:

.. code-block:: bash

token1 @default_class
token2 custom_class
token3 @default_class
token4

| Use space or tab to separate token from its class.
| Token that are not followed by class label automatically
| get ''@default_class'' as a lable (see ''token4'' in the example).

``MatrixMarket`` | See the description at http://math.nist.gov/MatrixMarket/formats.html
| In this mode parameter :attr:`docword_file_path` must refer to a file
Expand Down
32 changes: 23 additions & 9 deletions src/artm/core/collection_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ std::shared_ptr<DictionaryConfig> CollectionParser::ParseDocwordBagOfWordsUci(To

if (token_map->empty()) {
// Autogenerate some tokens
for (int i = 0; i < num_tokens; ++i) {
for (int i = 0; i < num_unique_tokens; ++i) {
std::string token_keyword = boost::lexical_cast<std::string>(i);
token_map->insert(std::make_pair(i, CollectionParserTokenInfo(token_keyword)));
token_map->insert(std::make_pair(i, CollectionParserTokenInfo(token_keyword, DefaultClass)));
}
}

Expand Down Expand Up @@ -208,6 +208,7 @@ std::shared_ptr<DictionaryConfig> CollectionParser::ParseDocwordBagOfWordsUci(To
if (iter == batch_dictionary.end()) {
batch_dictionary.insert(std::make_pair(token_id, batch_dictionary.size()));
batch.add_token((*token_map)[token_id].keyword);
batch.add_class_id((*token_map)[token_id].class_id);
iter = batch_dictionary.find(token_id);
}

Expand All @@ -234,7 +235,7 @@ std::shared_ptr<DictionaryConfig> CollectionParser::ParseDocwordBagOfWordsUci(To
for (auto& key_value : (*token_map)) {
artm::DictionaryEntry* entry = retval->add_entry();
entry->set_key_token(key_value.second.keyword);
entry->set_class_id(DefaultClass);
entry->set_class_id(key_value.second.class_id);
entry->set_token_count(key_value.second.token_count);
entry->set_items_count(key_value.second.items_count);
entry->set_value(static_cast<double>(key_value.second.token_count) /
Expand Down Expand Up @@ -268,7 +269,7 @@ CollectionParser::TokenMap CollectionParser::ParseVocabBagOfWordsUci() {

boost::iostreams::stream<mapped_file_source> vocab(config_.vocab_file_path());

std::map<std::string, int> token_to_token_id;
std::map<Token, int> token_to_token_id;

TokenMap token_info;
std::string str;
Expand All @@ -285,15 +286,28 @@ CollectionParser::TokenMap CollectionParser::ParseVocabBagOfWordsUci() {
BOOST_THROW_EXCEPTION(InvalidOperation(ss.str()));
}

if (token_to_token_id.find(str) != token_to_token_id.end()) {
std::vector<std::string> strs;
boost::split(strs, str, boost::is_any_of("\t "));
if ((strs.size() == 0) || (strs.size() > 2)) {
std::stringstream ss;
ss << "Token '" << str << "' found twice, lines " << (token_to_token_id.find(str)->second + 1)
ss << "Error at line " << (token_id + 1) << ", file " << config_.vocab_file_path()
<< ". Expected format: <token> [<class_id>]";
BOOST_THROW_EXCEPTION(InvalidOperation(ss.str()));
}

ClassId class_id = (strs.size() == 2) ? strs[1] : DefaultClass;
Token token(class_id, strs[0]);

if (token_to_token_id.find(token) != token_to_token_id.end()) {
std::stringstream ss;
ss << "Token (" << token.keyword << ", " << token.class_id << "' found twice, lines "
<< (token_to_token_id.find(token)->second + 1)
<< " and " << (token_id + 1) << ", file " << config_.vocab_file_path();
BOOST_THROW_EXCEPTION(InvalidOperation(ss.str()));
}

token_info.insert(std::make_pair(token_id, CollectionParserTokenInfo(str)));
token_to_token_id.insert(std::make_pair(str, token_id));
token_info.insert(std::make_pair(token_id, CollectionParserTokenInfo(token.keyword, token.class_id)));
token_to_token_id.insert(std::make_pair(token, token_id));
token_id++;
}

Expand All @@ -313,7 +327,7 @@ CollectionParser::TokenMap CollectionParser::ParseVocabMatrixMarket() {
int token_id, token_count;
for (std::string token; vocab >> token_id >> token >> token_count;) {
// token_count is ignored --- it will be re-calculated based on the docword file.
token_info.insert(std::make_pair(token_id, CollectionParserTokenInfo(token)));
token_info.insert(std::make_pair(token_id, CollectionParserTokenInfo(token, DefaultClass)));
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/artm/core/collection_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "boost/utility.hpp"

#include "artm/messages.pb.h"
#include "artm/core/common.h"

namespace artm {
namespace core {
Expand All @@ -31,10 +32,11 @@ class CollectionParser : boost::noncopyable {
private:
struct CollectionParserTokenInfo {
explicit CollectionParserTokenInfo()
: keyword(), token_count(), items_count() {}
explicit CollectionParserTokenInfo(std::string keyword_)
: keyword(keyword_), token_count(0), items_count(0) {}
: keyword(), class_id(DefaultClass), token_count(), items_count() {}
explicit CollectionParserTokenInfo(std::string keyword_, ClassId class_id_)
: keyword(keyword_), class_id(class_id_), token_count(0), items_count(0) {}
std::string keyword;
ClassId class_id;
int token_count;
int items_count;
};
Expand Down
6 changes: 6 additions & 0 deletions src/artm/core/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ void BatchHelpers::CompactBatch(const Batch& batch, Batch* compacted_batch) {
std::vector<int> orig_to_compacted_id_map(batch.token_size(), -1);
int compacted_dictionary_size = 0;

bool has_class_id = (batch.class_id_size() > 0);
for (int item_index = 0; item_index < batch.item_size(); ++item_index) {
auto item = batch.item(item_index);
auto compacted_item = compacted_batch->add_item();
Expand All @@ -147,10 +148,15 @@ void BatchHelpers::CompactBatch(const Batch& batch, Batch* compacted_batch) {
int token_id = field.token_id(token_index);
if (token_id < 0 || token_id >= batch.token_size())
BOOST_THROW_EXCEPTION(ArgumentOutOfRangeException("field.token_id", token_id));
if (has_class_id && (token_id >= batch.class_id_size()))
BOOST_THROW_EXCEPTION(ArgumentOutOfRangeException(
"field.token_id", token_id, "Too few entries in batch.class_id field"));

if (orig_to_compacted_id_map[token_id] == -1) {
orig_to_compacted_id_map[token_id] = compacted_dictionary_size++;
compacted_batch->add_token(batch.token(token_id));
if (has_class_id)
compacted_batch->add_class_id(batch.class_id(token_id));
}

compacted_field->set_token_id(token_index, orig_to_compacted_id_map[token_id]);
Expand Down
8 changes: 8 additions & 0 deletions src/artm/core/master_component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <vector>
#include <set>
#include <sstream>

#include "glog/logging.h"
#include "zmq.hpp"
Expand Down Expand Up @@ -51,6 +52,13 @@ bool MasterComponent::isInNetworkModusOperandi() const {
}

void MasterComponent::CreateOrReconfigureModel(const ModelConfig& config) {
if ((config.class_weight_size() != 0 || config.class_id_size() != 0) && !config.use_sparse_bow()) {
std::stringstream ss;
ss << "You have configured use_sparse_bow=false. "
<< "Fields ModelConfig.class_id and ModelConfig.class_weight not supported in this mode.";
BOOST_THROW_EXCEPTION(InvalidOperation(ss.str()));
}

instance_->CreateOrReconfigureModel(config);
network_client_interface_->CreateOrReconfigureModel(config);
}
Expand Down

0 comments on commit d51ffff

Please sign in to comment.