Skip to content

Commit

Permalink
Merge pull request #142 from cupy/atleast-numpy-array
Browse files Browse the repository at this point in the history
Refactor atleast_nd
  • Loading branch information
okuta committed Jun 18, 2017
2 parents 97d9aa4 + c404ce4 commit 320fad3
Showing 1 changed file with 33 additions and 36 deletions.
69 changes: 33 additions & 36 deletions cupy/manipulation/dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,36 @@
six_zip = six.moves.zip


# Shape map for atleast_nd functions
# (minimum dimension, input dimension) -> (output shape)
_atleast_nd_shape_map = {
(1, 0): lambda shape: (1,),
(2, 0): lambda shape: (1, 1),
(2, 1): lambda shape: (1,) + shape,
(3, 0): lambda shape: (1, 1, 1),
(3, 1): lambda shape: (1,) + shape + (1,),
(3, 2): lambda shape: shape + (1,),
}


def _atleast_nd_helper(n, arys):
"""Helper function for atleast_nd functions."""

res = []
for a in arys:
if isinstance(a, cupy.ndarray):
if a.ndim < n:
new_shape = _atleast_nd_shape_map[(n, a.ndim)](a.shape)
a = a.reshape(*new_shape)
else:
raise TypeError('Unsupported type {}'.format(type(a)))
res.append(a)

if len(res) == 1:
res, = res
return res


def atleast_1d(*arys):
"""Converts arrays to arrays with dimensions >= 1.
Expand All @@ -23,16 +53,7 @@ def atleast_1d(*arys):
.. seealso:: :func:`numpy.atleast_1d`
"""
res = []
for a in arys:
if not isinstance(a, cupy.ndarray):
raise TypeError('Only cupy arrays can be atleast_1d')
if a.ndim == 0:
a = a.reshape(1)
res.append(a)
if len(res) == 1:
res = res[0]
return res
return _atleast_nd_helper(1, arys)


def atleast_2d(*arys):
Expand All @@ -52,18 +73,7 @@ def atleast_2d(*arys):
.. seealso:: :func:`numpy.atleast_2d`
"""
res = []
for a in arys:
if not isinstance(a, cupy.ndarray):
raise TypeError('Only cupy arrays can be atleast_2d')
if a.ndim == 0:
a = a.reshape(1, 1)
elif a.ndim == 1:
a = a[None, :]
res.append(a)
if len(res) == 1:
res = res[0]
return res
return _atleast_nd_helper(2, arys)


def atleast_3d(*arys):
Expand All @@ -89,20 +99,7 @@ def atleast_3d(*arys):
.. seealso:: :func:`numpy.atleast_3d`
"""
res = []
for a in arys:
if not isinstance(a, cupy.ndarray):
raise TypeError('Only cupy arrays can be atleast_3d')
if a.ndim == 0:
a = a.reshape(1, 1, 1)
elif a.ndim == 1:
a = a[None, :, None]
elif a.ndim == 2:
a = a[:, :, None]
res.append(a)
if len(res) == 1:
res = res[0]
return res
return _atleast_nd_helper(3, arys)


broadcast = core.broadcast
Expand Down

0 comments on commit 320fad3

Please sign in to comment.