Skip to content

Commit

Permalink
Merge pull request #2678 from chainer/improve-docs-of-flatten
Browse files Browse the repository at this point in the history
Improve docs of flatten, reshape
  • Loading branch information
delta2323 committed Jun 19, 2017
2 parents 30fd10c + 7bc25fd commit c2e948d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
33 changes: 30 additions & 3 deletions chainer/functions/array/flatten.py
Expand Up @@ -15,12 +15,39 @@ def backward(self, inputs, grads):


def flatten(x):
"""Flatten a given array.
"""Flatten a given array into one dimension.
Args:
x (~chainer.Varaiable): Input variable.
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable.
Returns:
~chainer.Variable: Output variable.
~chainer.Variable: Output variable flatten to one dimension.
.. note::
When you input a scalar array (i.e. the shape is ``()``),
you can also get the one dimension array whose shape is ``(1,)``.
.. admonition:: Example
>>> x = np.array([[1, 2], [3, 4]])
>>> x.shape
(2, 2)
>>> y = F.flatten(x)
>>> y.shape
(4,)
>>> y.data
array([1, 2, 3, 4])
>>> x = np.arange(8).reshape(2, 2, 2)
>>> x.shape
(2, 2, 2)
>>> y = F.flatten(x)
>>> y.shape
(8,)
>>> y.data
array([0, 1, 2, 3, 4, 5, 6, 7])
"""
return Flatten()(x)
40 changes: 36 additions & 4 deletions chainer/functions/array/reshape.py
Expand Up @@ -53,12 +53,44 @@ def reshape(x, shape):
"""Reshapes an input variable without copy.
Args:
x (~chainer.Variable): Input variable.
shape (tuple of ints): Target shape.
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable.
shape (:class:`tuple` of :class:`int` s):
Expected shape of the output array. The number of elements which
the array of ``shape`` contains must be equal to that of input
array. One shape dimension can be -1. In this case, the value is
inferred from the length of the array and remaining dimensions.
Returns:
~chainer.Variable: Variable that holds a reshaped version of the input
variable.
~chainer.Variable:
Variable that holds a reshaped version of the input variable.
.. seealso:: :func:`numpy.reshape`, :func:`cupy.reshape`
.. admonition:: Example
>>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> y = F.reshape(x, (8,))
>>> y.shape
(8,)
>>> y.data
array([1, 2, 3, 4, 5, 6, 7, 8])
>>> y = F.reshape(x, (4, -1)) # the shape of output is inferred
>>> y.shape
(4, 2)
>>> y.data
array([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
>>> y = F.reshape(x, (4, 3)) \
# the shape of input and output are not consistent
Traceback (most recent call last):
...
chainer.utils.type_check.InvalidType:
Invalid operation is performed in: Reshape (Forward)
Expect: prod(in_types[0].shape) == prod((4, 3))
Actual: 8 != 12
"""
return Reshape(shape)(x)

0 comments on commit c2e948d

Please sign in to comment.