Skip to content

Commit

Permalink
Merge d826a9a into ea6ab85
Browse files Browse the repository at this point in the history
  • Loading branch information
delta2323 committed Aug 28, 2017
2 parents ea6ab85 + d826a9a commit 0665ac8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
31 changes: 22 additions & 9 deletions chainer/functions/array/get_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import chainer
from chainer import cuda
from chainer import function
from chainer import function_node
from chainer import utils
from chainer.utils import type_check
from chainer import variable


class GetItem(function.Function):
class GetItem(function_node.FunctionNode):

"""Function that slices array and extract elements."""

Expand Down Expand Up @@ -37,22 +37,35 @@ def check_type_forward(self, in_types):
type_check.expect(in_types[0].ndim >= valid_slice)

def forward(self, xs):
self.retain_inputs(())
ary = xs[0]
self._in_shape = ary.shape
self._in_dtype = ary.dtype
return utils.force_array(ary[self.slices]),

def backward(self, xs, gys):
xp = cuda.get_array_module(*gys)
gy = gys[0]
def backward(self, indexes, gy):
return GetItemGrad(
self.slices, self._in_shape, self._in_dtype).apply(gy)


class GetItemGrad(function_node.FunctionNode):

def __init__(self, slices, in_shape, in_dtype):
self.slices = slices
self._in_shape = in_shape
self._in_dtype = in_dtype

def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
gx = xp.zeros(self._in_shape, self._in_dtype)
if xp is numpy:
numpy.add.at(gx, self.slices, gy)
numpy.add.at(gx, self.slices, inputs[0])
else:
gx.scatter_add(self.slices, gy)
gx.scatter_add(self.slices, inputs[0])
return gx,

def backward(self, indexes, ggx):
return GetItem(self.slices).apply(ggx)


def get_item(x, slices):
"""Extract elements from array with specified shape, axes and offsets.
Expand Down Expand Up @@ -86,7 +99,7 @@ def get_item(x, slices):
<http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html>`_.
"""
return GetItem(slices)(x)
return GetItem(slices).apply((x,))[0]


def install_variable_get_item():
Expand Down
34 changes: 31 additions & 3 deletions tests/chainer_tests/functions_tests/array_tests/test_get_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def setUp(self):
self.x_data = numpy.random.uniform(-1, 1, (4, 3, 2))
self.shape = (4, 2, 1)
self.gy_data = numpy.random.uniform(-1, 1, self.sliced_shape)
self.ggx_data = numpy.random.uniform(-1, 1, (4, 3, 2))

if not hasattr(self, 'slices'):
# Convert axes, offsets and shape to slices
Expand Down Expand Up @@ -64,8 +65,11 @@ def test_forward_gpu(self):
self.check_forward(cuda.to_gpu(self.x_data))

def check_backward(self, x_data, y_grad):
def f(x):
return functions.get_item(x, self.slices)

gradient_check.check_backward(
functions.GetItem(self.slices), (x_data,), y_grad, dtype='d')
f, (x_data,), y_grad, dtype='d')

def test_backward_cpu(self):
self.check_backward(self.x_data, self.gy_data)
Expand All @@ -75,6 +79,23 @@ def test_backward_gpu(self):
self.check_backward(cuda.to_gpu(self.x_data),
cuda.to_gpu(self.gy_data))

def check_double_backward(self, x_data, y_grad, ggx_data):
def f(x):
y = functions.get_item(x, self.slices)
return y * y

gradient_check.check_double_backward(
f, (x_data,), y_grad, ggx_data, dtype='d')

def test_double_backward_cpu(self):
self.check_double_backward(self.x_data, self.gy_data, self.ggx_data)

@attr.gpu
def test_double_backward_gpu(self):
self.check_double_backward(cuda.to_gpu(self.x_data),
cuda.to_gpu(self.gy_data),
cuda.to_gpu(self.ggx_data))


@testing.parameterize(*testing.product_dict(
[
Expand Down Expand Up @@ -136,8 +157,11 @@ def test_forward_gpu(self):
self.check_forward(cuda.to_gpu(self.x_data))

def check_backward(self, x_data, y_grad):
def f(x):
return functions.get_item(x, self.slices)

gradient_check.check_backward(
functions.GetItem(self.slices), (x_data,), y_grad, dtype='d')
f, (x_data,), y_grad, dtype='d')

def test_backward_cpu(self):
self.check_backward(self.x_data, self.gy_data)
Expand Down Expand Up @@ -192,8 +216,12 @@ def check_backward(self, x_data, y_grad):
s = chainer.cuda.cupy.array(s, dtype=numpy.int32)
slices.append(s)
slices = tuple(slices)

def f(x):
return functions.get_item(x, slices)

gradient_check.check_backward(
functions.GetItem(slices), (x_data,), y_grad, dtype='d')
f, (x_data,), y_grad, dtype='d')

@attr.gpu
def test_backward_gpu(self):
Expand Down

0 comments on commit 0665ac8

Please sign in to comment.