Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less skipped tests #88

Merged
merged 1 commit into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 37 additions & 11 deletions numpy_groupies/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,58 @@ def _impl_name(impl):
return impl.__name__.rsplit("aggregate_", 1)[1].rsplit("_", 1)[-1]


_not_implemented_by_impl_name = {
"numpy": ("cumprod", "cummax", "cummin"),
"purepy": ("cumsum", "cumprod", "cummax", "cummin", "sumofsquares"),
"numba": ("array", "list", "sort"),
"pandas": ("array", "list", "sort", "sumofsquares", "nansumofsquares"),
"ufunc": "NO_CHECK",
_implemented_by_impl_name = {
"numpy": {"not_implemented": ("cumprod", "cummax", "cummin")},
"purepy": {
"not_implemented": ("cumsum", "cumprod", "cummax", "cummin", "sumofsquares")
},
"numba": {"not_implemented": ("array", "list", "sort")},
"pandas": {
"not_implemented": ("array", "list", "sort", "sumofsquares", "nansumofsquares")
},
"ufunc": {
"implemented": (
"sum",
"prod",
"min",
"max",
"len",
"all",
"any",
"anynan",
"allnan",
)
},
}


def _is_implemented(impl_name, funcname):
func_description = _implemented_by_impl_name[impl_name]
not_implemented = func_description.get("not_implemented", [])
implemented = func_description.get("implemented", [])
if impl_name == "purepy" and funcname.startswith("nan"):
return False
if funcname in not_implemented:
return False
if implemented and funcname not in implemented:
return False
return True


def _wrap_notimplemented_skip(impl, name=None):
"""Some implementations lack some functionality. That's ok, let's skip that instead of raising errors."""

@wraps(impl)
def try_skip(*args, **kwargs):
try:
return impl(*args, **kwargs)
except NotImplementedError as e:
except NotImplementedError:
impl_name = impl.__module__.split("_")[-1]
func = kwargs.pop("func", None)
if callable(func):
func = func.__name__
not_implemented_ok = _not_implemented_by_impl_name.get(impl_name, [])
if not_implemented_ok == "NO_CHECK" or func in not_implemented_ok:
if not _is_implemented(impl_name, func):
pytest.skip("Functionality not implemented")
else:
raise e

if name:
try_skip.__name__ = name
Expand Down
22 changes: 14 additions & 8 deletions numpy_groupies/tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from . import (
_impl_name,
_wrap_notimplemented_skip,
_is_implemented,
aggregate_numba,
aggregate_numpy,
aggregate_numpy_ufunc,
Expand All @@ -34,7 +34,7 @@ def aggregate_cmp(request, seed=100):
test_pair = request.param
if test_pair == "np/py":
# Some functions in purepy are not implemented
func_ref = _wrap_notimplemented_skip(aggregate_purepy.aggregate)
func_ref = aggregate_purepy.aggregate
func = aggregate_numpy.aggregate
group_cnt = 100
else:
Expand All @@ -52,7 +52,7 @@ def aggregate_cmp(request, seed=100):
if not impl:
pytest.skip("Implementation not available")
name = _impl_name(impl)
func = _wrap_notimplemented_skip(impl.aggregate, "aggregate_" + name)
func = impl.aggregate

rnd = np.random.RandomState(seed=seed)

Expand All @@ -77,10 +77,12 @@ def _deselect_purepy(aggregate_cmp, *args, **kwargs):
return aggregate_cmp.endswith("py")


def _deselect_purepy_nanfuncs(aggregate_cmp, func, *args, **kwargs):
# purepy implementation does not handle nan values correctly
# This is a won't fix and should be deselected instead of skipped
return "nan" in getattr(func, "__name__", func) and aggregate_cmp.endswith("py")
def _deselect_not_implemented(aggregate_cmp, func, fill_value, *args, **kwargs):
impl_name = (
"purepy" if aggregate_cmp.endswith("py") else aggregate_cmp.split("/", 1)[0]
)
funcname = getattr(func, "__name__", func)
return not _is_implemented(impl_name, funcname)


def func_arbitrary(iterator):
Expand All @@ -97,8 +99,9 @@ def func_preserve_order(iterator):
return tmp


@pytest.mark.filterwarnings("ignore::FutureWarning") # handled pandas deprecation
@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed")
@pytest.mark.deselect_if(func=_deselect_purepy_nanfuncs)
@pytest.mark.deselect_if(func=_deselect_not_implemented)
@pytest.mark.parametrize("fill_value", [0, 1, np.nan])
@pytest.mark.parametrize("func", func_list, ids=lambda x: getattr(x, "__name__", x))
def test_cmp(aggregate_cmp, func, fill_value, decimal=10):
Expand All @@ -123,6 +126,9 @@ def test_cmp(aggregate_cmp, func, fill_value, decimal=10):
pytest.skip(
"pure python version uses lists and does not raise ValueErrors when inserting nan into integers"
)
elif aggregate_cmp.test_pair.startswith("pandas"):
pytest.skip("pandas now raises ValueError on all-nan arrays")

else:
raise
if isinstance(ref, np.ndarray):
Expand Down
25 changes: 20 additions & 5 deletions numpy_groupies/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import numpy as np
import pytest

from . import _impl_name, _implementations, _wrap_notimplemented_skip, func_list
from . import (
_impl_name,
_implementations,
_wrap_notimplemented_skip,
func_list,
_is_implemented,
)


@pytest.fixture(params=_implementations, ids=_impl_name)
Expand All @@ -30,12 +36,21 @@ def _deselect_purepy_and_pandas(aggregate_all, *args, **kwargs):
return aggregate_all.__name__.endswith(("pandas", "purepy"))


def _deselect_purepy_and_invalid_axis(aggregate_all, size, axis, *args, **kwargs):
def _deselect_purepy_and_invalid_axis(aggregate_all, func, size, axis):
impl_name = aggregate_all.__name__.split("_")[-1]
if impl_name == "purepy":
# purepy does not handle axis parameter
return True
if axis >= len(size):
return True
if aggregate_all.__name__.endswith("purepy"):
# purepy does not handle axis parameter
if not _is_implemented(impl_name, func):
return True
return False


def _deselect_not_implemented(aggregate_all, func, *args, **kwargs):
impl_name = aggregate_all.__name__.split("_")[-1]
return not _is_implemented(impl_name, func)


def test_preserve_missing(aggregate_all):
Expand Down Expand Up @@ -533,7 +548,7 @@ def test_argreduction_negative_fill_value(aggregate_all):
np.testing.assert_equal(actual, expected)


@pytest.mark.deselect_if(func=_deselect_purepy)
@pytest.mark.deselect_if(func=_deselect_not_implemented)
@pytest.mark.parametrize(
"nan_inds", (None, tuple([[1, 4, 5], Ellipsis]), tuple((1, (0, 1, 2, 3))))
)
Expand Down