Skip to content

Commit

Permalink
Merge 5ee9fd5 into 1d00f3a
Browse files Browse the repository at this point in the history
  • Loading branch information
delta2323 committed Aug 23, 2017
2 parents 1d00f3a + 5ee9fd5 commit e849cfd
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 19 deletions.
96 changes: 78 additions & 18 deletions chainer/functions/activation/relu.py
Expand Up @@ -2,7 +2,7 @@

import chainer
from chainer import cuda
from chainer import function
from chainer import function_node
from chainer import utils
from chainer.utils import type_check

Expand All @@ -12,10 +12,11 @@
_mode = cudnn.cudnn.CUDNN_ACTIVATION_RELU


class ReLU(function.Function):
class ReLU(function_node.FunctionNode):

"""Rectified Linear Unit."""
# TODO(beam2d): Implement in-place version.

_use_cudnn = False

def check_type_forward(self, in_types):
type_check.expect(
Expand All @@ -24,36 +25,94 @@ def check_type_forward(self, in_types):
)

def forward_cpu(self, x):
self.retain_inputs(())
self.retain_outputs((0,))
return utils.force_array(numpy.maximum(x[0], 0, dtype=x[0].dtype)),

def forward_gpu(self, x):
if chainer.should_use_cudnn('==always') and x[0].flags.c_contiguous:
# cupy.activation_backward requires the input.
# So, we retain it for backward computation.
self.retain_inputs((0,))
self._use_cudnn = True
y = cudnn.activation_forward(x[0], _mode)
else:
self.retain_inputs(())
self._use_cudnn = False
y = cuda.cupy.maximum(x[0], 0)
self.retain_outputs((0,))
return y,

def backward_cpu(self, x, gy):
y = self.output_data[0]
return utils.force_array(gy[0] * (y > 0)),

def backward_gpu(self, x, gy):
y = self.output_data[0]
def backward(self, indexes, gy):
y = self.get_retained_outputs()[0]
if chainer.should_use_cudnn('==always') and self._use_cudnn:
gx = cudnn.activation_backward(x[0], y, gy[0], _mode)
x = self.get_retained_inputs()[0]
return ReLUGrad3(x, y).apply((gy[0],))
else:
gx = cuda.elementwise(
'T y, T gy', 'T gx',
'gx = y > 0 ? gy : (T)0',
'relu_bwd')(y, gy[0])
return ReLUGrad2(y).apply((gy[0],))


def _heaviside(x):
return (x > 0).astype(x.dtype)


class ReLUGrad2(function_node.FunctionNode):
"""Computes the gradient of the ReLU function.
This function takes 2 variables b and c, and
computes f(b, c) = sign(b) * c with backpropagation
where operations are dones in elementwise manner
and sign(x) = 1 when x > 0 is positive and 0 otherwise.
As the gradient of f with respect to b is 0,
we do not backpropagate errors toward b for computational efficiency.
"""

def __init__(self, b):
super(ReLUGrad2).__init__()
self.b = b.data

def forward_cpu(self, inputs):
y = (self.b > 0) * inputs[0]
return utils.force_array(y, dtype=y.dtype),

def forward_gpu(self, inputs):
b = cuda.to_gpu(self.b)
gx = cuda.elementwise(
'T y, T gy', 'T gx',
'gx = y > 0 ? gy : (T)0',
'relu_bwd')(b, inputs[0])
return gx,

def backward(self, indexes, gy):
return gy[0] * _heaviside(self.b),


class ReLUGrad3(function_node.FunctionNode):
"""Computes the gradient of the ReLU function.
This function takes 3 variables a, b, and c, and
computes f(a, b, c) = sign(b) * c with backpropagation
where operations are dones in elementwise manner
and sign(x) = 1 if x > 0 is positive and 0 otherwise.
As the gradient of f with respect to a and b are 0,
we do not backpropagate errors toward them for computational efficiency.
"""

def __init__(self, a, b):
self.a = a.data
self.b = b.data

def forward_cpu(self, inputs):
return (self.b > 0) * inputs[0],

def forward_gpu(self, inputs):
a = cuda.to_gpu(self.a)
b = cuda.to_gpu(self.b)
assert chainer.should_use_cudnn('==always')
return cudnn.activation_backward(a, b, inputs[0], _mode),

def backward(self, indexes, gy):
return gy[0] * _heaviside(self.b),


def relu(x):
"""Rectified Linear Unit function.
Expand Down Expand Up @@ -81,4 +140,5 @@ def relu(x):
(3, 2)
"""
return ReLU()(x)
y, = ReLU().apply((x,))
return y
Expand Up @@ -25,6 +25,7 @@ def setUp(self):
if -0.1 < self.x[i] < 0.1:
self.x[i] = 0.5
self.gy = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype)
self.ggx = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype)
self.check_backward_options = {}
if self.dtype == numpy.float16:
self.check_backward_options = {'dtype': numpy.float64}
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_forward_gpu_no_cudnn(self):
def check_backward(self, x_data, y_grad, use_cudnn='always'):
with chainer.using_config('use_cudnn', use_cudnn):
gradient_check.check_backward(
functions.ReLU(), x_data, y_grad,
functions.relu, x_data, y_grad,
**self.check_backward_options)

@condition.retry(3)
Expand All @@ -83,6 +84,44 @@ def test_backward_gpu_non_contiguous(self):
def test_backward_cpu_no_cudnn(self):
self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy), 'never')

def check_double_backward(self, x_data, y_grad, x_grad_grad,
use_cudnn='always'):
def f(x):
x = functions.relu(x)
return x * x

with chainer.using_config('use_cudnn', use_cudnn):
gradient_check.check_double_backward(
f, x_data, y_grad, x_grad_grad,
**self.check_backward_options)

@condition.retry(1)
def test_double_backward_cpu(self):
self.check_double_backward(self.x, self.gy, self.ggx)

@attr.gpu
@condition.retry(1)
def test_double_backward_gpu(self):
self.check_double_backward(cuda.to_gpu(self.x),
cuda.to_gpu(self.gy),
cuda.to_gpu(self.ggx))

@attr.gpu
@condition.retry(3)
def test_double_backward_gpu_non_contiguous(self):
self.check_double_backward(
cuda.cupy.asfortranarray(cuda.to_gpu(self.x)),
cuda.cupy.asfortranarray(cuda.to_gpu(self.gy)),
cuda.cupy.asfortranarray(cuda.to_gpu(self.ggx)))

@attr.gpu
@condition.retry(3)
def test_double_backward_cpu_no_cudnn(self):
self.check_double_backward(cuda.to_gpu(self.x),
cuda.to_gpu(self.gy),
cuda.to_gpu(self.ggx),
'never')


@testing.parameterize(*testing.product({
'use_cudnn': ['always', 'auto', 'never'],
Expand Down

0 comments on commit e849cfd

Please sign in to comment.