Skip to content

Commit

Permalink
Additional dtypes for labels. (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 14, 2023
1 parent a9a5784 commit c243e93
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
23 changes: 10 additions & 13 deletions numbagg/decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import abc
import itertools
from collections.abc import Iterable
from functools import cache, cached_property
from typing import Any, Callable, TypeVar
Expand Down Expand Up @@ -378,21 +379,17 @@ def __init__(
self.func = func

if signature is None:
signature = [
(numba.float32, numba.int32, numba.float32),
(numba.float32, numba.int64, numba.float32),
(numba.float64, numba.int32, numba.float64),
(numba.float64, numba.int64, numba.float64),
]

values_dtypes: tuple[numba.dtype, ...] = (numba.float32, numba.float64)
labels_dtypes = (numba.int8, numba.int16, numba.int32, numba.int64)
if supports_ints:
signature += [
(numba.int32, numba.int32, numba.int32),
(numba.int32, numba.int64, numba.int32),
(numba.int64, numba.int32, numba.int64),
(numba.int64, numba.int64, numba.int64),
]
values_dtypes += (numba.int32, numba.int64)

signature = [
(value_type, label_type, value_type)
for value_type, label_type in itertools.product(
values_dtypes, labels_dtypes
)
]
for sig in signature:
if not isinstance(sig, tuple):
raise TypeError(f"signatures for ndmoving must be tuples: {signature}")
Expand Down
5 changes: 3 additions & 2 deletions numbagg/test/test_grouped.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,14 @@ def test_additional_dim_equivalence(func, values, labels, dtype):
assert_almost_equal(result, expected)


@pytest.mark.parametrize("labels_type", [np.int8, np.int16, np.int32, np.int64])
@pytest.mark.parametrize("func, _, npfunc", [f for f in FUNCTIONS_CONSTANT])
def test_group_func_axis_1d_labels(func, _, npfunc):
def test_group_func_axis_1d_labels(func, _, npfunc, labels_type):
if npfunc is None:
pytest.skip("No numpy equivalent")

values = np.arange(5.0)
labels = np.arange(5)
labels = np.arange(5, dtype=labels_type)
result = func(values, labels)
assert_almost_equal(result, values)

Expand Down

0 comments on commit c243e93

Please sign in to comment.