From 05497c6f282762bdf4788a0266a56d4829f5649a Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 6 Nov 2025 23:24:18 +0000 Subject: [PATCH 1/4] feat: Support mixed scalar-analytic expressions --- bigframes/core/agg_expressions.py | 19 ++- bigframes/core/array_value.py | 28 +++- bigframes/core/block_transforms.py | 17 +- bigframes/core/blocks.py | 20 +++ bigframes/core/expression.py | 60 ++++++- bigframes/core/expression_factoring.py | 219 +++++++++++++++++++++++++ 6 files changed, 351 insertions(+), 12 deletions(-) create mode 100644 bigframes/core/expression_factoring.py diff --git a/bigframes/core/agg_expressions.py b/bigframes/core/agg_expressions.py index e65718bdc4..278116c4bc 100644 --- a/bigframes/core/agg_expressions.py +++ b/bigframes/core/agg_expressions.py @@ -19,7 +19,7 @@ import functools import itertools import typing -from typing import Callable, Mapping, TypeVar +from typing import Callable, Mapping, Tuple, TypeVar from bigframes import dtypes from bigframes.core import expression, window_spec @@ -63,6 +63,10 @@ def inputs( ) -> typing.Tuple[expression.Expression, ...]: ... + @property + def children(self) -> Tuple[expression.Expression, ...]: + return self.inputs + @property def free_variables(self) -> typing.Tuple[str, ...]: return tuple( @@ -73,6 +77,10 @@ def free_variables(self) -> typing.Tuple[str, ...]: def is_const(self) -> bool: return all(child.is_const for child in self.inputs) + @functools.cached_property + def is_scalar_expr(self) -> bool: + return False + @abc.abstractmethod def replace_args(self: TExpression, *arg) -> TExpression: ... @@ -176,8 +184,13 @@ def output_type(self) -> dtypes.ExpressionType: def inputs( self, ) -> typing.Tuple[expression.Expression, ...]: + # TODO: Maybe make the window spec itself an expression? return (self.analytic_expr, *self.window.expressions) + @property + def children(self) -> Tuple[expression.Expression, ...]: + return self.inputs + @property def free_variables(self) -> typing.Tuple[str, ...]: return tuple( @@ -188,6 +201,10 @@ def free_variables(self) -> typing.Tuple[str, ...]: def is_const(self) -> bool: return all(child.is_const for child in self.inputs) + @functools.cached_property + def is_scalar_expr(self) -> bool: + return False + def transform_children( self: WindowExpression, t: Callable[[expression.Expression], expression.Expression], diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index e2948cdd05..c3d71d19af 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -16,6 +16,7 @@ from dataclasses import dataclass import datetime import functools +import itertools import typing from typing import Iterable, List, Mapping, Optional, Sequence, Tuple @@ -23,12 +24,16 @@ import pandas import pyarrow as pa -from bigframes.core import agg_expressions, bq_data +from bigframes.core import ( + agg_expressions, + bq_data, + expression_factoring, + join_def, + local_data, +) import bigframes.core.expression as ex import bigframes.core.guid import bigframes.core.identifiers as ids -import bigframes.core.join_def as join_def -import bigframes.core.local_data as local_data import bigframes.core.nodes as nodes from bigframes.core.ordering import OrderingExpression import bigframes.core.ordering as orderings @@ -261,6 +266,23 @@ def compute_values(self, assignments: Sequence[ex.Expression]): col_ids, ) + def compute_general_expression(self, assignments: Sequence[ex.Expression]): + named_exprs = [ + expression_factoring.NamedExpression(expr, ids.ColumnId.unique()) + for expr in assignments + ] + # TODO: Push this to rewrite later to go from block expression to planning form + # TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions + fragments = tuple( + itertools.chain.from_iterable( + expression_factoring.fragmentize_expression(expr) + for expr in named_exprs + ) + ) + target_ids = tuple(named_expr.name for named_expr in named_exprs) + new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids) + return (ArrayValue(new_root), target_ids) + def project_to_id(self, expression: ex.Expression): array_val, ids = self.compute_values( [expression], diff --git a/bigframes/core/block_transforms.py b/bigframes/core/block_transforms.py index 2ee3dc38b3..1b95fbdfe1 100644 --- a/bigframes/core/block_transforms.py +++ b/bigframes/core/block_transforms.py @@ -399,15 +399,18 @@ def pct_change(block: blocks.Block, periods: int = 1) -> blocks.Block: window_spec = windows.unbound() original_columns = block.value_columns - block, shift_columns = block.multi_apply_window_op( - original_columns, agg_ops.ShiftOp(periods), window_spec=window_spec - ) exprs = [] - for original_col, shifted_col in zip(original_columns, shift_columns): - change_expr = ops.sub_op.as_expr(original_col, shifted_col) - pct_change_expr = ops.div_op.as_expr(change_expr, shifted_col) + for original_col in original_columns: + shift_expr = agg_expressions.WindowExpression( + agg_expressions.UnaryAggregation( + agg_ops.ShiftOp(periods), ex.deref(original_col) + ), + window_spec, + ) + change_expr = ops.sub_op.as_expr(original_col, shift_expr) + pct_change_expr = ops.div_op.as_expr(change_expr, shift_expr) exprs.append(pct_change_expr) - return block.project_exprs(exprs, labels=column_labels, drop=True) + return block.project_block_exprs(exprs, labels=column_labels, drop=True) def rank( diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index e968172c76..e34d4e5bf9 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -1154,6 +1154,26 @@ def project_exprs( index_labels=self._index_labels, ) + # This is a new experimental version of the project_exprs that supports mixing analytic and scalar expressions + def project_block_exprs( + self, + exprs: Sequence[ex.Expression], + labels: Union[Sequence[Label], pd.Index], + drop=False, + ) -> Block: + new_array, _ = self.expr.compute_general_expression(exprs) + if drop: + new_array = new_array.drop_columns(self.value_columns) + + return Block( + new_array, + index_columns=self.index_columns, + column_labels=labels + if drop + else self.column_labels.append(pd.Index(labels)), + index_labels=self._index_labels, + ) + def apply_window_op( self, column: str, diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 59679f1bc4..22b566f3ac 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -15,11 +15,12 @@ from __future__ import annotations import abc +import collections import dataclasses import functools import itertools import typing -from typing import Callable, Generator, Mapping, TypeVar, Union +from typing import Callable, Dict, Generator, Mapping, Tuple, TypeVar, Union import pandas as pd @@ -43,6 +44,7 @@ def free_var(id: str) -> UnboundVariableExpression: return UnboundVariableExpression(id) +T = TypeVar("T") TExpression = TypeVar("TExpression", bound="Expression") @@ -136,6 +138,11 @@ def is_identity(self) -> bool: """True for identity operation that does not transform input.""" return False + @functools.cached_property + def is_scalar_expr(self) -> bool: + """True if expression represents scalar value or expression over scalar values (no windows or aggregations)""" + return all(expr.is_scalar_expr for expr in self.children) + @abc.abstractmethod def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: ... @@ -150,6 +157,57 @@ def walk(self) -> Generator[Expression, None, None]: for child in self.children: yield from child.children + def unique_nodes( + self: Expression, + ) -> Generator[Expression, None, None]: + """Walks the tree for unique nodes""" + seen = set() + stack: list[Expression] = [self] + while stack: + item = stack.pop() + if item not in seen: + yield item + seen.add(item) + stack.extend(item.children) + + def iter_nodes_topo( + self: Expression, + ) -> Generator[Expression, None, None]: + """Returns nodes in reverse topological order, using Kahn's algorithm.""" + child_to_parents: Dict[Expression, list[Expression]] = collections.defaultdict( + list + ) + out_degree: Dict[Expression, int] = collections.defaultdict(int) + + queue: collections.deque["Expression"] = collections.deque() + for node in list(self.unique_nodes()): + num_children = len(node.children) + out_degree[node] = num_children + if num_children == 0: + queue.append(node) + for child in node.children: + child_to_parents[child].append(node) + + while queue: + item = queue.popleft() + yield item + parents = child_to_parents.get(item, []) + for parent in parents: + out_degree[parent] -= 1 + if out_degree[parent] == 0: + queue.append(parent) + + def reduce_up(self, reduction: Callable[[Expression, Tuple[T, ...]], T]) -> T: + """Apply a bottom-up reduction to the tree.""" + results: dict[Expression, T] = {} + for node in list(self.iter_nodes_topo()): + # child nodes have already been transformed + child_results = tuple(results[child] for child in node.children) + result = reduction(node, child_results) + results[node] = result + + return results[self] + @dataclasses.dataclass(frozen=True) class ScalarConstantExpression(Expression): diff --git a/bigframes/core/expression_factoring.py b/bigframes/core/expression_factoring.py new file mode 100644 index 0000000000..046881fdcf --- /dev/null +++ b/bigframes/core/expression_factoring.py @@ -0,0 +1,219 @@ +import collections +import dataclasses +import functools +from typing import Generic, Hashable, Iterable, Optional, Sequence, Tuple, TypeVar + +from bigframes.core import agg_expressions, expression, identifiers, nodes + +_MAX_INLINE_COMPLEXITY = 10 + + +@dataclasses.dataclass(frozen=True, eq=False) +class NamedExpression: + expr: expression.Expression + name: identifiers.ColumnId + + +@dataclasses.dataclass(frozen=True, eq=False) +class FactoredExpression: + root_expr: expression.Expression + sub_exprs: Tuple[NamedExpression, ...] + + +@dataclasses.dataclass(frozen=True, eq=False) +class ExpressionGroup: + exprs: Tuple[NamedExpression, ...] + + +def fragmentize_expression(root: NamedExpression) -> Sequence[NamedExpression]: + """ + The goal of this functions is to factor out an expression into multiple sub-expressions. + """ + + factored_expr = root.expr.reduce_up(gather_fragments) + root_expr = NamedExpression(factored_expr.root_expr, root.name) + return (root_expr, *factored_expr.sub_exprs) + + +def gather_fragments( + root: expression.Expression, fragmentized_children: Sequence[FactoredExpression] +) -> FactoredExpression: + replacements: list[expression.Expression] = [] + named_exprs = [] # root -> leaf dependency order + for child_result in fragmentized_children: + child_expr = child_result.root_expr + is_leaf = isinstance( + child_expr, (expression.DerefOp, expression.ScalarConstantExpression) + ) + is_window_agg = isinstance( + root, agg_expressions.WindowExpression + ) and isinstance(child_expr, agg_expressions.Aggregation) + do_inline = is_leaf | is_window_agg + if not do_inline: + id = identifiers.ColumnId.unique() + replacements.append( + expression.DerefOp(id) + ) # TODO: Determinism, maybe hash-based? + named_exprs.append(NamedExpression(child_result.root_expr, id)) + named_exprs.extend(child_result.sub_exprs) + else: + replacements.append(child_result.root_expr) + named_exprs.extend(child_result.sub_exprs) + new_root = replace_children(root, replacements) + return FactoredExpression(new_root, tuple(named_exprs)) + + +def replace_children( + root: expression.Expression, new_children: Sequence[expression.Expression] +): + mapping = {root.children[i]: new_children[i] for i in range(len(root.children))} + return root.transform_children(lambda x: mapping.get(x, x)) + + +T = TypeVar("T", bound=Hashable) + + +class DiGraph(Generic[T]): + def __init__(self, edges: Iterable[Tuple[T, T]]): + self._nodes = set() + self._parents = collections.defaultdict(set) + self._children = collections.defaultdict(set) # specifically, unpushed ones + # dict repr of graph + self._sinks = set() + for src, dst in edges: + self._children[src].add(dst) + self._parents[dst].add(src) + self._nodes.add(src) + self._nodes.add(dst) + # sinks have no children + if not self._children[dst]: + self._sinks.add(dst) + self._sinks.discard(src) + + @property + def nodes(self): + return self._nodes + + @property + def sinks(self) -> set[T]: + return self._sinks + + @property + def empty(self): + return len(self._nodes) == 0 + + def parents(self, node: T) -> set[T]: + return self._parents[node] + + def children(self, node: T) -> set[T]: + return self._children[node] + + def remove_node(self, node: T) -> None: + for child in self._children[node]: + self._parents[child].remove(node) + for parent in self._parents[node]: + self._children[parent].remove(node) + if len(self._children[parent]) == 0: + self._sinks.add(parent) + del self._children[node] + del self._parents[node] + self._nodes.remove(node) + self._sinks.discard(node) + + +def push_into_tree( + root: nodes.BigFrameNode, + exprs: Sequence[NamedExpression], + target_ids: Sequence[identifiers.ColumnId], +) -> nodes.BigFrameNode: + curr_root = root + by_id = {expr.name: expr for expr in exprs} + # id -> id + graph = DiGraph( + (expr.name, child_id) + for expr in exprs + for child_id in expr.expr.column_references + if child_id in by_id.keys() + ) + # be careful about merging multi-parent ids + # TODO: Also prevent inlining expensive or non-deterministic + multi_parent_ids = set(id for id in graph.nodes if len(graph.parents(id)) > 2) + scalar_ids = set(expr.name for expr in exprs if expr.expr.is_scalar_expr) + post_ids = (*root.ids, *target_ids) + + def graph_extract_scalar_exprs() -> Sequence[NamedExpression]: + results: dict[identifiers.ColumnId, expression.Expression] = dict() + while ( + True + ): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks + candidate_ids = graph.sinks.intersection(scalar_ids) + bad_inline = set( + id + for id in candidate_ids + if any( + ( + child in multi_parent_ids + and id in results.keys() + and not is_simple(results[id]) + ) + for child in graph.children(id) + ) + ) + candidate_ids = candidate_ids.difference(bad_inline) + if len(candidate_ids) == 0: + break + for id in candidate_ids: + graph.remove_node(id) + new_exprs = { + id: by_id[id].expr.bind_refs(results, allow_partial_bindings=True) + } + results.update(new_exprs) + return tuple( + NamedExpression(expr, id) + for id, expr in results.items() + if id in set([*graph.sinks, *target_ids]) + ) + + def graph_extract_window_expr() -> Optional[ + Tuple[identifiers.ColumnId, agg_expressions.WindowExpression] + ]: + candidate_ids = graph.sinks.difference(scalar_ids) + if not candidate_ids: + return None + else: + id = next(iter(candidate_ids)) + graph.remove_node(id) + result_expr = by_id[id].expr + assert isinstance(result_expr, agg_expressions.WindowExpression) + return (id, result_expr) + + while not graph.empty: + scalar_exprs = graph_extract_scalar_exprs() + if scalar_exprs: + curr_root = nodes.ProjectionNode( + curr_root, tuple((x.expr, x.name) for x in scalar_exprs) + ) + curr_root._validate() + while result := graph_extract_window_expr(): + id, window_expr = result + curr_root = nodes.WindowOpNode( + curr_root, window_expr.analytic_expr, window_expr.window, output_name=id + ) + curr_root._validate() + # TODO: Try to get the ordering right earlier, so can avoid this extra node. + if tuple(curr_root.ids) != post_ids: + curr_root = nodes.SelectionNode( + curr_root, tuple(nodes.AliasedRef.identity(id) for id in post_ids) + ) + curr_root._validate() + return curr_root + + +@functools.cache +def is_simple(expr: expression.Expression) -> bool: + count = 0 + for part in expr.walk(): + count += 1 + if count > _MAX_INLINE_COMPLEXITY: + return False + return True From 616eccf0b5d562865411010adece21ab17e92ed4 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 7 Nov 2025 21:00:49 +0000 Subject: [PATCH 2/4] fix various problems, migrate rank to new api --- bigframes/core/agg_expressions.py | 2 +- bigframes/core/block_transforms.py | 136 ++++++++++--------------- bigframes/core/blocks.py | 1 + bigframes/core/expression_factoring.py | 14 +-- bigframes/core/nodes.py | 6 ++ 5 files changed, 67 insertions(+), 92 deletions(-) diff --git a/bigframes/core/agg_expressions.py b/bigframes/core/agg_expressions.py index 278116c4bc..125e3fef63 100644 --- a/bigframes/core/agg_expressions.py +++ b/bigframes/core/agg_expressions.py @@ -210,7 +210,7 @@ def transform_children( t: Callable[[expression.Expression], expression.Expression], ) -> WindowExpression: return WindowExpression( - self.analytic_expr.transform_children(t), + t(self.analytic_expr), # type: ignore self.window.transform_exprs(t), ) diff --git a/bigframes/core/block_transforms.py b/bigframes/core/block_transforms.py index 1b95fbdfe1..4e7abb1104 100644 --- a/bigframes/core/block_transforms.py +++ b/bigframes/core/block_transforms.py @@ -431,16 +431,11 @@ def rank( columns = columns or tuple(col for col in block.value_columns) labels = [block.col_id_to_label[id] for id in columns] - # Step 1: Calculate row numbers for each row - # Identify null values to be treated according to na_option param - rownum_col_ids = [] - nullity_col_ids = [] + + result_exprs = [] for col in columns: - block, nullity_col_id = block.apply_unary_op( - col, - ops.isnull_op, - ) - nullity_col_ids.append(nullity_col_id) + # Step 1: Calculate row numbers for each row + # Identify null values to be treated according to na_option param window_ordering = ( ordering.OrderingExpression( ex.deref(col), @@ -451,87 +446,66 @@ def rank( ), ) # Count_op ignores nulls, so if na_option is "top" or "bottom", we instead count the nullity columns, where nulls have been mapped to bools - block, rownum_id = block.apply_window_op( - col if na_option == "keep" else nullity_col_id, - agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op, - window_spec=windows.unbound( - grouping_keys=grouping_cols, ordering=window_ordering - ) + target_expr = ( + ex.deref(col) if na_option == "keep" else ops.isnull_op.as_expr(col) + ) + window_op = agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op + window_spec = ( + windows.unbound(grouping_keys=grouping_cols, ordering=window_ordering) if method == "dense" else windows.rows( end=0, ordering=window_ordering, grouping_keys=grouping_cols - ), - skip_reproject_unsafe=(col != columns[-1]), + ) + ) + result_expr: ex.Expression = agg_expressions.WindowExpression( + agg_expressions.UnaryAggregation(window_op, target_expr), window_spec ) if pct: - block, max_id = block.apply_window_op( - rownum_id, agg_ops.max_op, windows.unbound(grouping_keys=grouping_cols) + result_expr = ops.div_op.as_expr( + result_expr, + agg_expressions.WindowExpression( + agg_expressions.UnaryAggregation(agg_ops.max_op, result_expr), + windows.unbound(grouping_keys=grouping_cols), + ), ) - block, rownum_id = block.project_expr(ops.div_op.as_expr(rownum_id, max_id)) - - rownum_col_ids.append(rownum_id) - - # Step 2: Apply aggregate to groups of like input values. - # This step is skipped for method=='first' or 'dense' - if method in ["average", "min", "max"]: - agg_op = { - "average": agg_ops.mean_op, - "min": agg_ops.min_op, - "max": agg_ops.max_op, - }[method] - post_agg_rownum_col_ids = [] - for i in range(len(columns)): - block, result_id = block.apply_window_op( - rownum_col_ids[i], - agg_op, - window_spec=windows.unbound(grouping_keys=(columns[i], *grouping_cols)), - skip_reproject_unsafe=(i < (len(columns) - 1)), + # Step 2: Apply aggregate to groups of like input values. + # This step is skipped for method=='first' or 'dense' + if method in ["average", "min", "max"]: + agg_op = { + "average": agg_ops.mean_op, + "min": agg_ops.min_op, + "max": agg_ops.max_op, + }[method] + result_expr = agg_expressions.WindowExpression( + agg_expressions.UnaryAggregation(agg_op, result_expr), + windows.unbound(grouping_keys=(col, *grouping_cols)), ) - post_agg_rownum_col_ids.append(result_id) - rownum_col_ids = post_agg_rownum_col_ids - - # Pandas masks all values where any grouping column is null - # Note: we use pd.NA instead of float('nan') - if grouping_cols: - predicate = functools.reduce( - ops.and_op.as_expr, - [ops.notnull_op.as_expr(column_id) for column_id in grouping_cols], - ) - block = block.project_exprs( - [ - ops.where_op.as_expr( - ex.deref(col), - predicate, - ex.const(None), - ) - for col in rownum_col_ids - ], - labels=labels, - ) - rownum_col_ids = list(block.value_columns[-len(rownum_col_ids) :]) - - # Step 3: post processing: mask null values and cast to float - if method in ["min", "max", "first", "dense"]: - # Pandas rank always produces Float64, so must cast for aggregation types that produce ints - return ( - block.select_columns(rownum_col_ids) - .multi_apply_unary_op(ops.AsTypeOp(pd.Float64Dtype())) - .with_column_labels(labels) - ) - if na_option == "keep": - # For na_option "keep", null inputs must produce null outputs - exprs = [] - for i in range(len(columns)): - exprs.append( - ops.where_op.as_expr( - ex.const(pd.NA, dtype=pd.Float64Dtype()), - nullity_col_ids[i], - rownum_col_ids[i], - ) + # Pandas masks all values where any grouping column is null + # Note: we use pd.NA instead of float('nan') + if grouping_cols: + predicate = functools.reduce( + ops.and_op.as_expr, + [ops.notnull_op.as_expr(column_id) for column_id in grouping_cols], + ) + result_expr = ops.where_op.as_expr( + result_expr, + predicate, + ex.const(None), ) - return block.project_exprs(exprs, labels=labels, drop=True) - return block.select_columns(rownum_col_ids).with_column_labels(labels) + # Step 3: post processing: mask null values and cast to float + if method in ["min", "max", "first", "dense"]: + # Pandas rank always produces Float64, so must cast for aggregation types that produce ints + result_expr = ops.AsTypeOp(pd.Float64Dtype()).as_expr(result_expr) + elif na_option == "keep": + # For na_option "keep", null inputs must produce null outputs + result_expr = ops.where_op.as_expr( + ex.const(pd.NA, dtype=pd.Float64Dtype()), + ops.isnull_op.as_expr(col), + result_expr, + ) + result_exprs.append(result_expr) + return block.project_block_exprs(result_exprs, labels=labels, drop=True) def dropna( diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index e34d4e5bf9..36a2ad6acf 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -1165,6 +1165,7 @@ def project_block_exprs( if drop: new_array = new_array.drop_columns(self.value_columns) + new_array.node.validate_tree() return Block( new_array, index_columns=self.index_columns, diff --git a/bigframes/core/expression_factoring.py b/bigframes/core/expression_factoring.py index 046881fdcf..2c3c008db1 100644 --- a/bigframes/core/expression_factoring.py +++ b/bigframes/core/expression_factoring.py @@ -135,11 +135,10 @@ def push_into_tree( for child_id in expr.expr.column_references if child_id in by_id.keys() ) - # be careful about merging multi-parent ids # TODO: Also prevent inlining expensive or non-deterministic + # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size multi_parent_ids = set(id for id in graph.nodes if len(graph.parents(id)) > 2) scalar_ids = set(expr.name for expr in exprs if expr.expr.is_scalar_expr) - post_ids = (*root.ids, *target_ids) def graph_extract_scalar_exprs() -> Sequence[NamedExpression]: results: dict[identifiers.ColumnId, expression.Expression] = dict() @@ -168,11 +167,8 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]: id: by_id[id].expr.bind_refs(results, allow_partial_bindings=True) } results.update(new_exprs) - return tuple( - NamedExpression(expr, id) - for id, expr in results.items() - if id in set([*graph.sinks, *target_ids]) - ) + # TODO: We can prune expressions that won't be reused here, + return tuple(NamedExpression(expr, id) for id, expr in results.items()) def graph_extract_window_expr() -> Optional[ Tuple[identifiers.ColumnId, agg_expressions.WindowExpression] @@ -193,19 +189,17 @@ def graph_extract_window_expr() -> Optional[ curr_root = nodes.ProjectionNode( curr_root, tuple((x.expr, x.name) for x in scalar_exprs) ) - curr_root._validate() while result := graph_extract_window_expr(): id, window_expr = result curr_root = nodes.WindowOpNode( curr_root, window_expr.analytic_expr, window_expr.window, output_name=id ) - curr_root._validate() # TODO: Try to get the ordering right earlier, so can avoid this extra node. + post_ids = (*root.ids, *target_ids) if tuple(curr_root.ids) != post_ids: curr_root = nodes.SelectionNode( curr_root, tuple(nodes.AliasedRef.identity(id) for id in post_ids) ) - curr_root._validate() return curr_root diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 553b41a631..a8457d383b 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -1199,6 +1199,7 @@ def _validate(self): for expression, _ in self.assignments: # throws TypeError if invalid _ = ex.bind_schema_fields(expression, self.child.field_by_id).output_type + assert expression.is_scalar_expr # Cannot assign to existing variables - append only! assert all(name not in self.child.schema.names for _, name in self.assignments) @@ -1404,6 +1405,11 @@ def _validate(self): not self.window_spec.is_row_bounded ) or self.expression.op.implicitly_inherits_order assert all(ref in self.child.ids for ref in self.expression.column_references) + assert self.added_field.dtype is not None + for agg_child in self.expression.children: + assert agg_child.is_scalar_expr + for window_expr in self.window_spec.expressions: + assert window_expr.is_scalar_expr @property def non_local(self) -> bool: From 7a1c53a5d92a38efb969c7e829b8300d42f28a6b Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Mon, 10 Nov 2025 19:40:48 +0000 Subject: [PATCH 3/4] make more deterministic --- bigframes/core/expression_factoring.py | 52 +++++++++++++------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/bigframes/core/expression_factoring.py b/bigframes/core/expression_factoring.py index 2c3c008db1..aa04737a7f 100644 --- a/bigframes/core/expression_factoring.py +++ b/bigframes/core/expression_factoring.py @@ -1,6 +1,7 @@ import collections import dataclasses import functools +import itertools from typing import Generic, Hashable, Iterable, Optional, Sequence, Tuple, TypeVar from bigframes.core import agg_expressions, expression, identifiers, nodes @@ -51,9 +52,7 @@ def gather_fragments( do_inline = is_leaf | is_window_agg if not do_inline: id = identifiers.ColumnId.unique() - replacements.append( - expression.DerefOp(id) - ) # TODO: Determinism, maybe hash-based? + replacements.append(expression.DerefOp(id)) named_exprs.append(NamedExpression(child_result.root_expr, id)) named_exprs.extend(child_result.sub_exprs) else: @@ -75,32 +74,31 @@ def replace_children( class DiGraph(Generic[T]): def __init__(self, edges: Iterable[Tuple[T, T]]): - self._nodes = set() self._parents = collections.defaultdict(set) self._children = collections.defaultdict(set) # specifically, unpushed ones - # dict repr of graph - self._sinks = set() + # use dict for stable ordering, which grants determinism + self._sinks: dict[T, None] = dict() for src, dst in edges: self._children[src].add(dst) self._parents[dst].add(src) - self._nodes.add(src) - self._nodes.add(dst) # sinks have no children if not self._children[dst]: - self._sinks.add(dst) - self._sinks.discard(src) + self._sinks[dst] = None + if src in self._sinks: + del self._sinks[src] @property def nodes(self): - return self._nodes + # should be the same set of ids as self._parents + return self._children.keys() @property - def sinks(self) -> set[T]: - return self._sinks + def sinks(self) -> Iterable[T]: + return self._sinks.keys() @property def empty(self): - return len(self._nodes) == 0 + return len(self.nodes) == 0 def parents(self, node: T) -> set[T]: return self._parents[node] @@ -114,11 +112,11 @@ def remove_node(self, node: T) -> None: for parent in self._parents[node]: self._children[parent].remove(node) if len(self._children[parent]) == 0: - self._sinks.add(parent) + self._sinks[parent] = None del self._children[node] del self._parents[node] - self._nodes.remove(node) - self._sinks.discard(node) + if node in self._sinks: + del self._sinks[node] def push_into_tree( @@ -145,11 +143,11 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]: while ( True ): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks - candidate_ids = graph.sinks.intersection(scalar_ids) - bad_inline = set( + candidate_ids = list( id - for id in candidate_ids - if any( + for id in graph.sinks + if (id in scalar_ids) + and not any( ( child in multi_parent_ids and id in results.keys() @@ -158,7 +156,6 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]: for child in graph.children(id) ) ) - candidate_ids = candidate_ids.difference(bad_inline) if len(candidate_ids) == 0: break for id in candidate_ids: @@ -173,17 +170,20 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]: def graph_extract_window_expr() -> Optional[ Tuple[identifiers.ColumnId, agg_expressions.WindowExpression] ]: - candidate_ids = graph.sinks.difference(scalar_ids) - if not candidate_ids: + candidate = list( + itertools.islice((id for id in graph.sinks if id not in scalar_ids), 1) + ) + if not candidate: return None else: - id = next(iter(candidate_ids)) + id = next(iter(candidate)) graph.remove_node(id) result_expr = by_id[id].expr assert isinstance(result_expr, agg_expressions.WindowExpression) return (id, result_expr) while not graph.empty: + pre_size = len(graph.nodes) scalar_exprs = graph_extract_scalar_exprs() if scalar_exprs: curr_root = nodes.ProjectionNode( @@ -194,6 +194,8 @@ def graph_extract_window_expr() -> Optional[ curr_root = nodes.WindowOpNode( curr_root, window_expr.analytic_expr, window_expr.window, output_name=id ) + if len(graph.nodes) >= pre_size: + raise ValueError("graph didn't shrink") # TODO: Try to get the ordering right earlier, so can avoid this extra node. post_ids = (*root.ids, *target_ids) if tuple(curr_root.ids) != post_ids: From e5bef69cee45e7b35d70edd9acf482bfc4ab60de Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Mon, 10 Nov 2025 20:53:39 +0000 Subject: [PATCH 4/4] cleanup dead code --- bigframes/core/expression_factoring.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/bigframes/core/expression_factoring.py b/bigframes/core/expression_factoring.py index aa04737a7f..07d5591bc5 100644 --- a/bigframes/core/expression_factoring.py +++ b/bigframes/core/expression_factoring.py @@ -21,11 +21,6 @@ class FactoredExpression: sub_exprs: Tuple[NamedExpression, ...] -@dataclasses.dataclass(frozen=True, eq=False) -class ExpressionGroup: - exprs: Tuple[NamedExpression, ...] - - def fragmentize_expression(root: NamedExpression) -> Sequence[NamedExpression]: """ The goal of this functions is to factor out an expression into multiple sub-expressions.