Skip to content

Commit

Permalink
Merge pull request #989 from bouthilx/monitoring_suffix
Browse files Browse the repository at this point in the history
Add suffix option to MonitoringExtension
  • Loading branch information
rizar committed Mar 7, 2016
2 parents 786be9f + 5a41d8f commit 4297cea
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
26 changes: 19 additions & 7 deletions blocks/extensions/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from blocks.monitoring.evaluators import (
AggregationBuffer, MonitoredQuantityBuffer, DatasetEvaluator)

PREFIX_SEPARATOR = '_'
SEPARATOR = '_'
logger = logging.getLogger(__name__)


Expand All @@ -20,17 +20,31 @@ class MonitoringExtension(TrainingExtension):
----------
prefix : str, optional
The prefix for the log records done by the extension. It is
appended to the variable names with an underscore as a separator.
If not given, the names of the observed variables are used as is.
prepended to the variable names with an underscore as a separator.
If not given, no prefix is added to the names of the observed
variables.
suffix : str, optional
The suffix for the log records done by the extension. It is
appended to the end of variable names with an underscore as a
separator. If not given, no suffix is added the names of the
observed variables.
"""
def __init__(self, prefix=None, **kwargs):
SEPARATOR = SEPARATOR

def __init__(self, prefix=None, suffix=None, **kwargs):
super(MonitoringExtension, self).__init__(**kwargs)
self.prefix = prefix
self.suffix = suffix

def _record_name(self, name):
"""The record name for a variable name."""
return self.prefix + PREFIX_SEPARATOR + name if self.prefix else name
if not isinstance(name, str):
raise ValueError("record name must be a string")

return self.SEPARATOR.join(
[morpheme for morpheme in [self.prefix, name, self.suffix]
if morpheme is not None])

def record_name(self, variable):
"""The record name for a variable."""
Expand Down Expand Up @@ -68,8 +82,6 @@ class DataStreamMonitoring(SimpleExtension, MonitoringExtension):
each time monitoring is done.
"""
PREFIX_SEPARATOR = '_'

def __init__(self, variables, data_stream, updates=None, **kwargs):
kwargs.setdefault("after_epoch", True)
kwargs.setdefault("before_first_epoch", True)
Expand Down
29 changes: 28 additions & 1 deletion tests/extensions/test_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from theano import tensor

from blocks.extensions import TrainingExtension, FinishAfter
from blocks.extensions.monitoring import TrainingDataMonitoring
from blocks.extensions.monitoring import (
MonitoringExtension,
TrainingDataMonitoring)
from blocks.monitoring import aggregation
from blocks.algorithms import GradientDescent, Scale
from blocks.utils import shared_floatx
Expand All @@ -26,6 +28,31 @@ def get_aggregated_value(self):
return self._aggregated / self._num_batches


def test_monitoring_extension__record_name():
test_name = "test-test"

monitor = MonitoringExtension()
assert monitor._record_name(test_name) == test_name

monitor = MonitoringExtension(prefix="abc")
assert (monitor._record_name(test_name) ==
"abc" + monitor.SEPARATOR + test_name)

monitor = MonitoringExtension(suffix="abc")
assert (monitor._record_name(test_name) ==
test_name + monitor.SEPARATOR + "abc")

monitor = MonitoringExtension(prefix="abc", suffix="def")
assert (monitor._record_name(test_name) ==
"abc" + monitor.SEPARATOR + test_name + monitor.SEPARATOR + "def")

try:
monitor = MonitoringExtension(prefix="abc", suffix="def")
monitor._record_name(None)
except ValueError as e:
assert str(e) == "record name must be a string"


def test_training_data_monitoring():
weights = numpy.array([-1, 1], dtype=theano.config.floatX)
features = [numpy.array(f, dtype=theano.config.floatX)
Expand Down

0 comments on commit 4297cea

Please sign in to comment.