Skip to content

Commit

Permalink
Merge pull request #1102 from dwf/evaluators
Browse files Browse the repository at this point in the history
Catch None variable names, refactor name uniqueness checks.
  • Loading branch information
dmitriy-serdyuk committed Jun 1, 2016
2 parents 387d159 + 66520c2 commit 57458e3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 22 deletions.
40 changes: 23 additions & 17 deletions blocks/monitoring/evaluators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import OrderedDict
from collections import OrderedDict, Counter
import logging

from picklable_itertools.extras import equizip
Expand All @@ -14,6 +14,21 @@
logger = logging.getLogger(__name__)


def _validate_variable_names(variables):
"""Check for missing and duplicate variable names."""
variable_names = [v.name for v in variables]
name_counts = Counter(variable_names)
if None in name_counts:
none_names = [v for v in variables if v.name is None]
raise ValueError('Variables must have names: {}'.format(none_names))

if any(v > 1 for v in name_counts.values()):
raise ValueError("Variables should have unique names."
" Duplicates: {}"
.format(', '.join(k for k, v in name_counts.items()
if v > 1)))


class MonitoredQuantityBuffer(object):
"""Intermediate results of aggregating values of monitored-quantity.
Expand All @@ -26,7 +41,7 @@ class MonitoredQuantityBuffer(object):
----------
quantities : list of :class:`MonitoredQuantity`
The quantity names are used as record names in the logs. Hence, all
the quantity names must be different.
the quantity names must be unique.
Attributes
----------
Expand Down Expand Up @@ -89,7 +104,7 @@ class AggregationBuffer(object):
----------
variables : list of :class:`~tensor.TensorVariable`
The variable names are used as record names in the logs. Hence, all
the variable names must be different.
the variable names must be unique.
use_take_last : bool
When ``True``, the :class:`TakeLast` aggregation scheme is used
instead of :class:`_DataIndependent` for those variables that
Expand All @@ -109,17 +124,10 @@ class AggregationBuffer(object):
"""
def __init__(self, variables, use_take_last=False):
_validate_variable_names(variables)
self.variables = variables
self.use_take_last = use_take_last

self.variable_names = [v.name for v in self.variables]
if len(set(self.variable_names)) < len(self.variables):
duplicates = []
for vname in set(self.variable_names):
if self.variable_names.count(vname) > 1:
duplicates.append(vname)
raise ValueError("variables should have different names!"
" Duplicates: {}".format(', '.join(duplicates)))
self.use_take_last = use_take_last
self._computation_graph = ComputationGraph(self.variables)
self.inputs = self._computation_graph.inputs

Expand Down Expand Up @@ -190,7 +198,7 @@ def initialize_aggregators(self):
def get_aggregated_values(self):
"""Readout the aggregated values."""
if not self._initialized:
raise Exception("To readout you must first initialize, then"
raise Exception("To readout you must first initialize, then "
"process batches!")
ret_vals = self._readout_fun()
return OrderedDict(equizip(self.variable_names, ret_vals))
Expand All @@ -217,7 +225,7 @@ class DatasetEvaluator(object):
variables : list of :class:`~tensor.TensorVariable` and
:class:`MonitoredQuantity`
The variable names are used as record names in the logs. Hence, all
the names must be different.
the names must be unique.
Each variable can be tagged with an :class:`AggregationScheme` that
specifies how the value can be computed for a data set by
Expand All @@ -233,6 +241,7 @@ class DatasetEvaluator(object):
"""
def __init__(self, variables, updates=None):
_validate_variable_names(variables)
theano_variables = []
monitored_quantities = []
for variable in variables:
Expand All @@ -242,9 +251,6 @@ def __init__(self, variables, updates=None):
theano_variables.append(variable)
self.theano_variables = theano_variables
self.monitored_quantities = monitored_quantities
variable_names = [v.name for v in variables]
if len(set(variable_names)) < len(variables):
raise ValueError("variables should have different names")
self.theano_buffer = AggregationBuffer(theano_variables)
self.monitored_quantities_buffer = MonitoredQuantityBuffer(
monitored_quantities)
Expand Down
15 changes: 10 additions & 5 deletions tests/monitoring/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy
import theano
from numpy.testing import assert_allclose, assert_raises
from numpy.testing import assert_allclose, assert_raises_regex
from theano import tensor

from blocks import bricks
Expand Down Expand Up @@ -91,7 +91,12 @@ def test_mean_aggregator():
numpy.array([35], dtype=theano.config.floatX))


def test_aggregation_buffer():
x1 = tensor.matrix('x')
x2 = tensor.matrix('x')
assert_raises(ValueError, AggregationBuffer, [x1, x2])
def test_aggregation_buffer_name_uniqueness():
x1 = tensor.scalar('x')
x2 = tensor.scalar('x')
assert_raises_regex(ValueError, 'unique', AggregationBuffer, [x1, x2])


def test_aggregation_buffer_name_none():
assert_raises_regex(ValueError, 'must have names',
AggregationBuffer, [theano.tensor.scalar()])
12 changes: 12 additions & 0 deletions tests/monitoring/test_monitored_quantity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from numpy.testing import assert_raises_regex
import numpy
import theano
from fuel.datasets import IterableDataset
Expand Down Expand Up @@ -44,3 +45,14 @@ def test_dataset_evaluators():
numpy.testing.assert_allclose(
values['monitored_cross_entropy1'],
values['categoricalcrossentropy_apply_cost'])


def test_dataset_evaluator_name_none():
assert_raises_regex(ValueError, 'must have names',
DatasetEvaluator, [theano.tensor.scalar()])


def test_dataset_evaluator_name_uniqueness():
assert_raises_regex(ValueError, 'unique',
DatasetEvaluator, [theano.tensor.scalar('A'),
theano.tensor.scalar('A')])

0 comments on commit 57458e3

Please sign in to comment.