Skip to content

Commit

Permalink
Add output arguments to a few more functions for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
teoliphant committed Aug 10, 2006
1 parent b772c97 commit c6f48c8
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 60 deletions.
2 changes: 1 addition & 1 deletion numpy/core/blasdot/_dotblas.c
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ static PyObject *dotblas_vdot(PyObject *dummy, PyObject *args) {
Py_DECREF(tmp2);
}
if (PyTypeNum_ISCOMPLEX(typenum)) {
op1 = PyArray_Conjugate(ap1);
op1 = PyArray_Conjugate(ap1, NULL);
if (op1==NULL) goto fail;
Py_DECREF(ap1);
ap1 = (PyArrayObject *)op1;
Expand Down
8 changes: 4 additions & 4 deletions numpy/core/defmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,14 @@ def all(self, axis=None, out=None):
def max(self, axis=None, out=None):
return N.ndarray.max(self, axis, out)._align(axis)

def argmax(self, axis=None):
return N.ndarray.argmax(self, axis)._align(axis)
def argmax(self, axis=None, out=None):
return N.ndarray.argmax(self, axis, out)._align(axis)

def min(self, axis=None, out=None):
return N.ndarray.min(self, axis, out)._align(axis)

def argmin(self, axis=None):
return N.ndarray.argmin(self, axis)._align(axis)
def argmin(self, axis=None, out=None):
return N.ndarray.argmin(self, axis, out)._align(axis)

def ptp(self, axis=None, out=None):
return N.ndarray.ptp(self, axis, out)._align(axis)
Expand Down
12 changes: 6 additions & 6 deletions numpy/core/fromnumeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def argsort(a, axis=-1, kind='quicksort'):
return _wrapit(a, 'argsort', axis, kind)
return argsort(axis, kind)

def argmax(a, axis=-1):
"""argmax(a,axis=-1) returns the indices to the maximum value of the
def argmax(a, axis=None):
"""argmax(a,axis=None) returns the indices to the maximum value of the
1-D arrays along the given axis.
"""
try:
Expand All @@ -151,8 +151,8 @@ def argmax(a, axis=-1):
return _wrapit(a, 'argmax', axis)
return argmax(axis)

def argmin(a, axis=-1):
"""argmin(a,axis=-1) returns the indices to the minimum value of the
def argmin(a, axis=None):
"""argmin(a,axis=None) returns the indices to the minimum value of the
1-D arrays along the given axis.
"""
try:
Expand Down Expand Up @@ -250,8 +250,8 @@ def shape(a):
result = asarray(a).shape
return result

def compress(condition, m, axis=-1, out=None):
"""compress(condition, x, axis=-1) = those elements of x corresponding
def compress(condition, m, axis=None, out=None):
"""compress(condition, x, axis=None) = those elements of x corresponding
to those elements of condition that are "true". condition must be the
same size as the given dimension of x."""
try:
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/include/numpy/arrayobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ extern "C" CONFUSE_EMACS
#define NPY_SUCCEED 1

/* Helpful to distinguish what is installed */
#define NPY_VERSION 0x01000001
#define NPY_VERSION 0x01000002

/* Some platforms don't define bool, long long, or long double.
Handle that here.
Expand Down
46 changes: 29 additions & 17 deletions numpy/core/src/arraymethods.c
Original file line number Diff line number Diff line change
Expand Up @@ -156,36 +156,42 @@ array_view(PyArrayObject *self, PyObject *args)
return PyArray_View(self, type, NULL);
}

static char doc_argmax[] = "a.argmax(axis=None)";
static char doc_argmax[] = "a.argmax(axis=None, out=None)";

static PyObject *
array_argmax(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
int axis=MAX_DIMS;
static char *kwlist[] = {"axis", NULL};
PyArrayObject *out=NULL;
static char *kwlist[] = {"axis", "out", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O&", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O&O&", kwlist,
PyArray_AxisConverter,
&axis))
&axis,
PyArray_OutputConverter,
&out))
return NULL;

return _ARET(PyArray_ArgMax(self, axis));
return _ARET(PyArray_ArgMax(self, axis, out));
}

static char doc_argmin[] = "a.argmin(axis=None)";
static char doc_argmin[] = "a.argmin(axis=None, out=None)";

static PyObject *
array_argmin(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
int axis=MAX_DIMS;
static char *kwlist[] = {"axis", NULL};
PyArrayObject *out=NULL;
static char *kwlist[] = {"axis", "out", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O&", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O&O&", kwlist,
PyArray_AxisConverter,
&axis))
&axis,
PyArray_OutputConverter,
&out))
return NULL;

return _ARET(PyArray_ArgMin(self, axis));
return _ARET(PyArray_ArgMin(self, axis, out));
}

static char doc_max[] = "a.max(axis=None)";
Expand Down Expand Up @@ -1550,19 +1556,22 @@ array_trace(PyArrayObject *self, PyObject *args, PyObject *kwds)
#undef _CHKTYPENUM


static char doc_clip[] = "a.clip(min=, max=)";
static char doc_clip[] = "a.clip(min=, max=, out=None)";

static PyObject *
array_clip(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
PyObject *min, *max;
static char *kwlist[] = {"min", "max", NULL};
PyArrayObject *out=NULL;
static char *kwlist[] = {"min", "max", "out", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist,
&min, &max))
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O&", kwlist,
&min, &max,
PyArray_OutputConverter,
&out))
return NULL;

return _ARET(PyArray_Clip(self, min, max));
return _ARET(PyArray_Clip(self, min, max, out));
}

static char doc_conj[] = "a.conj()";
Expand All @@ -1573,9 +1582,12 @@ static PyObject *
array_conjugate(PyArrayObject *self, PyObject *args)
{

if (!PyArg_ParseTuple(args, "")) return NULL;
PyArrayObject *out=NULL;
if (!PyArg_ParseTuple(args, "|O&",
PyArray_OutputConverter,
&out)) return NULL;

return PyArray_Conjugate(self);
return PyArray_Conjugate(self, out);
}


Expand Down
81 changes: 58 additions & 23 deletions numpy/core/src/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,9 @@ static PyObject *
PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out)
{
PyObject *f, *ret=NULL, *tmp, *op1, *op2;
if (out && (!PyArray_SAMESHAPE(out, a) ||
!PyArray_EquivTypes(a->descr, out->descr))) {
if (out && (PyArray_SIZE(out) != PyArray_SIZE(a))) {
PyErr_SetString(PyExc_ValueError,
"output array must have the same shape"
"and type");
"invalid output shape");
return NULL;
}
if (PyArray_ISCOMPLEX(a)) {
Expand Down Expand Up @@ -266,7 +264,7 @@ PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out)
if (decimals >= 0) {
if (PyArray_ISINTEGER(a)) {
if (out) {
if (PyArray_CopyInto(out, a) < 0) return NULL;
if (PyArray_CopyAnyInto(out, a) < 0) return NULL;
Py_INCREF(out);
return (PyObject *)out;
}
Expand Down Expand Up @@ -897,7 +895,7 @@ PyArray_Nonzero(PyArrayObject *self)
Clip
*/
static PyObject *
PyArray_Clip(PyArrayObject *self, PyObject *min, PyObject *max)
PyArray_Clip(PyArrayObject *self, PyObject *min, PyObject *max, PyArrayObject *out)
{
PyObject *selector=NULL, *newtup=NULL, *ret=NULL;
PyObject *res1=NULL, *res2=NULL, *res3=NULL;
Expand All @@ -924,7 +922,7 @@ PyArray_Clip(PyArrayObject *self, PyObject *min, PyObject *max)

newtup = Py_BuildValue("(OOO)", (PyObject *)self, min, max);
if (newtup == NULL) {Py_DECREF(selector); return NULL;}
ret = PyArray_Choose((PyAO *)selector, newtup, NULL, NPY_RAISE);
ret = PyArray_Choose((PyAO *)selector, newtup, out, NPY_RAISE);
Py_DECREF(selector);
Py_DECREF(newtup);
return ret;
Expand All @@ -934,14 +932,14 @@ PyArray_Clip(PyArrayObject *self, PyObject *min, PyObject *max)
Conjugate
*/
static PyObject *
PyArray_Conjugate(PyArrayObject *self)
PyArray_Conjugate(PyArrayObject *self, PyArrayObject *out)
{
if (PyArray_ISCOMPLEX(self)) {
PyObject *new;
intp size, i;
/* Make a copy */
new = PyArray_NewCopy(self, -1);
if (new==NULL) return NULL;
new = PyArray_NewCopy(self, -1);
if (new==NULL) return NULL;
size = PyArray_SIZE(new);
if (self->descr->type_num == PyArray_CFLOAT) {
cfloat *dptr = (cfloat *) PyArray_DATA(new);
Expand All @@ -964,11 +962,25 @@ PyArray_Conjugate(PyArrayObject *self)
dptr++;
}
}
if (out) {
if (PyArray_CopyAnyInto(out, (PyArrayObject *)new)<0)
return NULL;
Py_INCREF(out);
Py_DECREF(new);
return (PyObject *)out;
}
return new;
}
else {
Py_INCREF(self);
return (PyObject *) self;
PyArrayObject *ret;
if (out) {
if (PyArray_CopyAnyInto(out, self)< 0)
return NULL;
ret = out;
}
else ret = self;
Py_INCREF(ret);
return (PyObject *)ret;
}
}

Expand Down Expand Up @@ -1836,7 +1848,7 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
PyArrayObject *obj;
int flags = NPY_CARRAY | NPY_UPDATEIFCOPY;

if (!PyArray_SAMESHAPE(ap, ret)) {
if (PyArray_SIZE(ret) != PyArray_SIZE(ap)) {
PyErr_SetString(PyExc_TypeError,
"invalid shape for output array.");
ret = NULL;
Expand Down Expand Up @@ -3018,7 +3030,7 @@ PyArray_Correlate(PyObject *op1, PyObject *op2, int mode)
ArgMin
*/
static PyObject *
PyArray_ArgMin(PyArrayObject *ap, int axis)
PyArray_ArgMin(PyArrayObject *ap, int axis, PyArrayObject *out)
{
PyObject *obj, *new, *ret;

Expand All @@ -3039,7 +3051,7 @@ PyArray_ArgMin(PyArrayObject *ap, int axis)
new = PyArray_EnsureAnyArray(PyNumber_Subtract(obj, (PyObject *)ap));
Py_DECREF(obj);
if (new == NULL) return NULL;
ret = PyArray_ArgMax((PyArrayObject *)new, axis);
ret = PyArray_ArgMax((PyArrayObject *)new, axis, out);
Py_DECREF(new);
return ret;
}
Expand Down Expand Up @@ -3117,14 +3129,15 @@ PyArray_Ptp(PyArrayObject *ap, int axis, PyArrayObject *out)
ArgMax
*/
static PyObject *
PyArray_ArgMax(PyArrayObject *op, int axis)
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
{
PyArrayObject *ap=NULL, *rp=NULL;
PyArray_ArgFunc* arg_func;
char *ip;
intp *rptr;
intp i, n, m;
int elsize;
int copyret=0;

NPY_BEGIN_THREADS_DEF

Expand Down Expand Up @@ -3163,13 +3176,6 @@ PyArray_ArgMax(PyArrayObject *op, int axis)
goto fail;
}

rp = (PyArrayObject *)PyArray_New(ap->ob_type, ap->nd-1,
ap->dimensions, PyArray_INTP,
NULL, NULL, 0, 0,
(PyObject *)ap);
if (rp == NULL) goto fail;


elsize = ap->descr->elsize;
m = ap->dimensions[ap->nd-1];
if (m == 0) {
Expand All @@ -3178,6 +3184,28 @@ PyArray_ArgMax(PyArrayObject *op, int axis)
"of an empty sequence??");
goto fail;
}

if (!out) {
rp = (PyArrayObject *)PyArray_New(ap->ob_type, ap->nd-1,
ap->dimensions, PyArray_INTP,
NULL, NULL, 0, 0,
(PyObject *)ap);
if (rp == NULL) goto fail;
}
else {
if (PyArray_SIZE(out) != \
PyArray_MultiplyList(ap->dimensions, ap->nd-1)) {
PyErr_SetString(PyExc_TypeError,
"invalid shape for output array.");
}
rp = (PyArrayObject *)\
PyArray_FromArray(out,
PyArray_DescrFromType(PyArray_INTP),
NPY_CARRAY | NPY_UPDATEIFCOPY);
if (rp == NULL) goto fail;
if (rp != out) copyret = 1;
}

NPY_BEGIN_THREADS_DESCR(ap->descr)
n = PyArray_SIZE(ap)/m;
rptr = (intp *)rp->data;
Expand All @@ -3188,6 +3216,13 @@ PyArray_ArgMax(PyArrayObject *op, int axis)
NPY_END_THREADS_DESCR(ap->descr)

Py_DECREF(ap);
if (copyret) {
PyArrayObject *obj;
obj = (PyArrayObject *)rp->base;
Py_INCREF(obj);
Py_DECREF(rp);
rp = obj;
}
return (PyObject *)rp;

fail:
Expand Down
8 changes: 7 additions & 1 deletion numpy/numarray/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import os, sys, math, operator

from numpy import dot as matrixmultiply, dot, vdot, ravel, concatenate, all,\
allclose, any, argmax, argmin, around, argsort, array_equal, array_equiv,\
allclose, any, around, argsort, array_equal, array_equiv,\
array_str, array_repr, average, CLIP, RAISE, WRAP, clip, concatenate, \
diagonal, e, pi, fromfunction, indices, inner as innerproduct, nonzero, \
outer as outerproduct, kron as kroneckerproduct, lexsort, putmask, rank, \
Expand Down Expand Up @@ -438,3 +438,9 @@ def cumsum(a1, axis=0, out=None, type=None, dim=0):
def cumproduct(a1, axis=0, out=None, type=None, dim=0):
return N.asarray(a1).cumprod(axis,dtype=type,out=out)

def argmax(x, axis=-1):
return N.argmax(x, axis)

def argmin(x, axis=-1):
return N.argmin(x, axis)

Loading

0 comments on commit c6f48c8

Please sign in to comment.