Skip to content

Commit

Permalink
Add a few tiny tests and update a docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
bartvm committed Jan 15, 2015
1 parent 99e10a6 commit fbf56bf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
5 changes: 3 additions & 2 deletions blocks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ def unpack(arg, singleton=False):
will be cast to a list before returning. Any other variable
will be returned as is.
singleton : bool
If ``True``, `arg` is expected to be a singleton and an exception
is raised if this is not the case. ``False`` by default.
If ``True``, `arg` is expected to be a singleton (a list or tuple
with exactly one element) and an exception is raised if this is not
the case. ``False`` by default.
Returns
-------
Expand Down
32 changes: 32 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from numpy.testing import assert_raises
from theano import tensor

from blocks.utils import (check_theano_variable, graph_inputs, shared_floatx,
unpack)


def test_unpack():
assert unpack((1, 2)) == [1, 2]
assert unpack([1, 2]) == [1, 2]
assert unpack([1]) == 1
test = object()
assert unpack(test) is test
assert_raises(ValueError, unpack, [1, 2], True)


def test_check_theano_variable():
check_theano_variable(None, 3, 'float')
check_theano_variable([[1, 2]], 2, 'int')
assert_raises(ValueError, check_theano_variable,
tensor.vector(), 2, 'float')
assert_raises(ValueError, check_theano_variable,
tensor.vector(), 1, 'int')


def test_graph_inputs():
a = tensor.matrix('a')
b = shared_floatx(0, 'b')
c = 3

d = a + b + c
assert graph_inputs([d]) == [a]

0 comments on commit fbf56bf

Please sign in to comment.