Skip to content

Commit

Permalink
Merge 37d8ad5 into 5849412
Browse files Browse the repository at this point in the history
  • Loading branch information
kmaehashi committed Nov 9, 2017
2 parents 5849412 + 37d8ad5 commit de26874
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 22 deletions.
1 change: 1 addition & 0 deletions chainer/functions/__init__.py
Expand Up @@ -64,6 +64,7 @@
from chainer.functions.array.pad_sequence import PadSequence # NOQA
from chainer.functions.array.permutate import permutate # NOQA
from chainer.functions.array.permutate import Permutate # NOQA
from chainer.functions.array.repeat import repeat # NOQA
from chainer.functions.array.reshape import reshape # NOQA
from chainer.functions.array.reshape import Reshape # NOQA
from chainer.functions.array.resize_images import resize_images # NOQA
Expand Down
127 changes: 127 additions & 0 deletions chainer/functions/array/repeat.py
@@ -0,0 +1,127 @@
import six

from chainer import cuda
from chainer import function_node
from chainer.utils import type_check


class Repeat(function_node.FunctionNode):

"""Repeat elements of an array."""

def __init__(self, repeats, axis=None):
if isinstance(repeats, six.integer_types):
self.repeats = (repeats,)
elif isinstance(repeats, tuple) and all(
isinstance(x, six.integer_types) for x in repeats):
self.repeats = repeats
else:
raise TypeError('repeats must be int or tuple of ints')

if not all(x >= 0 for x in self.repeats):
raise ValueError('all elements in repeats must be zero or larger')

self.axis = axis

def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 1)

def forward(self, inputs):
self.retain_inputs((0,))
x, = inputs
xp = cuda.get_array_module(x)
repeats = self.repeats
if self.axis is None or len(self.repeats) == 1:
repeats = self.repeats[0]
return xp.repeat(x, repeats, self.axis),

def backward(self, indexes, grad_outputs):
x, = self.get_retained_inputs()
return RepeatGrad(self.repeats, self.axis, x.shape, x.dtype).apply(
grad_outputs)


class RepeatGrad(function_node.FunctionNode):

def __init__(self, repeats, axis, in_shape, in_dtype):
self.repeats = repeats
self.axis = axis
self.in_shape = in_shape
self.in_dtype = in_dtype

def forward(self, inputs):
gy, = inputs
xp = cuda.get_array_module(gy)
repeats = self.repeats
axis = self.axis

if len(gy) == 0:
gx = xp.zeros(self.in_shape, self.in_dtype)
return gx,
elif axis is None:
gx = gy.reshape(-1, repeats[0]).sum(axis=1).reshape(self.in_shape)
return gx,
elif len(repeats) == 1:
shape = list(self.in_shape)
shape[axis:axis + 1] = [-1, repeats[0]]
gx = gy.reshape(shape).sum(axis=axis + 1)
return gx,

gx = xp.zeros(self.in_shape, self.in_dtype)
slices = [slice(None) for _ in six.moves.range(self.axis)]
pos = 0
for (i, r) in enumerate(repeats):
src = slices + [slice(pos, pos + r)]
dst = slices + [slice(i, i + 1)]
gx[dst] = gy[src].sum(axis=self.axis, keepdims=True)
pos += r
return gx,

def backward(self, indexes, grad_outputs):
return Repeat(self.repeats, self.axis).apply(grad_outputs)


def repeat(x, repeats, axis=None):
"""Construct an array by repeating a given array.
Args:
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Input variable.
repeats (:class:`int` or :class:`tuple` of :class:`int` s):
The number of times which each element of ``x`` is repeated.
axis (:class:`int`):
The axis along which to repeat values.
Returns:
~chainer.Variable: The repeated output Variable.
.. admonition:: Example
>>> x = np.array([0, 1, 2])
>>> x.shape
(3,)
>>> y = F.repeat(x, 2)
>>> y.shape
(6,)
>>> y.data
array([0, 0, 1, 1, 2, 2])
>>> x = np.array([[1,2], [3,4]])
>>> x.shape
(2, 2)
>>> y = F.repeat(x, 3, axis=1)
>>> y.shape
(2, 6)
>>> y.data
array([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
>>> y = F.repeat(x, (1, 2), axis=0)
>>> y.shape
(3, 2)
>>> y.data
array([[1, 2],
[3, 4],
[3, 4]])
"""
return Repeat(repeats, axis).apply((x,))[0]
55 changes: 35 additions & 20 deletions chainer/functions/loss/contrastive.py
@@ -1,11 +1,12 @@
import numpy

from chainer import cuda
from chainer import function
from chainer import function_node
import chainer.functions
from chainer.utils import type_check


class Contrastive(function.Function):
class Contrastive(function_node.FunctionNode):

"""Contrastive loss function."""

Expand Down Expand Up @@ -38,37 +39,51 @@ def check_type_forward(self, in_types):

