Skip to content

Commit

Permalink
Merge commit '5d4673bcc3589ee07d0185571629675581740a31' into fix-spli…
Browse files Browse the repository at this point in the history
…t_axis
  • Loading branch information
ken-nakanishi committed Jan 27, 2018
2 parents fdc0bec + 5d4673b commit f6b5163
Show file tree
Hide file tree
Showing 102 changed files with 2,954 additions and 1,517 deletions.
2 changes: 1 addition & 1 deletion chainer/_version.py
@@ -1 +1 @@
__version__ = '4.0.0b2'
__version__ = '4.0.0b3'
2 changes: 1 addition & 1 deletion chainer/dataset/convert.py
Expand Up @@ -156,7 +156,7 @@ class ConcatWithAsyncTransfer(object):
from chainer.dataset import convert
...
updater = chainer.training.StandardUpdater(
updater = chainer.training.updaters.StandardUpdater(
...,
converter=convert.ConcatWithAsyncTransfer(),
...)
Expand Down
2 changes: 1 addition & 1 deletion chainer/function.py
Expand Up @@ -59,7 +59,7 @@ def force_backprop_mode():
... y = x + 1
>>> y.backward()
>>> x.grad
array([ 1.], dtype=float32)
array([1.], dtype=float32)
.. seealso::
Expand Down
18 changes: 15 additions & 3 deletions chainer/function_node.py
Expand Up @@ -652,7 +652,7 @@ def delete_hook(self, name):


def grad(outputs, inputs, grad_outputs=None, grad_inputs=None, set_grad=False,
retain_grad=False, enable_double_backprop=False):
retain_grad=False, enable_double_backprop=False, loss_scale=None):
"""Computes the gradient of output variables w.r.t.\\ the input variables.
This function implements the backpropagation algorithm. While
Expand Down Expand Up @@ -701,6 +701,14 @@ def grad(outputs, inputs, grad_outputs=None, grad_inputs=None, set_grad=False,
the memory consumption (and possibly the computational time) to
remember the intermediate gradient values for the second
backpropagation.
loss_scale (float): Loss scaling factor. Loss scaling is a usefull
technique to mitigate vanishing gradient issue that tends to happen
when low precision data type like float16 is used during training.
If you set loss scaling factor, gradients of loss values are to be
multiplied by the factor before backprop starts. The factor is
propagated to whole gradients in a computational graph along the
backporp. The gradients of parameters are divided by the factor
just before the parameters are to be updated.
Returns:
A list of gradient variables w.r.t. the inputs.
Expand Down Expand Up @@ -777,6 +785,8 @@ def grad(outputs, inputs, grad_outputs=None, grad_inputs=None, set_grad=False,
else:
gy_data = cuda.cupy.ones_like(y.data)
gy = variable.Variable(gy_data, requires_grad=False)
if loss_scale is not None:
gy.data *= loss_scale
grads[y.node] = gy

if grad_inputs is not None:
Expand All @@ -787,7 +797,8 @@ def grad(outputs, inputs, grad_outputs=None, grad_inputs=None, set_grad=False,
# Backprop implementation. It edits grads which will only contain the
# gradients w.r.t. the inputs.
with chainer.using_config('enable_backprop', enable_double_backprop):
_backprop(outputs, inputs, grad_required, retain_grad, grads)
_backprop(outputs, inputs, grad_required, retain_grad, grads,
loss_scale)

# Extract the gradients w.r.t. the inputs and return them.
ret = [grads.get(x.node, None) for x in inputs]
Expand All @@ -798,7 +809,7 @@ def grad(outputs, inputs, grad_outputs=None, grad_inputs=None, set_grad=False,
return ret


def _backprop(outputs, inputs, grad_required, retain_grad, grads):
def _backprop(outputs, inputs, grad_required, retain_grad, grads, loss_scale):
candidate_funcs, push_candidate, pop_candidate = _get_ordered_func_heap()

for y in outputs:
Expand Down Expand Up @@ -875,6 +886,7 @@ def _backprop(outputs, inputs, grad_required, retain_grad, grads):
v = node.get_variable_or_none()
if v is not None:
v.grad_var = g
v._loss_scale = loss_scale

creator = node.creator_node
if creator is not None:
Expand Down
1 change: 1 addition & 0 deletions chainer/functions/__init__.py
Expand Up @@ -124,6 +124,7 @@
from chainer.functions.connection.n_step_rnn import NStepBiRNNTanh # NOQA
from chainer.functions.connection.n_step_rnn import NStepRNNReLU # NOQA
from chainer.functions.connection.n_step_rnn import NStepRNNTanh # NOQA
from chainer.functions.connection.shift import shift # NOQA

from chainer.functions.evaluation.accuracy import accuracy # NOQA
from chainer.functions.evaluation.accuracy import Accuracy # NOQA
Expand Down
4 changes: 2 additions & 2 deletions chainer/functions/activation/crelu.py
Expand Up @@ -76,8 +76,8 @@ def crelu(x, axis=1):
[ 2., -3.]], dtype=float32)
>>> y = F.crelu(x, axis=1)
>>> y.data
array([[ 0., 0., 1., 0.],
[ 2., 0., 0., 3.]], dtype=float32)
array([[0., 0., 1., 0.],
[2., 0., 0., 3.]], dtype=float32)
"""
return CReLU(axis=axis).apply((x,))[0]
2 changes: 1 addition & 1 deletion chainer/functions/activation/hard_sigmoid.py
Expand Up @@ -97,7 +97,7 @@ def hard_sigmoid(x):
>>> x
array([-2.6, -1. , 0. , 1. , 2.6])
>>> F.hard_sigmoid(x).data
array([ 0. , 0.3, 0.5, 0.7, 1. ])
array([0. , 0.3, 0.5, 0.7, 1. ])
"""
return HardSigmoid().apply((x,))[0]
6 changes: 3 additions & 3 deletions chainer/functions/activation/leaky_relu.py
Expand Up @@ -116,9 +116,9 @@ def leaky_relu(x, slope=0.2):
[ 2., -3.],
[-2., 1.]], dtype=float32)
>>> F.leaky_relu(x, slope=0.2).data
array([[-0.2 , 0. ],
[ 2. , -0.60000002],
[-0.40000001, 1. ]], dtype=float32)
array([[-0.2, 0. ],
[ 2. , -0.6],
[-0.4, 1. ]], dtype=float32)
"""
return LeakyReLU(slope).apply((x,))[0]
8 changes: 4 additions & 4 deletions chainer/functions/activation/log_softmax.py
Expand Up @@ -156,11 +156,11 @@ def log_softmax(x):
>>> x = np.array([[0, 1, 2], [0, 2, 4]], 'f')
>>> x
array([[ 0., 1., 2.],
[ 0., 2., 4.]], dtype=float32)
array([[0., 1., 2.],
[0., 2., 4.]], dtype=float32)
>>> F.log_softmax(x).data
array([[-2.40760589, -1.40760589, -0.40760589],
[-4.14293146, -2.14293146, -0.14293146]], dtype=float32)
array([[-2.407606 , -1.4076059 , -0.4076059 ],
[-4.1429315 , -2.1429315 , -0.14293146]], dtype=float32)
>>> np.allclose(F.log_softmax(x).data, F.log(F.softmax(x)).data)
True
Expand Down
22 changes: 11 additions & 11 deletions chainer/functions/activation/maxout.py
Expand Up @@ -48,19 +48,19 @@ def maxout(x, pool_size, axis=1):
>>> y.shape
(1, 10)
>>> x.reshape((out_size, pool_size)).data
array([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
[ 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
[ 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
[ 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],
[ 40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],
[ 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],
[ 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],
[ 70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],
[ 80., 81., 82., 83., 84., 85., 86., 87., 88., 89.],
[ 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.]], \
array([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],
[50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],
[60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],
[70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],
[80., 81., 82., 83., 84., 85., 86., 87., 88., 89.],
[90., 91., 92., 93., 94., 95., 96., 97., 98., 99.]], \
dtype=float32)
>>> y.data
array([[ 9., 19., 29., 39., 49., 59., 69., 79., 89., 99.]], \
array([[ 9., 19., 29., 39., 49., 59., 69., 79., 89., 99.]], \
dtype=float32)
"""
Expand Down
2 changes: 1 addition & 1 deletion chainer/functions/activation/sigmoid.py
Expand Up @@ -113,7 +113,7 @@ def sigmoid(x):
>>> x
array([-2., 0., 2.], dtype=float32)
>>> F.sigmoid(x)
variable([ 0.11920291, 0.5 , 0.88079709])
variable([0.11920291, 0.5 , 0.8807971 ])
"""
y, = Sigmoid().apply((x,))
Expand Down
10 changes: 5 additions & 5 deletions chainer/functions/activation/softmax.py
Expand Up @@ -138,14 +138,14 @@ def softmax(x, axis=1):
>>> x = np.array([[0, 1, 2], [0, 2, 4]], 'f')
>>> x
array([[ 0., 1., 2.],
[ 0., 2., 4.]], dtype=float32)
array([[0., 1., 2.],
[0., 2., 4.]], dtype=float32)
>>> y = F.softmax(x, axis=1)
>>> y.data
array([[ 0.09003057, 0.24472848, 0.66524094],
[ 0.01587624, 0.11731043, 0.86681336]], dtype=float32)
array([[0.09003057, 0.24472848, 0.66524094],
[0.01587624, 0.11731043, 0.86681336]], dtype=float32)
>>> F.sum(y, axis=1).data
array([ 1., 1.], dtype=float32)
array([1., 1.], dtype=float32)
"""
return Softmax(axis=axis).apply((x,))[0]
2 changes: 1 addition & 1 deletion chainer/functions/activation/softplus.py
Expand Up @@ -113,7 +113,7 @@ def softplus(x, beta=1.0):
>>> x
array([-2., 0., 2.], dtype=float32)
>>> F.softplus(x, beta=1.0).data
array([ 0.126928 , 0.69314718, 2.12692809], dtype=float32)
array([0.126928 , 0.6931472, 2.126928 ], dtype=float32)
"""
y, = Softplus(beta=beta).apply((x,))
Expand Down
2 changes: 1 addition & 1 deletion chainer/functions/activation/tanh.py
Expand Up @@ -108,7 +108,7 @@ def tanh(x):
>>> x
array([-1., 1., 3.], dtype=float32)
>>> F.tanh(x).data
array([-0.76159418, 0.76159418, 0.99505478], dtype=float32)
array([-0.7615942, 0.7615942, 0.9950548], dtype=float32)
"""
return Tanh().apply((x,))[0]
24 changes: 12 additions & 12 deletions chainer/functions/array/depth2space.py
Expand Up @@ -71,25 +71,25 @@ def depth2space(X, r):
>>> X.shape
(1, 4, 2, 3)
>>> X
array([[[[ 0., 1., 2.],
[ 3., 4., 5.]],
array([[[[ 0., 1., 2.],
[ 3., 4., 5.]],
<BLANKLINE>
[[ 6., 7., 8.],
[ 9., 10., 11.]],
[[ 6., 7., 8.],
[ 9., 10., 11.]],
<BLANKLINE>
[[ 12., 13., 14.],
[ 15., 16., 17.]],
[[12., 13., 14.],
[15., 16., 17.]],
<BLANKLINE>
[[ 18., 19., 20.],
[ 21., 22., 23.]]]], dtype=float32)
[[18., 19., 20.],
[21., 22., 23.]]]], dtype=float32)
>>> y = F.depth2space(X, 2)
>>> y.shape
(1, 1, 4, 6)
>>> y.data
array([[[[ 0., 6., 1., 7., 2., 8.],
[ 12., 18., 13., 19., 14., 20.],
[ 3., 9., 4., 10., 5., 11.],
[ 15., 21., 16., 22., 17., 23.]]]], dtype=float32)
array([[[[ 0., 6., 1., 7., 2., 8.],
[12., 18., 13., 19., 14., 20.],
[ 3., 9., 4., 10., 5., 11.],
[15., 21., 16., 22., 17., 23.]]]], dtype=float32)
"""
return Depth2Space(r).apply((X,))[0]
24 changes: 12 additions & 12 deletions chainer/functions/array/permutate.py
Expand Up @@ -109,26 +109,26 @@ def permutate(x, indices, axis=0, inv=False):
>>> x = np.arange(6).reshape((3, 2)).astype('f')
>>> x
array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.]], dtype=float32)
array([[0., 1.],
[2., 3.],
[4., 5.]], dtype=float32)
>>> indices = np.array([2, 0, 1], 'i')
>>> y = F.permutate(x, indices)
>>> y.data
array([[ 4., 5.],
[ 0., 1.],
[ 2., 3.]], dtype=float32)
array([[4., 5.],
[0., 1.],
[2., 3.]], dtype=float32)
>>> y = F.permutate(x, indices, inv=True)
>>> y.data
array([[ 2., 3.],
[ 4., 5.],
[ 0., 1.]], dtype=float32)
array([[2., 3.],
[4., 5.],
[0., 1.]], dtype=float32)
>>> indices = np.array([1, 0], 'i')
>>> y = F.permutate(x, indices, axis=1)
>>> y.data
array([[ 1., 0.],
[ 3., 2.],
[ 5., 4.]], dtype=float32)
array([[1., 0.],
[3., 2.],
[5., 4.]], dtype=float32)
"""
y, = Permutate(axis, inv).apply((x, indices))
Expand Down
2 changes: 1 addition & 1 deletion chainer/functions/array/select_item.py
Expand Up @@ -110,7 +110,7 @@ def select_item(x, t):
>>> y.shape
(2,)
>>> y.data
array([ 0., 5.], dtype=float32)
array([0., 5.], dtype=float32)
"""
return SelectItem().apply((x, t))[0]
50 changes: 40 additions & 10 deletions chainer/functions/array/separate.py
@@ -1,5 +1,38 @@
from chainer.functions.array import reshape
from chainer.functions.array import split_axis
from chainer import cuda
from chainer import function_node
from chainer.functions.array import stack
from chainer.utils import type_check


