Skip to content

Commit

Permalink
Apply as_variable to no-cast result
Browse files Browse the repository at this point in the history
  • Loading branch information
beam2d committed Aug 28, 2017
1 parent 9f7e13e commit 2abe136
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
3 changes: 2 additions & 1 deletion chainer/functions/array/cast.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import chainer
from chainer import function_node
from chainer.utils import type_check

Expand Down Expand Up @@ -51,5 +52,5 @@ def cast(x, typ):
"""
if x.dtype == typ:
return x
return chainer.as_variable(x)
return Cast(typ).apply((x,))[0]
11 changes: 8 additions & 3 deletions tests/chainer_tests/functions_tests/array_tests/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,18 @@ def setUp(self):

def check_forward_no_cast(self, x_data):
y = functions.cast(x_data, self.dtype)
self.assertIs(y, x_data)
self.assertIsInstance(y, chainer.Variable)
self.assertIs(y.data, x_data)

def test_forward_no_cast_array(self):
self.check_forward_no_cast(self.x)
y = functions.cast(self.x, self.dtype)
self.assertIsInstance(y, chainer.Variable)
self.assertIs(y.data, self.x)

def test_forward_no_cast_variable(self):
self.check_forward_no_cast(chainer.Variable(self.x))
x = chainer.Variable(self.x)
y = functions.cast(x, self.dtype)
self.assertIs(y, x)


testing.run_module(__name__, __file__)

0 comments on commit 2abe136

Please sign in to comment.