Skip to content

Commit

Permalink
Merge a6b5946 into 289428e
Browse files Browse the repository at this point in the history
  • Loading branch information
unnonouno committed Aug 24, 2017
2 parents 289428e + a6b5946 commit fb80fbd
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 19 deletions.
2 changes: 1 addition & 1 deletion chainer/function_node.py
Expand Up @@ -550,7 +550,7 @@ def get_retained_outputs(self):
if outputs_modified:
self.outputs = tuple(new_outputs)

return ret
return tuple(ret)

def unchain(self):
"""Purges in/out nodes and this function node itself from the graph."""
Expand Down
54 changes: 43 additions & 11 deletions chainer/functions/activation/tanh.py
Expand Up @@ -3,6 +3,7 @@
import chainer
from chainer import cuda
from chainer import function
from chainer import function_node
from chainer import utils
from chainer.utils import type_check

Expand All @@ -12,7 +13,7 @@
_mode = libcudnn.CUDNN_ACTIVATION_TANH


class Tanh(function.Function):
class Tanh(function_node.FunctionNode):

"""Hyperbolic tangent function."""

Expand All @@ -24,37 +25,68 @@ def forward_cpu(self, x):
y = utils.force_array(numpy.tanh(x[0]))
self.retain_inputs(())
self.retain_outputs((0,))
self._use_cudnn = False
return y,

def forward_gpu(self, x):
if chainer.should_use_cudnn('==always') and x[0].flags.c_contiguous:
y = cudnn.activation_forward(x[0], _mode)
self.retain_inputs((0,))
self._use_cudnn = True
else:
y = cuda.cupy.empty_like(x[0])
cuda.cupy.tanh(x[0], out=y)
self.retain_inputs(())
self._use_cudnn = False

self.retain_outputs((0,))
return y,

def backward_cpu(self, x, gy):
y = self.output_data[0]
def backward(self, indexes, grad_outputs):
if self._use_cudnn:
x = self.get_retained_inputs()[0].data
else:
x = None
y = self.get_retained_outputs()[0]
gy = grad_outputs[0]
return TanhGrad(x).apply((y, gy))


class TanhGrad(function_node.FunctionNode):

def __init__(self, x):
super(TanhGrad, self).__init__()
# The original input `x` is only required for cuDNN.
# If it is None, this class does not use cuDNN.
self.x = x

def forward_cpu(self, inputs):
self.retain_inputs((0, 1))
y, gy = inputs
one = y.dtype.type(1)
return utils.force_array(gy[0] * (one - y * y)),
return utils.force_array(gy * (one - y * y)),

def backward_gpu(self, x, gy):
y = self.output_data[0]
def forward_gpu(self, inputs):
self.retain_inputs((0, 1))
y, gy = inputs
if (chainer.should_use_cudnn('==always') and
x[0] is not None and x[0].flags.c_contiguous and
gy[0].flags.c_contiguous):
gx = cudnn.activation_backward(x[0], y, gy[0], _mode)
self.x is not None and self.x.flags.c_contiguous and
gy.flags.c_contiguous):
gx = cudnn.activation_backward(self.x, y, gy, _mode)
else:
gx = cuda.elementwise(
'T y, T gy', 'T gx',
'gx = gy * (1 - y * y)',
'tanh_bwd')(y, gy[0])
'tanh_bwd')(y, gy)
return gx,

def backward(self, indexes, grad_outputs):
y, gy = self.get_retained_inputs()
g = grad_outputs[0]
grad_y = -2 * g * gy * y
ggy = g * (1 - y * y)
return grad_y, ggy


def tanh(x):
"""Elementwise hyperbolic tangent function.
Expand All @@ -79,4 +111,4 @@ def tanh(x):
array([-0.76159418, 0.76159418, 0.99505478], dtype=float32)
"""
return Tanh()(x)
return Tanh().apply((x,))[0]
52 changes: 50 additions & 2 deletions tests/chainer_tests/functions_tests/activation_tests/test_tanh.py
Expand Up @@ -6,6 +6,7 @@
import chainer
from chainer import cuda
from chainer import functions
from chainer.functions.activation import tanh
from chainer import gradient_check
from chainer import testing
from chainer.testing import attr
Expand All @@ -20,7 +21,8 @@ class TestTanh(unittest.TestCase):