class Separate(function_node.FunctionNode):

"""Function that separates a given array."""

def __init__(self, axis):
self.axis = axis

def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 1)
x_type = in_types[0]
if self.axis >= 0:
type_check.expect(self.axis < x_type.ndim)
else:
type_check.expect(-self.axis <= x_type.ndim)

def forward(self, inputs):
x, = inputs
self._xp = cuda.get_array_module(x)
xs = self._xp.split(x, x.shape[self.axis], self.axis)
ys = [self._xp.squeeze(y, self.axis) for y in xs]
self._shape = ys[0].shape
self._dtype = x.dtype
return tuple(ys)

def backward(self, indexes, grad_outputs):
grad_outputs = [
self._xp.zeros(self._shape, dtype=self._dtype)
if g is None else g for g in grad_outputs]
return stack.stack(grad_outputs, self.axis),


def separate(x, axis=0):
Expand Down Expand Up @@ -27,8 +60,8 @@ def separate(x, axis=0):
>>> x = np.arange(6).reshape((2, 3)).astype('f')
>>> x
array([[ 0., 1., 2.],
[ 3., 4., 5.]], dtype=float32)
array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32)
>>> x.shape
(2, 3)
>>> y = F.separate(x) # split along axis=0
Expand All @@ -39,17 +72,14 @@ def separate(x, axis=0):
>>> y[0].shape
(3,)
>>> y[0].data
array([ 0., 1., 2.], dtype=float32)
array([0., 1., 2.], dtype=float32)
>>> y = F.separate(x, axis=1)
>>> len(y)
3
>>> y[0].shape
(2,)
>>> y[0].data
array([ 0., 3.], dtype=float32)
array([0., 3.], dtype=float32)
"""
shape = list(x.shape)
del shape[axis]
ys = split_axis.split_axis(x, x.shape[axis], axis, force_tuple=True)
return tuple(reshape.reshape(y, shape) for y in ys)
return Separate(axis).apply((x,))

0 comments on commit f6b5163

Please sign in to comment.