Skip to content

Commit

Permalink
Refactoring (#824)
Browse files Browse the repository at this point in the history
Code refactoring:
- all unnecessary doubles were replaced with floats
- everywhere if and for statements are using with brackets
- use const qualifiers for references in range based for
- replace guard headers with #pragma once (update cpplint)
- some code decoration in regularizers and scores
  • Loading branch information
MelLain committed Jul 27, 2017
1 parent 762b9d9 commit a456174
Show file tree
Hide file tree
Showing 120 changed files with 2,546 additions and 1,566 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ src/Win32/*
*.log.ERROR.*
*.batch

*\#*\#*

# Binaries
/src/artm/unittests/
/src/cpp_client/cpp_client
Expand Down
4 changes: 2 additions & 2 deletions python/artm/master_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ def regularize_model(self, pwt, nwt, rwt, regularizer_name, regularizer_tau, reg
:param regularizer_name: list of names of Phi regularizers to use
:type regularizer_name: list of str
:param regularizer_tau: list of tau coefficients for Phi regularizers
:type regularizer_tau: list of double
:type regularizer_tau: list of double
:type regularizer_tau: list of floats
:type regularizer_tau: list of floats
"""
args = messages.RegularizeModelArgs(pwt_source_name=pwt,
nwt_source_name=nwt,
Expand Down
8 changes: 4 additions & 4 deletions python/artm/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class KlFunctionInfo(object):
def __init__(self, function_type='log', power_value=2.0):
"""
:param str function_type: the type of function, 'log' (logarithm) or 'pol' (polynomial)
:param float power_value: the double power of polynomial, ignored if type = 'log'
:param float power_value: the float power of polynomial, ignored if type = 'log'
"""
if function_type not in ['log', 'pol']:
raise ValueError('Function type can be only "log" or "pol"')
Expand Down Expand Up @@ -390,16 +390,16 @@ def __init__(self, name=None, tau=1.0, topic_names=None, alpha_iter=None,
User should guarantee the existence and correctness of\
document titles in batches (e.g. in src files with data, like WV).
:type doc_titles: list of strings
:param doc_topic_coef: Two cases: 1) list of doubles with length equal to num of topics.\
:param doc_topic_coef: Two cases: 1) list of floats with length equal to num of topics.\
Means additional multiplier in M-step formula besides alpha and\
tau, unique for each topic, but general for all processing documents.\
2) list of lists of doubles with outer list length equal to length\
2) list of lists of floats with outer list length equal to length\
of doc_titles, and each inner list length equal to num of topics.\
Means case 1 with unique list of additional multipliers for each\
document from doc_titles. Other documents will not be regularized\
according to description of doc_titles parameter.\
Note, that doc_topic_coef and topic_names are both using.
:type doc_topic_coef: list of doubles or list of lists of doubles
:type doc_topic_coef: list of floats or list of lists of floats
:param config: the low-level config of this regularizer
:type config: protobuf object
"""
Expand Down
60 changes: 44 additions & 16 deletions src/artm/c_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,53 @@ static void set_last_error(const std::string& error) {
static void EnableLogging(artm::ConfigureLoggingArgs* args) {
static bool logging_enabled = false;

if (logging_enabled && args != nullptr && args->has_log_dir() && (FLAGS_log_dir != args->log_dir()))
if (logging_enabled && args != nullptr && args->has_log_dir() && (FLAGS_log_dir != args->log_dir())) {
BOOST_THROW_EXCEPTION(::artm::core::InvalidOperation(
"Logging directory can't be change after the logging started."));
if (!logging_enabled && args != nullptr && args->has_log_dir())
if (!boost::filesystem::exists(args->log_dir()) || !boost::filesystem::is_directory(args->log_dir()))
}
if (!logging_enabled && args != nullptr && args->has_log_dir()) {
if (!boost::filesystem::exists(args->log_dir()) || !boost::filesystem::is_directory(args->log_dir())) {
BOOST_THROW_EXCEPTION(::artm::core::InvalidOperation(
"Can not enable logging to " + args->log_dir() + ", check that the folder exist"));
}
}

// Setting all other flags except log_dir
if (args != nullptr) {
if (args->has_minloglevel()) FLAGS_minloglevel = args->minloglevel();
if (args->has_stderrthreshold()) FLAGS_stderrthreshold = args->stderrthreshold();
if (args->has_logtostderr()) FLAGS_logtostderr = args->logtostderr();
if (args->has_minloglevel()) {
FLAGS_minloglevel = args->minloglevel();
}

if (args->has_stderrthreshold()) {
FLAGS_stderrthreshold = args->stderrthreshold();
}

if (args->has_logtostderr()) {
FLAGS_logtostderr = args->logtostderr();
}

if (args->has_colorlogtostderr()) {
FLAGS_colorlogtostderr = args->colorlogtostderr();
}

if (args->has_colorlogtostderr()) FLAGS_colorlogtostderr = args->colorlogtostderr();
if (args->has_alsologtostderr()) FLAGS_alsologtostderr = args->alsologtostderr();
if (args->has_alsologtostderr()) {
FLAGS_alsologtostderr = args->alsologtostderr();
}

if (args->has_logbufsecs()) FLAGS_logbufsecs = args->logbufsecs();
if (args->has_logbuflevel()) FLAGS_logbuflevel = args->logbuflevel();
if (args->has_logbufsecs()) {
FLAGS_logbufsecs = args->logbufsecs();
}

if (args->has_max_log_size()) FLAGS_max_log_size = args->max_log_size();
if (args->has_stop_logging_if_full_disk()) FLAGS_stop_logging_if_full_disk = args->stop_logging_if_full_disk();
if (args->has_logbuflevel()) {
FLAGS_logbuflevel = args->logbuflevel();
}

if (args->has_max_log_size()) {
FLAGS_max_log_size = args->max_log_size();
}
if (args->has_stop_logging_if_full_disk()) {
FLAGS_stop_logging_if_full_disk = args->stop_logging_if_full_disk();
}

// ::google::SetVLOGLevel() is not supported in non-gcc compilers
// https://groups.google.com/forum/#!topic/google-glog/f8D7qpXLWXw
Expand Down Expand Up @@ -277,14 +302,16 @@ int64_t ArtmAwaitOperation(int operation_id, int64_t length, const char* await_o
const int timeout = args.timeout_milliseconds();
auto time_start = boost::posix_time::microsec_clock::local_time();
for (;;) {
if (batch_manager->IsEverythingProcessed())
if (batch_manager->IsEverythingProcessed()) {
return ARTM_SUCCESS;
}

boost::this_thread::sleep(boost::posix_time::milliseconds(::artm::core::kIdleLoopFrequency));
auto time_end = boost::posix_time::microsec_clock::local_time();
if (timeout >= 0) {
if ((time_end - time_start).total_milliseconds() >= timeout)
if ((time_end - time_start).total_milliseconds() >= timeout) {
break;
}
}
}

Expand Down Expand Up @@ -369,10 +396,11 @@ int64_t ArtmExecute(int master_id, int64_t length, const char* args_blob, const
ArgsT args;
ParseFromArray(args_blob, length, &args);

if (name != nullptr)
if (name != nullptr) {
args.set_name(name);
else
} else {
args.clear_name();
}

::artm::core::FixAndValidateMessage(&args, /* throw_error =*/ true);
std::string description = ::artm::core::DescribeMessage(args);
Expand Down
5 changes: 1 addition & 4 deletions src/artm/c_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
// All methods must be inside "extern "C"" scope. All complex data structures should be passed in
// as Google Protobuf Messages, defined in messages.proto.