def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
self.retain_inputs((0, 1, 2))
x0, x1, y = inputs

self.diff = x0 - x1
self.dist_sq = xp.sum(self.diff ** 2, axis=1)
self.dist = xp.sqrt(self.dist_sq)
self.mdist = self.margin - self.dist
dist = xp.maximum(self.mdist, 0)
loss = (y * self.dist_sq + (1 - y) * dist * dist) * .5
diff = x0 - x1
dist_sq = xp.sum(diff ** 2, axis=1)
dist = xp.sqrt(dist_sq)
mdist = self.margin - dist
dist = xp.maximum(mdist, 0)
loss = (y * dist_sq + (1 - y) * dist * dist) * .5
if self.reduce == 'mean':
loss = xp.sum(loss) / x0.shape[0]
return xp.array(loss, dtype=xp.float32),

def backward(self, inputs, gy):
xp = cuda.get_array_module(*inputs)
x0, x1, y = inputs
def backward(self, indexes, grad_outputs):
x0, x1, y = self.get_retained_inputs()
gy, = grad_outputs
xp = cuda.get_array_module(gy.data)

# Recompute intermediate variables as in forward.
diff = x0 - x1
dist_sq = chainer.functions.sum(diff ** 2, axis=1)
dist = chainer.functions.sqrt(dist_sq)
mdist = self.margin - dist

y = y.data
x_dim = x0.shape[1]
y = xp.repeat(y[:, None], x_dim, axis=1)
if self.reduce == 'mean':
alpha = gy[0] / y.shape[0]
alpha = gy / y.shape[0]
else:
alpha = gy[0][:, None]
dist = xp.repeat(self.dist[:, None], x_dim, axis=1)
alpha = gy[:, None]
alpha = chainer.functions.broadcast_to(alpha, y.shape)
dist = chainer.functions.repeat(dist[:, None], x_dim, axis=1)
# avoid division by zero
dist = xp.maximum(dist, 1e-8)
dist = chainer.functions.maximum(
dist,
xp.broadcast_to(xp.array(1e-8, dtype=dist.dtype), dist.shape))
# similar pair
gx0 = alpha * y * self.diff
gx0 = alpha * y.astype(alpha.dtype) * diff
# dissimilar pair
mdist = xp.maximum(xp.repeat(self.mdist[:, None], x_dim, axis=1), 0)
gx0 += alpha * (1 - y) * mdist * -(self.diff / dist)
gx0 = gx0.astype(xp.float32)
d = chainer.functions.repeat(mdist[:, None], x_dim, axis=1)
mdist = chainer.functions.maximum(
d, xp.zeros(shape=d.shape, dtype=d.dtype))
gx0 += alpha * (1 - y) * mdist * -(diff / dist)
gx0 = chainer.functions.cast(gx0, xp.float32)

return gx0, -gx0, None

