Skip to content

Commit

Permalink
Merge pull request #96 from sashafrey/master
Browse files Browse the repository at this point in the history
Introduce SynchronizeModelArgs.apply_weight parameter
  • Loading branch information
bigartm committed Jan 2, 2015
2 parents d51ffff + 4dfe897 commit d5b4d32
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 43 deletions.
32 changes: 27 additions & 5 deletions docs/ref/messages.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1657,6 +1657,7 @@ Represents an argument of synchronize model operation.
optional string model_name = 1;
optional float decay_weight = 2 [default = 1.0];
optional bool invoke_regularizers = 3 [default = true];
optional float apply_weight = 4 [default = 1.0];
}

.. attribute:: SynchronizeModelArgs.model_name
Expand All @@ -1666,18 +1667,39 @@ Represents an argument of synchronize model operation.

.. attribute:: SynchronizeModelArgs.decay_weight

The decay weight to apply to current version of the topic model.
Expected values of this parameter are between 0.0 and 1.0.
The decay weight and :attr:`apply_weight` define how to combine existing topic model with all increments,
calculated since the last :c:func:`ArtmSynchronizeModel`.
This is best described by the following formula:

Decay weight 0.0 states that the previous Phi matrix of the topic model will be disregarded completely,
and the new Phi matrix will be formed based on new increments gathered since last model synchronize.
``n_wt_new = n_wt_old * decay_weight + n_wt_inc * apply_weight``,

Decay weight 1.0 states that new increments will be appended to the current Phi matrix without any decay.
where
``n_wt_old`` describe current topic model,
``n_wt_inc`` describe increment calculated since last :c:func:`ArtmSynchronizeModel`,
``n_wt_new`` define the resulting topic model.

Expected values of both parameters are between 0.0 and 1.0. Here are some examples:

* Combination of *decay_weight=0.0* and *apply_weight=1.0* states that the previous Phi matrix of the topic model will be disregarded completely,
and the new Phi matrix will be formed based on new increments gathered since last model synchronize.

* Combination of *decay_weight=1.0* and *apply_weight=1.0* states that new increments will be appended to the current Phi matrix without any decay.

* Combination of *decay_weight=1.0* and *apply_weight=0.0* states that new increments will be disregarded, and current Phi matrix will stay unchanged.

* To reproduce Online variational Bayes for LDA algorighm by Matthew D. Hoffman set
*decay_weight = 1 - rho* and *apply_weight = rho*, where parameter rho is defined as *rho = exp(tau + t, -kappa)*.
See `Online Learning for Latent Dirichlet Allocation`_ for further details.

.. attribute:: SynchronizeModelArgs.apply_weight

See :attr:`decay_weight` for the description.

.. attribute:: SynchronizeModelArgs.invoke_regularizers

A flag indicating whether to invoke all phi-regularizers.

.. _Online Learning for Latent Dirichlet Allocation: https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf

.. _InitializeModelArgs:

