-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New regularizer for smoothing time in topics was added
- Loading branch information
Showing
10 changed files
with
262 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2017, Additive Regularization of Topic Models. | ||
|
||
import shutil | ||
import glob | ||
import tempfile | ||
import os | ||
import pytest | ||
|
||
from six.moves import range, zip | ||
|
||
import artm | ||
import pandas as pd | ||
|
||
|
||
def test_func(): | ||
num_topics = 20 | ||
tolerance = 0.01 | ||
first_sparsity = 0.189 | ||
second_sparsity = 0.251 | ||
|
||
data_path = os.environ.get('BIGARTM_UNITTEST_DATA') | ||
batches_folder = tempfile.mkdtemp() | ||
|
||
try: | ||
batch_vectorizer = artm.BatchVectorizer(data_path=data_path, | ||
data_format='bow_uci', | ||
collection_name='kos', | ||
target_folder=batches_folder) | ||
|
||
model = artm.ARTM(num_topics=num_topics, dictionary=batch_vectorizer.dictionary) | ||
|
||
model.scores.add(artm.SparsityPhiScore(name='sp_phi_one', topic_names=model.topic_names[0: 10])) | ||
model.scores.add(artm.SparsityPhiScore(name='sp_phi_two', topic_names=model.topic_names[10: ])) | ||
|
||
model.regularizers.add(artm.SmoothTimeInTopicsPhiRegularizer(tau=1000.0, topic_names=model.topic_names[0: 10])) | ||
|
||
model.fit_offline(batch_vectorizer, 20) | ||
|
||
assert abs(model.score_tracker['sp_phi_one'].last_value - first_sparsity) < tolerance | ||
assert abs(model.score_tracker['sp_phi_two'].last_value - second_sparsity) < tolerance | ||
finally: | ||
shutil.rmtree(batches_folder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright 2017, Additive Regularization of Topic Models. | ||
|
||
// Author: Murat Apishev (great-mel@yandex.ru) | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "artm/core/protobuf_helpers.h" | ||
#include "artm/core/phi_matrix.h" | ||
#include "artm/regularizer/smooth_time_in_topics_phi.h" | ||
|
||
namespace artm { | ||
namespace regularizer { | ||
|
||
bool SmoothTimeInTopicsPhi::RegularizePhi(const ::artm::core::PhiMatrix& p_wt, | ||
const ::artm::core::PhiMatrix& n_wt, | ||
::artm::core::PhiMatrix* result) { | ||
// read the parameters from config and control their correctness | ||
const int topic_size = p_wt.topic_size(); | ||
const int token_size = p_wt.token_size(); | ||
|
||
std::vector<bool> topics_to_regularize; | ||
if (config_.topic_name().size() == 0) | ||
topics_to_regularize.assign(topic_size, true); | ||
else | ||
topics_to_regularize = core::is_member(p_wt.topic_name(), config_.topic_name()); | ||
|
||
// proceed the regularization | ||
// will update only tokens of given modality, that have prev and post tokens of this modality | ||
int index_prev_prev = -1; | ||
int index_prev = -1; | ||
for (int token_id = 0; token_id < token_size; ++token_id) { | ||
const ::artm::core::Token& token = p_wt.token(token_id); | ||
|
||
if (token.class_id != config_.class_id()) | ||
continue; | ||
|
||
if (index_prev_prev < 0) { | ||
index_prev_prev = token_id; | ||
continue; | ||
} | ||
|
||
if (index_prev < 0) { | ||
index_prev = token_id; | ||
continue; | ||
} | ||
|
||
for (int topic_id = 0; topic_id < topic_size; ++topic_id) { | ||
if (topics_to_regularize[topic_id]) { | ||
double value = p_wt.get(index_prev, topic_id); | ||
|
||
value *= ((p_wt.get(index_prev_prev, topic_id) - value) > 0.0 ? 1.0 : -1.0) + | ||
((p_wt.get(token_id, topic_id) - value) > 0.0 ? 1.0 : -1.0); | ||
|
||
result->set(index_prev, topic_id, value); | ||
} | ||
} | ||
index_prev_prev = index_prev; | ||
index_prev = token_id; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
google::protobuf::RepeatedPtrField<std::string> SmoothTimeInTopicsPhi::topics_to_regularize() { | ||
return config_.topic_name(); | ||
} | ||
|
||
google::protobuf::RepeatedPtrField<std::string> SmoothTimeInTopicsPhi::class_ids_to_regularize() { | ||
google::protobuf::RepeatedPtrField<std::string> retval; | ||
std::string* ptr = retval.Add(); | ||
*ptr = config_.class_id(); | ||
return retval; | ||
} | ||
|
||
bool SmoothTimeInTopicsPhi::Reconfigure(const RegularizerConfig& config) { | ||
std::string config_blob = config.config(); | ||
SmoothTimeInTopicsPhiConfig regularizer_config; | ||
if (!regularizer_config.ParseFromString(config_blob)) { | ||
BOOST_THROW_EXCEPTION(::artm::core::CorruptedMessageException( | ||
"Unable to parse SmoothSparsePhiConfig from RegularizerConfig.config")); | ||
} | ||
|
||
config_.CopyFrom(regularizer_config); | ||
return true; | ||
} | ||
|
||
} // namespace regularizer | ||
} // namespace artm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* Copyright 2017, Additive Regularization of Topic Models. | ||
Author: Murat Apishev (great-mel@yandex.ru) | ||
This class proceeds smoothing of tokens in Phi using nearest values. | ||
More preferable for time stamps tokens. Requires to be used with sorted dictionary | ||
(e.g. the tokens should follow in some order, chronological, for instance). | ||
The formula of M-step is | ||
p_wt \propto n_wt + tau * p_wt * (sign(p_{w-1,t} - p_wt) + sign(p_{w+1,t} - p_wt)), | ||
The parameters of the regularizer: | ||
- topic_names (the names of topics to regularize, empty == all) | ||
- class_id (class id to regularize, required) | ||
Note: regularizer ignores first and last tokens of given modality. | ||
*/ | ||
|
||
#ifndef SRC_ARTM_REGULARIZER_SMOOTH_TIME_IN_TOPICS_PHI_H_ | ||
#define SRC_ARTM_REGULARIZER_SMOOTH_TIME_IN_TOPICS_PHI_H_ | ||
|
||
#include <memory> | ||
#include <string> | ||
|
||
#include "artm/regularizer_interface.h" | ||
|
||
namespace artm { | ||
namespace regularizer { | ||
|
||
class SmoothTimeInTopicsPhi : public RegularizerInterface { | ||
public: | ||
explicit SmoothTimeInTopicsPhi(const SmoothTimeInTopicsPhiConfig& config) : config_(config) { } | ||
|
||
virtual bool RegularizePhi(const ::artm::core::PhiMatrix& p_wt, | ||
const ::artm::core::PhiMatrix& n_wt, | ||
::artm::core::PhiMatrix* result); | ||
|
||
virtual google::protobuf::RepeatedPtrField<std::string> topics_to_regularize(); | ||
virtual google::protobuf::RepeatedPtrField<std::string> class_ids_to_regularize(); | ||
|
||
virtual bool Reconfigure(const RegularizerConfig& config); | ||
|
||
private: | ||
SmoothTimeInTopicsPhiConfig config_; | ||
}; | ||
|
||
} // namespace regularizer | ||
} // namespace artm | ||
|
||
#endif // SRC_ARTM_REGULARIZER_SMOOTH_TIME_IN_TOPICS_PHI_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters