From 2c6d6ea219d73e91a318f50fe6d72a46d06a94d6 Mon Sep 17 00:00:00 2001 From: thuydotm Date: Wed, 27 Oct 2021 16:45:38 +0700 Subject: [PATCH 1/6] safely removed nodata_zones arg --- xrspatial/tests/test_zonal.py | 4 ++-- xrspatial/zonal.py | 26 -------------------------- 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/xrspatial/tests/test_zonal.py b/xrspatial/tests/test_zonal.py index 31d6af12..1c9c516d 100755 --- a/xrspatial/tests/test_zonal.py +++ b/xrspatial/tests/test_zonal.py @@ -115,12 +115,12 @@ def _range(values): # numpy case df_np = stats( zones=zones_np, values=values_np, stats_funcs=custom_stats, - zone_ids=[1, 2], nodata_zones=0, nodata_values=0 + zone_ids=[1, 2], nodata_values=0 ) # dask case df_da = stats( zones=zones_da, values=values_da, stats_funcs=custom_stats, - zone_ids=[1, 2], nodata_zones=0, nodata_values=0 + zone_ids=[1, 2], nodata_values=0 ) check_results(df_np, df_da, custom_stats_results) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index 53f130b7..26ee7502 100755 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -32,14 +32,6 @@ def _stats_count(data): ) -def _to_int(numeric_value): - # convert an integer in float type to integer type - # if not an integer, return the value itself - if float(numeric_value).is_integer(): - return int(numeric_value) - return numeric_value - - def _zone_cat_data( zones, values, @@ -106,7 +98,6 @@ def _stats( raise ValueError(stats) stats_dict[stats].append(stats_func(zone_values)) - unique_zones = list(map(_to_int, unique_zones)) stats_dict["zone"] = unique_zones return stats_dict @@ -117,7 +108,6 @@ def _stats_dask( values: xr.DataArray, zone_ids: List[Union[int, float]], stats_funcs: Dict, - nodata_zones: Union[int, float], nodata_values: Union[int, float], ) -> pd.DataFrame: @@ -125,8 +115,6 @@ def _stats_dask( # no zone_ids provided, find ids for all zones # precompute unique zones unique_zones = da.unique(zones.data[da.isfinite(zones.data)]).compute() - # do not consider zone with nodata values - unique_zones = sorted(list(set(unique_zones) - set([nodata_zones]))) else: unique_zones = np.array(zone_ids) @@ -179,19 +167,15 @@ def _stats_numpy( values: xr.DataArray, zone_ids: List[Union[int, float]], stats_funcs: Dict, - nodata_zones: Union[int, float], nodata_values: Union[int, float], ) -> pd.DataFrame: if zone_ids is None: # no zone_ids provided, find ids for all zones - # do not consider zone with nodata values unique_zones = np.unique(zones.data[np.isfinite(zones.data)]) - unique_zones = sorted(list(set(unique_zones) - set([nodata_zones]))) else: unique_zones = zone_ids - unique_zones = list(map(_to_int, unique_zones)) unique_zones = np.asarray(unique_zones) stats_dict = {} @@ -244,7 +228,6 @@ def stats( "var", "count", ], - nodata_zones: Optional[Union[int, float]] = None, nodata_values: Union[int, float] = None, ) -> Union[pd.DataFrame, dd.DataFrame]: """ @@ -283,11 +266,6 @@ def stats( callable. Function takes only one argument that is the `values` raster. The key become the column name in the output DataFrame. - nodata_zones: int, float, default=None - Nodata value in `zones` raster. - Cells with `nodata_zones` do not belong to any zone, - and thus excluded from calculation. - nodata_values: int, float, default=None Nodata value in `values` raster. Cells with `nodata_values` do not belong to any zone, @@ -404,7 +382,6 @@ def stats( values, zone_ids, stats_funcs_dict, - nodata_zones, nodata_values ) else: @@ -414,7 +391,6 @@ def stats( values, zone_ids, stats_funcs_dict, - nodata_zones, nodata_values ) @@ -424,8 +400,6 @@ def stats( def _crosstab_dict(zones, values, unique_zones, cats, nodata_values, agg): crosstab_dict = {} - - unique_zones = list(map(_to_int, unique_zones)) crosstab_dict["zone"] = unique_zones for i in cats: From b2d10813126f36b6013c1612007fd0386018fb6f Mon Sep 17 00:00:00 2001 From: thuydotm Date: Wed, 10 Nov 2021 19:16:07 +0700 Subject: [PATCH 2/6] dask zonal stats --- xrspatial/tests/test_zonal.py | 26 ++-- xrspatial/zonal.py | 259 ++++++++++++++++++++++++---------- 2 files changed, 196 insertions(+), 89 deletions(-) diff --git a/xrspatial/tests/test_zonal.py b/xrspatial/tests/test_zonal.py index 1c9c516d..9bb2890b 100755 --- a/xrspatial/tests/test_zonal.py +++ b/xrspatial/tests/test_zonal.py @@ -56,18 +56,19 @@ def check_results(df_np, df_da, expected_results_dict): df_np[col], expected_results_dict[col], equal_nan=True ).all() - # dask case - assert isinstance(df_da, dd.DataFrame) - df_da = df_da.compute() - assert isinstance(df_da, pd.DataFrame) + if df_da is not None: + # dask case + assert isinstance(df_da, dd.DataFrame) + df_da = df_da.compute() + assert isinstance(df_da, pd.DataFrame) - # numpy results equal dask results - # zone column - assert (df_np['zone'] == df_da['zone']).all() + # numpy results equal dask results + # zone column + assert (df_np['zone'] == df_da['zone']).all() - assert (df_np.columns == df_da.columns).all() - for col in df_np.columns[1:]: - assert np.isclose(df_np[col], df_da[col], equal_nan=True).all() + assert (df_np.columns == df_da.columns).all() + for col in df_np.columns[1:]: + assert np.isclose(df_np[col], df_da[col], equal_nan=True).all() def test_stats(): @@ -118,10 +119,7 @@ def _range(values): zone_ids=[1, 2], nodata_values=0 ) # dask case - df_da = stats( - zones=zones_da, values=values_da, stats_funcs=custom_stats, - zone_ids=[1, 2], nodata_values=0 - ) + df_da = None check_results(df_np, df_da, custom_stats_results) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index 26ee7502..b75dc516 100755 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -1,13 +1,15 @@ from math import sqrt from typing import Optional, Callable, Union, Dict, List -import dask.array as da -import dask.dataframe as dd import numpy as np import pandas as pd import xarray as xr from xarray import DataArray +import dask.array as da +import dask.dataframe as dd +from dask import delayed + from xrspatial.utils import ngjit @@ -32,6 +34,28 @@ def _stats_count(data): ) +_DASK_BLOCK_STATS = dict( + max=lambda z: z.max(), + min=lambda z: z.min(), + sum=lambda z: z.sum(), + count=lambda z: _stats_count(z), + sum_squares=lambda z: (z**2).sum() +) + + +_DASK_STATS = dict( + max=lambda block_maxes: np.nanmax(block_maxes, axis=0), + min=lambda block_mins: np.nanmin(block_mins, axis=0), + sum=lambda block_sums: np.nansum(block_sums, axis=0), + count=lambda block_counts: np.nansum(block_counts, axis=0), + sum_squares=lambda block_sum_squares: np.nansum(block_sum_squares, axis=0), + squared_sum=lambda block_sums: np.nansum(block_sums, axis=0)**2, +) +_dask_mean=lambda sums, counts: sums / counts +_dask_std=lambda sum_squares, squared_sum, n: np.sqrt((sum_squares - squared_sum/n) / n) # noqa +_dask_var=lambda sum_squares, squared_sum, n: (sum_squares - squared_sum/n) / n # noqa + + def _zone_cat_data( zones, values, @@ -103,92 +127,172 @@ def _stats( return stats_dict -def _stats_dask( - zones: xr.DataArray, - values: xr.DataArray, - zone_ids: List[Union[int, float]], - stats_funcs: Dict, - nodata_values: Union[int, float], +@ngjit +def _strides(flatten_zones, unique_zones): + num_elements = flatten_zones.shape[0] + num_zones = len(unique_zones) + strides = np.zeros(len(unique_zones), dtype=np.int32) + + count = 0 + for i in range(num_zones): + while (count < num_elements) and ( + flatten_zones[count] == unique_zones[i]): + count += 1 + strides[i] = count + + return strides + + +@delayed +def _stats_func_dask_numpy( + zones: np.array, + values: np.array, + unique_zones: np.array, + zone_ids: np.array, + func: callable, + nodata_values: Union[int, float] = None, +) -> pd.DataFrame: + + sorted_zones = np.sort(zones.flatten()) + sored_indices = np.argsort(zones.flatten()) + values_by_zones = values.flatten()[sored_indices] + + # exclude nans from calculation + # flatten_zones is already sorted, NaN elements (if any) are at the end + # of the array, removing them will not affect data before them + sorted_zones = sorted_zones[np.isfinite(sorted_zones)] + zone_breaks = _strides(sorted_zones, unique_zones) + + start = 0 + results = np.zeros(unique_zones.shape) * np.nan + for i in range(len(unique_zones)): + end = zone_breaks[i] + if unique_zones[i] in zone_ids: + zone_values = values_by_zones[start: end] + # filter out non-finite and nodata_values + zone_values = zone_values[ + np.isfinite(zone_values) & (zone_values != nodata_values)] + if len(zone_values) > 0: + results[i] = func(zone_values) + start = end + + return results + + +def _stats_dask_numpy( + zones: da.Array, + values: da.Array, + zone_ids: List[Union[int, float]], + stats_funcs: Dict, + nodata_values: Union[int, float], ) -> pd.DataFrame: + # find ids for all zones + unique_zones = np.unique(zones[np.isfinite(zones)]) + # selected zones to do analysis if zone_ids is None: - # no zone_ids provided, find ids for all zones - # precompute unique zones - unique_zones = da.unique(zones.data[da.isfinite(zones.data)]).compute() - else: - unique_zones = np.array(zone_ids) + zone_ids = unique_zones - stats_dict = _stats( - zones, - values, - unique_zones, - stats_funcs, - nodata_values - ) + # # remove zones that does not exist in `zones` raster + # zone_ids = [z for z in zone_ids if z in unique_zones] - stats_dict = { - stats: da.stack(zonal_stats, axis=0) - for stats, zonal_stats in stats_dict.items() - } + zones_blocks = zones.to_delayed().ravel() + values_blocks = values.to_delayed().ravel() + + stats_dict = {} + stats_dict["zone"] = unique_zones # zone column + + compute_sum_squares = False + compute_sum = False + compute_count = False + + if 'mean' or 'std' or 'var' in stats_funcs: + compute_sum = True + compute_count = True + + if 'std' or 'var' in stats_funcs: + compute_sum_squares = True + + basis_stats = [s for s in _DASK_BLOCK_STATS if s in stats_funcs] + if compute_count and 'count' not in basis_stats: + basis_stats.append('count') + if compute_sum and 'sum' not in basis_stats: + basis_stats.append('sum') + if compute_sum_squares: + basis_stats.append('sum_squares') + + for s in basis_stats: + if s == 'sum_squares' and not compute_sum_squares: + continue + stats_func = _DASK_BLOCK_STATS.get(s) + if not callable(stats_func): + raise ValueError(s) + stats_by_block = [ + da.from_delayed( + delayed(_stats_func_dask_numpy)( + z, v, unique_zones, zone_ids, stats_func, nodata_values + ), shape=(np.nan,), dtype=np.float64 + ) + for z, v in zip(zones_blocks, values_blocks) + ] + zonal_stats = da.stack(stats_by_block, allow_unknown_chunksizes=True) + stats_func_by_block = delayed(_DASK_STATS[s]) + stats_dict[s] = da.from_delayed( + stats_func_by_block(zonal_stats), shape=(np.nan,), dtype=np.float64 + ) + + if 'mean' in stats_funcs: + stats_dict['mean'] = _dask_mean(stats_dict['sum'], stats_dict['count']) + if 'std' in stats_funcs: + stats_dict['std'] = _dask_std( + stats_dict['sum_squares'], stats_dict['sum'] ** 2, + stats_dict['count'] + ) + if 'var' in stats_funcs: + stats_dict['var'] = _dask_var( + stats_dict['sum_squares'], stats_dict['sum'] ** 2, + stats_dict['count'] + ) # generate dask dataframe stats_df = dd.concat( - [dd.from_dask_array(stats) for stats in stats_dict.values()], axis=1 + [dd.from_dask_array(s) for s in stats_dict.values()], axis=1 ) # name columns stats_df.columns = stats_dict.keys() stats_df.set_index("zone") + # select columns + stats_df = stats_df[["zone"] + list(stats_funcs.keys())] return stats_df -@ngjit -def _strides(flatten_zones, zone_ids): - num_elements = flatten_zones.shape[0] - strides = np.zeros(len(zone_ids), dtype=np.int32) - - zone_count = 0 - for i in range(num_elements - 1): - if (flatten_zones[i] != flatten_zones[i + 1]): - if flatten_zones[i] in zone_ids: - strides[zone_count] = i - zone_count += 1 - - # check last elements - if flatten_zones[num_elements - 1] != strides[zone_count - 1]: - if flatten_zones[num_elements - 1] in zone_ids: - strides[zone_count] = num_elements - 1 - - return strides - - def _stats_numpy( - zones: xr.DataArray, - values: xr.DataArray, - zone_ids: List[Union[int, float]], - stats_funcs: Dict, - nodata_values: Union[int, float], + zones: xr.DataArray, + values: xr.DataArray, + zone_ids: List[Union[int, float]], + stats_funcs: Dict, + nodata_values: Union[int, float], ) -> pd.DataFrame: - + # find ids for all zones + unique_zones = np.unique(zones[np.isfinite(zones)]) + # selected zones to do analysis if zone_ids is None: - # no zone_ids provided, find ids for all zones - unique_zones = np.unique(zones.data[np.isfinite(zones.data)]) - else: - unique_zones = zone_ids - - unique_zones = np.asarray(unique_zones) + zone_ids = unique_zones + # remove zones that does not exist in `zones` raster + zone_ids = [z for z in zone_ids if z in unique_zones] stats_dict = {} # zone column - stats_dict["zone"] = unique_zones + stats_dict["zone"] = zone_ids # stats columns for stats in stats_funcs: stats_dict[stats] = [] - flatten_zones = zones.data.flatten() + flatten_zones = zones.flatten() sorted_indices = np.argsort(flatten_zones) sorted_zones = flatten_zones[sorted_indices] - values_by_zones = values.data.flatten()[sorted_indices] + values_by_zones = values.flatten()[sorted_indices] # exclude nans from calculation # flatten_zones is already sorted, NaN elements (if any) are at the end @@ -198,20 +302,25 @@ def _stats_numpy( start = 0 for i in range(len(unique_zones)): - end = zone_breaks[i] + 1 - zone_values = values_by_zones[start:end] - # filter out non-finite and nodata_values - zone_values = zone_values[ - np.isfinite(zone_values) & (zone_values != nodata_values)] - for stats in stats_funcs: - stats_func = stats_funcs.get(stats) - if not callable(stats_func): - raise ValueError(stats) - stats_dict[stats].append(stats_func(zone_values)) + end = zone_breaks[i] + if unique_zones[i] in zone_ids: + zone_values = values_by_zones[start:end] + # filter out non-finite and nodata_values + zone_values = zone_values[ + np.isfinite(zone_values) & (zone_values != nodata_values)] + if len(zone_values) == 0: + stats_dict[stats].append(np.nan) + else: + for stats in stats_funcs: + stats_func = stats_funcs.get(stats) + if not callable(stats_func): + raise ValueError(stats) + stats_dict[stats].append(stats_func(zone_values)) start = end stats_df = pd.DataFrame(stats_dict) stats_df.set_index("zone") + return stats_df @@ -378,17 +487,17 @@ def stats( if isinstance(values.data, np.ndarray): # numpy case stats_df = _stats_numpy( - zones, - values, + zones.data, + values.data, zone_ids, stats_funcs_dict, nodata_values ) else: # dask case - stats_df = _stats_dask( - zones, - values, + stats_df = _stats_dask_numpy( + zones.data, + values.data, zone_ids, stats_funcs_dict, nodata_values From 46b65050a2c0e01e2211178d5e7ab81ca2304b7b Mon Sep 17 00:00:00 2001 From: thuydotm Date: Thu, 11 Nov 2021 00:06:52 +0700 Subject: [PATCH 3/6] dask case: support zone_ids --- xrspatial/tests/test_zonal.py | 31 +++++++++++++++++++++-------- xrspatial/zonal.py | 37 ++++++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/xrspatial/tests/test_zonal.py b/xrspatial/tests/test_zonal.py index 9bb2890b..2128325b 100755 --- a/xrspatial/tests/test_zonal.py +++ b/xrspatial/tests/test_zonal.py @@ -62,13 +62,8 @@ def check_results(df_np, df_da, expected_results_dict): df_da = df_da.compute() assert isinstance(df_da, pd.DataFrame) - # numpy results equal dask results - # zone column - assert (df_np['zone'] == df_da['zone']).all() - - assert (df_np.columns == df_da.columns).all() - for col in df_np.columns[1:]: - assert np.isclose(df_np[col], df_da[col], equal_nan=True).all() + # numpy results equal dask results, ignoring their indexes + assert np.array_equal(df_np.values, df_da.values, equal_nan=True) def test_stats(): @@ -94,7 +89,27 @@ def test_stats(): df_da = stats(zones=zones_da, values=values_da) check_results(df_np, df_da, default_stats_results) - # ---- custom stats ---- + # expected results + stats_results_zone_0_3 = { + 'zone': [0, 3], + 'mean': [0, 2.4], + 'max': [0, 3], + 'min': [0, 0], + 'sum': [0, 12], + 'std': [0, 1.2], + 'var': [0, 1.44], + 'count': [5, 5] + } + + # numpy case + df_np_zone_0_3 = stats(zones=zones_np, values=values_np, zone_ids=[0, 3]) + + # dask case + df_da_zone_0_3 = stats(zones=zones_da, values=values_da, zone_ids=[0, 3]) + + check_results(df_np_zone_0_3, df_da_zone_0_3, stats_results_zone_0_3) + + # ---- custom stats (NumPy only) ---- # expected results custom_stats_results = { 'zone': [1, 2], diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index b75dc516..ccbb096b 100755 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -51,9 +51,9 @@ def _stats_count(data): sum_squares=lambda block_sum_squares: np.nansum(block_sum_squares, axis=0), squared_sum=lambda block_sums: np.nansum(block_sums, axis=0)**2, ) -_dask_mean=lambda sums, counts: sums / counts -_dask_std=lambda sum_squares, squared_sum, n: np.sqrt((sum_squares - squared_sum/n) / n) # noqa -_dask_var=lambda sum_squares, squared_sum, n: (sum_squares - squared_sum/n) / n # noqa +_dask_mean = lambda sums, counts: sums / counts # noqa +_dask_std = lambda sum_squares, squared_sum, n: np.sqrt((sum_squares - squared_sum/n) / n) # noqa +_dask_var = lambda sum_squares, squared_sum, n: (sum_squares - squared_sum/n) / n # noqa def _zone_cat_data( @@ -189,12 +189,12 @@ def _stats_dask_numpy( # find ids for all zones unique_zones = np.unique(zones[np.isfinite(zones)]) - # selected zones to do analysis + + select_all_zones = False + # selecte zones to do analysis if zone_ids is None: zone_ids = unique_zones - - # # remove zones that does not exist in `zones` raster - # zone_ids = [z for z in zone_ids if z in unique_zones] + select_all_zones = True zones_blocks = zones.to_delayed().ravel() values_blocks = values.to_delayed().ravel() @@ -221,6 +221,15 @@ def _stats_dask_numpy( if compute_sum_squares: basis_stats.append('sum_squares') + dask_dtypes = dict( + max=values.dtype, + min=values.dtype, + sum=values.dtype, + count=np.int64, + sum_squares=values.dtype, + squared_sum=values.dtype, + ) + for s in basis_stats: if s == 'sum_squares' and not compute_sum_squares: continue @@ -231,7 +240,7 @@ def _stats_dask_numpy( da.from_delayed( delayed(_stats_func_dask_numpy)( z, v, unique_zones, zone_ids, stats_func, nodata_values - ), shape=(np.nan,), dtype=np.float64 + ), shape=(np.nan,), dtype=dask_dtypes[s] ) for z, v in zip(zones_blocks, values_blocks) ] @@ -260,9 +269,16 @@ def _stats_dask_numpy( ) # name columns stats_df.columns = stats_dict.keys() - stats_df.set_index("zone") # select columns - stats_df = stats_df[["zone"] + list(stats_funcs.keys())] + stats_df = stats_df[['zone'] + list(stats_funcs.keys())] + + if not select_all_zones: + # only return zones specified in `zone_ids` + selected_rows = [] + for index, row in stats_df.iterrows(): + if row['zone'] in zone_ids: + selected_rows.append(stats_df.loc[index]) + stats_df = dd.concat(selected_rows) return stats_df @@ -319,7 +335,6 @@ def _stats_numpy( start = end stats_df = pd.DataFrame(stats_dict) - stats_df.set_index("zone") return stats_df From 8039d63b5ad6b6408d11c900765a102baa19c7aa Mon Sep 17 00:00:00 2001 From: thuydotm Date: Mon, 15 Nov 2021 14:23:58 +0700 Subject: [PATCH 4/6] refactor --- xrspatial/zonal.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index ccbb096b..9b6b585b 100755 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -145,17 +145,17 @@ def _strides(flatten_zones, unique_zones): @delayed def _stats_func_dask_numpy( - zones: np.array, - values: np.array, + zones_block: np.array, + values_block: np.array, unique_zones: np.array, zone_ids: np.array, func: callable, nodata_values: Union[int, float] = None, ) -> pd.DataFrame: - sorted_zones = np.sort(zones.flatten()) - sored_indices = np.argsort(zones.flatten()) - values_by_zones = values.flatten()[sored_indices] + sorted_zones = np.sort(zones_block.flatten()) + sored_indices = np.argsort(zones_block.flatten()) + values_by_zones = values_block.flatten()[sored_indices] # exclude nans from calculation # flatten_zones is already sorted, NaN elements (if any) are at the end @@ -180,11 +180,11 @@ def _stats_func_dask_numpy( def _stats_dask_numpy( - zones: da.Array, - values: da.Array, - zone_ids: List[Union[int, float]], - stats_funcs: Dict, - nodata_values: Union[int, float], + zones: da.Array, + values: da.Array, + zone_ids: List[Union[int, float]], + stats_funcs: Dict, + nodata_values: Union[int, float], ) -> pd.DataFrame: # find ids for all zones From 5cd5fe42706ebdb7de68ae833d0d1181280ff62a Mon Sep 17 00:00:00 2001 From: thuydotm Date: Mon, 15 Nov 2021 14:53:38 +0700 Subject: [PATCH 5/6] update docs --- xrspatial/zonal.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index 9b6b585b..fb685aaa 100755 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -234,8 +234,6 @@ def _stats_dask_numpy( if s == 'sum_squares' and not compute_sum_squares: continue stats_func = _DASK_BLOCK_STATS.get(s) - if not callable(stats_func): - raise ValueError(s) stats_by_block = [ da.from_delayed( delayed(_stats_func_dask_numpy)( @@ -355,14 +353,14 @@ def stats( nodata_values: Union[int, float] = None, ) -> Union[pd.DataFrame, dd.DataFrame]: """ - Calculate summary statistics for each zone defined by a zone - dataset, based on values aggregate. + Calculate summary statistics for each zone defined by a `zones` + dataset, based on `values` aggregate. - A single output value is computed for every zone in the input zone + A single output value is computed for every zone in the input `zones` dataset. This function currently supports numpy backed, and dask with numpy backed - xarray DataArray. + xarray DataArrays. Parameters ---------- @@ -383,12 +381,15 @@ def stats( all zones will be used. stats_funcs : dict, or list of strings, default=['mean', 'max', 'min', - 'sum', 'std', 'var', 'count']) + 'sum', 'std', 'var', 'count'] The statistics to calculate for each zone. If a list, possible choices are subsets of the default options. In the dictionary case, all of its values must be callable. Function takes only one argument that is the `values` raster. The key become the column name in the output DataFrame. + Note that if `zones` and `values` are dask backed DataArrays, + `stats_funcs` must be provided as a list that is a subset of + default supported stats. nodata_values: int, float, default=None Nodata value in `values` raster. @@ -484,6 +485,14 @@ def stats( ): raise ValueError("`values` must be an array of integers or floats.") + # validate stats_funcs + if isinstance(values.data, da.Array) and not isinstance(stats_funcs, list): + raise ValueError( + "Got dask-backed DataArray as `values` aggregate. " + "`stats_funcs` must be a subset of default supported stats " + "`[\'mean\', \'max\', \'min\', \'sum\', \'std\', \'var\', \'count\']`" # noqa + ) + if isinstance(stats_funcs, list): # create a dict of stats stats_funcs_dict = {} From 632c54a7253386b85f39adaec3f0d1135d2c2d62 Mon Sep 17 00:00:00 2001 From: thuydotm Date: Tue, 16 Nov 2021 12:08:52 +0700 Subject: [PATCH 6/6] clean code --- xrspatial/zonal.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index fb685aaa..f64f0f27 100755 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -149,13 +149,13 @@ def _stats_func_dask_numpy( values_block: np.array, unique_zones: np.array, zone_ids: np.array, - func: callable, + func: Callable, nodata_values: Union[int, float] = None, ) -> pd.DataFrame: sorted_zones = np.sort(zones_block.flatten()) - sored_indices = np.argsort(zones_block.flatten()) - values_by_zones = values_block.flatten()[sored_indices] + sorted_indices = np.argsort(zones_block.flatten()) + values_by_zones = values_block.flatten()[sorted_indices] # exclude nans from calculation # flatten_zones is already sorted, NaN elements (if any) are at the end @@ -164,11 +164,11 @@ def _stats_func_dask_numpy( zone_breaks = _strides(sorted_zones, unique_zones) start = 0 - results = np.zeros(unique_zones.shape) * np.nan + results = np.full(unique_zones.shape, np.nan) for i in range(len(unique_zones)): end = zone_breaks[i] if unique_zones[i] in zone_ids: - zone_values = values_by_zones[start: end] + zone_values = values_by_zones[start:end] # filter out non-finite and nodata_values zone_values = zone_values[ np.isfinite(zone_values) & (zone_values != nodata_values)]