Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add more mixins + tests. #10460

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions numpy/lib/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,34 @@ def _disables_array_ufunc(obj):

def _binary_method(ufunc, name):
"""Implement a forward binary method with a ufunc, e.g., __add__."""

def func(self, other):
if _disables_array_ufunc(other):
return NotImplemented
return ufunc(self, other)

func.__name__ = '__{}__'.format(name)
return func


def _reflected_binary_method(ufunc, name):
"""Implement a reflected binary method with a ufunc, e.g., __radd__."""

def func(self, other):
if _disables_array_ufunc(other):
return NotImplemented
return ufunc(other, self)

func.__name__ = '__r{}__'.format(name)
return func


def _inplace_binary_method(ufunc, name):
"""Implement an in-place binary method with a ufunc, e.g., __iadd__."""

def func(self, other):
return ufunc(self, other, out=(self,))

func.__name__ = '__i{}__'.format(name)
return func

Expand All @@ -54,9 +60,33 @@ def _numeric_methods(ufunc, name):

def _unary_method(ufunc, name):
"""Implement a unary special method with a ufunc."""

def func(self):
return ufunc(self)
func.__name__ = '__{}__'.format(name)

func.__name__ = name
return func


def _reduction_method(ufunc, name):
"""Implement a reduction method with a ufunc."""

def func(self, *args, **kwargs):
return ufunc.reduce(self, *args, **kwargs)

func.__name__ = name

return func


def _accumulation_method(ufunc, name):
"""Implement a reduction method with a ufunc."""

def func(self, *args, **kwargs):
return ufunc.accumulate(self, *args, **kwargs)

func.__name__ = name

return func


Expand All @@ -81,7 +111,9 @@ class NDArrayOperatorsMixin(object):
class that simply wraps a NumPy array and ensures that the result of any
arithmetic operation is also an ``ArrayLike`` object::

class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin,
np.lib.mixins.NDArrayReductionsMixin,
np.lib.mixins.NDArrayAccumulationsMixin):
def __init__(self, value):
self.value = np.asarray(value)

Expand Down Expand Up @@ -179,3 +211,45 @@ def __repr__(self):
__pos__ = _unary_method(um.positive, 'pos')
__abs__ = _unary_method(um.absolute, 'abs')
__invert__ = _unary_method(um.invert, 'invert')


class NDArrayReductionsMixin(object):
"""
Mixin defining all array reduction methods using __array_ufunc__.

This class implements methods for the reductions supported by ``ndarray``,
including ``sum``, ``min``, ``any``, etc.

It is useful for writing classes that do not inherit from `numpy.ndarray`,
but that should support reductions like arrays as described in :ref:`A
Mechanism for Overriding Ufuncs <neps.ufunc-overrides>`.
Copy link
Member

@eric-wieser eric-wieser Jan 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really accurate - np.sum() and other functions will work just fine on any object defining __array_ufunc__ - all this is doing is adding alias methods that match ndarray.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I'll change that.

"""
# Sum
sum = _reduction_method(um.add, 'sum')

# Product
prod = _reduction_method(um.multiply, 'prod')

# Min/max
min = _reduction_method(um.minimum, 'min')
max = _reduction_method(um.maximum, 'max')

# Any/all
any = _reduction_method(um.logical_or, 'any')
all = _reduction_method(um.logical_and, 'all')


class NDArrayAccumulationsMixin(object):
"""
Mixin defining all array accumulation methods using __array_ufunc__.

This class implements methods for the accumulations supported by ``ndarray``,
including ``cumsum`` and ``cumprod``.

It is useful for writing classes that do not inherit from `numpy.ndarray`,
but that should support accumulations like arrays as described in :ref:`A
Mechanism for Overriding Ufuncs <neps.ufunc-overrides>`.
"""
# Accumulations here.
cumsum = _accumulation_method(um.add, 'cumsum')
cumprod = _accumulation_method(um.multiply, 'cumprod')
83 changes: 78 additions & 5 deletions numpy/lib/tests/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
import numbers
import operator
import sys
import itertools

