Skip to content

Commit

Permalink
Perplexity refactoring and fixing (#814)
Browse files Browse the repository at this point in the history
Perplexity refactoring for dealing with modalities.
  • Loading branch information
MelLain committed Jun 20, 2017
1 parent 14d0e2b commit b8b5e68
Show file tree
Hide file tree
Showing 15 changed files with 356 additions and 121 deletions.
20 changes: 16 additions & 4 deletions python/artm/score_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def __create_dict(keys, values):
zip(score_arrays, topics)} for score_arrays, topics in score_topic_list_list]) # noqa
return score_list_dict[-1] if last else score_list_dict

elif field_attrs[1] == 'repeated' and field_attrs[2] == 'struct':
score_list_dict = [{__getattr(s, field_attrs[3]): s for s in score_list}
for score_list in score_list_list] # noqa
return score_list_dict[-1] if last else score_list_dict
else:
raise ValueError('Unkown type of score tracker field')


def _set_properties(class_ref, attr_data):
for name, params in iteritems(attr_data):
Expand Down Expand Up @@ -128,16 +135,21 @@ def __init__(self, score):
:Properties:
* Note: every field is a list of info about score on all synchronizations.
* value - values of perplexity.
* raw - raw values in formula for perplexity.
* normalizer - normalizer values in formula for perplexity.
* zero_tokens - number of zero p(w|d) = sum_t p(w|t) p(t|d).
* raw - raw values in formula for perplexity (in case of one class id).
* normalizer - normalizer values in formula for perplexity (in case of one class id).
* zero_tokens - number of zero p(w|d) = sum_t p(w|t) p(t|d) (in case of one class id).
* class_id_info - array of structures, each structure contains raw, normalizer\
zero_tokens and class_id name (in case of several class ids).
* Note: every field has a version with prefix 'last_', means retrieving only\
info about the last synchronization.
"""
BaseScoreTracker.__init__(self, score)

_set_properties(PerplexityScoreTracker, {'value': {}, 'raw': {}, 'normalizer': {},
'zero_tokens': {'proto_name': 'zero_words'}})
'zero_tokens': {'proto_name': 'zero_words'},
'class_id_info': {'proto_qualifier': 'repeated',
'proto_type': 'struct',
'key_field_name': 'class_id'}})


class ItemsProcessedScoreTracker(BaseScoreTracker):
Expand Down
2 changes: 1 addition & 1 deletion python/tests/wrapper/test_02_regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_func():
num_document_passes = 10
num_outer_iterations = 8

perplexity_tol = 0.001
perplexity_tol = 1.0
expected_perplexity_value_on_iteration = {
0: 6703.161,
1: 2426.277,
Expand Down
2 changes: 1 addition & 1 deletion python/tests/wrapper/test_03_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_func():
num_document_passes = 10
num_outer_iterations = 5

perplexity_tol = 0.001
perplexity_tol = 1.0
expected_perplexity_value_on_iteration = {
0: 6710.208,
1: 2434.135,
Expand Down
4 changes: 2 additions & 2 deletions python/tests/wrapper/test_04_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def test_func():
smsp_phi_tau = -20.0
smsp_theta_tau = -3.0

perplexity_tol = 0.1
perplexity_tol = 1.0
expected_perp_col_value_on_iteration = {
0: 6649.1,
0: 6650.1,
1: 2300.2,
2: 1996.8,
3: 1786.1,
Expand Down
2 changes: 1 addition & 1 deletion python/tests/wrapper/test_11_master_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_func():
num_document_passes = 10
num_outer_iterations = 8

perplexity_tol = 0.001
perplexity_tol = 1.0
expected_perplexity_value_on_iteration = {
0: 6703.161,
1: 2426.277,
Expand Down
14 changes: 9 additions & 5 deletions src/artm/core/dictionary_operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ std::shared_ptr<Dictionary> DictionaryOperations::Gather(const GatherDictionaryA
}

int total_items_count = 0;
double sum_w_tf = 0.0;
std::unordered_map<ClassId, double> sum_w_tf;
for (const std::string& batch_file : batches) {
std::shared_ptr<Batch> batch_ptr = mem_batches.get(batch_file);
try {
Expand Down Expand Up @@ -279,15 +279,19 @@ std::shared_ptr<Dictionary> DictionaryOperations::Gather(const GatherDictionaryA

for (int index = 0; index < batch.token_size(); ++index) {
// unordered_map.operator[] creates element using default constructor if the key doesn't exist
TokenValues& token_info = token_freq_map[Token(batch.class_id(index), batch.token(index))];
ClassId token_class_id = batch.class_id(index);
TokenValues& token_info = token_freq_map[Token(token_class_id, batch.token(index))];
token_info.token_tf += token_n_w[index];
sum_w_tf += token_n_w[index];
token_info.token_df += token_df[index];

sum_w_tf[token_class_id] += token_n_w[index];
}
}

for (auto iter = token_freq_map.begin(); iter != token_freq_map.end(); ++iter)
iter->second.token_value = static_cast<float>(iter->second.token_tf / sum_w_tf);
for (auto& token_freq : token_freq_map) {
token_freq.second.token_value = static_cast<float>(token_freq.second.token_tf /
sum_w_tf[token_freq.first.class_id]);
}

LOG(INFO) << "Find " << token_freq_map.size()
<< " unique tokens in " << total_items_count << " items";
Expand Down
13 changes: 13 additions & 0 deletions src/artm/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,23 @@ message PerplexityScoreConfig {

// Represents a result of calculation of a perplexity score
message PerplexityScore {
message ClassIdInfo {
optional string class_id = 1;
optional double raw = 2;
optional double normalizer = 3;
optional int64 zero_words = 4;
}

// general perplexity value for all cases
optional double value = 1;

// fields for case of all class ids
optional double raw = 2;
optional double normalizer = 3;
optional int64 zero_words = 4;

// field for case of custom class ids
repeated ClassIdInfo class_id_info = 5;
}

// Represents a configuration of a theta sparsity score
Expand Down
184 changes: 134 additions & 50 deletions src/artm/score/perplexity.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2017, Additive Regularization of Topic Models.

// Author: Marina Suvorova (m.dudarenko@gmail.com)
// Authors: Marina Suvorova (m.dudarenko@gmail.com)
// Murat Apishev (great-mel@yandex.ru)

#include <cmath>
#include <map>
Expand Down Expand Up @@ -33,52 +34,68 @@ void Perplexity::AppendScore(
Score* score) {
int topic_size = p_wt.topic_size();

// the following code counts perplexity
bool use_classes_from_model = false;
if (config_.class_id_size() == 0) use_classes_from_model = true;
// fields of proto messages for all classes
std::unordered_map<::artm::core::ClassId, float> class_weight_map;
std::unordered_map<::artm::core::ClassId, double> normalizer_map;
std::unordered_map<::artm::core::ClassId, double> raw_map;
std::unordered_map<::artm::core::ClassId, ::google::protobuf::int64> zero_words_map;

std::map< ::artm::core::ClassId, float> class_weights;
if (use_classes_from_model) {
for (int i = 0; (i < args.class_id_size()) && (i < args.class_weight_size()); ++i)
class_weights.insert(std::make_pair(args.class_id(i), args.class_weight(i)));
double normalizer = 0.0;
double raw = 0.0;
::google::protobuf::int64 zero_words = 0;

// choose class_ids policy
if (config_.class_id_size() == 0) {
for (int i = 0; (i < args.class_id_size()) && (i < args.class_weight_size()); ++i) {
class_weight_map.insert(std::make_pair(args.class_id(i), args.class_weight(i)));
normalizer_map.insert(std::make_pair(args.class_id(i), 0.0));
raw_map.insert(std::make_pair(args.class_id(i), 0.0));
zero_words_map.insert(std::make_pair(args.class_id(i), 0));
}
} else {
for (auto& class_id : config_.class_id()) {
for (int i = 0; (i < args.class_id_size()) && (i < args.class_weight_size()); ++i)
for (int i = 0; (i < args.class_id_size()) && (i < args.class_weight_size()); ++i) {
if (class_id == args.class_id(i)) {
class_weights.insert(std::make_pair(args.class_id(i), args.class_weight(i)));
class_weight_map.insert(std::make_pair(args.class_id(i), args.class_weight(i)));
normalizer_map.insert(std::make_pair(args.class_id(i), 0.0));
raw_map.insert(std::make_pair(args.class_id(i), 0.0));
zero_words_map.insert(std::make_pair(args.class_id(i), 0));
break;
}
}
}
if (class_weight_map.empty()) {
LOG(ERROR) << "None of requested classes were presented in model. Score calculation will be skipped";
return;
}
}
bool use_class_id = !class_weights.empty();
bool use_class_ids = !class_weight_map.empty();

float n_d = 0;
// count perplexity normalizer n_d
for (int token_index = 0; token_index < item.token_weight_size(); ++token_index) {
float class_weight = 1.0f;
if (use_class_id) {
if (use_class_ids) {
::artm::core::ClassId class_id = token_dict[item.token_id(token_index)].class_id;
auto iter = class_weights.find(class_id);
if (iter == class_weights.end())
auto class_weight_iter = class_weight_map.find(class_id);
if (class_weight_iter == class_weight_map.end()) {
// we should not take tokens without class id weight into consideration
continue;
class_weight = iter->second;
}
}

n_d += class_weight * item.token_weight(token_index);
normalizer_map[class_id] += item.token_weight(token_index);
} else {
normalizer += item.token_weight(token_index);
}
}

::google::protobuf::int64 zero_words = 0;
double normalizer = 0;
double raw = 0;

// check dictionary existence for replacing zero pwt sums
std::shared_ptr<core::Dictionary> dictionary_ptr = nullptr;
if (config_.has_dictionary_name())
dictionary_ptr = dictionary(config_.dictionary_name());
bool has_dictionary = dictionary_ptr != nullptr;

bool use_document_unigram_model = true;
if (config_.has_model_type()) {
if (config_.model_type() == PerplexityScoreConfig_Type_UnigramCollectionModel) {
if (has_dictionary) {
if (dictionary_ptr) {
use_document_unigram_model = false;
} else {
LOG(ERROR) << "Perplexity was configured to use UnigramCollectionModel with dictionary " <<
Expand All @@ -88,21 +105,24 @@ void Perplexity::AppendScore(
}
}

// count raw values
std::vector<float> helper_vector(topic_size, 0.0f);
for (int token_index = 0; token_index < item.token_weight_size(); ++token_index) {
double sum = 0.0;
const artm::core::Token& token = token_dict[item.token_id(token_index)];

float class_weight = 1.0f;
if (use_class_id) {
auto iter = class_weights.find(token.class_id);
if (iter == class_weights.end())
if (use_class_ids) {
auto class_weight_iter = class_weight_map.find(token.class_id);
if (class_weight_iter == class_weight_map.end())
continue;
class_weight = iter->second;
class_weight = class_weight_iter->second;
}

float token_weight = class_weight * item.token_weight(token_index);
if (token_weight == 0.0f) continue;
if (token_weight == 0.0f)
continue;


int p_wt_token_index = p_wt.token_index(token);
if (p_wt_token_index != ::artm::core::PhiMatrix::kUndefIndex) {
Expand All @@ -113,7 +133,7 @@ void Perplexity::AppendScore(
}
if (sum == 0.0) {
if (use_document_unigram_model) {
sum = token_weight / n_d;
sum = token_weight / (use_class_ids ? normalizer_map[token.class_id] : normalizer);
} else {
auto entry_ptr = dictionary_ptr->entry(token);
bool failed = true;
Expand All @@ -128,21 +148,32 @@ void Perplexity::AppendScore(
<< ". Verify that the token exists in the dictionary and it's value > 0. "
<< "Document unigram model will be used for this token "
<< "(and for all other tokens under the same conditions).";
sum = token_weight / n_d;
sum = token_weight / (use_class_ids ? normalizer_map[token.class_id] : normalizer);
}
}
zero_words++;
// the presence of class_id in the maps here and below is guaranteed
++(use_class_ids ? zero_words_map[token.class_id] : zero_words);
}

normalizer += token_weight;
raw += token_weight * log(sum);
(use_class_ids ? raw_map[token.class_id] : raw) += token_weight * log(sum);
}

// prepare results
PerplexityScore perplexity_score;
perplexity_score.set_normalizer(normalizer);
perplexity_score.set_raw(raw);
perplexity_score.set_zero_words(zero_words);
if (use_class_ids) {
for (auto iter = normalizer_map.begin(); iter != normalizer_map.end(); ++iter) {
auto class_id_info = perplexity_score.add_class_id_info();

class_id_info->set_class_id(iter->first);
class_id_info->set_normalizer(iter->second);
class_id_info->set_raw(raw_map[iter->first]);
class_id_info->set_zero_words(zero_words_map[iter->first]);
}
} else {
perplexity_score.set_normalizer(normalizer);
perplexity_score.set_raw(raw);
perplexity_score.set_zero_words(zero_words);
}

AppendScore(perplexity_score, score);
}

Expand All @@ -163,17 +194,70 @@ void Perplexity::AppendScore(const Score& score, Score* target) {
BOOST_THROW_EXCEPTION(::artm::core::InternalError(error_message));
}

perplexity_target->set_normalizer(perplexity_target->normalizer() +
perplexity_score->normalizer());
perplexity_target->set_raw(perplexity_target->raw() +
perplexity_score->raw());
perplexity_target->set_zero_words(perplexity_target->zero_words() +
perplexity_score->zero_words());
perplexity_target->set_value(exp(- perplexity_target->raw() / perplexity_target->normalizer()));

VLOG(1) << "normalizer=" << perplexity_target->normalizer()
<< ", raw=" << perplexity_target->raw()
<< ", zero_words=" << perplexity_target->zero_words();
bool empty_target = !perplexity_target->class_id_info_size() && !perplexity_target->normalizer();
bool score_has_class_ids = perplexity_score->class_id_info_size();
bool target_has_class_ids = empty_target ? score_has_class_ids : perplexity_target->class_id_info_size();
if (target_has_class_ids != score_has_class_ids) {
std::stringstream ss;
ss <<"Inconsistent new content of perplexity score. Old content uses class ids: " << target_has_class_ids;
BOOST_THROW_EXCEPTION(::artm::core::InternalError(ss.str()));
}

double pre_value = 0.0;
if (target_has_class_ids) {
for (size_t i = 0; i < perplexity_score->class_id_info_size(); ++i) {
auto src = perplexity_score->class_id_info(i);

bool was_added = false;
for (size_t j = 0; j < perplexity_target->class_id_info_size(); ++j) {
if (perplexity_score->class_id_info(i).class_id() == perplexity_target->class_id_info(j).class_id()) {
// update existing class_id info
auto dst = perplexity_target->mutable_class_id_info(j);
dst->set_normalizer(dst->normalizer() + src.normalizer());
dst->set_raw(dst->raw() + src.raw());
dst->set_zero_words(dst->zero_words() + src.zero_words());

was_added = true;
break;
}
}

if (!was_added) {
// add new class_id info
auto dst = perplexity_target->add_class_id_info();
dst->set_class_id(src.class_id());
dst->set_normalizer(src.normalizer());
dst->set_raw(src.raw());
dst->set_zero_words(src.zero_words());
}
}

for (size_t j = 0; j < perplexity_target->class_id_info_size(); ++j) {
auto score = perplexity_target->class_id_info(j);
pre_value += score.raw() / score.normalizer();

VLOG(1) << "class_id=" << score.class_id()
<< ", normalizer=" << score.normalizer()
<< ", raw=" << score.raw()
<< ", zero_words=" << score.zero_words();
}
} else {
auto src = perplexity_score;
auto dst = perplexity_target;

dst->set_normalizer(dst->normalizer() + src->normalizer());
dst->set_raw(dst->raw() + src->raw());
dst->set_zero_words(dst->zero_words() + src->zero_words());

pre_value = dst->raw() / dst->normalizer();

VLOG(1) << "use all class_ids"
<< ", normalizer=" << dst->normalizer()
<< ", raw=" << dst->raw()
<< ", zero_words=" << dst->zero_words();
}

perplexity_target->set_value(exp(- pre_value));
}

} // namespace score
Expand Down

0 comments on commit b8b5e68

Please sign in to comment.