Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions xrspatial/tests/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,14 @@ def check_results(df_np, df_da, expected_results_dict):
df_np[col], expected_results_dict[col], equal_nan=True
).all()

# dask case
assert isinstance(df_da, dd.DataFrame)
df_da = df_da.compute()
assert isinstance(df_da, pd.DataFrame)

# numpy results equal dask results
# zone column
assert (df_np['zone'] == df_da['zone']).all()
if df_da is not None:
# dask case
assert isinstance(df_da, dd.DataFrame)
df_da = df_da.compute()
assert isinstance(df_da, pd.DataFrame)

assert (df_np.columns == df_da.columns).all()
for col in df_np.columns[1:]:
assert np.isclose(df_np[col], df_da[col], equal_nan=True).all()
# numpy results equal dask results, ignoring their indexes
assert np.array_equal(df_np.values, df_da.values, equal_nan=True)


def test_stats():
Expand All @@ -93,7 +89,27 @@ def test_stats():
df_da = stats(zones=zones_da, values=values_da)
check_results(df_np, df_da, default_stats_results)

# ---- custom stats ----
# expected results
stats_results_zone_0_3 = {
'zone': [0, 3],
'mean': [0, 2.4],
'max': [0, 3],
'min': [0, 0],
'sum': [0, 12],
'std': [0, 1.2],
'var': [0, 1.44],
'count': [5, 5]
}

# numpy case
df_np_zone_0_3 = stats(zones=zones_np, values=values_np, zone_ids=[0, 3])

# dask case
df_da_zone_0_3 = stats(zones=zones_da, values=values_da, zone_ids=[0, 3])

check_results(df_np_zone_0_3, df_da_zone_0_3, stats_results_zone_0_3)

# ---- custom stats (NumPy only) ----
# expected results
custom_stats_results = {
'zone': [1, 2],
Expand All @@ -115,13 +131,10 @@ def _range(values):
# numpy case
df_np = stats(
zones=zones_np, values=values_np, stats_funcs=custom_stats,
zone_ids=[1, 2], nodata_zones=0, nodata_values=0
zone_ids=[1, 2], nodata_values=0
)
# dask case
df_da = stats(
zones=zones_da, values=values_da, stats_funcs=custom_stats,
zone_ids=[1, 2], nodata_zones=0, nodata_values=0
)
df_da = None
check_results(df_np, df_da, custom_stats_results)


Expand Down
Loading