Skip to content

Commit

Permalink
Model Saving From Python API (#816)
Browse files Browse the repository at this point in the history
Model Saving From Python API
+ add score_tracker import/export in cpp code
+ fix config issues in scores anr regularizers in python api
  • Loading branch information
MelLain committed Jul 22, 2017
1 parent 82ac025 commit 00c0a43
Show file tree
Hide file tree
Showing 19 changed files with 891 additions and 96 deletions.
16 changes: 6 additions & 10 deletions docs/tutorials/python_userguide/different.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,23 @@ This call has one feature, it rewrites the old dictionary with new one. So if yo

* **Saving/loading model**:

Now let's study saving the model to disk.

It's important to understand that the model contains two matrices: :math:`\Phi` (or :math:`p_{wt}`) and :math:`n_{wt}`. To make model be loadable without loses you need to save both these matrices. The current library version can save only one matrix per method call, so you will need two calls:
Now let's study saving the model to disk. To save your model on disk from Python you can use ``artm.ARTM.dump_artm_model`` method:

.. code-block:: python

model.save(filename='saved_p_wt', model_name='p_wt')
model.save(filename='saved_n_wt', model_name='n_wt')
model.dump_artm_model('my_model_folder')

The model will be saved in binary format. To use it later you need to load it's matrices back:
The model will be saved in binary format, its parameters will be duplicated also in json file. To use it later you need to load it back via ``artm.load_artm_model`` function:

.. code-block:: python

model.load(filename='saved_p_wt', model_name='p_wt')
model.load(filename='saved_n_wt', model_name='n_wt')
model = artm.load_artm_model('my_model_folder')

.. note::

The model after loading will only contain :math:`\Phi` and :math:`n_{wt}` matrices and some associated information (like number of topics, their names, the names of the modalities (without weights!) and some other data). So you need to restore all necessary scores, regularizers, modality weights and all important parameters, like ``cache_theta``.
To use these methods correctly you should either set ``cache_theta`` flag to False (and don't use Theta matrix) or set ``theta_name`` parameter (that will store Theta as Phi-like object).

You can use ``save/load`` methods pair in case of long fitting, when restoring parameters is much more easier than model re-fitting.
You can use pair of ``dump_artm_model/load_artm_model`` functions in case of long fitting, when restoring parameters is much more easier than model re-fitting.

* **Creating batches manually**:

Expand Down
2 changes: 1 addition & 1 deletion python/artm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright 2017, Additive Regularization of Topic Models.

from .artm_model import ARTM, version
from .artm_model import ARTM, version, load_artm_model
from .lda_model import LDA
from .hierarchy_utils import hARTM
from .dictionary import *
Expand Down
201 changes: 196 additions & 5 deletions python/artm/artm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import shutil
import tempfile
import numpy
import datetime
import json
import pickle

from pandas import DataFrame
from six import iteritems, string_types
Expand All @@ -21,7 +24,9 @@
from . import master_component as mc

from .regularizers import Regularizers
from .scores import Scores, TopicMassPhiScore # temp
from .regularizers import *
from .scores import Scores
from .scores import *
from . import score_tracker

SCORE_TRACKER = {
Expand All @@ -37,6 +42,13 @@
const.ScoreType_BackgroundTokensRatio: score_tracker.BackgroundTokensRatioScoreTracker,
}

SCORE_TRACKER_FILENAME = 'score_tracker.bin'
PWT_FILENAME = 'p_wt.bin'
NWT_FILENAME = 'n_wt.bin'
PTD_FILENAME = 'p_td.bin'
PARAMETERS_FILENAME_JSON = 'parameters.json'
PARAMETERS_FILENAME_BIN = 'parameters.bin'


def _run_from_ipython():
try:
Expand Down Expand Up @@ -146,7 +158,8 @@ def __init__(self, num_topics=None, topic_names=None, num_processors=None, class
Later you can retrieve this matix with ARTM.get_phi(model_name=ARTM.theta_name),\
change its values with ARTM.master.attach_model(model=ARTM.theta_name),\
export/import this matrix with ARTM.master.export_model('ptd', filename) and\
ARTM.master.import_model('ptd', file_name).
ARTM.master.import_model('ptd', file_name). In this case you are also able to work\
with theta matrix when using 'dump_artm_model' method and 'load_artm_model' function.
"""
self._num_processors = None
self._cache_theta = False
Expand All @@ -164,9 +177,8 @@ def __init__(self, num_topics=None, topic_names=None, num_processors=None, class
else:
raise ValueError('Either num_topics or topic_names parameter should be set')

if class_ids is None:
self._class_ids = {}
elif len(class_ids) > 0:
self._class_ids = {}
if class_ids is not None and isinstance(class_ids, dict) and len(class_ids) > 0:
self._class_ids = class_ids

if isinstance(num_processors, int) and num_processors > 0:
Expand Down Expand Up @@ -1011,6 +1023,185 @@ def __repr__(self):
return 'artm.ARTM(num_topics={0}, num_tokens={1}{2})'.format(
self.num_topics, num_tokens, class_ids)

def dump_artm_model(self, data_path):
"""
:Description: dump all necessary model files into given folder.
:param str data_path: full path to folder (should unexist)
"""
if os.path.exists(data_path):
raise IOError('Folder {} already exists'.format(data_path))

os.mkdir(data_path)
# save core score tracker
self._master.export_score_tracker(os.path.join(data_path, SCORE_TRACKER_FILENAME))
# save phi and n_wt matrices
self._master.export_model(self.model_pwt, os.path.join(data_path, PWT_FILENAME))
self._master.export_model(self.model_nwt, os.path.join(data_path, NWT_FILENAME))
# save theta if has theta_name
if self.theta_name is not None:
self._master.export_model(self.theta_name, os.path.join(data_path, PTD_FILENAME))

# save parameters in human-readable format
params = {}
params['version'] = self.library_version
params['creation_time'] = str(datetime.datetime.now())
params['num_processors'] = self._num_processors
params['cache_theta'] = self._cache_theta
params['num_document_passes'] = self._num_document_passes
params['reuse_theta'] = self._reuse_theta
params['theta_columns_naming'] = self._theta_columns_naming
params['seed'] = self._seed
params['show_progress_bars'] = self._show_progress_bars
params['topic_names'] = self._topic_names
params['class_ids'] = self._class_ids
params['model_pwt'] = self._model_pwt
params['model_nwt'] = self._model_nwt
params['theta_name'] = self._theta_name
params['synchronizations_processed'] = self._synchronizations_processed
params['num_online_processed_batches'] = self._num_online_processed_batches
params['initialized'] = self._initialized

regularizers = {}
for name, regularizer in iteritems(self._regularizers.data):
tau = None
gamma = None
try:
tau = regularizer.tau
gamma = regularizer.gamma
except KeyError:
pass
regularizers[name] = [str(regularizer.config), tau, gamma]
params['regularizers'] = regularizers

scores = {}
for name, score in iteritems(self._scores.data):
model_name = None
try:
model_name = score.model_name
except KeyError:
pass
scores[name] = [str(score.config), model_name]

params['scores'] = scores

with open(os.path.join(data_path, PARAMETERS_FILENAME_JSON), 'w') as fout:
json.dump(params, fout)

# save parameters in binary format
regularizers = {}
for name, regularizer in iteritems(self._regularizers._data):
regularizers[name] = [regularizer._config_message.__name__,
regularizer.config.SerializeToString()]

tau = None
gamma = None
try:
tau = regularizer.tau
gamma = regularizer.gamma
except KeyError:
pass

if tau is not None:
regularizers[name].append(tau)
if gamma is not None:
regularizers[name].append(gamma)

params['regularizers'] = regularizers

scores = {}
for name, score in iteritems(self._scores._data):
scores[name] = [score._config_message.__name__,
score.config.SerializeToString()]

model_name = None
try:
model_name = score.model_name
except KeyError:
pass
if model_name is not None:
scores[name].append(model_name)

params['scores'] = scores

with open(os.path.join(data_path, PARAMETERS_FILENAME_BIN), 'wb') as fout:
pickle.dump(params, fout)


def version():
return ARTM(num_topics=1).library_version


def load_artm_model(data_path):
"""
:Description: load all necessary files for model creation from given folder.
:param str data_path: full path to folder (should exist)
:return: artm.ARTM object, created using given dumped data
"""
# load parameters
with open(os.path.join(data_path, PARAMETERS_FILENAME_BIN), 'rb') as fin:
params = pickle.load(fin)

if params['version'] > version():
raise RuntimeError('File was generated with newer version of library ({}). '.format(params['version']) +
'Current library version is {}'.format(version()))

model = ARTM(topic_names=params['topic_names'],
num_processors=params['num_processors'],
class_ids=params['class_ids'],
num_document_passes=params['num_document_passes'],
reuse_theta=params['reuse_theta'],
cache_theta=params['cache_theta'],
theta_columns_naming=params['theta_columns_naming'],
seed=params['seed'],
show_progress_bars=params['show_progress_bars'],
theta_name=params['theta_name'])

model._model_pwt = params['model_pwt']
model._model_nwt = params['model_nwt']
model._synchronizations_processed = params['synchronizations_processed']
model._num_online_processed_batches = params['num_online_processed_batches']
model._initialized = params['initialized']

for name, type_config in iteritems(params['regularizers']):
config = None
func = None
for reg_info in mc.REGULARIZERS:
if reg_info[0].__name__ == type_config[0]:
config = reg_info[0]()
func = reg_info[2]
config.ParseFromString(type_config[1])

if len(type_config) == 3:
model.regularizers.add(func(name=name, config=config, tau=type_config[2]))
elif len(type_config) == 4:
model.regularizers.add(func(name=name, config=config, tau=type_config[2], gamma=type_config[3]))
else:
model.regularizers.add(func(name=name, config=config))

# load scores and configure python score_tracker
for name, type_config in iteritems(params['scores']):
config = None
func = None
for score_info in mc.SCORES:
if score_info[1].__name__ == type_config[0]:
config = score_info[1]()
func = score_info[3]
config.ParseFromString(type_config[1])
if len(type_config) == 3:
model.scores.add(func(name=name, config=config, model_name=type_config[2]))
else:
model.scores.add(func(name=name, config=config))
model.score_tracker[name] = SCORE_TRACKER[model.scores[name].type](model.scores[name])

# load core score tracker
model._master.import_score_tracker(os.path.join(data_path, SCORE_TRACKER_FILENAME))
# load phi and n_wt matrices
model._master.import_model(model.model_pwt, os.path.join(data_path, PWT_FILENAME))
model._master.import_model(model.model_nwt, os.path.join(data_path, NWT_FILENAME))
# load theta if has theta_name
if model.theta_name is not None:
model._master.import_model(model.theta_name, os.path.join(data_path, PTD_FILENAME))

return model

0 comments on commit 00c0a43

Please sign in to comment.