def setUp(self):
self.x = numpy.random.uniform(-.5, .5, self.shape).astype(self.dtype)
self.gy = numpy.random.uniform(-.1, .1, self.shape).astype(self.dtype)
self.gy = numpy.random.uniform(-.5, .5, self.shape).astype(self.dtype)
self.ggx = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype)
self.check_backward_options = {}
if self.dtype == numpy.float16:
self.check_backward_options = {
Expand Down Expand Up @@ -53,7 +55,7 @@ def test_forward_gpu_no_cudnn(self):
def check_backward(self, x_data, gy_data, use_cudnn='always'):
with chainer.using_config('use_cudnn', use_cudnn):
gradient_check.check_backward(
functions.Tanh(), x_data, gy_data,
functions.tanh, x_data, gy_data,
**self.check_backward_options)

@condition.retry(3)
Expand All @@ -76,6 +78,20 @@ def test_backward_gpu_non_contiguous(self):
def test_backward_gpu_no_cudnn(self):
self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy), 'never')

def check_double_backward(self, x_data, gy_data, ggx_data):
gradient_check.check_double_backward(
chainer.functions.tanh, x_data, gy_data, ggx_data, dtype='d')

@condition.retry(3)
def test_double_backward_cpu(self):
self.check_double_backward(self.x, self.gy, self.ggx)

@attr.gpu
@condition.retry(3)
def test_double_backward_gpu(self):
self.check_double_backward(
cuda.to_gpu(self.x), cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx))


@testing.parameterize(*testing.product({
'use_cudnn': ['always', 'auto', 'never'],
Expand Down Expand Up @@ -113,4 +129,36 @@ def test_call_cudnn_backward(self):
self.assertEqual(func.called, self.expect)


@testing.parameterize(*testing.product({
'shape': [(3, 2), ()],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
}))
class TestTanhGrad(unittest.TestCase):

def setUp(self):
self.x = numpy.random.uniform(-.5, .5, self.shape).astype(self.dtype)
self.gy = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype)
self.ggx = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype)

def check_backward(self, x_data, y_data, gy_data, ggx_data):
def f(y, gy):
return tanh.TanhGrad(x_data).apply((y, gy))[0]

gradient_check.check_backward(
f, (y_data, gy_data), ggx_data, dtype='d', atol=1e-4, rtol=1e-4)

@condition.retry(3)
def test_backward_cpu(self):
y = numpy.array(numpy.tanh(self.x))
self.check_backward(self.x, y, self.gy, self.ggx)

@attr.gpu
@condition.retry(3)
def test_backward_cpu(self):
y = numpy.array(numpy.tanh(self.x))
self.check_backward(
cuda.to_gpu(self.x), cuda.to_gpu(y), cuda.to_gpu(self.gy),
cuda.to_gpu(self.ggx))


testing.run_module(__name__, __file__)
Expand Up @@ -7,13 +7,25 @@

import chainer
from chainer import configuration
from chainer import functions
from chainer import function_node
from chainer import links
from chainer import testing
from chainer import training
from chainer.training.extensions import computational_graph as c


class Function1(function_node.FunctionNode):

def forward(self, inputs):
return inputs[0],


class Function2(function_node.FunctionNode):

def forward(self, inputs):
return inputs[0],


class Dataset(chainer.dataset.DatasetMixin):
def __init__(self, values):
self.values = values
Expand All @@ -39,9 +51,9 @@ def __call__(self, x):

h = self.l1(x)
if self.i == 0:
h = functions.Sigmoid()(h)
h, = Function1().apply((h,))
else:
h = functions.Tanh()(h)
h, = Function2().apply((h,))
h = self.l2(h)

self.i += 1
Expand Down Expand Up @@ -83,8 +95,8 @@ def _run_test(self, tempdir, initial_flag):
graph_dot = f.read()

# Check that only the first iteration is dumped
self.assertIn('Sigmoid', graph_dot)
self.assertNotIn('Tanh', graph_dot)
self.assertIn('Function1', graph_dot)
self.assertNotIn('Function2', graph_dot)

def _check(self, initial_flag):
tempdir = tempfile.mkdtemp()
Expand Down

0 comments on commit fb80fbd

Please sign in to comment.