Skip to content

Commit

Permalink
Merge pull request #3191 from beam2d/no-cast
Browse files Browse the repository at this point in the history
Make F.cast return the input as is if no cast is needed
  • Loading branch information
unnonouno committed Aug 28, 2017
2 parents c796529 + 2abe136 commit ea6ab85
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
3 changes: 3 additions & 0 deletions chainer/functions/array/cast.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import chainer
from chainer import function_node
from chainer.utils import type_check

Expand Down Expand Up @@ -50,4 +51,6 @@ def cast(x, typ):
dtype('float16')
"""
if x.dtype == typ:
return chainer.as_variable(x)
return Cast(typ).apply((x,))[0]
22 changes: 22 additions & 0 deletions tests/chainer_tests/functions_tests/array_tests/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,26 @@ def test_backward_cpu(self):
self.check_backward(self.x, self.g)


class TestNoCast(unittest.TestCase):

def setUp(self):
self.dtype = numpy.float32
self.x = numpy.empty(1, self.dtype)

def check_forward_no_cast(self, x_data):
y = functions.cast(x_data, self.dtype)
self.assertIsInstance(y, chainer.Variable)
self.assertIs(y.data, x_data)

def test_forward_no_cast_array(self):
y = functions.cast(self.x, self.dtype)
self.assertIsInstance(y, chainer.Variable)
self.assertIs(y.data, self.x)

def test_forward_no_cast_variable(self):
x = chainer.Variable(self.x)
y = functions.cast(x, self.dtype)
self.assertIs(y, x)


testing.run_module(__name__, __file__)

0 comments on commit ea6ab85

Please sign in to comment.