Skip to content

Commit

Permalink
ARTM.reshape() in python API
Browse files Browse the repository at this point in the history
  • Loading branch information
ofrei committed Mar 16, 2018
1 parent aab2984 commit e09634d
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion python/artm/artm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def reshape_topics(self, topic_names):
"""
:Description: update topic names of the model.
Adds, removes, and reorders columns of phi matrices
Adds, removes, or reorders columns of phi matrices
according to the new set of topic names.
New topics are initialized with zeros.
"""
Expand All @@ -1057,6 +1057,46 @@ def reshape_topics(self, topic_names):
self.master.reconfigure_topic_name(topic_names=topic_names)
self._topic_names = topic_names

def reshape_tokens(self, dictionary):
"""
:Description: update tokens of the model.
Adds, removes, or reorders the tokens of the model
according to a new dictionary.
This operation changes n_wt matrix, but has no immediate effect on the p_wt matrix.
You are expected to call ARTM.fit_offline() method
to re-calculate p_wt matrix for the new set of tokens.
"""
if not dictionary:
raise IOError('Dictionary must not be None')
dictionary_name = dictionary if isinstance(dictionary, str) else dictionary.name
self.master.initialize_model(model_name=self.model_nwt, dictionary_name=dictionary_name)

def reshape(self, topic_names=None, dictionary=None):
"""
:Description: change the shape of the model,
e.i. add/remove topics, or add/remove tokens.
:param topic_names: names of topics in model
:type topic_names: list of str
:param dictionary: dictionary that define new set of tokens
:type dictionary: str or reference to Dictionary object
Only one of the arguments (topic_names or dictionary) can be specified at a time.
For further description see methods
ARTM.reshape_topics() and ARTM.reshape_tokens().
"""
if topic_names and dictionary:
raise IOError('Only one of the arguments should be specified (topic_names or dictionary)')

if topic_names:
self.reshape_topics(topic_names)
return

if dictionary:
self.reshape_tokens(dictionary)
return

def __repr__(self):
num_tokens = next((x.num_tokens for x in self.info.model if x.name == self._model_pwt), None)
class_ids = ', class_ids={0}'.format(list(self.class_ids.keys())) if self.class_ids else ''
Expand Down

0 comments on commit e09634d

Please sign in to comment.