Skip to content

Commit 456432f

Browse files
authored
updated test hotspots gpu (#692)
1 parent 379e75e commit 456432f

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

xrspatial/focal.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,6 @@ def hotspots(raster, kernel):
728728
Dimensions without coordinates: dim_0, dim_1
729729
"""
730730

731-
# TODO: edit unit of output raster to percent (%)
732-
733731
# validate raster
734732
if not isinstance(raster, DataArray):
735733
raise TypeError("`raster` must be instance of DataArray")

xrspatial/tests/test_focal.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,15 @@ def test_hotspot_gpu(data_hotspots):
452452
data, kernel, expected_result = data_hotspots
453453
cupy_agg = create_test_raster(data, backend='cupy')
454454
cupy_hotspots = hotspots(cupy_agg, kernel)
455-
general_output_checks(cupy_agg, cupy_hotspots, expected_result)
455+
general_output_checks(cupy_agg, cupy_hotspots, expected_result, verify_attrs=False)
456+
# validate attrs
457+
assert cupy_hotspots.shape == cupy_agg.shape
458+
assert cupy_hotspots.dims == cupy_agg.dims
459+
for coord in cupy_agg.coords:
460+
np.testing.assert_allclose(
461+
cupy_hotspots[coord].data, cupy_agg[coord].data, equal_nan=True
462+
)
463+
assert cupy_hotspots.attrs['unit'] == '%'
456464

457465
# dask + cupy case not implemented
458466
dask_cupy_agg = create_test_raster(data, backend='dask+cupy')

0 commit comments

Comments
 (0)