Skip to content

Commit

Permalink
NumPy zonal stats: return a data array of calculated stats (#685)
Browse files Browse the repository at this point in the history
* zonal_stats returns a DataArray

* flake8
  • Loading branch information
thuydotm committed Mar 31, 2022
1 parent b02d5f8 commit ea6a465
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 42 deletions.
133 changes: 132 additions & 1 deletion xrspatial/tests/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from xrspatial import zonal_stats as stats
from xrspatial.zonal import regions

from .general_checks import create_test_raster, has_cuda_and_cupy
from .general_checks import create_test_raster, general_output_checks, has_cuda_and_cupy


@pytest.fixture
Expand Down Expand Up @@ -60,6 +60,40 @@ def result_default_stats():
return expected_result


@pytest.fixture
def result_default_stats_dataarray():
expected_result = np.array(
[[[0., 0., 1., 1., 2., 2., 2.4, 2.4],
[0., 0., 1., 1., 2., 2., 2.4, 2.4],
[0., 0., 1., 1., 2., np.nan, 2.4, 2.4]],

[[0., 0., 1., 1., 2., 2., 3., 3.],
[0., 0., 1., 1., 2., 2., 3., 3.],
[0., 0., 1., 1., 2., np.nan, 3., 3.]],

[[0., 0., 1., 1., 2., 2., 0., 0.],
[0., 0., 1., 1., 2., 2., 0., 0.],
[0., 0., 1., 1., 2., np.nan, 0., 0.]],

[[0., 0., 6., 6., 8., 8., 12., 12.],
[0., 0., 6., 6., 8., 8., 12., 12.],
[0., 0., 6., 6., 8., np.nan, 12., 12.]],

[[0., 0., 0., 0., 0., 0., 1.2, 1.2],
[0., 0., 0., 0., 0., 0., 1.2, 1.2],
[0., 0., 0., 0., 0., np.nan, 1.2, 1.2]],

[[0., 0., 0., 0., 0., 0., 1.44, 1.44],
[0., 0., 0., 0., 0., 0., 1.44, 1.44],
[0., 0., 0., 0., 0., np.nan, 1.44, 1.44]],

[[5., 5., 6., 6., 4., 4., 5., 5.],
[5., 5., 6., 6., 4., 4., 5., 5.],
[5., 5., 6., 6., 4., np.nan, 5., 5.]]]
)
return expected_result


@pytest.fixture
def result_zone_ids_stats():
zone_ids = [0, 3]
Expand All @@ -76,6 +110,41 @@ def result_zone_ids_stats():
return zone_ids, expected_result


@pytest.fixture
def result_zone_ids_stats_dataarray():
zone_ids = [0, 3]
expected_result = np.array(
[[[0., 0., np.nan, np.nan, np.nan, np.nan, 2.4, 2.4],
[0., 0., np.nan, np.nan, np.nan, np.nan, 2.4, 2.4],
[0., 0., np.nan, np.nan, np.nan, np.nan, 2.4, 2.4]],

[[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.],
[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.],
[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.]],

[[0., 0., np.nan, np.nan, np.nan, np.nan, 0., 0.],
[0., 0., np.nan, np.nan, np.nan, np.nan, 0., 0.],
[0., 0., np.nan, np.nan, np.nan, np.nan, 0., 0.]],

[[0., 0., np.nan, np.nan, np.nan, np.nan, 12., 12.],
[0., 0., np.nan, np.nan, np.nan, np.nan, 12., 12.],
[0., 0., np.nan, np.nan, np.nan, np.nan, 12., 12.]],

[[0., 0., np.nan, np.nan, np.nan, np.nan, 1.2, 1.2],
[0., 0., np.nan, np.nan, np.nan, np.nan, 1.2, 1.2],
[0., 0., np.nan, np.nan, np.nan, np.nan, 1.2, 1.2]],

[[0., 0., np.nan, np.nan, np.nan, np.nan, 1.44, 1.44],
[0., 0., np.nan, np.nan, np.nan, np.nan, 1.44, 1.44],
[0., 0., np.nan, np.nan, np.nan, np.nan, 1.44, 1.44]],

[[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.],
[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.],
[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.]]])

return zone_ids, expected_result


def _double_sum(values):
return values.sum() * 2

Expand All @@ -96,6 +165,22 @@ def result_custom_stats():
return nodata_values, zone_ids, expected_result


@pytest.fixture
def result_custom_stats_dataarray():
zone_ids = [1, 2]
nodata_values = 0
expected_result = np.array(
[[[np.nan, np.nan, 12., 12., 16., 16., np.nan, np.nan],
[np.nan, np.nan, 12., 12., 16., 16., np.nan, np.nan],
[np.nan, np.nan, 12., 12., 16., np.nan, np.nan, np.nan]],

[[np.nan, np.nan, 0., 0., 0., 0., np.nan, np.nan],
[np.nan, np.nan, 0., 0., 0., 0., np.nan, np.nan],
[np.nan, np.nan, 0., 0., 0., np.nan, np.nan, np.nan]]]
)
return nodata_values, zone_ids, expected_result


@pytest.fixture
def result_count_crosstab_2d():
zone_ids = [1, 2, 3]
Expand Down Expand Up @@ -174,6 +259,22 @@ def test_default_stats(backend, data_zones, data_values_2d, result_default_stats
check_results(backend, df_result, result_default_stats)


@pytest.mark.parametrize("backend", ['numpy'])
def test_default_stats_dataarray(
backend, data_zones, data_values_2d, result_default_stats_dataarray
):
dataarray_result = stats(
zones=data_zones, values=data_values_2d, return_type='xarray.DataArray'
)
general_output_checks(
data_values_2d,
dataarray_result,
result_default_stats_dataarray,
verify_dtype=False,
verify_attrs=False,
)


@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy'])
def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_stats):
if backend == 'cupy' and not has_cuda_and_cupy():
Expand All @@ -184,6 +285,19 @@ def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_sta
check_results(backend, df_result, expected_result)


@pytest.mark.parametrize("backend", ['numpy'])
def test_zone_ids_stats_dataarray(
backend, data_zones, data_values_2d, result_zone_ids_stats_dataarray
):
zone_ids, expected_result = result_zone_ids_stats_dataarray
dataarray_result = stats(
zones=data_zones, values=data_values_2d, zone_ids=zone_ids, return_type='xarray.DataArray'
)
general_output_checks(
data_values_2d, dataarray_result, expected_result, verify_dtype=False, verify_attrs=False
)


@pytest.mark.parametrize("backend", ['numpy', 'cupy'])
def test_custom_stats(backend, data_zones, data_values_2d, result_custom_stats):
# ---- custom stats (NumPy and CuPy only) ----
Expand All @@ -203,6 +317,23 @@ def test_custom_stats(backend, data_zones, data_values_2d, result_custom_stats):
check_results(backend, df_result, expected_result)


@pytest.mark.parametrize("backend", ['numpy'])
def test_custom_stats_dataarray(backend, data_zones, data_values_2d, result_custom_stats_dataarray):
# ---- custom stats returns a xr.DataArray (NumPy only) ----
custom_stats = {
'double_sum': _double_sum,
'range': _range,
}
nodata_values, zone_ids, expected_result = result_custom_stats_dataarray
dataarray_result = stats(
zones=data_zones, values=data_values_2d, stats_funcs=custom_stats,
zone_ids=zone_ids, nodata_values=nodata_values, return_type='xarray.DataArray'
)
general_output_checks(
data_values_2d, dataarray_result, expected_result, verify_dtype=False, verify_attrs=False
)


@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
def test_count_crosstab_2d(backend, data_zones, data_values_2d, result_count_crosstab_2d):
zone_ids, cat_ids, expected_result = result_count_crosstab_2d
Expand Down
104 changes: 63 additions & 41 deletions xrspatial/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _sort_and_stride(zones, values, unique_zones):
sorted_zones = sorted_zones[np.isfinite(sorted_zones)]
zone_breaks = _strides(sorted_zones, unique_zones)

return values_by_zones, zone_breaks
return sorted_indices, values_by_zones, zone_breaks


def _calc_stats(
Expand All @@ -123,8 +123,7 @@ def _calc_stats(
if unique_zones[i] in zone_ids:
zone_values = values_by_zones[start:end]
# filter out non-finite and nodata_values
zone_values = zone_values[
np.isfinite(zone_values) & (zone_values != nodata_values)]
zone_values = zone_values[np.isfinite(zone_values) & (zone_values != nodata_values)]
if len(zone_values) > 0:
results[i] = func(zone_values)
start = end
Expand All @@ -141,13 +140,8 @@ def _single_stats_func(
nodata_values: Union[int, float] = None,
) -> pd.DataFrame:

values_by_zones, zone_breaks = _sort_and_stride(
zones_block, values_block, unique_zones
)
results = _calc_stats(
values_by_zones, zone_breaks,
unique_zones, zone_ids, func, nodata_values
)
_, values_by_zones, zone_breaks = _sort_and_stride(zones_block, values_block, unique_zones)
results = _calc_stats(values_by_zones, zone_breaks, unique_zones, zone_ids, func, nodata_values)
return results


Expand Down Expand Up @@ -224,19 +218,15 @@ def _stats_dask_numpy(
stats_dict['mean'] = _dask_mean(stats_dict['sum'], stats_dict['count'])
if 'std' in stats_funcs:
stats_dict['std'] = _dask_std(
stats_dict['sum_squares'], stats_dict['sum'] ** 2,
stats_dict['count']
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
)
if 'var' in stats_funcs:
stats_dict['var'] = _dask_var(
stats_dict['sum_squares'], stats_dict['sum'] ** 2,
stats_dict['count']
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
)

# generate dask dataframe
stats_df = dd.concat(
[dd.from_dask_array(s) for s in stats_dict.values()], axis=1
)
stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1)
# name columns
stats_df.columns = stats_dict.keys()
# select columns
Expand All @@ -259,7 +249,8 @@ def _stats_numpy(
zone_ids: List[Union[int, float]],
stats_funcs: Dict,
nodata_values: Union[int, float],
) -> pd.DataFrame:
return_type: str,
) -> Union[pd.DataFrame, np.ndarray]:

# find ids for all zones
unique_zones = np.unique(zones[np.isfinite(zones)])
Expand All @@ -271,23 +262,40 @@ def _stats_numpy(
# remove zones that do not exist in `zones` raster
zone_ids = [z for z in zone_ids if z in unique_zones]

selected_indexes = [i for i, z in enumerate(unique_zones) if z in zone_ids]
values_by_zones, zone_breaks = _sort_and_stride(
zones, values, unique_zones
)

stats_dict = {}
stats_dict["zone"] = zone_ids
for stats in stats_funcs:
func = stats_funcs.get(stats)
stats_dict[stats] = _calc_stats(
values_by_zones, zone_breaks,
unique_zones, zone_ids, func, nodata_values
)
stats_dict[stats] = stats_dict[stats][selected_indexes]
sorted_indices, values_by_zones, zone_breaks = _sort_and_stride(zones, values, unique_zones)
if return_type == 'pandas.DataFrame':
stats_dict = {}
stats_dict["zone"] = zone_ids
selected_indexes = [i for i, z in enumerate(unique_zones) if z in zone_ids]
for stats in stats_funcs:
func = stats_funcs.get(stats)
stats_dict[stats] = _calc_stats(
values_by_zones, zone_breaks,
unique_zones, zone_ids, func, nodata_values
)
stats_dict[stats] = stats_dict[stats][selected_indexes]
result = pd.DataFrame(stats_dict)

stats_df = pd.DataFrame(stats_dict)
return stats_df
else:
result = np.full((len(stats_funcs), values.size), np.nan)
zone_ids_map = {z: i for i, z in enumerate(unique_zones) if z in zone_ids}
stats_id = 0
for stats in stats_funcs:
func = stats_funcs.get(stats)
stats_results = _calc_stats(
values_by_zones, zone_breaks,
unique_zones, zone_ids, func, nodata_values
)
for zone in zone_ids:
iz = zone_ids_map[zone] # position of zone in unique_zones
if iz == 0:
zs = sorted_indices[: zone_breaks[iz]]
else:
zs = sorted_indices[zone_breaks[iz-1]: zone_breaks[iz]]
result[stats_id][zs] = stats_results[iz]
stats_id += 1
result = result.reshape(len(stats_funcs), *values.shape)
return result


def _stats_cupy(
Expand Down Expand Up @@ -391,7 +399,8 @@ def stats(
"count",
],
nodata_values: Union[int, float] = None,
) -> Union[pd.DataFrame, dd.DataFrame]:
return_type: str = 'pandas.DataFrame',
) -> Union[pd.DataFrame, dd.DataFrame, xr.DataArray]:
"""
Calculate summary statistics for each zone defined by a `zones`
dataset, based on `values` aggregate.
Expand Down Expand Up @@ -438,6 +447,11 @@ def stats(
Cells with `nodata_values` do not belong to any zone,
and thus excluded from calculation.
return_type: str, default='pandas.DataFrame'
Format of returned data. If `zones` and `values` numpy backed xarray DataArray,
allowed values are 'pandas.DataFrame', and 'xarray.DataArray'.
Otherwise, only 'pandas.DataFrame' is supported.
Returns
-------
stats_df : Union[pandas.DataFrame, dask.dataframe.DataFrame]
Expand Down Expand Up @@ -568,17 +582,25 @@ def stats(
stats_funcs_dict = stats_funcs.copy()

mapper = ArrayTypeFunctionMapping(
numpy_func=_stats_numpy,
numpy_func=lambda *args: _stats_numpy(*args, return_type=return_type),
dask_func=_stats_dask_numpy,
cupy_func=_stats_cupy,
dask_cupy_func=lambda *args: not_implemented_func(
*args, messages='stats() does not support dask with cupy backed DataArray' # noqa
),
)
stats_df = mapper(values)(
zones.data, values.data, zone_ids, stats_funcs_dict, nodata_values
result = mapper(values)(
zones.data, values.data, zone_ids, stats_funcs_dict, nodata_values,
)
return stats_df

if return_type == 'xarray.DataArray':
return xr.DataArray(
result,
coords={'stats': list(stats_funcs_dict.keys()), **values.coords},
dims=('stats', *values.dims),
attrs=values.attrs
)
return result


def _find_cats(values, cat_ids, nodata_values):
Expand Down Expand Up @@ -680,7 +702,7 @@ def _crosstab_numpy(
for cat in cat_ids:
crosstab_dict[cat] = []

values_by_zones, zone_breaks = _sort_and_stride(
_, values_by_zones, zone_breaks = _sort_and_stride(
zones, values, unique_zones
)

Expand Down Expand Up @@ -731,7 +753,7 @@ def _single_chunk_crosstab(
for cat in cat_ids:
results[cat] = []

values_by_zones, zone_breaks = _sort_and_stride(
_, values_by_zones, zone_breaks = _sort_and_stride(
zones_block, values_block, unique_zones
)

Expand Down

0 comments on commit ea6a465

Please sign in to comment.