Skip to content

Commit

Permalink
Fix bug in hartm (seed must be integer), and add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
ofrei committed Mar 1, 2018
1 parent db576e7 commit fbb1470
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/artm/hierarchy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def tmp_files_path(self, tmp_files_path):
# ========== METHODS ==========
def _get_seed(self, level_idx):
np.random.seed(self._seed)
return np.random.randint(10000, size=level_idx + 1)[-1]
return int(np.random.randint(10000, size=level_idx + 1)[-1])

def add_level(self, num_topics=None, topic_names=None, parent_level_weight=1):
"""
Expand Down
16 changes: 16 additions & 0 deletions python/tests/artm/test_hartm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ def test_func():

assert(level1.clone() is not None)
assert(hier.clone() is not None)

# test the same functionality with hARTM, and validate that resulting psi matrix is exactly the same
level1_plain = artm.ARTM(num_topics=num_topics_level0, num_document_passes=num_document_passes,
theta_columns_naming='title', seed=level0.seed)
level1_plain.initialize(dictionary=dictionary)
level1_plain.fit_offline(num_collection_passes=num_collection_passes, batch_vectorizer=batch_vectorizer)
level2_plain = artm.ARTM(num_topics=num_topics_level1, parent_model=level1_plain,
parent_model_weight=parent_level_weight, theta_columns_naming='title',
seed=level1.seed)
level2_plain.initialize(dictionary=dictionary)
level2_plain.regularizers.add(artm.HierarchySparsingThetaRegularizer(name='HierSp', tau=regularizer_tau))
level2_plain.fit_offline(num_collection_passes=num_collection_passes, batch_vectorizer=batch_vectorizer)
psi_plain = level2_plain.get_parent_psi()
max_diff = (psi_plain - psi).abs().max().max()
assert(max_diff < 1e-3)

finally:
shutil.rmtree(batches_folder)
shutil.rmtree(parent_batch_folder)
1 change: 1 addition & 0 deletions src/artm/core/check_messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ inline std::string DescribeMessage(const ::artm::InitializeModelArgs& message) {
ss << ", dictionary_name=" << message.dictionary_name();
}
ss << ", topic_name_size=" << message.topic_name_size();
ss << ", seed=" << message.seed();
return ss.str();
}

Expand Down

0 comments on commit fbb1470

Please sign in to comment.