Skip to content

Commit

Permalink
support passing int as shape to broadcast_to
Browse files Browse the repository at this point in the history
  • Loading branch information
kmaehashi committed Dec 30, 2022
1 parent 309d002 commit 3efff6b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions cupy/_core/_routines_manipulation.pyx
Expand Up @@ -476,6 +476,7 @@ cpdef _ndarray_base broadcast_to(_ndarray_base array, shape):
:meth:`numpy.broadcast_to`
"""
shape = tuple(shape) if numpy.iterable(shape) else (shape,)
cdef int i, j, ndim = array._shape.size(), length = len(shape)
cdef Py_ssize_t sh, a_sh
if ndim > length:
Expand Down
6 changes: 6 additions & 0 deletions tests/cupy_tests/manipulation_tests/test_dims.py
Expand Up @@ -55,6 +55,12 @@ def test_broadcast_to(self, xp, dtype):
b = xp.broadcast_to(a, (2, 3, 3, 4))
return b

@testing.numpy_cupy_array_equal()
def test_broadcast_to_int(self, xp):
a = testing.shaped_arange((10,), xp, xp.float32)
b = xp.broadcast_to(a, 10)
return b

@testing.for_all_dtypes()
def test_broadcast_to_fail(self, dtype):
for xp in (numpy, cupy):
Expand Down

0 comments on commit 3efff6b

Please sign in to comment.