Skip to content

Commit

Permalink
add new time regularizer (#818)
Browse files Browse the repository at this point in the history
New regularizer for smoothing time in topics was added
  • Loading branch information
MelLain committed Jul 19, 2017
1 parent b8b5e68 commit 82ac025
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 1 deletion.
3 changes: 3 additions & 0 deletions python/artm/master_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
), (
messages.TopicSegmentationPtdwConfig,
constants.RegularizerType_TopicSegmentationPtdw
), (
messages.SmoothTimeInTopicsPhiConfig,
constants.RegularizerType_SmoothTimeInTopicsPhi
),
)

Expand Down
58 changes: 57 additions & 1 deletion python/artm/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'BitermsPhiRegularizer',
'HierarchySparsingThetaRegularizer',
'TopicSegmentationPtdwRegularizer',
'SmoothTimeInTopicsPhiRegularizer',
]


Expand Down Expand Up @@ -534,7 +535,7 @@ def __init__(self, name=None, tau=1.0, gamma=None, topic_names=None, class_id=No
:param str name: the identifier of regularizer, will be auto-generated if not specified
:param float tau: the coefficient of regularization for this regularizer
:param float gamma: the coefficient of relative regularization for this regularizer
:param class_id: class_id to regularize
:param str class_id: class_id to regularize
:param topic_names: list of names or single name of topic to regularize,\
will regularize all topics if empty or None
:type topic_names: list of str or single str or None
Expand Down Expand Up @@ -814,3 +815,58 @@ def __init__(self, name=None, window=None, threshold=None, background_topic_name
background_topic_names = [background_topic_names]
for topic_name in background_topic_names:
self._config.background_topic_names.append(topic_name)


class SmoothTimeInTopicsPhiRegularizer(BaseRegularizerPhi):
_config_message = messages.SmoothTimeInTopicsPhiConfig
_type = const.RegularizerType_SmoothTimeInTopicsPhi

def __init__(self, name=None, tau=1.0, gamma=None, class_id=None, topic_names=None, config=None):
"""
:param str name: the identifier of regularizer, will be auto-generated if not specified
:param float tau: the coefficient of regularization for this regularizer
:param float gamma: the coefficient of relative regularization for this regularizer
:param str class_id: class_id to regularize
:param topic_names: list of names or single name of topic to regularize,\
will regularize all topics if empty or None
:type topic_names: list of str or single str or None
:param config: the low-level config of this regularizer
:type config: protobuf object
"""
BaseRegularizerPhi.__init__(self,
name=name,
tau=tau,
gamma=gamma,
config=config,
topic_names=topic_names,
class_ids=None,
dictionary=None)

self._class_id = '@default_class'
if class_id is not None:
self._config.class_id = class_id
self._class_id = class_id

@property
def class_id(self):
return self._class_id

@property
def class_ids(self):
raise KeyError('No class_ids parameter')

@property
def dictionary(self):
raise KeyError('No dictionary parameter')

@class_id.setter
def class_id(self, class_id):
_reconfigure_field(self, class_id, 'class_id')

@class_ids.setter
def class_ids(self, class_ids):
raise KeyError('No class_ids parameter')

@dictionary.setter
def dictionary(self, dictionary):
raise KeyError('No dictionary parameter')
1 change: 1 addition & 0 deletions python/artm/wrapper/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RegularizerType_BitermsPhi = 9
RegularizerType_HierarchySparsingTheta = 10
RegularizerType_TopicSegmentationPtdw = 11
RegularizerType_SmoothTimeInTopicsPhi = 12
RegularizerType_Unknown = 9999
SpecifiedSparsePhiConfig_SparseMode_SparseTopics = 0
SpecifiedSparsePhiConfig_SparseMode_SparseTokens = 1
Expand Down
42 changes: 42 additions & 0 deletions python/tests/artm/test_time_regularizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2017, Additive Regularization of Topic Models.

import shutil
import glob
import tempfile
import os
import pytest

from six.moves import range, zip

import artm
import pandas as pd


def test_func():
num_topics = 20
tolerance = 0.01
first_sparsity = 0.189
second_sparsity = 0.251

data_path = os.environ.get('BIGARTM_UNITTEST_DATA')
batches_folder = tempfile.mkdtemp()

try:
batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
data_format='bow_uci',
collection_name='kos',
target_folder=batches_folder)

model = artm.ARTM(num_topics=num_topics, dictionary=batch_vectorizer.dictionary)

model.scores.add(artm.SparsityPhiScore(name='sp_phi_one', topic_names=model.topic_names[0: 10]))
model.scores.add(artm.SparsityPhiScore(name='sp_phi_two', topic_names=model.topic_names[10: ]))

model.regularizers.add(artm.SmoothTimeInTopicsPhiRegularizer(tau=1000.0, topic_names=model.topic_names[0: 10]))

model.fit_offline(batch_vectorizer, 20)

assert abs(model.score_tracker['sp_phi_one'].last_value - first_sparsity) < tolerance
assert abs(model.score_tracker['sp_phi_two'].last_value - second_sparsity) < tolerance
finally:
shutil.rmtree(batches_folder)
2 changes: 2 additions & 0 deletions src/artm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ set(SRC_LIST
regularizer/hierarchy_sparsing_theta.h
regularizer/topic_segmentation_ptdw.cc
regularizer/topic_segmentation_ptdw.h
regularizer/smooth_time_in_topics_phi.cc
regularizer/smooth_time_in_topics_phi.h
score/class_precision.cc
score/class_precision.h
score/items_processed.cc
Expand Down
7 changes: 7 additions & 0 deletions src/artm/core/instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "artm/regularizer/biterms_phi.h"
#include "artm/regularizer/hierarchy_sparsing_theta.h"
#include "artm/regularizer/topic_segmentation_ptdw.h"
#include "artm/regularizer/smooth_time_in_topics_phi.h"

#include "artm/score/items_processed.h"
#include "artm/score/sparsity_theta.h"
Expand Down Expand Up @@ -284,6 +285,12 @@ void Instance::CreateOrReconfigureRegularizer(const RegularizerConfig& config) {
break;
}

case artm::RegularizerType_SmoothTimeInTopicsPhi: {
CREATE_OR_RECONFIGURE_REGULARIZER(::artm::SmoothTimeInTopicsPhiConfig,
::artm::regularizer::SmoothTimeInTopicsPhi);
break;
}

default:
BOOST_THROW_EXCEPTION(ArgumentOutOfRangeException(
"RegularizerConfig.type", regularizer_type));
Expand Down
7 changes: 7 additions & 0 deletions src/artm/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ enum RegularizerType {
RegularizerType_BitermsPhi = 9;
RegularizerType_HierarchySparsingTheta = 10;
RegularizerType_TopicSegmentationPtdw = 11;
RegularizerType_SmoothTimeInTopicsPhi = 12;
RegularizerType_Unknown = 9999;
}

Expand Down Expand Up @@ -194,6 +195,12 @@ message TopicSegmentationPtdwConfig {
optional double threshold = 4 [default = 0.5];
}

// Represents a configuration of a SmoothTimeInTopics Phi regularizer
message SmoothTimeInTopicsPhiConfig {
repeated string topic_name = 1;
optional string class_id = 2 [default = "@default_class"];
}

// Represents the transform functions
message TransformConfig {
enum TransformType {
Expand Down
89 changes: 89 additions & 0 deletions src/artm/regularizer/smooth_time_in_topics_phi.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2017, Additive Regularization of Topic Models.

// Author: Murat Apishev (great-mel@yandex.ru)

#include <string>
#include <vector>

#include "artm/core/protobuf_helpers.h"
#include "artm/core/phi_matrix.h"
#include "artm/regularizer/smooth_time_in_topics_phi.h"

namespace artm {
namespace regularizer {

bool SmoothTimeInTopicsPhi::RegularizePhi(const ::artm::core::PhiMatrix& p_wt,
const ::artm::core::PhiMatrix& n_wt,
::artm::core::PhiMatrix* result) {
// read the parameters from config and control their correctness
const int topic_size = p_wt.topic_size();
const int token_size = p_wt.token_size();

std::vector<bool> topics_to_regularize;
if (config_.topic_name().size() == 0)
topics_to_regularize.assign(topic_size, true);
else
topics_to_regularize = core::is_member(p_wt.topic_name(), config_.topic_name());

// proceed the regularization
// will update only tokens of given modality, that have prev and post tokens of this modality
int index_prev_prev = -1;
int index_prev = -1;
for (int token_id = 0; token_id < token_size; ++token_id) {
const ::artm::core::Token& token = p_wt.token(token_id);

if (token.class_id != config_.class_id())
continue;

if (index_prev_prev < 0) {
index_prev_prev = token_id;
continue;
}

if (index_prev < 0) {
index_prev = token_id;
continue;
}

for (int topic_id = 0; topic_id < topic_size; ++topic_id) {
if (topics_to_regularize[topic_id]) {
double value = p_wt.get(index_prev, topic_id);

value *= ((p_wt.get(index_prev_prev, topic_id) - value) > 0.0 ? 1.0 : -1.0) +
((p_wt.get(token_id, topic_id) - value) > 0.0 ? 1.0 : -1.0);

result->set(index_prev, topic_id, value);
}
}
index_prev_prev = index_prev;
index_prev = token_id;
}

return true;
}

google::protobuf::RepeatedPtrField<std::string> SmoothTimeInTopicsPhi::topics_to_regularize() {
return config_.topic_name();
}

google::protobuf::RepeatedPtrField<std::string> SmoothTimeInTopicsPhi::class_ids_to_regularize() {
google::protobuf::RepeatedPtrField<std::string> retval;
std::string* ptr = retval.Add();
*ptr = config_.class_id();
return retval;
}

bool SmoothTimeInTopicsPhi::Reconfigure(const RegularizerConfig& config) {
std::string config_blob = config.config();
SmoothTimeInTopicsPhiConfig regularizer_config;
if (!regularizer_config.ParseFromString(config_blob)) {
BOOST_THROW_EXCEPTION(::artm::core::CorruptedMessageException(
"Unable to parse SmoothSparsePhiConfig from RegularizerConfig.config"));
}

config_.CopyFrom(regularizer_config);
return true;
}

} // namespace regularizer
} // namespace artm
52 changes: 52 additions & 0 deletions src/artm/regularizer/smooth_time_in_topics_phi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/* Copyright 2017, Additive Regularization of Topic Models.
Author: Murat Apishev (great-mel@yandex.ru)
This class proceeds smoothing of tokens in Phi using nearest values.
More preferable for time stamps tokens. Requires to be used with sorted dictionary
(e.g. the tokens should follow in some order, chronological, for instance).
The formula of M-step is
p_wt \propto n_wt + tau * p_wt * (sign(p_{w-1,t} - p_wt) + sign(p_{w+1,t} - p_wt)),
The parameters of the regularizer:
- topic_names (the names of topics to regularize, empty == all)
- class_id (class id to regularize, required)
Note: regularizer ignores first and last tokens of given modality.
*/

#ifndef SRC_ARTM_REGULARIZER_SMOOTH_TIME_IN_TOPICS_PHI_H_
#define SRC_ARTM_REGULARIZER_SMOOTH_TIME_IN_TOPICS_PHI_H_

#include <memory>
#include <string>

#include "artm/regularizer_interface.h"

namespace artm {
namespace regularizer {

class SmoothTimeInTopicsPhi : public RegularizerInterface {
public:
explicit SmoothTimeInTopicsPhi(const SmoothTimeInTopicsPhiConfig& config) : config_(config) { }

virtual bool RegularizePhi(const ::artm::core::PhiMatrix& p_wt,
const ::artm::core::PhiMatrix& n_wt,
::artm::core::PhiMatrix* result);

virtual google::protobuf::RepeatedPtrField<std::string> topics_to_regularize();
virtual google::protobuf::RepeatedPtrField<std::string> class_ids_to_regularize();

virtual bool Reconfigure(const RegularizerConfig& config);

private:
SmoothTimeInTopicsPhiConfig config_;
};

} // namespace regularizer
} // namespace artm

#endif // SRC_ARTM_REGULARIZER_SMOOTH_TIME_IN_TOPICS_PHI_H_
2 changes: 2 additions & 0 deletions utils/cpplint_files.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ src/artm/regularizer/smooth_ptdw.cc
src/artm/regularizer/topic_selection_theta.cc
src/artm/regularizer/biterms_phi.cc
src/artm/regularizer/topic_segmentation_ptdw.cc
src/artm/regularizer/smooth_time_in_topics_phi.cc
src/artm/score/class_precision.cc
src/artm/score/peak_memory.cc
src/artm/score/perplexity.cc
Expand Down Expand Up @@ -102,6 +103,7 @@ src/artm/regularizer/smooth_ptdw.h
src/artm/regularizer/topic_selection_theta.h
src/artm/regularizer/biterms_phi.h
src/artm/regularizer/topic_segmentation_ptdw.h
src/artm/regularizer/smooth_time_in_topics_phi.h
src/artm/score/class_precision.h
src/artm/score/peak_memory.h
src/artm/score/perplexity.h
Expand Down

0 comments on commit 82ac025

Please sign in to comment.