Skip to content

Commit

Permalink
fix loading new models (#897)
Browse files Browse the repository at this point in the history
  • Loading branch information
MelLain committed Mar 28, 2018
1 parent a818ff4 commit a0c3129
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions python/artm/artm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,12 @@ def load(self, filename, model_name='p_wt'):
self._topic_names = [topic_name for topic_name in topics_and_tokens_info.topic_name]

transaction_types = {}
for transaction_type in topics_and_tokens_info.transaction_type:
transaction_types[transaction_type] = 1.0
if hasattr(topics_and_tokens_info, 'transaction_type'):
for transaction_type in topics_and_tokens_info.transaction_type:
transaction_types[transaction_type] = 1.0
else:
for class_id in topics_and_tokens_info.class_id:
transaction_types[class_id] = 1.0
self._transaction_type = transaction_type

# Remove all info about previous iterations
Expand Down Expand Up @@ -1297,9 +1301,10 @@ def load_artm_model(data_path):
raise RuntimeError('File was generated with newer version of library ({}). '.format(params['version']) +
'Current library version is {}'.format(version()))

tt = params['transaction_types'] if 'transaction_types' in params else params['class_ids']
model = ARTM(topic_names=params['topic_names'],
num_processors=params['num_processors'],
transaction_types=params['transaction_types'],
transaction_types=tt,
num_document_passes=params['num_document_passes'],
reuse_theta=params['reuse_theta'],
cache_theta=params['cache_theta'],
Expand Down

0 comments on commit a0c3129

Please sign in to comment.