Skip to content

Commit

Permalink
Merge pull request #106 from sashafrey/master
Browse files Browse the repository at this point in the history
Calculate scores in GetThetaMatrix
  • Loading branch information
bigartm committed Jan 18, 2015
2 parents 239a854 + caefa58 commit d4bf649
Show file tree
Hide file tree
Showing 32 changed files with 1,008 additions and 844 deletions.
8 changes: 5 additions & 3 deletions docs/ref/c_interface.txt
Original file line number Diff line number Diff line change
Expand Up @@ -522,16 +522,18 @@ ArtmRequestRegularizerState
ArtmRequestScore
----------------

.. c:function:: int ArtmRequestScore(int master_id, const char* model_name, const char* score_name)
.. c:function:: int ArtmRequestScore(int master_id, int length, const char* get_score_args)

Request the result of score calculation.

:param int master_id: The ID of a master component or a master proxy,
returned by either :c:func:`ArtmCreateMasterComponent` or :c:func:`ArtmCreateMasterProxy` method.

:param const_char* model_name: A string identified of the model.
:param const_char*: get_score_args:
Serialized :ref:`GetScoreValueArgs` message,
describing the arguments of this operation.

:param const_char* score_name: A string identified of the score.
:param int length: The length in bytes of the *get_score_args* message.

:return: In case of success, returns the length in bytes of a buffer that should be allocated on callers site
and then passed to :c:func:`ArtmCopyRequestResult` method.
Expand Down
30 changes: 30 additions & 0 deletions docs/ref/messages.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1812,3 +1812,33 @@ Represents an argument of get theta matrix operation.
Note that this field acts similar to :attr:`GetThetaMatrixArgs.topic_name`.
It is not allowed to specify both *topic_index* and *topic_name* at the same time.
The recommendation is to use *topic_name*.


.. _GetScoreValueArgs:

GetScoreValueArgs
=================

Represents an argument of get score operation.

.. code-block:: bash

message GetScoreValueArgs {
optional string model_name = 1;
optional string score_name = 2;
optional Batch batch = 3;
}

.. attribute:: GetScore ValueArgs.model_name

The name of the model to retrieved score for.

.. attribute:: GetScoreValueArgs.score_name

The name of the score to retrieved.

.. attribute:: GetScoreValueArgs.batch

The :ref:`Batch` to calculate the score.
This option is only applicable to cumulative scores.
When not provided the score will be reported for all batches processed since last :c:func:`ArtmInvokeIteration`.
3 changes: 2 additions & 1 deletion docs/ref/python_interface.txt
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,10 @@ Score

Returns the string name of the score.

.. py:method:: GetValue(model)
.. py:method:: GetValue(model = None, batch = None)

Retrieves the score for a specific model.
For cumulative scores such as Perplexity of ThetaSparsity score it is possible to use *batch* argument.

Dictionary
==========
Expand Down
6 changes: 4 additions & 2 deletions src/artm/c_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,12 @@ int ArtmRequestRegularizerState(int master_id, const char* regularizer_name) {
} CATCH_EXCEPTIONS;
}

int ArtmRequestScore(int master_id, const char* model_name, const char* score_name) {
int ArtmRequestScore(int master_id, int length, const char* get_score_args) {
try {
::artm::ScoreData score_data;
master_component(master_id)->RequestScore(model_name, score_name, &score_data);
artm::GetScoreValueArgs args;
ParseFromArray(get_score_args, length, &args);
master_component(master_id)->RequestScore(args, &score_data);
score_data.SerializeToString(last_message());
return last_message()->size();
} CATCH_EXCEPTIONS;
Expand Down
2 changes: 1 addition & 1 deletion src/artm/c_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ extern "C" {
DLL_PUBLIC int ArtmRequestThetaMatrix(int master_id, int length, const char* get_theta_args);
DLL_PUBLIC int ArtmRequestTopicModel(int master_id, int length, const char* get_model_args);
DLL_PUBLIC int ArtmRequestRegularizerState(int master_id, const char* regularizer_name);
DLL_PUBLIC int ArtmRequestScore(int master_id, const char* model_name, const char* score_name);
DLL_PUBLIC int ArtmRequestScore(int master_id, int length, const char* get_score_args);
DLL_PUBLIC int ArtmRequestParseCollection(int length, const char* collection_parser_config);
DLL_PUBLIC int ArtmRequestLoadDictionary(const char* filename);
DLL_PUBLIC int ArtmRequestLoadBatch(const char* filename);
Expand Down
2 changes: 1 addition & 1 deletion src/artm/core/instance_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void InstanceSchema::clear_score_calculator(const ScoreName& name) {
}

std::shared_ptr<ScoreCalculatorInterface> InstanceSchema::score_calculator(
const ScoreName& name) {
const ScoreName& name) const {
auto iter = score_calculators_.find(name);
if (iter != score_calculators_.end()) {
return iter->second;
Expand Down
2 changes: 1 addition & 1 deletion src/artm/core/instance_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class InstanceSchema {
bool has_regularizer(const std::string& name) const;
void clear_regularizer(const std::string& name);

std::shared_ptr<ScoreCalculatorInterface> score_calculator(const ScoreName& name);
std::shared_ptr<ScoreCalculatorInterface> score_calculator(const ScoreName& name) const;
void set_score_calculator(const ScoreName& name,
const std::shared_ptr<ScoreCalculatorInterface>& score_calculator);
bool has_score_calculator(const ScoreName& name) const;
Expand Down

0 comments on commit d4bf649

Please sign in to comment.