Skip to content

Commit

Permalink
Merge pull request #98 from sashafrey/master
Browse files Browse the repository at this point in the history
#93 Three operations required to set phi matrix, plus bug fix
  • Loading branch information
bigartm committed Jan 3, 2015
2 parents 8bae77f + 865e8fe commit 7d6c80c
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 20 deletions.
20 changes: 13 additions & 7 deletions docs/ref/cpp_interface.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,30 @@ Model
Returns mutable configuration of the model.
Remember to call :cpp:func:`Reconfigure` to propagate your changes to the model.

.. cpp:function:: void Overwrite(const TopicModel& topic_model)
.. cpp:function:: void Overwrite(const TopicModel& topic_model, bool commit = true)

Updates the model with new Phi matrix.
This can be used to provide an explicit initial approximation, or to adjust
the model in between iterations.
Remember to call :cpp:func:`Synchronize` after overwrite to propagate your change.
Updates the model with new Phi matrix, defined by *topic_model*.
This operation can be used to provide an explicit initial approximation of the topic model, or to adjust the model in between iterations.

Depending on the *commit* flag the change can be applied immediately (*commit = true*) or queued (*commit = false*).
The default setting is to use *commit = true*.
You may want to use *commit = false* if your model is too big to be updated in a single protobuf message.
In this case you should split your model into parts, each part containing subset of all tokens,
and then submit each part in separate Overwrite operation with *commit = false*.
After that remember to call :cpp:func:`MasterComponent::WaitIdle` and :cpp:func:`Synchronize` to propagate your change.

.. cpp:function:: void Initialize(const Dictionary& dictionary)

Initialize topic model based on the :cpp:class:`Dictionary`.
Each token from the dictionary will be included in the model with randomly generated weight.

.. cpp:function:: void Synchronize(double decay, bool invoke_regularizers)
.. cpp:function:: void Synchronize(double decay_weight, double apply_weight, bool invoke_regularizers)

Synchronize the model.

This operation updates the Phi matrix of the topic model with all model increments, collected since the last call to :cpp:func:`Synchronize` method.
The weights in the Phi matrix are decreased according to *decay_weight* (refer to :attr:`SynchronizeModelArgs.decay_weight` for more details).
The weights in the Phi matrix are set according to *decay_weight* and *apply_weight* values
(refer to :attr:`SynchronizeModelArgs.decay_weight` for more details).
Depending on *invoke_regularizers* parameter this operation may also invoke all regularizers.

Remember to call :cpp:func:`Model::Synchronize` operation every time after calling :cpp:func:`MasterComponent::WaitIdle`.
Expand Down
13 changes: 10 additions & 3 deletions docs/ref/python_interface.txt
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,17 @@ Model

*dictionary* must be an instance of :py:class:`Dictionary` class.

.. py:method:: Overwrite(topic_model)
.. py:method:: Overwrite(topic_model, commit = True)

Schedules an update of the current Phi matrix with new values, defined by *topic_model* (:ref:`TopicModel`).
To apply the update you must call :py:meth:`MasterComponent.WaitIdle` and then :py:meth:`Model.Synchronize`.
Updates the model with new Phi matrix, defined by *topic_model* (:ref:`TopicModel`).
This operation can be used to provide an explicit initial approximation of the topic model, or to adjust the model in between iterations.

Depending on the *commit* flag the change can be applied immediately (*commit = true*) or queued (*commit = false*).
The default setting is to use *commit = true*.
You may want to use *commit = false* if your model is too big to be updated in a single protobuf message.
In this case you should split your model into parts, each part containing subset of all tokens,
and then submit each part in separate Overwrite operation with *commit = false*.
After that remember to call :py:meth:`MasterComponent.WaitIdle` and :py:meth:`Model.Synchronize` to propagate your change.

.. py:method:: Enable()

Expand Down
9 changes: 8 additions & 1 deletion src/artm/core/merger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,14 @@ void Merger::SynchronizeModel(const ModelName& model_name, float decay_weight,
return; // no-op in network modus operandi
}

CuckooWatch cuckoo("Merger::SynchronizeModel");
std::stringstream ss;
ss << "Merger::SynchronizeModel (" << model_name
<< ", decay_weight=" << decay_weight
<< ", apply_weight=" << apply_weight
<< ", invoke_regularizers=" << (invoke_regularizers ? "true" : "false")
<< ")";

CuckooWatch cuckoo(ss.str());
auto model_names = topic_model_.keys();
if (!model_name.empty()) {
model_names.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/artm/core/topic_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ TopicModel::TopicModel(const TopicModel& rhs, float decay,
}
}

