Skip to content

Commit

Permalink
Fix variable repr and str when data is None
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Jun 1, 2017
1 parent c6adfeb commit 444679d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
19 changes: 14 additions & 5 deletions chainer/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,13 @@ def variable_repr(var):
else:
prefix = 'variable'

if arr.size > 0 or arr.shape == (0,):
if arr is None:
lst = 'None'
elif arr.size > 0 or arr.shape == (0,):
lst = numpy.array2string(arr, None, None, None, ', ', prefix + '(')
else: # show zero-length shape unless it is (0,)
lst = '[], shape=%s' % (repr(arr.shape),)

return '%s(%s)' % (prefix, lst)


Expand All @@ -94,12 +97,18 @@ def variable_str(var):
arr = var.data
else:
arr = var.data.get()

if var.name:
prefix = 'variable ' + var.name + '('
prefix = 'variable ' + var.name
else:
prefix = 'variable'

if arr is None:
lst = 'None'
else:
prefix = 'variable('
return (prefix + numpy.array2string(arr, None, None, None, ' ', prefix) +
')')
lst = numpy.array2string(arr, None, None, None, ' ', prefix + '(')

return '%s(%s)' % (prefix, lst)


class VariableNode(object):
Expand Down
26 changes: 18 additions & 8 deletions tests/chainer_tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ def test_backward_gpu(self):


@testing.parameterize(
{'x_shape': None, 'dtype': None, 'repr': 'variable(None)',
'str': 'variable(None)'},
{'x_shape': (2, 2,), 'dtype': np.float16,
'repr': 'variable([[ 0., 1.],\n [ 2., 3.]])',
'str': 'variable([[ 0. 1.]\n [ 2. 3.]])'},
Expand All @@ -1152,10 +1154,13 @@ def test_backward_gpu(self):
class TestUnnamedVariableToString(unittest.TestCase):

def setUp(self):
x = np.empty(self.x_shape)
x = np.arange(x.size).reshape(self.x_shape)
x = x.astype(self.dtype)
self.x = chainer.Variable(x)
if self.x_shape is None:
self.x = chainer.Variable()
else:
x = np.empty(self.x_shape)
x = np.arange(x.size).reshape(self.x_shape)
x = x.astype(self.dtype)
self.x = chainer.Variable(x)

def test_repr_cpu(self):
self.assertEqual(repr(self.x), self.repr)
Expand Down Expand Up @@ -1205,6 +1210,8 @@ def test_str_gpu(self):


@testing.parameterize(
{'x_shape': None, 'dtype': None, 'repr': 'variable x(None)',
'str': 'variable x(None)'},
{'x_shape': (2, 2,), 'dtype': np.float32,
'repr': 'variable x([[ 0., 1.],\n [ 2., 3.]])',
'str': 'variable x([[ 0. 1.]\n [ 2. 3.]])'},
Expand All @@ -1214,10 +1221,13 @@ def test_str_gpu(self):
class TestNamedVariableToString(unittest.TestCase):

def setUp(self):
x = np.empty(self.x_shape)
x = np.arange(x.size).reshape(self.x_shape)
x = x.astype(self.dtype)
self.x = chainer.Variable(x, name='x')
if self.x_shape is None:
self.x = chainer.Variable(name='x')
else:
x = np.empty(self.x_shape)
x = np.arange(x.size).reshape(self.x_shape)
x = x.astype(self.dtype)
self.x = chainer.Variable(x, name='x')

def test_named_repr(self):
self.assertEqual(repr(self.x), self.repr)
Expand Down

0 comments on commit 444679d

Please sign in to comment.