Skip to content

Commit

Permalink
Ensure integer arrays used for indexing are NumPy arrays (#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jun 20, 2024
1 parent a01a576 commit e9de1ae
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from toolz import accumulate, map

from cubed import config
from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
Expand Down Expand Up @@ -408,8 +409,11 @@ def index(x, key):
key = (key,)

# Replace Cubed arrays with NumPy arrays - note that this may trigger a computation!
# Note that NumPy arrays are needed for ndindex.
key = tuple(
dim_sel.compute() if isinstance(dim_sel, CoreArray) else dim_sel
backend_array_to_numpy_array(dim_sel.compute())
if isinstance(dim_sel, CoreArray)
else dim_sel
for dim_sel in key
)

Expand Down

0 comments on commit e9de1ae

Please sign in to comment.