Skip to content

Commit

Permalink
Python 3.x support for hierarchical topic models (#833)
Browse files Browse the repository at this point in the history
* * Python 3.x support for hierarchical topic models
* Fix bug in hARTM.load(...) -- models saved with hARTM.save(...) didn't load before

* Re-structure imports in hierarchy_utils.py
  • Loading branch information
AVBelyy authored and JeanPaulShapo committed Aug 11, 2017
1 parent 527d30e commit 350d943
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions python/artm/hierarchy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import uuid
import copy
import numpy as np
import pandas
import os.path
import os
import pickle
import glob
import warnings
from pandas import DataFrame
import pandas

from six.moves import range


class hARTM(object):
Expand Down Expand Up @@ -287,7 +288,7 @@ def del_level(self, level_idx):
if level_idx == -1:
del self._levels[-1]
return
for _ in xrange(level_idx, len(self._levels)):
for _ in range(level_idx, len(self._levels)):
del self._levels[-1]

def get_level(self, level_idx):
Expand Down Expand Up @@ -363,7 +364,7 @@ def save(self, path):
level.save(os.path.join(path, "level" +
str(level_idx) + "_pwt.model"), model_name="p_wt")
info = {"parent_level_weight": [
level.phi_batch_weight for level in self._levels[1:]]}
level.parent_level_weight for level in self._levels[1:]]}
with open(os.path.join(path, "info.dump"), "wb") as fout:
pickle.dump(info, fout)

Expand All @@ -386,14 +387,14 @@ def load(self, path):
info_filename = glob.glob(os.path.join(path, "info.dump"))
if len(info_filename) != 1:
raise ValueError("Given path is not hARTM safe")
with open(info_filename[0]) as fin:
with open(info_filename[0], "rb") as fin:
info = pickle.load(fin)
model_filenames = glob.glob(os.path.join(path, "*.model"))
if len({len(info["parent_level_weight"]) + 1, len(model_filenames) / 2}) > 1:
raise ValueError("Given path is not hARTM safe")
self._levels = []
sorted_model_filenames = sorted(model_filenames)
for level_idx in xrange(len(model_filenames) / 2):
for level_idx in range(len(model_filenames) // 2):
if not len(self._levels):
model = artm.ARTM(num_topics=1,
seed=self._get_seed(level_idx),
Expand All @@ -407,10 +408,10 @@ def load(self, path):
num_topics=1,
seed=self._get_seed(level_idx),
**self._common_models_args)
filename = sorted_model_filenames[2 * level_idx]
model.load(filename, "n_wt")
filename = sorted_model_filenames[2 * level_idx + 1]
model.load(filename, "p_wt")
filename = sorted_model_filenames[2 * level_idx]
model.load(filename, "n_wt")
config = model.master._config
config.opt_for_avx = False
model.master._lib.ArtmReconfigureMasterModel(
Expand Down Expand Up @@ -650,9 +651,9 @@ def get_theta(self, topic_names=None):
_, nd_array = self.master.get_theta_matrix(topic_names=use_topic_names)

titles_list = [item_title for item_title in theta_info.item_title]
theta_data_frame = DataFrame(data=nd_array.transpose(),
columns=titles_list,
index=use_topic_names)
theta_data_frame = pandas.DataFrame(data=nd_array.transpose(),
columns=titles_list,
index=use_topic_names)
item_idxs = np.logical_not(
theta_data_frame.columns.isin(self._parent_model.topic_names))
theta_data_frame = theta_data_frame.drop(
Expand Down

0 comments on commit 350d943

Please sign in to comment.