Expand Down
19 changes: 10 additions & 9 deletions src/artm/core/merger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Merger::~Merger() {

void Merger::DisposeModel(ModelName model_name) {
topic_model_.erase(model_name);
internal_task_queue_.push(MergerTask(kDisposeModel, model_name, 0.0f, false, nullptr));
internal_task_queue_.push(MergerTask(kDisposeModel, model_name, 0.0f, 0.0f, false, nullptr));
}

void Merger::CreateOrReconfigureModel(const ModelConfig& model) {
Expand Down Expand Up @@ -124,25 +124,26 @@ void Merger::OverwriteTopicModel(const ::artm::TopicModel& topic_model) {
void Merger::ForceSynchronizeModel(const SynchronizeModelArgs& args) {
rpcz::sync_event sync_event;
internal_task_queue_.push(MergerTask(kForceSynchronizeTopicModel, args.model_name(),
args.decay_weight(), args.invoke_regularizers(), &sync_event));
args.decay_weight(), args.apply_weight(), args.invoke_regularizers(),
&sync_event));
sync_event.wait();
}

void Merger::ForceResetScores(ModelName model_name) {
rpcz::sync_event sync_event;
internal_task_queue_.push(MergerTask(kForceResetScores, model_name, 0.0f, false, &sync_event));
internal_task_queue_.push(MergerTask(kForceResetScores, model_name, 0.0f, 0.0f, false, &sync_event));
sync_event.wait();
}

void Merger::ForcePullTopicModel() {
rpcz::sync_event sync_event;
internal_task_queue_.push(MergerTask(kForcePullTopicModel, ModelName(), 0.0f, false, &sync_event));
internal_task_queue_.push(MergerTask(kForcePullTopicModel, ModelName(), 0.0f, 0.0f, false, &sync_event));
sync_event.wait();
}

void Merger::ForcePushTopicModelIncrement() {
rpcz::sync_event sync_event;
internal_task_queue_.push(MergerTask(kForcePushTopicModelIncrement, ModelName(), 0.0f, false, &sync_event));
internal_task_queue_.push(MergerTask(kForcePushTopicModelIncrement, ModelName(), 0.0f, 0.0f, false, &sync_event));
sync_event.wait();
}

Expand Down Expand Up @@ -213,7 +214,7 @@ void Merger::ThreadFunction() {
break;
case kForceSynchronizeTopicModel:
SynchronizeModel(merger_task.model_name, merger_task.decay_weight,
merger_task.invoke_regularizers);
merger_task.apply_weight, merger_task.invoke_regularizers);
break;
case kForceResetScores:
ResetScores(merger_task.model_name);
Expand Down Expand Up @@ -253,7 +254,7 @@ void Merger::ThreadFunction() {
iter = topic_model_inc_.find(model_name);
}

iter->second->ApplyDiff(*model_increment);
iter->second->ApplyDiff(*model_increment, 1.0f);
for (int score_index = 0;
score_index < model_increment->score_name_size();
++score_index) {
Expand Down Expand Up @@ -452,7 +453,7 @@ bool Merger::RequestScore(const ModelName& model_name, const ScoreName& score_na
}

void Merger::SynchronizeModel(const ModelName& model_name, float decay_weight,
bool invoke_regularizers) {
float apply_weight, bool invoke_regularizers) {
if (master_component_service_ != nullptr) {
return; // no-op in network modus operandi
}
Expand Down Expand Up @@ -489,7 +490,7 @@ void Merger::SynchronizeModel(const ModelName& model_name, float decay_weight,
target_model_config_.set(name, nullptr);
// Apply increment
if (inc_ttm != topic_model_inc_.end())
new_ttm->ApplyDiff(*inc_ttm->second);
new_ttm->ApplyDiff(*inc_ttm->second, apply_weight);

if (invoke_regularizers)
InvokePhiRegularizers(new_ttm.get());
Expand Down
9 changes: 6 additions & 3 deletions src/artm/core/merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@ class Merger : boost::noncopyable {
MergerTask() {}

MergerTask(MergerTaskType _task_type, ModelName _model_name, float _decay_weight,
bool _invoke_regularizers, rpcz::sync_event* _sync_event)
float _apply_weight, bool _invoke_regularizers, rpcz::sync_event* _sync_event)
: task_type(_task_type), model_name(_model_name), decay_weight(_decay_weight),
invoke_regularizers(_invoke_regularizers), sync_event(_sync_event) {}
apply_weight(_apply_weight), invoke_regularizers(_invoke_regularizers),
sync_event(_sync_event) {}

MergerTaskType task_type;
ModelName model_name;
float decay_weight;
float apply_weight;
bool invoke_regularizers;
rpcz::sync_event* sync_event;
};
Expand All @@ -127,7 +129,8 @@ class Merger : boost::noncopyable {
boost::thread thread_;
void ThreadFunction();

void SynchronizeModel(const ModelName& model_name, float decay_weight, bool invoke_regularizers);
void SynchronizeModel(const ModelName& model_name, float decay_weight, float apply_weight,
bool invoke_regularizers);
void PullTopicModel();
void PushTopicModelIncrement();
void InvokePhiRegularizers(::artm::core::TopicModel* topic_model);
Expand Down
10 changes: 5 additions & 5 deletions src/artm/core/topic_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ TopicModel::TopicModel(const ::artm::core::ModelIncrement& model_increment)
for (auto iter = topic_name.begin(); iter != topic_name.end(); ++iter) {
topic_name_.push_back(*iter);
}
ApplyDiff(model_increment);
ApplyDiff(model_increment, 1.0f);
}

TopicModel::~TopicModel() {
Expand Down Expand Up @@ -221,7 +221,7 @@ void TopicModel::RetrieveModelIncrement(::artm::core::ModelIncrement* diff) cons
}
}

void TopicModel::ApplyDiff(const ::artm::core::ModelIncrement& diff) {
void TopicModel::ApplyDiff(const ::artm::core::ModelIncrement& diff, float apply_weight) {
int diff_token_size = diff.token_size();
if ((diff.class_id_size() != diff_token_size) ||
(diff.operation_type_size() != diff_token_size) ||
Expand Down Expand Up @@ -255,7 +255,7 @@ void TopicModel::ApplyDiff(const ::artm::core::ModelIncrement& diff) {
current_token_id = this->AddToken(token, false);
target = n_wt_[current_token_id];
for (int topic_index = 0; topic_index < topics_count; ++topic_index)
target[topic_index] += counters.value(topic_index);
target[topic_index] += apply_weight * counters.value(topic_index);
break;

case ModelIncrement_OperationType_OverwriteValue:
Expand All @@ -282,7 +282,7 @@ void TopicModel::ApplyDiff(const ::artm::core::ModelIncrement& diff) {
}
}

void TopicModel::ApplyDiff(const ::artm::core::TopicModel& diff) {
void TopicModel::ApplyDiff(const ::artm::core::TopicModel& diff, float apply_weight) {
int topics_count = this->topic_size();

for (int token_index = 0;
Expand All @@ -295,7 +295,7 @@ void TopicModel::ApplyDiff(const ::artm::core::TopicModel& diff) {
}

for (int topic_index = 0; topic_index < topics_count; ++topic_index) {
this->IncreaseTokenWeight(current_token, topic_index, counters[topic_index]);
this->IncreaseTokenWeight(current_token, topic_index, apply_weight * counters[topic_index]);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/artm/core/topic_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ class TopicModel : public Regularizable {
void RetrieveModelIncrement(::artm::core::ModelIncrement* diff) const;

// Applies model increment to this TopicModel.
void ApplyDiff(const ::artm::core::ModelIncrement& diff);
void ApplyDiff(const ::artm::core::TopicModel& diff);
void ApplyDiff(const ::artm::core::ModelIncrement& diff, float apply_weight);
void ApplyDiff(const ::artm::core::TopicModel& diff, float apply_weight);

void RemoveToken(const Token& token);
int AddToken(const Token& token, bool random_init = true);
Expand Down
60 changes: 50 additions & 10 deletions src/artm/messages.pb.cc

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d5b4d32

Please sign in to comment.