Skip to content

Commit

Permalink
Merge pull request #4347 from toslunar/elu-bwd
Browse files Browse the repository at this point in the history
Improve ELU.backward
  • Loading branch information
okuta committed Feb 16, 2018
2 parents c3765bd + 269eefc commit 036ce47
Showing 1 changed file with 14 additions and 23 deletions.
37 changes: 14 additions & 23 deletions chainer/functions/activation/elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ def forward_cpu(self, x):
self.retain_inputs((0,))
y = x[0].copy()
neg_indices = x[0] < 0
y[neg_indices] = self.alpha * (numpy.exp(y[neg_indices]) - 1)
y[neg_indices] = self.alpha * (numpy.expm1(y[neg_indices]))
return y,

def forward_gpu(self, x):
self.retain_inputs((0,))
y = cuda.elementwise(
'T x, T alpha', 'T y',
'y = x >= 0 ? x : (T)(alpha * (exp(x) - 1))',
'y = x >= 0 ? x : (T)(alpha * expm1(x))',
'elu_fwd')(
x[0], self.alpha)
return y,

def backward(self, indexes, grad_outputs):
x, = self.get_retained_inputs()
gy, = grad_outputs
return ELUGrad(self.alpha).apply((x, gy))
return ELUGrad(self.alpha).apply((x,))[0] * gy,


class ELUGrad(function_node.FunctionNode):
Expand All @@ -48,43 +48,34 @@ def __init__(self, alpha):
self.alpha = alpha

def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 2)
type_check.expect(in_types.size() == 1)
type_check.expect(in_types[0].dtype.kind == 'f')
type_check.expect(in_types[1].dtype.kind == 'f')

def forward_cpu(self, inputs):
x, gy = inputs
gx = gy.copy()
x, = inputs
gx = numpy.ones_like(x)
neg_indices = x < 0
gx[neg_indices] *= self.alpha * numpy.exp(x[neg_indices])
self.retain_inputs((0, 1))
self.retain_inputs((0,))
self.retain_outputs((0,))
return gx,

def forward_gpu(self, inputs):
x, gy = inputs
x, = inputs
gx = cuda.elementwise(
'T x, T gy, T alpha', 'T gx',
'gx = x >= 0 ? gy : (T)(gy * alpha * exp(x))',
'T x, T alpha', 'T gx',
'gx = x >= 0 ? (T)1 : (T)(alpha * exp(x))',
'elu_bwd')(
x, gy, self.alpha)
self.retain_inputs((0, 1))
x, self.alpha)
self.retain_inputs((0,))
self.retain_outputs((0,))
return gx,

def backward(self, indexes, grad_outputs):
x, gy = self.get_retained_inputs()
x, = self.get_retained_inputs()
gx, = self.get_retained_outputs()
ggx, = grad_outputs
ggxgx = ggx * gx

ret = []
if 0 in indexes:
ret.append(ggxgx * (x.data < 0))
if 1 in indexes:
ret.append(ggxgx / gy)

return ret
return ggx * gx * (x.data < 0),


def elu(x, alpha=1.0):
Expand Down

0 comments on commit 036ce47

Please sign in to comment.