Skip to content

Commit

Permalink
Add CuPy support
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 9, 2022
1 parent 71f119f commit 8f14bbf
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions numpy_groupies/utils_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _ravel_group_idx(group_idx, a, axis, size, order, method="ravel"):
size = []
for ii, s in enumerate(a.shape):
if method == "ravel":
ii_idx = group_idx_in if ii == axis else np.arange(s)
ii_idx = group_idx_in if ii == axis else np.arange(s, like=group_idx_in)
ii_shape = [1] * ndim_a
ii_shape[ii] = s
group_idx.append(ii_idx.reshape(ii_shape))
Expand Down Expand Up @@ -249,10 +249,8 @@ def offset_labels(group_idx, inshape, axis, order, size):
group_idx = np.moveaxis(group_idx, axis, -1)
newshape = group_idx.shape[:-1] + (-1,)

group_idx = (group_idx +
np.arange(np.prod(newshape[:-1]), dtype=int).reshape(newshape)
* size
)
offset_ = np.arange(np.prod(newshape[:-1]), dtype=int, like=group_idx).reshape(newshape)
group_idx = group_idx + offset_ * size
if axis not in (-1, len(inshape) - 1):
return np.moveaxis(group_idx, -1, axis)
else:
Expand Down

0 comments on commit 8f14bbf

Please sign in to comment.