Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TopicSegmentationRegularizer implementation in biagartm9.0 #917

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
16 changes: 8 additions & 8 deletions python/artm/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,8 @@ class TopicSegmentationPtdwRegularizer(BaseRegularizer):
_config_message = messages.TopicSegmentationPtdwConfig
_type = const.RegularizerType_TopicSegmentationPtdw

def __init__(self, name=None, window=None, threshold=None, background_topic_names=None, config=None):
def __init__(self, name=None, tau=1.0, window=1, threshold=None, background_topic_names=None,
merge_into_segments=False, merge_threshold=0.5, config=None):
"""
:param str name: the identifier of regularizer, will be auto-generated if not specified
:param int window: a number of words to the one side over which smoothing will be performed
Expand All @@ -973,28 +974,27 @@ def __init__(self, name=None, window=None, threshold=None, background_topic_name

BaseRegularizer.__init__(self,
name=name,
tau=1.0,
tau=tau,
gamma=None,
config=config)
if window is not None:
self._config.window = window
self._window = window
elif config is not None and config.HasField('window'):
self._window = config.window

if threshold is not None:
self._config.threshold = threshold
self._threshold = threshold
elif config is not None and config.HasField('threshold'):
self._threshold = config.threshold

if background_topic_names is not None:
if isinstance(background_topic_names, string_types):
background_topic_names = [background_topic_names]
for topic_name in background_topic_names:
self._config.background_topic_names.append(topic_name)
elif config is not None and len(config.background_topic_names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you delete loading of parameters from config? Also you should add class fields for new parameters, add setters/getters and make parameter loading from config for them too.

self._background_topic_names = [name for name in config.background_topic_names]

self._config.merge_threshold = 0.0
if merge_into_segments:
self._config.merge_into_segments=True
self._config.merge_threshold = merge_threshold


class SmoothTimeInTopicsPhiRegularizer(BaseRegularizerPhi):
Expand Down
6 changes: 4 additions & 2 deletions src/artm/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ message HierarchySparsingThetaConfig {
// Represents a configuration of a Topic Segmentation regularizer
message TopicSegmentationPtdwConfig {
repeated string background_topic_names = 1;
optional int32 window = 3 [default = 10];
optional float threshold = 4 [default = 0.5];
optional int32 window = 3 [default = 1];
optional double threshold = 4 [default = 0.5];
optional bool merge_into_segments = 5 [default = false];
optional float merge_threshold = 6 [default = 0.5];
}

// Represents a configuration of a SmoothTimeInTopics Phi regularizer
Expand Down
Loading