Skip to content

Commit

Permalink
fix in topic selection regularizer (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
MelLain committed Jul 24, 2017
1 parent 436b29d commit 3f9a639
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 19 deletions.
1 change: 0 additions & 1 deletion python/artm/artm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def _topic_selection_regularizer_func(self, regularizers):
if no_score:
self._internal_topic_mass_score_name = 'ITMScore_{}'.format(str(uuid.uuid4()))
self.scores.add(TopicMassPhiScore(name=self._internal_topic_mass_score_name,
class_id='@default_class',
model_name=self.model_nwt)) # ugly hack!

if not self._synchronizations_processed or no_score:
Expand Down
32 changes: 29 additions & 3 deletions python/artm/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,11 @@ class TopicMassPhiScore(BaseScore):
_config_message = messages.TopicMassPhiScoreConfig
_type = const.ScoreType_TopicMassPhi

def __init__(self, name=None, class_id=None, topic_names=None, model_name=None, eps=None, config=None):
def __init__(self, name=None, class_ids=None, topic_names=None, model_name=None, eps=None, config=None):
"""
:param str name: the identifier of score, will be auto-generated if not specified
:param str class_id: class_id to score
:param class_ids: class_id to score, means that tokens of all class_ids will be used
:type class_ids: list of str
:param topic_names: list of names or single name of topic to regularize, will\
score all topics if empty or None
:type topic_names: list of str or str or None
Expand All @@ -668,7 +669,7 @@ def __init__(self, name=None, class_id=None, topic_names=None, model_name=None,
"""
BaseScore.__init__(self,
name=name,
class_id=class_id,
class_id=None,
topic_names=topic_names,
model_name=model_name,
config=config)
Expand All @@ -680,14 +681,39 @@ def __init__(self, name=None, class_id=None, topic_names=None, model_name=None,
elif config is not None and config.HasField('eps'):
self._eps = config.eps

self._class_ids = []
if class_ids is not None:
self._config.ClearField('class_id')
for class_id in class_ids:
self._config.class_id.append(class_id)
self._class_ids.append(class_id)
elif config is not None and len(config.class_id):
self._class_ids = [class_id for class_id in config.class_id]

@property
def eps(self):
return self._eps

@property
def class_ids(self):
return self._class_ids

@property
def class_id(self):
raise KeyError('No class_id parameter')

@eps.setter
def eps(self, eps):
_reconfigure_field(self, eps, 'eps')

@class_ids.setter
def class_ids(self, class_ids):
_reconfigure_field(self, class_ids, 'class_ids', 'class_id')

@class_id.setter
def class_id(self, class_id):
raise KeyError('No class_id parameter')


class ClassPrecisionScore(BaseScore):
_config_message = messages.ClassPrecisionScoreConfig
Expand Down
2 changes: 1 addition & 1 deletion src/artm/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ message TopicKernelScore {
// Represents a configuration of a topic part in nwt score
message TopicMassPhiScoreConfig {
optional float eps = 1 [default = 1e-37];
optional string class_id = 2 [default = "@default_class"];
repeated string class_id = 2;
repeated string topic_name = 3;
}

Expand Down
28 changes: 15 additions & 13 deletions src/artm/score/topic_mass_phi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,28 @@ std::shared_ptr<Score> TopicMassPhi::CalculateScore(const artm::core::PhiMatrix&
topics_to_score_size = config_.topic_name_size();
}

::artm::core::ClassId class_id = ::artm::core::DefaultClass;
if (config_.has_class_id())
class_id = config_.class_id();
bool use_all_classes = false;
if (config_.class_id_size() == 0) {
use_all_classes = true;
}

std::vector<double> topic_mass;
topic_mass.assign(topics_to_score_size, 0.0);
double denominator = 0.0;
double numerator = 0.0;

for (int token_index = 0; token_index < token_size; token_index++) {
if (p_wt.token(token_index).class_id == class_id) {
int real_topic_index = 0;
for (int topic_index = 0; topic_index < topic_size; ++topic_index) {
double value = p_wt.get(token_index, topic_index);
denominator += value;

if (topics_to_score[topic_index]) {
numerator += value;
topic_mass[real_topic_index++] += value;
}
if (!use_all_classes && !core::is_member(p_wt.token(token_index).class_id, config_.class_id()))
continue;

int real_topic_index = 0;
for (int topic_index = 0; topic_index < topic_size; ++topic_index) {
double value = p_wt.get(token_index, topic_index);
denominator += value;

if (topics_to_score[topic_index]) {
numerator += value;
topic_mass[real_topic_index++] += value;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/artm/score/topic_mass_phi.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Parameters:
- topic_name (names of topics from which top tokens need to be extracted)
- class_id (class_id to use, empty == DefaultClass)
- class_id (class_id to use, empty == all modalities)
- eps
*/
Expand Down

0 comments on commit 3f9a639

Please sign in to comment.