Skip to content

Commit

Permalink
Merge pull request #961 from rizar/better_error_for_duplicate_names
Browse files Browse the repository at this point in the history
Better error message and test for duplicate names
  • Loading branch information
rizar committed Jan 29, 2016
2 parents 5a964e4 + 2c59075 commit b7a7805
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
7 changes: 6 additions & 1 deletion blocks/monitoring/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ def __init__(self, variables, use_take_last=False):

self.variable_names = [v.name for v in self.variables]
if len(set(self.variable_names)) < len(self.variables):
raise ValueError("variables should have different names")
duplicates = []
for vname in set(self.variable_names):
if self.variable_names.count(vname) > 1:
duplicates.append(vname)
raise ValueError("variables should have different names!"
" Duplicates: {}".format(', '.join(duplicates)))
self._computation_graph = ComputationGraph(self.variables)
self.inputs = self._computation_graph.inputs

Expand Down
10 changes: 8 additions & 2 deletions tests/monitoring/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy
import theano
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_raises
from theano import tensor

from blocks import bricks
Expand All @@ -14,7 +14,7 @@
from fuel.streams import DataStream
from fuel.schemes import SequentialScheme

from blocks.monitoring.evaluators import DatasetEvaluator
from blocks.monitoring.evaluators import DatasetEvaluator, AggregationBuffer


class TestBrick(bricks.Brick):
Expand Down Expand Up @@ -89,3 +89,9 @@ def test_mean_aggregator():
numpy.array([8.25, 26.75], dtype=theano.config.floatX))
assert_allclose(DatasetEvaluator([z]).evaluate(data_stream)['z'],
numpy.array([35], dtype=theano.config.floatX))


def test_aggregation_buffer():
x1 = tensor.matrix('x')
x2 = tensor.matrix('x')
assert_raises(ValueError, AggregationBuffer, [x1, x2])

0 comments on commit b7a7805

Please sign in to comment.