Skip to content

Commit

Permalink
Implemented new version of cast
Browse files Browse the repository at this point in the history
  • Loading branch information
unnonouno committed Aug 15, 2017
1 parent a1320b9 commit 2c468f9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 5 additions & 5 deletions chainer/functions/array/cast.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from chainer import function
from chainer import function_node
from chainer.utils import type_check


class Cast(function.Function):
class Cast(function_node.FunctionNode):

"""Cast function."""

Expand All @@ -20,8 +20,8 @@ def forward(self, x):
self._in_type = x[0].dtype.type
return x[0].astype(self.type, copy=False),

def backward(self, x, g):
return g[0].astype(self._in_type, copy=False),
def backward(self, indexes, g):
return cast(g[0], self._in_type),


def cast(x, typ):
Expand Down Expand Up @@ -51,4 +51,4 @@ def cast(x, typ):
dtype('float16')
"""
return Cast(typ)(x)
return Cast(typ).apply((x,))[0]
4 changes: 3 additions & 1 deletion tests/chainer_tests/functions_tests/array_tests/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def test_forward_gpu(self):
self.check_forward(cuda.to_gpu(self.x))

def check_backward(self, x_data, g_data):
func = functions.Cast(self.out_type)
def func(x):
return functions.cast(x, self.out_type)

gradient_check.check_backward(
func, x_data, g_data, eps=2.0 ** -2, atol=1e-3, rtol=1e-3)

Expand Down

0 comments on commit 2c468f9

Please sign in to comment.