Skip to content

Commit

Permalink
BUG: ticket #1149, make argmin work for datetime.
Browse files Browse the repository at this point in the history
Backport of 6fc0737 'ENH: Explicitly coded argmin for timedelta'
  • Loading branch information
WeatherGod authored and charris committed Apr 27, 2012
1 parent 74b9f5e commit ff72395
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 19 deletions.
7 changes: 7 additions & 0 deletions numpy/core/include/numpy/ndarraytypes.h
Expand Up @@ -532,6 +532,13 @@ typedef struct {
PyArray_FastClipFunc *fastclip;
PyArray_FastPutmaskFunc *fastputmask;
PyArray_FastTakeFunc *fasttake;

/*
* Function to select smallest
* Can be NULL
*/
PyArray_ArgFunc *argmin;

} PyArray_ArrFuncs;

/* The item must be reference counted when it is inserted or extracted. */
Expand Down
136 changes: 134 additions & 2 deletions numpy/core/src/multiarray/arraytypes.c.src
Expand Up @@ -2924,6 +2924,79 @@ static int

/**end repeat**/

/**begin repeat
*
* #fname = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
* LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA#
* #type = Bool, byte, ubyte, short, ushort, int, uint, long, ulong,
* longlong, ulonglong, npy_half, float, double, longdouble,
* float, double, longdouble, datetime, timedelta#
* #isfloat = 0*11, 1*7, 0*2#
* #isnan = nop*11, npy_half_isnan, npy_isnan*6, nop*2#
* #le = _LESS_THAN_OR_EQUAL*11, npy_half_le, _LESS_THAN_OR_EQUAL*8#
* #iscomplex = 0*15, 1*3, 0*2#
* #incr = ip++*15, ip+=2*3, ip++*2#
*/
static int
@fname@_argmin(@type@ *ip, intp n, intp *min_ind, PyArrayObject *NPY_UNUSED(aip))
{
intp i;
@type@ mp = *ip;
#if @iscomplex@
@type@ mp_im = ip[1];
#endif

*min_ind = 0;

#if @isfloat@
if (@isnan@(mp)) {
/* nan encountered; it's minimal */
return 0;
}
#endif
#if @iscomplex@
if (@isnan@(mp_im)) {
/* nan encountered; it's minimal */
return 0;
}
#endif

for (i = 1; i < n; i++) {
@incr@;
/*
* Propagate nans, similarly as max() and min()
*/
#if @iscomplex@
/* Lexical order for complex numbers */
if ((mp > ip[0]) || ((ip[0] == mp) && (mp_im > ip[1]))
|| @isnan@(ip[0]) || @isnan@(ip[1])) {
mp = ip[0];
mp_im = ip[1];
*min_ind = i;
if (@isnan@(mp) || @isnan@(mp_im)) {
/* nan encountered, it's minimal */
break;
}
}
#else
if (!@le@(mp, *ip)) { /* negated, for correct nan handling */
mp = *ip;
*min_ind = i;
#if @isfloat@
if (@isnan@(mp)) {
/* nan encountered, it's minimal */
break;
}
#endif
}
#endif
}
return 0;
}

/**end repeat**/

#undef _LESS_THAN_OR_EQUAL

static int
Expand Down Expand Up @@ -2982,6 +3055,63 @@ static int

#define VOID_argmax NULL