import numpy as np
from numpy.testing import (
run_module_suite, assert_, assert_equal, assert_raises
)

)

PY2 = sys.version_info.major < 3


# NOTE: This class should be kept as an exact copy of the example from the
# docstring for NDArrayOperatorsMixin.

class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin,
np.lib.mixins.NDArrayReductionsMixin,
np.lib.mixins.NDArrayAccumulationsMixin):
def __init__(self, value):
self.value = np.asarray(value)

Expand Down Expand Up @@ -76,6 +78,7 @@ def _assert_equal_type_and_value(result, expected, err_msg=None):
getattr(expected.value, 'dtype', None), err_msg=err_msg)


# TODO: Test operator.div on Python 2.
_ALL_BINARY_OPERATORS = [
operator.lt,
operator.le,
Expand All @@ -88,7 +91,6 @@ def _assert_equal_type_and_value(result, expected, err_msg=None):
operator.mul,
operator.truediv,
operator.floordiv,
# TODO: test div on Python 2, only
operator.mod,
divmod,
pow,
Expand All @@ -99,9 +101,22 @@ def _assert_equal_type_and_value(result, expected, err_msg=None):
operator.or_,
]

_ALL_REDUCTIONS = [
'sum',
'prod',
'min',
'max',
'any',
'all',
]

class TestNDArrayOperatorsMixin(object):
_ALL_ACCUMULATIONS = [
'cumsum',
'cumprod',
]


class TestNDArrayOperatorsMixin(object):
def test_array_like_add(self):

def check(result):
Expand Down Expand Up @@ -215,5 +230,63 @@ def test_ufunc_two_outputs(self):
np.frexp(ArrayLike(np.array(2 ** -3))), expected)


class TestNDArrayReductionsMixin(object):
def test_reductions_simple(self):
array = np.array([-1, 0, 1, 2])
array_like = ArrayLike(array)

for reduction in _ALL_REDUCTIONS:
expected = wrap_array_like(getattr(array, reduction)())
actual = getattr(array_like, reduction)()

_assert_equal_type_and_value(actual, expected)

def test_reductions_with_axis(self):
array = np.array([[-1, 0, 1, 2],
[1, 0, 1, 0],
[3, 0, 1, 6]])
array_like = ArrayLike(array)

for reduction, axis in itertools.product(_ALL_REDUCTIONS, (0, 1, (0, 1))):
expected = wrap_array_like(getattr(array, reduction)(axis=axis))
actual = getattr(array_like, reduction)(axis=axis)

_assert_equal_type_and_value(actual, expected)

def test_reductions_with_keepdims(self):
array = np.array([-1, 0, 1, 2])
array_like = ArrayLike(array)

for reduction, keepdims in itertools.product(_ALL_REDUCTIONS, (True, False)):
expected = wrap_array_like(getattr(array, reduction)(keepdims=keepdims))
actual = getattr(array_like, reduction)(keepdims=keepdims)

_assert_equal_type_and_value(actual, expected)


class TestNDArrayAccumulationsMixin(object):
def test_accumulations_simple(self):
array = np.array([-1, 0, 1, 2])
array_like = ArrayLike(array)

for accumulation in _ALL_ACCUMULATIONS:
expected = wrap_array_like(getattr(array, accumulation)())
actual = getattr(array_like, accumulation)()

_assert_equal_type_and_value(actual, expected)

def test_accumulations_with_axis(self):
array = np.array([[-1, 0, 1, 2],
[1, 0, 1, 0],
[3, 0, 1, 6]])
array_like = ArrayLike(array)

for accumulation, axis in itertools.product(_ALL_ACCUMULATIONS, (0, 1)):
expected = wrap_array_like(getattr(array, accumulation)(axis=axis))
actual = getattr(array_like, accumulation)(axis=axis)

_assert_equal_type_and_value(actual, expected)


if __name__ == "__main__":
run_module_suite()