Skip to content
Browse files

BUG: ticket #1149, make argmin work for datetime.

Backport of 6fc0737 'ENH: Explicitly coded argmin for timedelta'
  • Loading branch information...
1 parent 74b9f5e commit ff72395969c2c2a6abe39de8d78bfd03c4b0705d @WeatherGod WeatherGod committed with charris Sep 15, 2011
View
7 numpy/core/include/numpy/ndarraytypes.h
@@ -532,6 +532,13 @@ typedef struct {
PyArray_FastClipFunc *fastclip;
PyArray_FastPutmaskFunc *fastputmask;
PyArray_FastTakeFunc *fasttake;
+
+ /*
+ * Function to select smallest
+ * Can be NULL
+ */
+ PyArray_ArgFunc *argmin;
+
} PyArray_ArrFuncs;
/* The item must be reference counted when it is inserted or extracted. */
View
136 numpy/core/src/multiarray/arraytypes.c.src
@@ -2924,6 +2924,79 @@ static int
/**end repeat**/
+/**begin repeat
+ *
+ * #fname = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
+ * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA#
+ * #type = Bool, byte, ubyte, short, ushort, int, uint, long, ulong,
+ * longlong, ulonglong, npy_half, float, double, longdouble,
+ * float, double, longdouble, datetime, timedelta#
+ * #isfloat = 0*11, 1*7, 0*2#
+ * #isnan = nop*11, npy_half_isnan, npy_isnan*6, nop*2#
+ * #le = _LESS_THAN_OR_EQUAL*11, npy_half_le, _LESS_THAN_OR_EQUAL*8#
+ * #iscomplex = 0*15, 1*3, 0*2#
+ * #incr = ip++*15, ip+=2*3, ip++*2#
+ */
+static int
+@fname@_argmin(@type@ *ip, intp n, intp *min_ind, PyArrayObject *NPY_UNUSED(aip))
+{
+ intp i;
+ @type@ mp = *ip;
+#if @iscomplex@
+ @type@ mp_im = ip[1];
+#endif
+
+ *min_ind = 0;
+
+#if @isfloat@
+ if (@isnan@(mp)) {
+ /* nan encountered; it's minimal */
+ return 0;
+ }
+#endif
+#if @iscomplex@
+ if (@isnan@(mp_im)) {
+ /* nan encountered; it's minimal */
+ return 0;
+ }
+#endif
+
+ for (i = 1; i < n; i++) {
+ @incr@;
+ /*
+ * Propagate nans, similarly as max() and min()
+ */
+#if @iscomplex@
+ /* Lexical order for complex numbers */
+ if ((mp > ip[0]) || ((ip[0] == mp) && (mp_im > ip[1]))
+ || @isnan@(ip[0]) || @isnan@(ip[1])) {
+ mp = ip[0];
+ mp_im = ip[1];
+ *min_ind = i;
+ if (@isnan@(mp) || @isnan@(mp_im)) {
+ /* nan encountered, it's minimal */
+ break;
+ }
+ }
+#else
+ if (!@le@(mp, *ip)) { /* negated, for correct nan handling */
+ mp = *ip;
+ *min_ind = i;
+#if @isfloat@
+ if (@isnan@(mp)) {
+ /* nan encountered, it's minimal */
+ break;
+ }
+#endif
+ }
+#endif
+ }
+ return 0;
+}
+
+/**end repeat**/
+
#undef _LESS_THAN_OR_EQUAL
static int
@@ -2982,6 +3055,63 @@ static int
#define VOID_argmax NULL
+static int
+OBJECT_argmin(PyObject **ip, intp n, intp *min_ind, PyArrayObject *NPY_UNUSED(aip))
+{
+ intp i;
+ PyObject *mp = ip[0];
+
+ *min_ind = 0;
+ i = 1;
+ while (i < n && mp == NULL) {
+ mp = ip[i];
+ i++;
+ }
+ for (; i < n; i++) {
+ ip++;
+#if defined(NPY_PY3K)
+ if (*ip != NULL && PyObject_RichCompareBool(mp, *ip, Py_GT) == 1) {
+#else
+ if (*ip != NULL && PyObject_Compare(mp, *ip) > 0) {
+#endif
+ mp = *ip;
+ *min_ind = i;
+ }
+ }
+ return 0;
+}
+
+/**begin repeat
+ *
+ * #fname = STRING, UNICODE#
+ * #type = char, PyArray_UCS4#
+ */
+static int
+@fname@_argmin(@type@ *ip, intp n, intp *min_ind, PyArrayObject *aip)
+{
+ intp i;
+ int elsize = PyArray_DESCR(aip)->elsize;
+ @type@ *mp = (@type@ *)PyArray_malloc(elsize);
+
+ if (mp==NULL) return 0;
+ memcpy(mp, ip, elsize);
+ *min_ind = 0;
+ for(i=1; i<n; i++) {
+ ip += elsize;
+ if (@fname@_compare(mp,ip,aip) > 0) {
+ memcpy(mp, ip, elsize);
+ *min_ind=i;
+ }
+ }
+ PyArray_free(mp);
+ return 0;
+}
+
+/**end repeat**/
+
+
+#define VOID_argmin NULL
+
/*
*****************************************************************************
@@ -3626,7 +3756,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
NULL,
(PyArray_FastClipFunc *)NULL,
(PyArray_FastPutmaskFunc *)NULL,
- (PyArray_FastTakeFunc *)NULL
+ (PyArray_FastTakeFunc *)NULL,
+ (PyArray_ArgFunc*)@from@_argmin
};
/*
@@ -3717,7 +3848,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
NULL,
(PyArray_FastClipFunc*)@from@_fastclip,
(PyArray_FastPutmaskFunc*)@from@_fastputmask,
- (PyArray_FastTakeFunc*)@from@_fasttake
+ (PyArray_FastTakeFunc*)@from@_fasttake,
+ (PyArray_ArgFunc*)@from@_argmin
};
/*
View
111 numpy/core/src/multiarray/calculation.c
@@ -156,32 +156,109 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
* ArgMin
*/
NPY_NO_EXPORT PyObject *
-PyArray_ArgMin(PyArrayObject *ap, int axis, PyArrayObject *out)
+PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
{
- PyObject *obj, *new, *ret;
+ PyArrayObject *ap = NULL, *rp = NULL;
+ PyArray_ArgFunc* arg_func;
+ char *ip;
+ intp *rptr;
+ intp i, n, m;
+ int elsize;
+ NPY_BEGIN_THREADS_DEF;
- if (PyArray_ISFLEXIBLE(ap)) {
- PyErr_SetString(PyExc_TypeError,
- "argmax is unsupported for this type");
+ if ((ap=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
return NULL;
}
- else if (PyArray_ISUNSIGNED(ap)) {
- obj = PyInt_FromLong((long) -1);
- }
- else if (PyArray_TYPE(ap) == PyArray_BOOL) {
- obj = PyInt_FromLong((long) 1);
+ /*
+ * We need to permute the array so that axis is placed at the end.
+ * And all other dimensions are shifted left.
+ */
+ if (axis != PyArray_NDIM(ap)-1) {
+ PyArray_Dims newaxes;
+ intp dims[MAX_DIMS];
+ int i;
+
+ newaxes.ptr = dims;
+ newaxes.len = PyArray_NDIM(ap);
+ for (i = 0; i < axis; i++) dims[i] = i;
+ for (i = axis; i < PyArray_NDIM(ap) - 1; i++) dims[i] = i + 1;
+ dims[PyArray_NDIM(ap) - 1] = axis;
+ op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
+ Py_DECREF(ap);
+ if (op == NULL) {
+ return NULL;
+ }
}
else {
- obj = PyInt_FromLong((long) 0);
+ op = ap;
}
- new = PyArray_EnsureAnyArray(PyNumber_Subtract(obj, (PyObject *)ap));
- Py_DECREF(obj);
- if (new == NULL) {
+
+ /* Will get native-byte order contiguous copy. */
+ ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op,
+ PyArray_DESCR(op)->type_num, 1, 0);
+ Py_DECREF(op);
+ if (ap == NULL) {
return NULL;
}
- ret = PyArray_ArgMax((PyArrayObject *)new, axis, out);
- Py_DECREF(new);
- return ret;
+ arg_func = PyArray_DESCR(ap)->f->argmin;
+ if (arg_func == NULL) {
+ PyErr_SetString(PyExc_TypeError, "data type not ordered");
+ goto fail;
+ }
+ elsize = PyArray_DESCR(ap)->elsize;
+ m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1];
+ if (m == 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "attempt to get argmax/argmin "\
+ "of an empty sequence");
+ goto fail;
+ }
+
+ if (!out) {
+ rp = (PyArrayObject *)PyArray_New(Py_TYPE(ap), PyArray_NDIM(ap)-1,
+ PyArray_DIMS(ap), PyArray_INTP,
+ NULL, NULL, 0, 0,
+ (PyObject *)ap);
+ if (rp == NULL) {
+ goto fail;
+ }
+ }
+ else {
+ if (PyArray_SIZE(out) !=
+ PyArray_MultiplyList(PyArray_DIMS(ap), PyArray_NDIM(ap) - 1)) {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid shape for output array.");
+ }
+ rp = (PyArrayObject *)PyArray_FromArray(out,
+ PyArray_DescrFromType(PyArray_INTP),
+ NPY_CARRAY | NPY_UPDATEIFCOPY);
+ if (rp == NULL) {
+ goto fail;
+ }
+ }
+
+ NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap));
+ n = PyArray_SIZE(ap)/m;
+ rptr = (intp *)PyArray_DATA(rp);
+ for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) {
+ arg_func(ip, m, rptr, ap);
+ rptr += 1;
+ }
+ NPY_END_THREADS_DESCR(PyArray_DESCR(ap));
+
+ Py_DECREF(ap);
+ /* Trigger the UPDATEIFCOPY if necessary */
+ if (out != NULL && out != rp) {
+ Py_DECREF(rp);
+ rp = out;
+ Py_INCREF(rp);
+ }
+ return (PyObject *)rp;
+
+ fail:
+ Py_DECREF(ap);
+ Py_XDECREF(rp);
+ return NULL;
}
/*NUMPY_API
View
1 numpy/core/src/multiarray/usertypes.c
@@ -103,6 +103,7 @@ PyArray_InitArrFuncs(PyArray_ArrFuncs *f)
f->copyswap = NULL;
f->compare = NULL;
f->argmax = NULL;
+ f->argmin = NULL;
f->dotfunc = NULL;
f->scanfunc = NULL;
f->fromstr = NULL;
View
105 numpy/core/tests/test_multiarray.py
@@ -6,6 +6,9 @@
from numpy.core import *
from numpy.core.multiarray_tests import test_neighborhood_iterator, test_neighborhood_iterator_oob
+# Need to test an object that does not fully implement math interface
+from datetime import timedelta
+
from numpy.compat import asbytes, getexception, strchar
from test_print import in_foreign_locale
@@ -921,6 +924,38 @@ class TestArgmax(TestCase):
([complex(0, 0), complex(0, 2), complex(0, 1)], 1),
([complex(1, 0), complex(0, 2), complex(0, 1)], 0),
([complex(1, 0), complex(0, 2), complex(1, 1)], 2),
+
+ # Fails on 32-bit systems (haven't tested 64-bit) due to y2.038k bug
+ #([np.datetime64('1923-04-14T12:43:12'),
+ # np.datetime64('1994-06-21T14:43:15'),
+ # np.datetime64('2001-10-15T04:10:32'),
+ # np.datetime64('1995-11-25T16:02:16'),
+ # np.datetime64('2005-01-04T03:14:12'),
+ # np.datetime64('2041-12-03T14:05:03')], 5),
+ ([np.datetime64('1935-09-14T04:40:11'),
+ np.datetime64('1949-10-12T12:32:11'),
+ np.datetime64('2010-01-03T05:14:12'),
+ np.datetime64('2015-11-20T12:20:59'),
+ np.datetime64('1932-09-23T10:10:13'),
+ np.datetime64('2014-10-10T03:50:30')], 3),
+ #([np.datetime64('2059-03-14T12:43:12'),
+ # np.datetime64('1996-09-21T14:43:15'),
+ # np.datetime64('2001-10-15T04:10:32'),
+ # np.datetime64('2022-12-25T16:02:16'),
+ # np.datetime64('1963-10-04T03:14:12'),
+ # np.datetime64('2013-05-08T18:15:23')], 0),
+
+ ([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35),
+ timedelta(days=-1, seconds=23)], 0),
+ ([timedelta(days=1, seconds=43), timedelta(days=10, seconds=5),
+ timedelta(days=5, seconds=14)], 1),
+ ([timedelta(days=10, seconds=24), timedelta(days=10, seconds=5),
+ timedelta(days=10, seconds=43)], 2),
+
+ # Can't reduce a "flexible type"
+ #(['a', 'z', 'aa', 'zz'], 3),
+ #(['zz', 'a', 'aa', 'a'], 0),
+ #(['aa', 'z', 'zz', 'a'], 2),
]
def test_all(self):
@@ -938,6 +973,76 @@ def test_combinations(self):
assert_equal(arr[np.argmax(arr)], np.max(arr), err_msg="%r"%arr)
+class TestArgmin(TestCase):
+
+ nan_arr = [
+ ([0, 1, 2, 3, np.nan], 4),
+ ([0, 1, 2, np.nan, 3], 3),
+ ([np.nan, 0, 1, 2, 3], 0),
+ ([np.nan, 0, np.nan, 2, 3], 0),
+ ([0, 1, 2, 3, complex(0,np.nan)], 4),
+ ([0, 1, 2, 3, complex(np.nan,0)], 4),
+ ([0, 1, 2, complex(np.nan,0), 3], 3),
+ ([0, 1, 2, complex(0,np.nan), 3], 3),
+ ([complex(0,np.nan), 0, 1, 2, 3], 0),
+ ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
+ ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
+ ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
+ ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
+
+ ([complex(0, 0), complex(0, 2), complex(0, 1)], 0),
+ ([complex(1, 0), complex(0, 2), complex(0, 1)], 2),
+ ([complex(1, 0), complex(0, 2), complex(1, 1)], 1),
+
+ # Fails on 32-bit systems (haven't tested 64-bit) due to y2.038k bug
+ #([np.datetime64('1923-04-14T12:43:12'),
+ # np.datetime64('1994-06-21T14:43:15'),
+ # np.datetime64('2001-10-15T04:10:32'),
+ # np.datetime64('1995-11-25T16:02:16'),
+ # np.datetime64('2005-01-04T03:14:12'),
+ # np.datetime64('2041-12-03T14:05:03')], 0),
+ ([np.datetime64('1935-09-14T04:40:11'),
+ np.datetime64('1949-10-12T12:32:11'),
+ np.datetime64('2010-01-03T05:14:12'),
+ np.datetime64('2014-11-20T12:20:59'),
+ np.datetime64('2015-09-23T10:10:13'),
+ np.datetime64('1932-10-10T03:50:30')], 5),
+ #([np.datetime64('2059-03-14T12:43:12'),
+ # np.datetime64('1996-09-21T14:43:15'),
+ # np.datetime64('2001-10-15T04:10:32'),
+ # np.datetime64('2022-12-25T16:02:16'),
+ # np.datetime64('1963-10-04T03:14:12'),
+ # np.datetime64('2013-05-08T18:15:23')], 4),
+
+ ([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35),
+ timedelta(days=-1, seconds=23)], 2),
+ ([timedelta(days=1, seconds=43), timedelta(days=10, seconds=5),
+ timedelta(days=5, seconds=14)], 0),
+ ([timedelta(days=10, seconds=24), timedelta(days=10, seconds=5),
+ timedelta(days=10, seconds=43)], 1),
+
+ # Can't reduce a "flexible type"
+ #(['a', 'z', 'aa', 'zz'], 0),
+ #(['zz', 'a', 'aa', 'a'], 1),
+ #(['aa', 'z', 'zz', 'a'], 3),
+ ]
+
+ def test_all(self):
+ a = np.random.normal(0,1,(4,5,6,7,8))
+ for i in xrange(a.ndim):
+ amin = a.min(i)
+ aargmin = a.argmin(i)
+ axes = range(a.ndim)
+ axes.remove(i)
+ assert_(all(amin == aargmin.choose(*a.transpose(i,*axes))))
+
+ def test_combinations(self):
+ for arr, pos in self.nan_arr:
+ assert_equal(np.argmin(arr), pos, err_msg="%r"%arr)
+ assert_equal(arr[np.argmin(arr)], np.min(arr), err_msg="%r"%arr)
+
+
+
class TestMinMax(TestCase):
def test_scalar(self):
assert_raises(ValueError, np.amax, 1, 1)

0 comments on commit ff72395

Please sign in to comment.
Something went wrong with that request. Please try again.