Skip to content

Commit

Permalink
Fix bigartm#758 ARTM.get_phi_dense() to extract phi without pandas.Da…
Browse files Browse the repository at this point in the history
…taFrame
  • Loading branch information
ofrei committed Feb 13, 2017
1 parent db68c4a commit babee64
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
3 changes: 3 additions & 0 deletions docs/release_notes/python.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ v0.8.3

* Enable copying of ARTM, LDA, hARTM and ARTM_Level objects with ``clone()`` method and ``copy.deepcopy(obj)``.
* Experimental support for import/export/editing of theta matrices; for more details see python reference of ``ARTM.__init__(ptd_name='ptd')``.
* ``ARTM.get_phi_dense()`` method to extract phi matrix without pandas.DataFrame, see `#758 <https://github.com/bigartm/bigartm/issues/758>`_.
* Bug fix in ``ARTM.get_phi_sparse()`` to return tokens as rows, and topic names as columns (previously it was the opposite way)


v0.8.2
------
Expand Down
38 changes: 30 additions & 8 deletions python/artm/artm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,9 @@ def load(self, filename, model_name='p_wt'):
self._num_online_processed_batches = 0
self._phi_cached = None

def get_phi(self, topic_names=None, class_ids=None, model_name=None):
def get_phi_dense(self, topic_names=None, class_ids=None, model_name=None):
"""
:Description: get custom Phi matrix of model. The extraction of the\
whole Phi matrix expects ARTM.phi_ call.
:Description: get phi matrix in dense format
:param topic_names: list with topics or single topic to extract, None value means all topics
:type topic_names: list of str or str or None
Expand All @@ -657,10 +656,10 @@ def get_phi(self, topic_names=None, class_ids=None, model_name=None):
reasonable to extract unnormalized counters
:return:
* pandas.DataFrame: (data, columns, rows), where:
* columns --- the names of topics in topic model;
* a 3-tuple of (data, rows, columns), where
* data --- numpy.ndarray with Phi data (i.e., p(w|t) values)
* rows --- the tokens of topic model;
* data --- content of Phi matrix.
* columns --- the names of topics in topic model;
"""
if not self._initialized:
raise RuntimeError('Model does not exist yet. Use ARTM.initialize()/ARTM.fit_*()')
Expand All @@ -682,6 +681,29 @@ def get_phi(self, topic_names=None, class_ids=None, model_name=None):
if class_ids is None or class_id in class_ids]
topic_names = [topic_name for topic_name in topics_and_tokens_info.topic_name
if topic_names is None or topic_name in topic_names]
return nd_array, tokens, topic_names

def get_phi(self, topic_names=None, class_ids=None, model_name=None):
"""
:Description: get custom Phi matrix of model. The extraction of the\
whole Phi matrix expects ARTM.phi_ call.
:param topic_names: list with topics or single topic to extract, None value means all topics
:type topic_names: list of str or str or None
:param class_ids: list with class_ids or single class_id to extract, None means all class ids
:type class_ids: list of str or str or None
:param str model_name: self.model_pwt by default, self.model_nwt is also\
reasonable to extract unnormalized counters
:return:
* pandas.DataFrame: (data, columns, rows), where:
* columns --- the names of topics in topic model;
* rows --- the tokens of topic model;
* data --- content of Phi matrix.
"""
(nd_array, tokens, topic_names) = self.get_phi_dense(topic_names=topic_names,
class_ids=class_ids,
model_name=model_name)
phi_data_frame = DataFrame(data=nd_array,
columns=topic_names,
index=tokens)
Expand Down Expand Up @@ -744,8 +766,8 @@ def get_phi_sparse(self, topic_names=None, class_ids=None, model_name=None, eps=
# Columns correspond to topics; get topic names from topic_model.topic_name
data = sparse.csr_matrix((data, (row_ind, col_ind)),
shape=(len(topic_model.token), len(topic_model.topic_name)))
rows = list(topic_model.topic_name)
columns = list(topic_model.token)
columns = list(topic_model.topic_name)
rows = list(topic_model.token)
return data, rows, columns

def get_theta(self, topic_names=None):
Expand Down

0 comments on commit babee64

Please sign in to comment.