diff --git a/chainer/functions/math/sqrt.py b/chainer/functions/math/sqrt.py index fd2531c130bc..36db23bf89b2 100644 --- a/chainer/functions/math/sqrt.py +++ b/chainer/functions/math/sqrt.py @@ -29,7 +29,7 @@ def backward(self, indexes, grad_outputs): return gy / (gx * 2.0), -class Rsqrt(function_node.FunctionNode): +class RsqrtGPU(function_node.FunctionNode): @property def label(self): @@ -41,16 +41,10 @@ def check_type_forward(self, in_types): in_types[0].dtype.kind == 'f', ) - def forward(self, inputs): + def forward_gpu(self, inputs): self.retain_inputs((0,)) x, = inputs - xp = cuda.get_array_module(x) - dtype = x.dtype - if xp is numpy: - out = xp.reciprocal(xp.sqrt(x, dtype=dtype), dtype=dtype) - else: - # CuPy provides `rsqrt` which is faster than `1.0 / sqrt(x)`. - out = cuda.cupyx.rsqrt(x, dtype=dtype) + out = cuda.cupyx.rsqrt(x, dtype=x.dtype) return utils.force_array(out), def backward(self, indexes, grad_outputs): @@ -91,4 +85,9 @@ def rsqrt(x): .. seealso:: :func:`~chainer.functions.sqrt` """ - return Rsqrt().apply((x,))[0] + xp = cuda.get_array_module(x) + if xp is numpy: + return 1.0 / sqrt(x) + + # CuPy provides `rsqrt` which is faster than `1.0 / sqrt(x)`. + return RsqrtGPU().apply((x,))[0]