Skip to content

Commit

Permalink
Merge pull request #128 from sashafrey/master
Browse files Browse the repository at this point in the history
Loop across all batch fields in perplexity.cc
  • Loading branch information
bigartm committed Feb 19, 2015
2 parents 928ff83 + 233c6ac commit 08ab5b2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 59 deletions.
8 changes: 4 additions & 4 deletions src/artm/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ message ModelConfig {
repeated string topic_name = 3;
optional bool enabled = 4 [default = true];
optional int32 inner_iterations_count = 5 [default = 10];
optional string field_name = 6 [default = "@body"];
optional string field_name = 6 [default = "@body"]; // obsolete in BigARTM v0.5.8
optional string stream_name = 7 [default = "@global"];
repeated string score_name = 8;
optional bool reuse_theta = 9 [default = false];
Expand Down Expand Up @@ -236,7 +236,7 @@ message PerplexityScoreConfig {
UnigramDocumentModel = 0;
UnigramCollectionModel = 1;
}
optional string field_name = 1 [default = "@body"];
optional string field_name = 1 [default = "@body"]; // obsolete in BigARTM v0.5.8
optional string stream_name = 2 [default = "@global"];
optional Type model_type = 3 [default = UnigramDocumentModel];
optional string dictionary_name = 4;
Expand All @@ -257,7 +257,7 @@ message PerplexityScore {

// Represents a configuration of a theta sparsity score
message SparsityThetaScoreConfig {
optional string field_name = 1 [default = "@body"];
optional string field_name = 1 [default = "@body"]; // obsolete in BigARTM v0.5.8
optional string stream_name = 2 [default = "@global"];
optional float eps = 3 [default = 1e-37];
repeated string topic_name = 4;
Expand Down Expand Up @@ -286,7 +286,7 @@ message SparsityPhiScore {

// Represents a configuration of an items processed score
message ItemsProcessedScoreConfig {
optional string field_name = 1 [default = "@body"];
optional string field_name = 1 [default = "@body"]; // obsolete in BigARTM v0.5.8
optional string stream_name = 2 [default = "@global"];
}

Expand Down
102 changes: 47 additions & 55 deletions src/artm/score_sandbox/perplexity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,20 @@ void Perplexity::AppendScore(
class_weights.insert(std::make_pair(model_config.class_id(i), model_config.class_weight(i)));
bool use_class_id = !class_weights.empty();

const Field* field = nullptr;
for (int field_index = 0; field_index < item.field_size(); field_index++) {
if (item.field(field_index).name() == config_.field_name()) {
field = &item.field(field_index);
}
}

if (field == nullptr) {
LOG(ERROR) << "Unable to find field " << config_.field_name() << " in item " << item.id();
return;
}

float n_d = 0;
for (int token_index = 0; token_index < field->token_count_size(); ++token_index) {
float class_weight = 1.0f;
if (use_class_id) {
::artm::core::ClassId class_id = token_dict[field->token_id(token_index)].class_id;
auto iter = class_weights.find(class_id);
if (iter == class_weights.end())
continue;
class_weight = iter->second;
}
for (auto& field : item.field()) {
for (int token_index = 0; token_index < field.token_count_size(); ++token_index) {
float class_weight = 1.0f;
if (use_class_id) {
::artm::core::ClassId class_id = token_dict[field.token_id(token_index)].class_id;
auto iter = class_weights.find(class_id);
if (iter == class_weights.end())
continue;
class_weight = iter->second;
}

n_d += class_weight * static_cast<float>(field->token_count(token_index));
n_d += class_weight * static_cast<float>(field.token_count(token_index));
}
}

int zero_words = 0;
Expand Down Expand Up @@ -112,47 +102,49 @@ void Perplexity::AppendScore(
}
}

for (int token_index = 0; token_index < field->token_count_size(); ++token_index) {
double sum = 0.0;
const artm::core::Token& token = token_dict[field->token_id(token_index)];

float class_weight = 1.0f;
if (use_class_id) {
auto iter = class_weights.find(token.class_id);
if (iter == class_weights.end())
continue;
class_weight = iter->second;
}
for (auto& field : item.field()) {
for (int token_index = 0; token_index < field.token_count_size(); ++token_index) {
double sum = 0.0;
const artm::core::Token& token = token_dict[field.token_id(token_index)];

float class_weight = 1.0f;
if (use_class_id) {
auto iter = class_weights.find(token.class_id);
if (iter == class_weights.end())
continue;
class_weight = iter->second;
}

int token_count_int = field->token_count(token_index);
if (token_count_int == 0) continue;
double token_count = class_weight * static_cast<double>(token_count_int);
int token_count_int = field.token_count(token_index);
if (token_count_int == 0) continue;
double token_count = class_weight * static_cast<double>(token_count_int);

if (topic_model.has_token(token)) {
::artm::core::TopicWeightIterator topic_iter = topic_model.GetTopicWeightIterator(token);
while (topic_iter.NextNonZeroTopic() < topics_count) {
sum += theta[topic_iter.TopicIndex()] * topic_iter.Weight();
if (topic_model.has_token(token)) {
::artm::core::TopicWeightIterator topic_iter = topic_model.GetTopicWeightIterator(token);
while (topic_iter.NextNonZeroTopic() < topics_count) {
sum += theta[topic_iter.TopicIndex()] * topic_iter.Weight();
}
}
}

if (sum == 0.0) {
if (use_document_unigram_model) {
sum = token_count / n_d;
} else {
if (dictionary_ptr->find(token) != dictionary_ptr->end()) {
float n_w = dictionary_ptr->find(token)->second.value();
sum = n_w / dictionary_ptr->size();
} else {
LOG(INFO) << "No token " << token.keyword << " from class " << token.class_id <<
"in dictionary, document unigram model will be used.";
if (sum == 0.0) {
if (use_document_unigram_model) {
sum = token_count / n_d;
} else {
if (dictionary_ptr->find(token) != dictionary_ptr->end()) {
float n_w = dictionary_ptr->find(token)->second.value();
sum = n_w / dictionary_ptr->size();
} else {
LOG(INFO) << "No token " << token.keyword << " from class " << token.class_id <<
"in dictionary, document unigram model will be used.";
sum = token_count / n_d;
}
}
zero_words++;
}
zero_words++;
}

normalizer += token_count;
raw += token_count * log(sum);
normalizer += token_count;
raw += token_count * log(sum);
}
}

// prepare results
Expand Down
1 change: 1 addition & 0 deletions src/artm_tests/multiple_classes_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ artm::Batch GenerateBatch(int nTokens, int nDocs, std::string class1, std::strin
artm::Item* item = batch.add_item();
item->set_id(iDoc);
artm::Field* field = item->add_field();
field->set_name("custom_field_name");
for (int iToken = 0; iToken < nTokens; ++iToken) {
field->add_token_id(iToken);
int background_count = (iToken > 40) ? (1 + rand() % 5) : 0; // NOLINT
Expand Down

0 comments on commit 08ab5b2

Please sign in to comment.