Skip to content

Commit

Permalink
MNT: refactor so itemsizes are correct
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed May 16, 2024
1 parent f7a20cc commit 0506c50
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions numpy/_core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -3283,17 +3283,20 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
NPY_ITER_READONLY | NPY_ITER_ALIGNED,
NPY_ITER_READONLY | NPY_ITER_ALIGNED
};

common_dt = PyArray_ResultType(2, &op_in[2], 0, NULL);
if (common_dt == NULL) {
goto fail;
}
PyArray_Descr *x_dt, *y_dt;
int has_ref = PyDataType_REFCHK(common_dt);
if (has_ref) {
x_dt = PyArray_DESCR(op_in[2]);
y_dt = PyArray_DESCR(op_in[3]);
}
else {

// If x and y don't have references, we ask the iterator to create buffers
// using the common data type of x and y and then do fast trivial copies
// in the loop below.
// Otherwise trivial copies aren't possible and we handle the cast item by item
// in the loop.
PyArray_Descr *x_dt = PyArray_DESCR(op_in[2]), *y_dt = PyArray_DESCR(op_in[3]);
int has_ref = (PyDataType_REFCHK(x_dt) || PyDataType_REFCHK(y_dt));
if (!has_ref) {
x_dt = common_dt;
y_dt = common_dt;
}
Expand All @@ -3313,27 +3316,28 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)

/* Get the result from the iterator object array */
ret = (PyObject*)NpyIter_GetOperandArray(iter)[0];

npy_intp itemsize = common_dt->elsize;
PyArray_Descr *ret_dt = PyArray_DESCR((PyArrayObject *)ret);
npy_intp itemsize = ret_dt->elsize;

NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;

npy_intp transfer_strides[2] = {itemsize, itemsize};
npy_intp x_strides[2] = {x_dt->elsize, itemsize};
npy_intp y_strides[2] = {y_dt->elsize, itemsize};
npy_intp one = 1;

if (has_ref || ((itemsize != 16) && (itemsize != 8) && (itemsize != 4) &&
(itemsize != 2) && (itemsize != 1))) {
// The iterator has NPY_ITER_ALIGNED flag so no need to check alignment
// of the input arrays.
if (PyArray_GetDTypeTransferFunction(
1, itemsize, itemsize,
PyArray_DESCR(op_in[2]), common_dt, 0,
1, x_strides[0], x_strides[1],
PyArray_DESCR(op_in[2]), ret_dt, 0,
&x_cast_info, &transfer_flags) != NPY_SUCCEED) {
goto fail;
}
if (PyArray_GetDTypeTransferFunction(
1, itemsize, itemsize,
PyArray_DESCR(op_in[3]), common_dt, 0,
1, y_strides[0], y_strides[1],
PyArray_DESCR(op_in[3]), ret_dt, 0,
&y_cast_info, &transfer_flags) != NPY_SUCCEED) {
goto fail;
}
Expand Down Expand Up @@ -3389,7 +3393,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)

if (x_cast_info.func(
&x_cast_info.context, args, &one,
transfer_strides, x_cast_info.auxdata) < 0) {
x_strides, x_cast_info.auxdata) < 0) {
goto fail;
}
}
Expand All @@ -3398,7 +3402,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)

if (y_cast_info.func(
&y_cast_info.context, args, &one,
transfer_strides, y_cast_info.auxdata) < 0) {
y_strides, y_cast_info.auxdata) < 0) {
goto fail;
}
}
Expand Down

0 comments on commit 0506c50

Please sign in to comment.