Skip to content

Commit

Permalink
Mel lain transaction artm (#856)
Browse files Browse the repository at this point in the history
Add trans-artm models support (beta, inefficient)
  • Loading branch information
MelLain committed Feb 25, 2018
1 parent 4c591cf commit 839b927
Show file tree
Hide file tree
Showing 69 changed files with 2,587 additions and 618 deletions.
109 changes: 76 additions & 33 deletions python/artm/artm_model.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion python/artm/batches_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ def __reset_batch():
batch_vocab[token] = len(batch.token)
batch.token.append(token)

item.token_id.append(batch_vocab[token])
item.transaction_token_id.append(batch_vocab[token])
item.transaction_start_index.append(len(item.transaction_start_index))
item.token_weight.append(float(value))

if ((item_id + 1) % self._batch_size == 0 and item_id != 0) or ((item_id + 1) == n_wd.shape[1]):
Expand Down
105 changes: 78 additions & 27 deletions python/artm/master_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ def _score_data_func(score_data_type):
return mfunc


def _prepare_config(topic_names=None, class_ids=None, scores=None, regularizers=None, num_processors=None,
pwt_name=None, nwt_name=None, num_document_passes=None, reuse_theta=None, cache_theta=None,
args=None):
def _prepare_config(topic_names=None, class_ids=None, transaction_types=None,
scores=None, regularizers=None, num_processors=None,
pwt_name=None, nwt_name=None, num_document_passes=None,
reuse_theta=None, cache_theta=None, args=None):
master_config = messages.MasterModelConfig()

if args is not None:
Expand All @@ -158,12 +159,18 @@ def _prepare_config(topic_names=None, class_ids=None, scores=None, regularizers=
for topic_name in topic_names:
master_config.topic_name.append(topic_name)

if class_ids is not None:
master_config.ClearField('class_id')
master_config.ClearField('class_weight')
for class_id, class_weight in iteritems(class_ids):
master_config.class_id.append(class_id)
master_config.class_weight.append(class_weight)
if transaction_types is not None:
master_config.ClearField('transaction_type')
master_config.ClearField('transaction_weight')
for transaction_type, transaction_weight in iteritems(transaction_types):
master_config.transaction_type.append(transaction_type)
master_config.transaction_weight.append(transaction_weight)
elif class_ids is not None:
master_config.ClearField('transaction_type')
master_config.ClearField('transaction_weight')
for transaction_type, transaction_weight in iteritems(transaction_types):
master_config.transaction_type.append(transaction_type)
master_config.transaction_weight.append(transaction_weight)

if scores is not None:
master_config.ClearField('score_config')
Expand Down Expand Up @@ -206,15 +213,19 @@ def _prepare_config(topic_names=None, class_ids=None, scores=None, regularizers=


class MasterComponent(object):
def __init__(self, library=None, topic_names=None, class_ids=None, scores=None, regularizers=None,
num_processors=None, pwt_name=None, nwt_name=None, num_document_passes=None,
reuse_theta=None, cache_theta=False, config=None, master_id=None):
def __init__(self, library=None, topic_names=None, class_ids=None, transaction_types=None,
scores=None, regularizers=None, num_processors=None, pwt_name=None,
nwt_name=None, num_document_passes=None, reuse_theta=None,
cache_theta=False, config=None, master_id=None):
"""
:param library: an instance of LibArtm
:param topic_names: list of topic names to use in model
:type topic_names: list of str
:param dict class_ids: key - class_id, value - class_weight
:param dict class_ids: key - class_id, value - class_weight,\
use either class_ids or transaction_types
:param dict transaction_types: key - transaction_type, value - transaction_weight,\
use either class_ids or transaction_types
:param dict scores: key - score name, value - config
:param dict regularizers: key - regularizer name, value - tuple (config, tau)\
or triple (config, tau, gamma)
Expand All @@ -229,6 +240,7 @@ def __init__(self, library=None, topic_names=None, class_ids=None, scores=None,

master_config = _prepare_config(topic_names=topic_names,
class_ids=class_ids,
transaction_types=transaction_types,
scores=scores,
regularizers=regularizers,
num_processors=num_processors,
Expand All @@ -247,11 +259,12 @@ def __deepcopy__(self, memo):
self.master_id, messages.DuplicateMasterComponentArgs())
return MasterComponent(self._lib, config=self._config, master_id=new_master_id)

def reconfigure(self, topic_names=None, class_ids=None, scores=None, regularizers=None,
num_processors=None, pwt_name=None, nwt_name=None, num_document_passes=None,
reuse_theta=None, cache_theta=None):
def reconfigure(self, topic_names=None, class_ids=None, transaction_types=None,
scores=None, regularizers=None, num_processors=None, pwt_name=None,
nwt_name=None, num_document_passes=None, reuse_theta=None, cache_theta=None):
master_config = _prepare_config(topic_names=topic_names,
class_ids=class_ids,
transaction_types=transaction_types,
scores=scores,
regularizers=regularizers,
num_processors=num_processors,
Expand Down Expand Up @@ -436,8 +449,9 @@ def clear_score_array_cache(self):
def process_batches(self, pwt, nwt=None, num_document_passes=None, batches_folder=None,
batches=None, regularizer_name=None, regularizer_tau=None,
class_ids=None, class_weights=None, find_theta=False,
transaction_types=None, transaction_weights=None,
reuse_theta=False, find_ptdw=False,
predict_class_id=None):
predict_class_id=None, predict_transaction_type=None):
"""
:param str pwt: name of pwt matrix in BigARTM
:param str nwt: name of nwt matrix in BigARTM
Expand All @@ -449,16 +463,29 @@ def process_batches(self, pwt, nwt=None, num_document_passes=None, batches_folde
:type regularizer_name: list of str
:param regularizer_tau: list of tau coefficients for Theta regularizers
:type regularizer_tau: list of float
:param class_ids: list of class ids to use during processing
:param class_ids: list of class ids to use during processing.\
Use either transaction_types or class_ids parameter-weight pairs
:type class_ids: list of str
:param class_weights: list of corresponding weights of class ids
:param class_weights: list of corresponding weights of class ids.\
Use either transaction_types or class_ids parameter-weight pairs
:type class_weights: list of float
:param transaction_types: list of transaction types to use during processing.\
Use either transaction_types or class_ids parameter-weight pairs
:type transaction_types: list of str
:param transaction_weights: list of corresponding weights of transaction types.\
Use either transaction_types or class_ids parameter-weight pairs
:type transaction_weights: list of float
:param bool find_theta: find theta matrix for 'batches' (if alternative 2)
:param bool reuse_theta: initialize by theta from previous collection pass
:param bool find_ptdw: calculate and return Ptdw matrix or not\
(works if find_theta == False)
:param predict_class_id: class_id of a target modality to predict
:type predict_class_id: str, default None
:param predict_transaction_type: transaction type to predict (in case of None class_id\
parameter all class_ids in transaction will be predicted,\
it is invalid behavior, so predict_transaction_type should\
always be used with predict_class_id
:type predict_transaction_type: str, default None
:return:
* tuple (messages.ThetaMatrix, numpy.ndarray) --- the info about Theta\
(if find_theta == True)
Expand All @@ -485,14 +512,21 @@ def process_batches(self, pwt, nwt=None, num_document_passes=None, batches_folde
args.regularizer_name.append(name)
args.regularizer_tau.append(tau)

if class_ids is not None and class_weights is not None:
if transaction_types is not None and transaction_weights is not None:
for transaction_type, weight in zip(transaction_types, transaction_weights):
args.transaction_type.append(transaction_type)
args.transaction_weight.append(weight)
elif class_ids is not None and class_weights is not None:
for class_id, weight in zip(class_ids, class_weights):
args.class_id.append(class_id)
args.class_weight.append(weight)
args.transaction_type.append(class_id)
args.transaction_weight.append(weight)

if predict_class_id is not None:
args.predict_class_id = predict_class_id

if predict_transaction_type is not None:
args.predict_transaction_type = predict_transaction_type

func = None
if find_theta or find_ptdw:
args.theta_matrix_type = constants.ThetaMatrixType_Dense
Expand Down Expand Up @@ -584,7 +618,7 @@ def attach_model(self, model):
"""
:param str model: name of matrix in BigARTM
:return:
* messahes.TopicModel() object with info about Phi matrix
* messages.TopicModel() object with info about Phi matrix
* numpy.ndarray with Phi data (i.e., p(w|t) values)
"""
topic_model = self.get_phi_info(model)
Expand Down Expand Up @@ -741,13 +775,17 @@ def get_phi_info(self, model):

return phi_matrix_info

def get_phi_matrix(self, model, topic_names=None, class_ids=None, use_sparse_format=None):
def get_phi_matrix(self, model, topic_names=None, class_ids=None,
transaction_types=None, use_sparse_format=None):
"""
:param str model: name of matrix in BigARTM
:param topic_names: list of topics to retrieve (None means all topics)
:type topic_names: list of str or None
:param class_ids: list of class ids to retrieve (None means all class ids)
:type class_ids: list of str or None
:param transaction_types: list of transaction types to retrieve\
(None means all transaction types)
:type transaction_types: list of str or None
:param bool use_sparse_format: use sparse\dense layout
:return: numpy.ndarray with Phi data (i.e., p(w|t) values)
"""
Expand All @@ -760,6 +798,10 @@ def get_phi_matrix(self, model, topic_names=None, class_ids=None, use_sparse_for
args.ClearField('class_id')
for class_id in class_ids:
args.class_id.append(class_id)
if transaction_types is not None:
args.ClearField('transaction_type')
for transaction_type in transaction_types:
args.transaction_type.append(transaction_type)
if use_sparse_format is not None:
args.matrix_layout = constants.MatrixLayout_Sparse

Expand Down Expand Up @@ -876,14 +918,20 @@ def fit_online(self, batch_filenames=None, batch_weights=None, update_after=None

self._lib.ArtmFitOnlineMasterModel(self.master_id, args)

def transform(self, batches=None, batch_filenames=None,
theta_matrix_type=None, predict_class_id=None):
def transform(self, batches=None, batch_filenames=None, theta_matrix_type=None,
predict_class_id=None, predict_transaction_type=None):
"""
:param batches: list of Batch instances
:param batch_weights: weights of batches to transform
:type batch_weights: list of float
:param int theta_matrix_type: type of matrix to be returned
:param int predict_class_id: type of matrix to be returned
:param predict_class_id: class_id of a target modality to predict
:type predict_class_id: str, default None
:param predict_transaction_type: transaction type to predict (in case of None class_id\
parameter all class_ids in transaction will be predicted,\
it is invalid behavior, so predict_transaction_type should\
always be used with predict_class_id
:type predict_transaction_type: str, default None
:return: messages.ThetaMatrix object
"""
args = messages.TransformMasterModelArgs()
Expand All @@ -904,6 +952,9 @@ def transform(self, batches=None, batch_filenames=None,
if predict_class_id is not None:
args.predict_class_id = predict_class_id

if predict_transaction_type is not None:
args.predict_transaction_type = predict_transaction_type

if theta_matrix_type not in [constants.ThetaMatrixType_None, constants.ThetaMatrixType_Cache]:
theta_matrix_info = self._lib.ArtmRequestTransformMasterModelExternal(self.master_id, args)

Expand Down

0 comments on commit 839b927

Please sign in to comment.