From 609e56edb3232a908117b4aaffd937f2162f84c7 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Sat, 5 Oct 2024 01:16:39 +0000 Subject: [PATCH 1/2] perf: Reduce schema tracking overhead --- bigframes/core/nodes.py | 78 ++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index e65040686e..8e48422654 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -109,10 +109,10 @@ def roots(self) -> typing.Set[BigFrameNode]: ) return set(roots) - # TODO: For deep trees, this can create a lot of overhead, maybe use zero-copy persistent datastructure? + # TODO: Store some local data lazily for select, aggregate nodes. @property @abc.abstractmethod - def fields(self) -> Tuple[Field, ...]: + def fields(self) -> Iterable[Field]: ... @property @@ -234,8 +234,8 @@ class UnaryNode(BigFrameNode): def child_nodes(self) -> typing.Sequence[BigFrameNode]: return (self.child,) - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: + @property + def fields(self) -> Iterable[Field]: return self.child.fields @property @@ -288,9 +288,9 @@ def explicitly_ordered(self) -> bool: def __hash__(self): return self._node_hash - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: - return tuple(itertools.chain(self.left_child.fields, self.right_child.fields)) + @property + def fields(self) -> Iterable[Field]: + return itertools.chain(self.left_child.fields, self.right_child.fields) @functools.cached_property def variables_introduced(self) -> int: @@ -348,10 +348,10 @@ def explicitly_ordered(self) -> bool: def __hash__(self): return self._node_hash - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: + @property + def fields(self) -> Iterable[Field]: # TODO: Output names should probably be aligned beforehand or be part of concat definition - return tuple( + return ( Field(bfet_ids.ColumnId(f"column_{i}"), field.dtype) for i, field in enumerate(self.children[0].fields) ) @@ -398,8 +398,10 @@ def explicitly_ordered(self) -> bool: return True @functools.cached_property - def fields(self) -> Tuple[Field, ...]: - return (Field(bfet_ids.ColumnId("labels"), self.start.fields[0].dtype),) + def fields(self) -> Iterable[Field]: + return ( + Field(bfet_ids.ColumnId("labels"), next(iter(self.start.fields)).dtype), + ) @functools.cached_property def variables_introduced(self) -> int: @@ -463,11 +465,11 @@ class ReadLocalNode(LeafNode): def __hash__(self): return self._node_hash - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: - return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items) + @property + def fields(self) -> Iterable[Field]: + return (Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items) - @functools.cached_property + @property def variables_introduced(self) -> int: """Defines the number of variables generated by the current node. Used to estimate query planning complexity.""" return len(self.scan_list.items) + 1 @@ -571,9 +573,9 @@ def session(self): def __hash__(self): return self._node_hash - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: - return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items) + @property + def fields(self) -> Iterable[Field]: + return (Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items) @property def relation_ops_created(self) -> int: @@ -645,8 +647,10 @@ def non_local(self) -> bool: return True @property - def fields(self) -> Tuple[Field, ...]: - return (*self.child.fields, Field(self.col_id, bigframes.dtypes.INT_DTYPE)) + def fields(self) -> Iterable[Field]: + return itertools.chain( + self.child.fields, [Field(self.col_id, bigframes.dtypes.INT_DTYPE)] + ) @property def relation_ops_created(self) -> int: @@ -741,9 +745,9 @@ class SelectionNode(UnaryNode): def __hash__(self): return self._node_hash - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: - return tuple( + @property + def fields(self) -> Iterable[Field]: + return ( Field(output, self.child.get_type(input.id)) for input, output in self.input_output_pairs ) @@ -789,14 +793,14 @@ def __post_init__(self): def __hash__(self): return self._node_hash - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: + @property + def fields(self) -> Iterable[Field]: input_types = self.child._dtype_lookup new_fields = ( Field(id, bigframes.dtypes.dtype_for_etype(ex.output_type(input_types))) for ex, id in self.assignments ) - return (*self.child.fields, *new_fields) + return itertools.chain(self.child.fields, new_fields) @property def variables_introduced(self) -> int: @@ -827,8 +831,8 @@ def row_preserving(self) -> bool: def non_local(self) -> bool: return True - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: + @property + def fields(self) -> Iterable[Field]: return (Field(bfet_ids.ColumnId("count"), bigframes.dtypes.INT_DTYPE),) @property @@ -859,8 +863,8 @@ def __hash__(self): def non_local(self) -> bool: return True - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: + @property + def fields(self) -> Iterable[Field]: by_items = ( Field(ref.id, self.child.get_type(ref.id)) for ref in self.by_column_ids ) @@ -873,7 +877,7 @@ def fields(self) -> Tuple[Field, ...]: ) for agg, id in self.aggregations ) - return (*by_items, *agg_items) + return itertools.chain(by_items, agg_items) @property def variables_introduced(self) -> int: @@ -918,8 +922,8 @@ def __hash__(self): def non_local(self) -> bool: return True - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: + @property + def fields(self) -> Iterable[Field]: input_type = self.child.get_type(self.column_name.id) new_item_dtype = self.op.output_type(input_type) return (*self.child.fields, Field(self.output_name, new_item_dtype)) @@ -990,9 +994,9 @@ def row_preserving(self) -> bool: def __hash__(self): return self._node_hash - @functools.cached_property - def fields(self) -> Tuple[Field, ...]: - return tuple( + @property + def fields(self) -> Iterable[Field]: + return ( Field( field.id, bigframes.dtypes.arrow_dtype_to_bigframes_dtype( From 6924530fd760805b00c6f990e049ad2901d1b095 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Mon, 7 Oct 2024 19:39:51 +0000 Subject: [PATCH 2/2] cache new fields --- bigframes/core/nodes.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 8e48422654..1fbee77be5 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -745,9 +745,9 @@ class SelectionNode(UnaryNode): def __hash__(self): return self._node_hash - @property + @functools.cached_property def fields(self) -> Iterable[Field]: - return ( + return tuple( Field(output, self.child.get_type(input.id)) for input, output in self.input_output_pairs ) @@ -793,14 +793,17 @@ def __post_init__(self): def __hash__(self): return self._node_hash - @property - def fields(self) -> Iterable[Field]: + @functools.cached_property + def added_fields(self) -> Tuple[Field, ...]: input_types = self.child._dtype_lookup - new_fields = ( + return tuple( Field(id, bigframes.dtypes.dtype_for_etype(ex.output_type(input_types))) for ex, id in self.assignments ) - return itertools.chain(self.child.fields, new_fields) + + @property + def fields(self) -> Iterable[Field]: + return itertools.chain(self.child.fields, self.added_fields) @property def variables_introduced(self) -> int: @@ -863,7 +866,7 @@ def __hash__(self): def non_local(self) -> bool: return True - @property + @functools.cached_property def fields(self) -> Iterable[Field]: by_items = ( Field(ref.id, self.child.get_type(ref.id)) for ref in self.by_column_ids @@ -877,7 +880,7 @@ def fields(self) -> Iterable[Field]: ) for agg, id in self.aggregations ) - return itertools.chain(by_items, agg_items) + return tuple(itertools.chain(by_items, agg_items)) @property def variables_introduced(self) -> int: @@ -924,9 +927,7 @@ def non_local(self) -> bool: @property def fields(self) -> Iterable[Field]: - input_type = self.child.get_type(self.column_name.id) - new_item_dtype = self.op.output_type(input_type) - return (*self.child.fields, Field(self.output_name, new_item_dtype)) + return itertools.chain(self.child.fields, [self.added_field]) @property def variables_introduced(self) -> int: @@ -937,6 +938,12 @@ def relation_ops_created(self) -> int: # Assume that if not reprojecting, that there is a sequence of window operations sharing the same window return 0 if self.skip_reproject_unsafe else 4 + @functools.cached_property + def added_field(self) -> Field: + input_type = self.child.get_type(self.column_name.id) + new_item_dtype = self.op.output_type(input_type) + return Field(self.output_name, new_item_dtype) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: if self.output_name not in used_cols: return self.child