Skip to content

Commit

Permalink
Merge pull request #3618 from okuta/fix-diff-dtype-histogram
Browse files Browse the repository at this point in the history
Add different dtype input test in histogram
  • Loading branch information
takagi committed Sep 17, 2020
2 parents bc8245c + 306ed6f commit 2351f05
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/cupy_tests/statistics_tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,40 @@ def test_histogram_range_float(self, xp, dtype):
assert int(h.sum()) == 10
return h, b

@testing.for_all_dtypes_combination(['dtype_a', 'dtype_b'],
no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_histogram_with_bins(self, xp, dtype_a, dtype_b):
x = testing.shaped_arange((10,), xp, dtype_a)
bins = testing.shaped_arange((4,), xp, dtype_b)

if xp is numpy:
return xp.histogram(x, bins)[0]

# xp is cupy, first ensure we really use CUB
cub_func = 'cupy._statistics.histogram.cub.device_histogram'
with testing.AssertFunctionIsCalled(cub_func):
xp.histogram(x, bins)
# ...then perform the actual computation
return xp.histogram(x, bins)[0]

@testing.for_all_dtypes_combination(['dtype_a', 'dtype_b'],
no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_histogram_with_bins2(self, xp, dtype_a, dtype_b):
x = testing.shaped_arange((10,), xp, dtype_a)
bins = testing.shaped_arange((4,), xp, dtype_b)

if xp is numpy:
return xp.histogram(x, bins)[1]

# xp is cupy, first ensure we really use CUB
cub_func = 'cupy._statistics.histogram.cub.device_histogram'
with testing.AssertFunctionIsCalled(cub_func):
xp.histogram(x, bins)
# ...then perform the actual computation
return xp.histogram(x, bins)[1]


@testing.gpu
@testing.parameterize(*testing.product(
Expand Down

0 comments on commit 2351f05

Please sign in to comment.