static int
OBJECT_argmin(PyObject **ip, intp n, intp *min_ind, PyArrayObject *NPY_UNUSED(aip))
{
intp i;
PyObject *mp = ip[0];

*min_ind = 0;
i = 1;
while (i < n && mp == NULL) {
mp = ip[i];
i++;
}
for (; i < n; i++) {
ip++;
#if defined(NPY_PY3K)
if (*ip != NULL && PyObject_RichCompareBool(mp, *ip, Py_GT) == 1) {
#else
if (*ip != NULL && PyObject_Compare(mp, *ip) > 0) {
#endif
mp = *ip;
*min_ind = i;
}
}
return 0;
}

/**begin repeat
*
* #fname = STRING, UNICODE#
* #type = char, PyArray_UCS4#
*/
static int
@fname@_argmin(@type@ *ip, intp n, intp *min_ind, PyArrayObject *aip)
{
intp i;
int elsize = PyArray_DESCR(aip)->elsize;
@type@ *mp = (@type@ *)PyArray_malloc(elsize);

if (mp==NULL) return 0;
memcpy(mp, ip, elsize);
*min_ind = 0;
for(i=1; i<n; i++) {
ip += elsize;
if (@fname@_compare(mp,ip,aip) > 0) {
memcpy(mp, ip, elsize);
*min_ind=i;
}
}
PyArray_free(mp);
return 0;
}

/**end repeat**/


#define VOID_argmin NULL


/*
*****************************************************************************
Expand Down Expand Up @@ -3626,7 +3756,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
NULL,
(PyArray_FastClipFunc *)NULL,
(PyArray_FastPutmaskFunc *)NULL,
(PyArray_FastTakeFunc *)NULL
(PyArray_FastTakeFunc *)NULL,
(PyArray_ArgFunc*)@from@_argmin
};

/*
Expand Down Expand Up @@ -3717,7 +3848,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
NULL,
(PyArray_FastClipFunc*)@from@_fastclip,
(PyArray_FastPutmaskFunc*)@from@_fastputmask,
(PyArray_FastTakeFunc*)@from@_fasttake
(PyArray_FastTakeFunc*)@from@_fasttake,
(PyArray_ArgFunc*)@from@_argmin
};

/*
Expand Down
111 changes: 94 additions & 17 deletions numpy/core/src/multiarray/calculation.c
Expand Up @@ -156,32 +156,109 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
* ArgMin
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgMin(PyArrayObject *ap, int axis, PyArrayObject *out)
PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
{
PyObject *obj, *new, *ret;
PyArrayObject *ap = NULL, *rp = NULL;
PyArray_ArgFunc* arg_func;
char *ip;
intp *rptr;
intp i, n, m;
int elsize;
NPY_BEGIN_THREADS_DEF;

if (PyArray_ISFLEXIBLE(ap)) {
PyErr_SetString(PyExc_TypeError,
"argmax is unsupported for this type");
if ((ap=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
return NULL;
}
else if (PyArray_ISUNSIGNED(ap)) {
obj = PyInt_FromLong((long) -1);
}
else if (PyArray_TYPE(ap) == PyArray_BOOL) {
obj = PyInt_FromLong((long) 1);
/*
* We need to permute the array so that axis is placed at the end.
* And all other dimensions are shifted left.
*/
if (axis != PyArray_NDIM(ap)-1) {
PyArray_Dims newaxes;
intp dims[MAX_DIMS];
int i;

newaxes.ptr = dims;
newaxes.len = PyArray_NDIM(ap);
for (i = 0; i < axis; i++) dims[i] = i;
for (i = axis; i < PyArray_NDIM(ap) - 1; i++) dims[i] = i + 1;
dims[PyArray_NDIM(ap) - 1] = axis;
op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
Py_DECREF(ap);
if (op == NULL) {
return NULL;
}
}
else {
obj = PyInt_FromLong((long) 0);
op = ap;
}
new = PyArray_EnsureAnyArray(PyNumber_Subtract(obj, (PyObject *)ap));
Py_DECREF(obj);
if (new == NULL) {

/* Will get native-byte order contiguous copy. */
ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op,
PyArray_DESCR(op)->type_num, 1, 0);
Py_DECREF(op);
if (ap == NULL) {
return NULL;
}
ret = PyArray_ArgMax((PyArrayObject *)new, axis, out);
Py_DECREF(new);
return ret;
arg_func = PyArray_DESCR(ap)->f->argmin;
if (arg_func == NULL) {
PyErr_SetString(PyExc_TypeError, "data type not ordered");
goto fail;
}
elsize = PyArray_DESCR(ap)->elsize;
m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1];
if (m == 0) {
PyErr_SetString(PyExc_ValueError,
"attempt to get argmax/argmin "\
"of an empty sequence");
goto fail;
}

if (!out) {
rp = (PyArrayObject *)PyArray_New(Py_TYPE(ap), PyArray_NDIM(ap)-1,
PyArray_DIMS(ap), PyArray_INTP,
NULL, NULL, 0, 0,
(PyObject *)ap);
if (rp == NULL) {
goto fail;
}
}
else {
if (PyArray_SIZE(out) !=
PyArray_MultiplyList(PyArray_DIMS(ap), PyArray_NDIM(ap) - 1)) {
PyErr_SetString(PyExc_TypeError,
"invalid shape for output array.");
}
rp = (PyArrayObject *)PyArray_FromArray(out,
PyArray_DescrFromType(PyArray_INTP),
NPY_CARRAY | NPY_UPDATEIFCOPY);
if (rp == NULL) {
goto fail;
}
}

NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap));
n = PyArray_SIZE(ap)/m;
rptr = (intp *)PyArray_DATA(rp);
for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) {
arg_func(ip, m, rptr, ap);
rptr += 1;
}
NPY_END_THREADS_DESCR(PyArray_DESCR(ap));

Py_DECREF(ap);
/* Trigger the UPDATEIFCOPY if necessary */
if (out != NULL && out != rp) {
Py_DECREF(rp);
rp = out;
Py_INCREF(rp);
}
return (PyObject *)rp;

fail:
Py_DECREF(ap);
Py_XDECREF(rp);
return NULL;
}

/*NUMPY_API
Expand Down
1 change: 1 addition & 0 deletions numpy/core/src/multiarray/usertypes.c
Expand Up @@ -103,6 +103,7 @@ PyArray_InitArrFuncs(PyArray_ArrFuncs *f)
f->copyswap = NULL;
f->compare = NULL;
f->argmax = NULL;
f->argmin = NULL;
f->dotfunc = NULL;
f->scanfunc = NULL;
f->fromstr = NULL;
Expand Down

0 comments on commit ff72395

Please sign in to comment.