Skip to content

Commit

Permalink
BUG, ENH: np._from_dlpack: export correct device information
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel authored and charris committed Mar 2, 2022
1 parent 593cc07 commit 8f9951e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
6 changes: 6 additions & 0 deletions numpy/core/src/multiarray/dlpack.c
Expand Up @@ -88,6 +88,12 @@ array_get_dl_device(PyArrayObject *self) {
ret.device_type = kDLCPU;
ret.device_id = 0;
PyObject *base = PyArray_BASE(self);

// walk the bases (see gh-20340)
while (base != NULL && PyArray_Check(base)) {
base = PyArray_BASE((PyArrayObject *)base);
}

// The outer if is due to the fact that NumPy arrays are on the CPU
// by default (if not created from DLPack).
if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) {
Expand Down
5 changes: 4 additions & 1 deletion numpy/core/tests/test_dlpack.py
Expand Up @@ -91,7 +91,10 @@ def test_higher_dims(self, ndim):
def test_dlpack_device(self):
x = np.arange(5)
assert x.__dlpack_device__() == (1, 0)
assert np._from_dlpack(x).__dlpack_device__() == (1, 0)
y = np._from_dlpack(x)
assert y.__dlpack_device__() == (1, 0)
z = y[::2]
assert z.__dlpack_device__() == (1, 0)

def dlpack_deleter_exception(self):
x = np.arange(5)
Expand Down

0 comments on commit 8f9951e

Please sign in to comment.