#ifndef SRC_ARTM_C_INTERFACE_H_
#define SRC_ARTM_C_INTERFACE_H_
#pragma once

#include <stdint.h>

Expand Down Expand Up @@ -104,5 +103,3 @@ extern "C" {
DLL_PUBLIC int64_t ArtmSetProtobufMessageFormatToBinary();
DLL_PUBLIC int64_t ArtmProtobufMessageFormatIsJson();
}

#endif // SRC_ARTM_C_INTERFACE_H_
2 changes: 1 addition & 1 deletion src/artm/core/batch_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace artm {
namespace core {

BatchManager::BatchManager() : lock_(), in_progress_() {}
BatchManager::BatchManager() : lock_(), in_progress_() { }

void BatchManager::Add(const boost::uuids::uuid& task_id) {
boost::lock_guard<boost::mutex> guard(lock_);
Expand Down
5 changes: 1 addition & 4 deletions src/artm/core/batch_manager.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2017, Additive Regularization of Topic Models.

#ifndef SRC_ARTM_CORE_BATCH_MANAGER_H_
#define SRC_ARTM_CORE_BATCH_MANAGER_H_
#pragma once

#include <set>
#include <string>
Expand Down Expand Up @@ -40,5 +39,3 @@ class BatchManager : boost::noncopyable {

} // namespace core
} // namespace artm

#endif // SRC_ARTM_CORE_BATCH_MANAGER_H_
78 changes: 54 additions & 24 deletions src/artm/core/cache_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@ namespace artm {
namespace core {

ThetaCacheEntry::ThetaCacheEntry()
: theta_matrix_(std::make_shared<ThetaMatrix>()), filename_() {}
: theta_matrix_(std::make_shared<ThetaMatrix>())
, filename_() { }

ThetaCacheEntry::~ThetaCacheEntry() {
if (!filename_.empty()) {
try { fs::remove(fs::path(filename_)); }
catch (...) {}
catch (...) { }
}
}

CacheManager::CacheManager(const std::string& disk_path, Instance* instance)
: lock_(), disk_path_(disk_path), instance_(instance), cache_() {
: lock_()
, disk_path_(disk_path)
, instance_(instance)
, cache_() {
Clear();
}

Expand All @@ -46,10 +50,11 @@ void CacheManager::Clear() {
}

void CacheManager::RequestMasterComponentInfo(MasterComponentInfo* master_info) const {
for (auto& key : cache_.keys()) {
for (const auto& key : cache_.keys()) {
std::shared_ptr<ThetaCacheEntry> entry = cache_.get(key);
if (entry == nullptr)
if (entry == nullptr) {
continue;
}

MasterComponentInfo::CacheEntryInfo* info = master_info->add_cache_entry();
info->set_key(boost::lexical_cast<std::string>(key));
Expand Down Expand Up @@ -83,36 +88,44 @@ static bool PopulateThetaMatrixFromCacheEntry(const ThetaMatrix& cache,
}
} else { // use all topics
assert(cache.topic_name_size() > 0);
for (int i = 0; i < cache.topic_name_size(); ++i)
for (int i = 0; i < cache.topic_name_size(); ++i) {
topics_to_use.push_back(i);
}
use_all_topics = true;
}

// Populate num_topics and topic_name fields in the resulting message
::google::protobuf::RepeatedPtrField< ::std::string> result_topic_name;
for (int topic_index : topics_to_use)
for (int topic_index : topics_to_use) {
result_topic_name.Add()->assign(cache.topic_name(topic_index));
}

if (theta_matrix->topic_name_size() == 0) {
// Assign
theta_matrix->set_num_topics(result_topic_name.size());
assert(theta_matrix->topic_name_size() == 0);
for (const TopicName& topic_name : result_topic_name)
for (const TopicName& topic_name : result_topic_name) {
theta_matrix->add_topic_name(topic_name);
}
} else {
// Verify
if (theta_matrix->num_topics() != result_topic_name.size())
if (theta_matrix->num_topics() != result_topic_name.size()) {
BOOST_THROW_EXCEPTION(artm::core::InternalError("theta_matrix->num_topics() != result_topic_name.size()"));
}

for (int i = 0; i < theta_matrix->topic_name_size(); ++i) {
if (theta_matrix->topic_name(i) != result_topic_name.Get(i))
if (theta_matrix->topic_name(i) != result_topic_name.Get(i)) {
BOOST_THROW_EXCEPTION(artm::core::InternalError("theta_matrix->topic_name(i) != result_topic_name.Get(i)"));
}
}
}

bool has_title = (cache.item_title_size() == cache.item_id_size());
for (int item_index = 0; item_index < cache.item_id_size(); ++item_index) {
theta_matrix->add_item_id(cache.item_id(item_index));
if (has_title) theta_matrix->add_item_title(cache.item_title(item_index));
if (has_title) {
theta_matrix->add_item_title(cache.item_title(item_index));
}
::artm::FloatArray* theta_vec = theta_matrix->add_item_weights();

const artm::FloatArray& item_theta = cache.item_weights(item_index);
Expand All @@ -125,8 +138,9 @@ static bool PopulateThetaMatrixFromCacheEntry(const ThetaMatrix& cache,
}
} else {
// dense output -- dense cache
for (int topic_index : topics_to_use)
for (int topic_index : topics_to_use) {
theta_vec->add_value(item_theta.value(topic_index));
}
}
} else {
::artm::IntArray* sparse_topic_indices = theta_matrix->add_topic_indices();
Expand Down Expand Up @@ -179,19 +193,21 @@ void CacheManager::RequestThetaMatrix(const GetThetaMatrixArgs& get_theta_args,
cached_theta.add_item_id(-1); // not available
::artm::FloatArray* item_weights = cached_theta.add_item_weights();
phi_matrix->get(token_id, &values);
for (int topic_index = 0; topic_index < phi_matrix->topic_size(); topic_index++)
for (int topic_index = 0; topic_index < phi_matrix->topic_size(); topic_index++) {
item_weights->add_value(values[topic_index]);
}
}

PopulateThetaMatrixFromCacheEntry(cached_theta, get_theta_args, theta_matrix);
return;
}

auto keys = cache_.keys();
for (auto &key : keys) {
for (const auto &key : keys) {
std::shared_ptr<ThetaMatrix> cached_theta = FindCacheEntry(key);
if (cached_theta != nullptr)
if (cached_theta != nullptr) {
PopulateThetaMatrixFromCacheEntry(*cached_theta, get_theta_args, theta_matrix);
}
}
}

Expand All @@ -205,15 +221,23 @@ std::shared_ptr<ThetaMatrix> CacheManager::FindCacheEntry(const Batch& batch) co
std::vector<float> values; values.resize(phi_matrix->topic_size());
for (int item_id = 0; item_id < batch.item_size(); item_id++) {
Token token(DocumentsClass, batch.item(item_id).title());
if (token.keyword.empty()) continue;

if (token.keyword.empty()) {
continue;
}

int token_index = phi_matrix->token_index(token);
if (token_index < 0) continue;
if (token_index < 0) {
continue;
}

cached_theta->add_item_title(batch.item(item_id).title());
cached_theta->add_item_id(batch.item(item_id).id());
::artm::FloatArray* item_weights = cached_theta->add_item_weights();
phi_matrix->get(token_index, &values);
for (int topic_index = 0; topic_index < phi_matrix->topic_size(); topic_index++)
for (int topic_index = 0; topic_index < phi_matrix->topic_size(); topic_index++) {
item_weights->add_value(values[topic_index]);
}
}

return cached_theta;
Expand All @@ -224,16 +248,18 @@ std::shared_ptr<ThetaMatrix> CacheManager::FindCacheEntry(const Batch& batch) co

std::shared_ptr<ThetaMatrix> CacheManager::FindCacheEntry(const std::string& batch_id) const {
std::shared_ptr<ThetaCacheEntry> retval = cache_.get(batch_id);
if (retval == nullptr)
if (retval == nullptr) {
return nullptr;
if (retval->filename().empty())
}
if (retval->filename().empty()) {
return retval->theta_matrix();
}

try {
std::shared_ptr<ThetaMatrix> copy(std::make_shared<ThetaMatrix>());
Helpers::LoadMessage(retval->filename(), copy.get());
return copy;
} catch(...) {
} catch (...) {
LOG(ERROR) << "Unable to reload cache for " << retval->filename();
}

Expand All @@ -249,9 +275,12 @@ void CacheManager::UpdateCacheEntry(const std::string& batch_id, const ThetaMatr
for (int i = 0; i < theta_matrix.item_title_size(); i++) {
Token token(DocumentsClass, theta_matrix.item_title(i));
int token_id = phi_matrix->token_index(token);
if (token_id < 0) token_id = mutable_phi_matrix->AddToken(token);
for (int topic_index = 0; topic_index < theta_matrix.topic_name_size(); topic_index++)
if (token_id < 0) {
token_id = mutable_phi_matrix->AddToken(token);
}
for (int topic_index = 0; topic_index < theta_matrix.topic_name_size(); topic_index++) {
mutable_phi_matrix->set(token_id, topic_index, theta_matrix.item_weights(i).value(topic_index));
}
}
return;
}
Expand All @@ -278,8 +307,9 @@ void CacheManager::UpdateCacheEntry(const std::string& batch_id, const ThetaMatr
void CacheManager::CopyFrom(const CacheManager& cache_manager) {
disk_path_ = cache_manager.disk_path_;
auto keys = cache_manager.cache_.keys();
for (auto key : keys)
for (const auto& key : keys) {
cache_.set(key, cache_manager.cache_.get(key));
}
}

} // namespace core
Expand Down

0 comments on commit a456174

Please sign in to comment.