diff --git a/bigframes/core/rewrite/pruning.py b/bigframes/core/rewrite/pruning.py index 8a07f0b87e..41664e1c47 100644 --- a/bigframes/core/rewrite/pruning.py +++ b/bigframes/core/rewrite/pruning.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses import functools +import itertools import typing from bigframes.core import identifiers, nodes @@ -51,17 +52,9 @@ def prune_columns(node: nodes.BigFrameNode): if isinstance(node, nodes.SelectionNode): result = prune_selection_child(node) elif isinstance(node, nodes.ResultNode): - result = node.replace_child( - prune_node( - node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1]) - ) - ) + result = node.replace_child(prune_node(node.child, node.consumed_ids)) elif isinstance(node, nodes.AggregateNode): - result = node.replace_child( - prune_node( - node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1]) - ) - ) + result = node.replace_child(prune_node(node.child, node.consumed_ids)) elif isinstance(node, nodes.InNode): result = dataclasses.replace( node, @@ -149,9 +142,13 @@ def prune_node( if not (set(node.ids) - ids): return node else: + # If no child ids are needed, probably a size op or numbering op above, keep a single column always + ids_to_keep = tuple(id for id in node.ids if id in ids) or tuple( + itertools.islice(node.ids, 0, 1) + ) return nodes.SelectionNode( node, - tuple(nodes.AliasedRef.identity(id) for id in node.ids if id in ids), + tuple(nodes.AliasedRef.identity(id) for id in ids_to_keep), ) diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index 9b4efe8cbe..a25c167f71 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -48,6 +48,25 @@ def apply_agg_to_all_valid( return new_arr +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +def test_engines_aggregate_post_filter_size( + scalars_array_value: array_value.ArrayValue, + engine, +): + w_offsets, offsets_id = ( + scalars_array_value.select_columns(("bool_col", "string_col")) + .filter(expression.deref("bool_col")) + .promote_offsets() + ) + plan = ( + w_offsets.select_columns((offsets_id, "bool_col", "string_col")) + .row_count() + .node + ) + + assert_equivalence_execution(plan, REFERENCE_ENGINE, engine) + + @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_aggregate_size( scalars_array_value: array_value.ArrayValue, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql index d20a635e3d..b48dcfa01b 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql @@ -1,13 +1,27 @@ WITH `bfcte_0` AS ( SELECT - `bool_col` AS `bfcol_0` + `bool_col` AS `bfcol_0`, + `bytes_col` AS `bfcol_1`, + `date_col` AS `bfcol_2`, + `datetime_col` AS `bfcol_3`, + `geography_col` AS `bfcol_4`, + `int64_col` AS `bfcol_5`, + `int64_too` AS `bfcol_6`, + `numeric_col` AS `bfcol_7`, + `float64_col` AS `bfcol_8`, + `rowindex` AS `bfcol_9`, + `rowindex_2` AS `bfcol_10`, + `string_col` AS `bfcol_11`, + `time_col` AS `bfcol_12`, + `timestamp_col` AS `bfcol_13`, + `duration_col` AS `bfcol_14` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT *, - ROW_NUMBER() OVER () AS `bfcol_1` + ROW_NUMBER() OVER () AS `bfcol_32` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `row_number` + `bfcol_32` AS `row_number` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql index 19ae8aa3fd..8cda9a3d80 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql @@ -1,12 +1,26 @@ WITH `bfcte_0` AS ( SELECT - `rowindex` AS `bfcol_0` + `bool_col` AS `bfcol_0`, + `bytes_col` AS `bfcol_1`, + `date_col` AS `bfcol_2`, + `datetime_col` AS `bfcol_3`, + `geography_col` AS `bfcol_4`, + `int64_col` AS `bfcol_5`, + `int64_too` AS `bfcol_6`, + `numeric_col` AS `bfcol_7`, + `float64_col` AS `bfcol_8`, + `rowindex` AS `bfcol_9`, + `rowindex_2` AS `bfcol_10`, + `string_col` AS `bfcol_11`, + `time_col` AS `bfcol_12`, + `timestamp_col` AS `bfcol_13`, + `duration_col` AS `bfcol_14` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT - COUNT(1) AS `bfcol_2` + COUNT(1) AS `bfcol_32` FROM `bfcte_0` ) SELECT - `bfcol_2` AS `size` + `bfcol_32` AS `size` FROM `bfcte_1` \ No newline at end of file