Skip to content

Commit

Permalink
Add dask-expr handling to test_sort_topk
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbluca committed Apr 9, 2024
1 parent 76b55e3 commit 876d282
Showing 1 changed file with 35 additions and 13 deletions.
48 changes: 35 additions & 13 deletions tests/integration/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,28 @@ def test_sort_by_old_alias(c, input_table_1, request):
]


def check_sort_topk(df, layer, contains=True):
if dd._dask_expr_enabled():
from dask_expr._reductions import NLargest, NSmallest

if layer == "nsmallest":
assert len(list(df.expr.find_operations(NSmallest))) == (
1 if contains else 0
)
elif layer == "nlargest":
assert len(list(df.expr.find_operations(NLargest))) == (
1 if contains else 0
)
else:
assert False
else:
assert (
any([layer in key for key in df.dask.layers.keys()])
if contains
else all([layer not in key for key in df.dask.layers.keys()])
)


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_topk(gpu):
c = Context()
Expand All @@ -366,7 +388,7 @@ def test_sort_topk(gpu):
c.create_table("df", dd.from_pandas(df, npartitions=10), gpu=gpu)

df_result = c.sql("""SELECT * FROM df ORDER BY a LIMIT 10""")
assert any(["nsmallest" in key for key in df_result.dask.layers.keys()])
check_sort_topk(df_result, "nsmallest", True)
assert_eq(
df_result,
pd.DataFrame(
Expand All @@ -380,7 +402,7 @@ def test_sort_topk(gpu):
)

df_result = c.sql("""SELECT * FROM df ORDER BY a, b LIMIT 10""")
assert any(["nsmallest" in key for key in df_result.dask.layers.keys()])
check_sort_topk(df_result, "nsmallest", True)
assert_eq(
df_result,
pd.DataFrame({"a": [1.0] * 10, "b": [1] * 10, "c": ["a"] * 10}),
Expand All @@ -390,7 +412,7 @@ def test_sort_topk(gpu):
df_result = c.sql(
"""SELECT * FROM df ORDER BY a DESC NULLS LAST, b DESC NULLS LAST LIMIT 10"""
)
assert any(["nlargest" in key for key in df_result.dask.layers.keys()])
check_sort_topk(df_result, "nlargest", True)
assert_eq(
df_result,
pd.DataFrame({"a": [1.0] * 10, "b": [3] * 10, "c": ["c"] * 10}),
Expand All @@ -400,8 +422,8 @@ def test_sort_topk(gpu):
# String column nlargest/smallest not supported for pandas
df_result = c.sql("""SELECT * FROM df ORDER BY c LIMIT 10""")
if not gpu:
assert all(["nlargest" not in key for key in df_result.dask.layers.keys()])
assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()])
check_sort_topk(df_result, "nsmallest", False)
check_sort_topk(df_result, "nlargest", False)
else:
assert_eq(
df_result,
Expand All @@ -413,24 +435,24 @@ def test_sort_topk(gpu):
df_result = c.sql(
"""SELECT * FROM df ORDER BY a DESC, b DESC NULLS LAST LIMIT 10"""
)
assert all(["nlargest" not in key for key in df_result.dask.layers.keys()])
assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()])
check_sort_topk(df_result, "nlargest", False)
check_sort_topk(df_result, "nsmallest", False)

# Assert optimization isn't applied for mixed asc + desc sort
df_result = c.sql("""SELECT * FROM df ORDER BY a, b DESC NULLS LAST LIMIT 10""")
assert all(["nlargest" not in key for key in df_result.dask.layers.keys()])
assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()])
check_sort_topk(df_result, "nlargest", False)
check_sort_topk(df_result, "nsmallest", False)

# Assert optimization isn't applied when the number of requested elements
# exceed topk-nelem-limit config value
# Default topk-nelem-limit is 1M and 334k*3columns takes it above this limit
df_result = c.sql("""SELECT * FROM df ORDER BY a, b LIMIT 333334""")
assert all(["nlargest" not in key for key in df_result.dask.layers.keys()])
assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()])
check_sort_topk(df_result, "nlargest", False)
check_sort_topk(df_result, "nsmallest", False)

df_result = c.sql(
"""SELECT * FROM df ORDER BY a, b LIMIT 10""",
config_options={"sql.sort.topk-nelem-limit": 29},
)
assert all(["nlargest" not in key for key in df_result.dask.layers.keys()])
assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()])
check_sort_topk(df_result, "nlargest", False)
check_sort_topk(df_result, "nsmallest", False)

0 comments on commit 876d282

Please sign in to comment.