Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix index_tricks issue #445

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions numpy/core/src/multiarray/nditer_api.c
Expand Up @@ -844,6 +844,15 @@ NpyIter_IsGrowInner(NpyIter *iter)
return (NIT_ITFLAGS(iter)&NPY_ITFLAG_GROWINNER) != 0;
}

/*
* Whether the iterator output is scalar
*/
NPY_NO_EXPORT npy_bool
NpyIter_IsScalar(NpyIter *iter)
{
return (NIT_ITFLAGS(iter)&NPY_ITFLAG_SCALAR) != 0;
}

/*NUMPY_API
* Gets the size of the buffer, or 0 if buffering is not enabled
*/
Expand Down
32 changes: 13 additions & 19 deletions numpy/core/src/multiarray/nditer_constr.c
Expand Up @@ -54,8 +54,7 @@ static int
npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itflags,
char **op_dataptr,
npy_uint32 *op_flags, int **op_axes,
npy_intp *itershape,
int output_scalars);
npy_intp *itershape);
static void
npyiter_replace_axisdata(NpyIter *iter, int iop,
PyArrayObject *op,
Expand All @@ -74,8 +73,7 @@ npyiter_find_best_axis_ordering(NpyIter *iter);
static PyArray_Descr *
npyiter_get_common_dtype(int nop, PyArrayObject **op,
npyiter_opitflags *op_itflags, PyArray_Descr **op_dtype,
PyArray_Descr **op_request_dtypes,
int only_inputs, int output_scalars);
PyArray_Descr **op_request_dtypes, int only_inputs);
static PyArrayObject *
npyiter_new_temp_array(NpyIter *iter, PyTypeObject *subtype,
npy_uint32 flags, npyiter_opitflags *op_itflags,
Expand All @@ -86,7 +84,7 @@ npyiter_allocate_arrays(NpyIter *iter,
npy_uint32 flags,
PyArray_Descr **op_dtype, PyTypeObject *subtype,
npy_uint32 *op_flags, npyiter_opitflags *op_itflags,
int **op_axes, int output_scalars);
int **op_axes);
static void
npyiter_get_priority_subtype(int nop, PyArrayObject **op,
npyiter_opitflags *op_itflags,
Expand Down Expand Up @@ -123,7 +121,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
npy_int8 *perm;
NpyIter_BufferData *bufferdata = NULL;
int any_allocate = 0, any_missing_dtypes = 0,
output_scalars = 0, need_subtype = 0;
need_subtype = 0;

/* The subtype for automatically allocated outputs */
double subtype_priority = NPY_PRIORITY;
Expand Down Expand Up @@ -177,7 +175,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,

/* If 'ndim' is zero, any outputs should be scalars */
if (ndim == 0) {
output_scalars = 1;
itflags |= NPY_ITFLAG_SCALAR;
ndim = 1;
}

Expand Down Expand Up @@ -231,8 +229,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,

/* Fill in the AXISDATA arrays and set the ITERSIZE field */
if (!npyiter_fill_axisdata(iter, flags, op_itflags, op_dataptr,
op_flags, op_axes, itershape,
output_scalars)) {
op_flags, op_axes, itershape)) {
NpyIter_Deallocate(iter);
return NULL;
}
Expand Down Expand Up @@ -338,8 +335,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
dtype = npyiter_get_common_dtype(nop, op,
op_itflags, op_dtype,
op_request_dtypes,
only_inputs,
output_scalars);
only_inputs);
if (dtype == NULL) {
NpyIter_Deallocate(iter);
return NULL;
Expand Down Expand Up @@ -389,7 +385,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
* done now using a memory layout matching the iterator.
*/
if (!npyiter_allocate_arrays(iter, flags, op_dtype, subtype, op_flags,
op_itflags, op_axes, output_scalars)) {
op_itflags, op_axes)) {
NpyIter_Deallocate(iter);
return NULL;
}
Expand Down Expand Up @@ -1437,8 +1433,7 @@ static int
npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itflags,
char **op_dataptr,
npy_uint32 *op_flags, int **op_axes,
npy_intp *itershape,
int output_scalars)
npy_intp *itershape)
{
npy_uint32 itflags = NIT_ITFLAGS(iter);
int idim, ndim = NIT_NDIM(iter);
Expand Down Expand Up @@ -1558,7 +1553,7 @@ npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itf
ondim = PyArray_NDIM(op_cur);
if (bshape == 1) {
strides[iop] = 0;
if (idim >= ondim && !output_scalars &&
if (idim >= ondim && !(itflags & NPY_ITFLAG_SCALAR) &&
(op_flags[iop] & NPY_ITER_NO_BROADCAST)) {
goto operand_different_than_broadcast;
}
Expand Down Expand Up @@ -2393,8 +2388,7 @@ npyiter_find_best_axis_ordering(NpyIter *iter)
static PyArray_Descr *
npyiter_get_common_dtype(int nop, PyArrayObject **op,
npyiter_opitflags *op_itflags, PyArray_Descr **op_dtype,
PyArray_Descr **op_request_dtypes,
int only_inputs, int output_scalars)
PyArray_Descr **op_request_dtypes, int only_inputs)
{
int iop;
npy_intp narrs = 0, ndtypes = 0;
Expand Down Expand Up @@ -2693,7 +2687,7 @@ npyiter_allocate_arrays(NpyIter *iter,
npy_uint32 flags,
PyArray_Descr **op_dtype, PyTypeObject *subtype,
npy_uint32 *op_flags, npyiter_opitflags *op_itflags,
int **op_axes, int output_scalars)
int **op_axes)
{
npy_uint32 itflags = NIT_ITFLAGS(iter);
int idim, ndim = NIT_NDIM(iter);
Expand Down Expand Up @@ -2724,7 +2718,7 @@ npyiter_allocate_arrays(NpyIter *iter,
if (op[iop] == NULL) {
PyArrayObject *out;
PyTypeObject *op_subtype;
int ondim = output_scalars ? 0 : ndim;
int ondim = (itflags & NPY_ITFLAG_SCALAR) ? 0 : ndim;

/* Check whether the subtype was disabled */
op_subtype = (op_flags[iop] & NPY_ITER_NO_SUBTYPE) ?
Expand Down
2 changes: 2 additions & 0 deletions numpy/core/src/multiarray/nditer_impl.h
Expand Up @@ -101,6 +101,8 @@
#define NPY_ITFLAG_REDUCE 0x1000
/* Reduce iteration doesn't need to recalculate reduce loops next time */
#define NPY_ITFLAG_REUSE_REDUCE_LOOPS 0x2000
/* The iterator output is scalar */
#define NPY_ITFLAG_SCALAR 0x4000

/* Internal iterator per-operand iterator flags */

Expand Down
4 changes: 3 additions & 1 deletion numpy/core/src/multiarray/nditer_pywrap.c
Expand Up @@ -1542,6 +1542,9 @@ static PyObject *npyiter_multi_index_get(NewNpyArrayIterObject *self)
}

if (self->get_multi_index != NULL) {
if (NpyIter_IsScalar(self->iter)) {
return PyTuple_New(0);
}
ndim = NpyIter_GetNDim(self->iter);
self->get_multi_index(self->iter, multi_index);
ret = PyTuple_New(ndim);
Expand Down Expand Up @@ -1809,7 +1812,6 @@ static PyObject *npyiter_has_delayed_bufalloc_get(NewNpyArrayIterObject *self)
"Iterator is invalid");
return NULL;
}

if (NpyIter_HasDelayedBufAlloc(self->iter)) {
Py_RETURN_TRUE;
}
Expand Down
3 changes: 3 additions & 0 deletions numpy/lib/index_tricks.py
Expand Up @@ -533,6 +533,9 @@ class ndindex(object):

"""
def __init__(self, *shape):
# Accept shapes in the form f(x, y, ..) as well as f((x ,y, ..))
if len(shape) == 1 and isinstance(shape[0], tuple):
shape = shape[0]
x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape))
self._it = _nx.nditer(x, flags=['multi_index'], order='C')

Expand Down
6 changes: 6 additions & 0 deletions numpy/lib/tests/test_index_tricks.py
Expand Up @@ -241,6 +241,12 @@ def test_ndindex():
x = list(np.ndindex(1, 2, 3))
expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))]
assert_array_equal(x, expected)
# Packed as well as unpacked tuple are acceptable
y = list(np.ndindex((1, 2, 3)))
assert_array_equal(x, y)
# Empty shape gives empty index
z = list(np.ndindex(()))
assert_equal(z, [()])


if __name__ == "__main__":
Expand Down