Skip to content

Commit

Permalink
New style concat and split_axis
Browse files Browse the repository at this point in the history
  • Loading branch information
okuta committed Aug 20, 2017
1 parent 2de0c34 commit a2cadc1
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 38 deletions.
21 changes: 11 additions & 10 deletions chainer/functions/array/concat.py
Expand Up @@ -2,11 +2,11 @@
import six

from chainer import cuda
from chainer import function
from chainer import function_node
from chainer.utils import type_check


class Concat(function.Function):
class Concat(function_node.FunctionNode):

"""Concatenate multiple tensors towards specified axis."""

Expand Down Expand Up @@ -39,19 +39,20 @@ def check_type_forward(self, in_types):
type_check.expect(in_types[0].shape[d] == in_types[i].shape[d])

def forward(self, xs):
self.retain_inputs(())
self._xp = cuda.get_array_module(*xs)
xp = cuda.get_array_module(*xs)
self._x_shapes = [x.shape for x in xs]
return self._xp.concatenate(xs, axis=self.axis),
return xp.concatenate(xs, self.axis),

def backward(self, xs, gy):
if len(xs) == 1:
return gy
def backward(self, indexes, grad_outputs):
if len(self._x_shapes) == 1:
return grad_outputs

sizes = numpy.array(
[shape[self.axis] for shape in self._x_shapes[:-1]]
).cumsum()
return self._xp.split(gy[0], sizes, axis=self.axis)
# to avoid import error
from chainer.functions.array import split_axis
return split_axis.SplitAxis(sizes, self.axis).apply(grad_outputs)


def concat(xs, axis=1):
Expand Down Expand Up @@ -87,4 +88,4 @@ def concat(xs, axis=1):
[ 8, 9, 10, 11, 2]])
"""
return Concat(axis=axis)(*xs)
return Concat(axis).apply(xs)[0]
48 changes: 22 additions & 26 deletions chainer/functions/array/split_axis.py
Expand Up @@ -4,12 +4,12 @@

import chainer
from chainer import cuda
from chainer import function
from chainer import function_node
from chainer.functions.array import concat
from chainer.utils import type_check
from chainer import variable


class SplitAxis(function.Function):
class SplitAxis(function_node.FunctionNode):

"""Function that splits multiple arrays along the specified axis."""

Expand Down Expand Up @@ -41,28 +41,24 @@ def check_type_forward(self, in_types):
self.indices_or_sections, 'sections')
type_check.expect(in_types[0].shape[self.axis] % sections == 0)

def forward(self, x):
self.retain_inputs(())
def forward(self, inputs):
x, = inputs
if isinstance(self.indices_or_sections, collections.Iterable):
cdimx = x[0].shape[self.axis]
cdimx = x.shape[self.axis]
ind = list(self.indices_or_sections)
ind.append(cdimx)
self._xp = cuda.get_array_module(*x)
self._x_shape = x[0].shape
self._x_dtype = x[0].dtype
return tuple(self._xp.split(x[0], self.indices_or_sections, self.axis))

def backward(self, x, gys):
if any(gy is None for gy in gys):
gx = self._xp.zeros(self._x_shape, dtype=self._x_dtype)
gxs = self._xp.split(gx, self.indices_or_sections, self.axis)
for gxi, gy in six.moves.zip(gxs, gys):
if gy is None:
continue
gxi[:] = gy
return gx,
else:
return self._xp.concatenate(gys, axis=self.axis),
self._xp = cuda.get_array_module(x)
self._x_shape = x.shape
self._x_dtype = x.dtype
ret = tuple(self._xp.split(x, self.indices_or_sections, self.axis))
self._shapes = [r.shape for r in ret]
return ret

def backward(self, indexes, grad_outputs):
grads = [
self._xp.zeros(shape, dtype=self._x_dtype) if gy is None else gy
for gy, shape in six.moves.zip(grad_outputs, self._shapes)]
return concat.Concat(self.axis).apply(grads)


def split_axis(x, indices_or_sections, axis, force_tuple=True):
Expand Down Expand Up @@ -93,7 +89,7 @@ def split_axis(x, indices_or_sections, axis, force_tuple=True):
(i.e. ``axis``-th value of its shape is zero).
"""
res = SplitAxis(indices_or_sections, axis)(x)
if force_tuple and isinstance(res, variable.Variable):
res = (res,)
return res
res = SplitAxis(indices_or_sections, axis).apply((x,))
if force_tuple or len(res) != 1:
return res
return res[0]
4 changes: 2 additions & 2 deletions chainer/functions/array/stack.py
Expand Up @@ -87,5 +87,5 @@ def stack(xs, axis=0):
(3, 4, 2)
"""
xs = [expand_dims.expand_dims(x, axis=axis) for x in xs]
return concat.concat(xs, axis=axis)
xs = [expand_dims.expand_dims(x, axis) for x in xs]
return concat.concat(xs, axis)

0 comments on commit a2cadc1

Please sign in to comment.