for (size_t token_id = 0; token_id < token_size(); token_id++) {
for (size_t token_id = 0; token_id < rhs.token_size(); token_id++) {
AddToken(rhs.token(token_id), false);
auto iter = rhs.GetTopicWeightIterator(token_id);
int topic_index = 0;
Expand Down
13 changes: 11 additions & 2 deletions src/artm/cpp_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,19 @@ void Model::Reconfigure(const ModelConfig& config) {
}

void Model::Overwrite(const TopicModel& topic_model) {
Overwrite(topic_model, true);
}

void Model::Overwrite(const TopicModel& topic_model, bool commit) {
std::string blob;
TopicModel topic_model_copy(topic_model);
topic_model_copy.set_name(name());
topic_model_copy.SerializeToString(&blob);
HandleErrorCode(ArtmOverwriteTopicModel(master_id(), blob.size(), StringAsArray(&blob)));
if (commit) {
HandleErrorCode(ArtmWaitIdle(master_id(), -1));
Synchronize(0.0, 1.0, false);
}
}

void Model::Enable() {
Expand All @@ -244,13 +252,14 @@ void Model::Initialize(const Dictionary& dictionary) {
}

void Model::Synchronize(double decay) {
Synchronize(decay, true);
Synchronize(decay, 1.0, true);
}

void Model::Synchronize(double decay, bool invoke_regularizers) {
void Model::Synchronize(double decay, double apply, bool invoke_regularizers) {
SynchronizeModelArgs args;
args.set_model_name(this->name());
args.set_decay_weight(static_cast<float>(decay));
args.set_apply_weight(static_cast<float>(apply));
args.set_invoke_regularizers(invoke_regularizers);
Synchronize(args);
}
Expand Down
3 changes: 2 additions & 1 deletion src/artm/cpp_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ class Model {

void Reconfigure(const ModelConfig& config);
void Overwrite(const TopicModel& topic_model);
void Overwrite(const TopicModel& topic_model, bool commit);
void Initialize(const Dictionary& dictionary);
void Enable();
void Disable();
void Synchronize(double decay);
void Synchronize(double decay, bool invoke_regularizers);
void Synchronize(double decay, double apply, bool invoke_regularizers);
void Synchronize(const SynchronizeModelArgs& args);

int master_id() const { return master_id_; }
Expand Down
2 changes: 0 additions & 2 deletions src/artm_tests/cpp_interface_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ void BasicTest(bool is_network_mode, bool is_proxy_mode, bool online_processing)
}

model2.Overwrite(new_topic_model);
master_component->WaitIdle();
model2.Synchronize(0.0, false);

artm::GetTopicModelArgs args;
args.set_model_name(model2.name());
Expand Down
2 changes: 0 additions & 2 deletions src/artm_tests/multiple_classes_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ TEST(MultipleClasses, BasicTest) {
artm::Model model2(master_component, model_config2);
artm::Model model3(master_component, model_config3);
model1.Overwrite(initial_model); model2.Overwrite(initial_model); model3.Overwrite(initial_model);
master_component.WaitIdle();
model1.Synchronize(0.0); model2.Synchronize(0.0); model3.Synchronize(0.0);

for (int iDoc = 0; iDoc < nDocs; iDoc++) {
artm::Item* item = batch.add_item();
Expand Down
7 changes: 6 additions & 1 deletion src/python/artm/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def Initialize(self, dictionary):
HandleErrorCode(self.lib_,
self.lib_.ArtmInitializeModel(self.master_id_, len(blob), blob_p))

def Overwrite(self, topic_model):
def Overwrite(self, topic_model, commit = True):
copy_ = messages_pb2.TopicModel()
copy_.CopyFrom(topic_model)
copy_.name = self.name()
Expand All @@ -482,6 +482,11 @@ def Overwrite(self, topic_model):
HandleErrorCode(self.lib_,
self.lib_.ArtmOverwriteTopicModel(self.master_id_, len(blob), blob_p))

if commit:
timeout = -1
HandleErrorCode(self.lib_, self.lib_.ArtmWaitIdle(self.master_id_, timeout))
self.Synchronize(decay_weight=0.0, apply_weight=1.0, invoke_regularizers=False)

def Enable(self):
config_copy_ = messages_pb2.ModelConfig()
config_copy_.CopyFrom(self.config_)
Expand Down

0 comments on commit 7d6c80c

Please sign in to comment.