Expand Down Expand Up @@ -144,4 +159,4 @@ def contrastive(x0, x1, y, margin=1, reduce='mean'):
array([ 0.625, 0. ], dtype=float32)
"""
return Contrastive(margin, reduce)(x0, x1, y)
return Contrastive(margin, reduce).apply((x0, x1, y))[0]
1 change: 1 addition & 0 deletions docs/source/reference/functions.rst
Expand Up @@ -74,6 +74,7 @@ Array manipulations
chainer.functions.pad
chainer.functions.pad_sequence
chainer.functions.permutate
chainer.functions.repeat
chainer.functions.reshape
chainer.functions.resize_images
chainer.functions.rollaxis
Expand Down
113 changes: 113 additions & 0 deletions tests/chainer_tests/functions_tests/array_tests/test_repeat.py
@@ -0,0 +1,113 @@
import unittest

import numpy

from chainer import cuda
from chainer import functions
from chainer import gradient_check
from chainer import testing
from chainer.testing import attr


@testing.parameterize(*testing.product({
'shape_repeats_axis': [
(2, 0, None),
(2, 1, None),
(2, 2, None),
(2, 2, 0),
((3, 2), (2,), 0),
((3, 2), 2, 0),
((3, 2), 2, 1),
((3, 2), (3, 4, 3), 0),
((3, 2), (3, 2), 1),
((3, 2, 3), (3, 2, 1), 0),
((3, 2, 3), (3, 4), 1),
((3, 2, 3), (3, 2, 1), 2),
((3, 4, 3, 2), 3, 1),
((3, 4, 3, 2), (2, 2, 3, 3), 1),
],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
}))
class TestRepeat(unittest.TestCase):

def setUp(self):
(self.in_shape, self.repeats, self.axis) = self.shape_repeats_axis
self.x = numpy.random.uniform(-1, 1, self.in_shape).astype(self.dtype)
out_shape = numpy.repeat(self.x, self.repeats, self.axis).shape
self.gy = numpy.random.uniform(-1, 1, out_shape).astype(self.dtype)
self.ggx = numpy.random.uniform(-1, 1, self.in_shape) \
.astype(self.dtype)

self.check_forward_options = {}
self.check_backward_options = {'dtype': numpy.float64}
if self.dtype == numpy.float16:
self.check_forward_options = {'atol': 5e-4, 'rtol': 5e-3}
self.check_backward_options = {
'dtype': numpy.float64, 'atol': 2 ** -4, 'rtol': 2 ** -4}

def check_forward(self, x_data):
y = functions.repeat(x_data, self.repeats, self.axis)
y_expected = numpy.repeat(self.x, self.repeats, self.axis)
self.assertEqual(y.dtype, y_expected.dtype)
testing.assert_allclose(
y.data, y_expected, **self.check_forward_options)

def test_forward_cpu(self):
self.check_forward(self.x)

@attr.gpu
def test_forward_gpu(self):
self.check_forward(cuda.to_gpu(self.x))

def check_backward(self, x_data, y_grad):
def f(x):
return functions.repeat(x, self.repeats, self.axis)

gradient_check.check_backward(
f, x_data, y_grad, **self.check_backward_options)

def test_backward_cpu(self):
self.check_backward(self.x, self.gy)

@attr.gpu
def test_backward_gpu(self):
self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy))

def check_double_backward(self, x_data, y_grad, x_grad_grad):
def f(x):
y = functions.repeat(x, self.repeats, self.axis)
return y * y

gradient_check.check_double_backward(
f, x_data, y_grad, x_grad_grad, **self.check_backward_options)

def test_double_backward_cpu(self):
self.check_double_backward(self.x, self.gy, self.ggx)

@attr.gpu
def test_double_backward_gpu(self):
self.check_double_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy),
cuda.to_gpu(self.ggx))


@testing.parameterize(*testing.product({
'repeats': [-1, (-1, -1)],
'axis': [-1],
}))
class TestRepeatValueError(unittest.TestCase):

def test_value_error(self):
x = numpy.random.uniform(-1, 1, (2,)).astype(numpy.float32)
with self.assertRaises(ValueError):
functions.repeat(x, self.repeats, self.axis)


class TestRepeatTypeError(unittest.TestCase):

def test_type_error(self):
x = numpy.random.uniform(-1, 1, (2,)).astype(numpy.float32)
with self.assertRaises(TypeError):
functions.repeat(x, 'a')


testing.run_module(__name__, __file__)
30 changes: 28 additions & 2 deletions tests/chainer_tests/functions_tests/loss_tests/test_contrastive.py
Expand Up @@ -31,6 +31,8 @@ def setUp(self):
else:
self.gy = numpy.random.uniform(
-1, 1, (self.batchsize,)).astype(numpy.float32)
self.gx0 = numpy.random.uniform(-1, 1, x_shape).astype(numpy.float32)
self.gx1 = numpy.random.uniform(-1, 1, x_shape).astype(numpy.float32)

self.check_backward_options = {'rtol': 1e-2, 'atol': 1e-3}

Expand Down Expand Up @@ -77,9 +79,11 @@ def test_forward_gpu_no_cudnn(self):
cuda.to_gpu(self.t))

def check_backward(self, x0_data, x1_data, t_data, gy_data):
def f(x0, x1, t):
return functions.contrastive(x0, x1, t, self.margin, self.reduce)

gradient_check.check_backward(
functions.Contrastive(self.margin, self.reduce),
(x0_data, x1_data, t_data), gy_data, dtype='d',
f, (x0_data, x1_data, t_data), gy_data, dtype='d',
**self.check_backward_options)

def test_backward_cpu(self):
Expand All @@ -98,6 +102,28 @@ def test_backward_zero_dist_gpu_no_cudnn(self):
self.check_backward(cuda.to_gpu(self.x0), cuda.to_gpu(self.x0),
cuda.to_gpu(self.t), cuda.to_gpu(self.gy))

def check_double_backward(
self, x0_data, x1_data, t_data, gy_data, gx0_data, gx1_data):
def f(x0, x1):
y = functions.contrastive(x0, x1, t_data, self.margin, self.reduce)
return y * y

gradient_check.check_double_backward(
f, (x0_data, x1_data), gy_data,
(gx0_data, gx1_data),
dtype='f', rtol=1e-2, atol=1e-3)

def test_double_backward_cpu(self):
self.check_double_backward(
self.x0, self.x1, self.t, self.gy, self.gx0, self.gx1)

@attr.gpu
def test_double_backward_gpu(self):
self.check_double_backward(
cuda.to_gpu(self.x0), cuda.to_gpu(self.x1),
cuda.to_gpu(self.t), cuda.to_gpu(self.gy),
cuda.to_gpu(self.gx0), cuda.to_gpu(self.gx1))


class TestContrastiveInvalidReductionOption(unittest.TestCase):

Expand Down

0 comments on commit de26874

Please sign in to comment.