Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 42 additions & 75 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class Field:
dtype: bigframes.dtypes.Dtype


@dataclass(frozen=True)
class BigFrameNode:
@dataclass(eq=False, frozen=True)
class BigFrameNode(abc.ABC):
"""
Immutable node for representing 2D typed array as a tree of operators.

Expand Down Expand Up @@ -95,12 +95,30 @@ def session(self):
return sessions[0]
return None

def _as_tuple(self) -> Tuple:
"""Get all fields as tuple."""
return tuple(getattr(self, field.name) for field in fields(self))

def __hash__(self) -> int:
# Custom hash that uses cache to avoid costly recomputation
return self._cached_hash

def __eq__(self, other) -> bool:
# Custom eq that tries to short-circuit full structural comparison
if not isinstance(other, self.__class__):
return False
if self is other:
return True
if hash(self) != hash(other):
return False
return self._as_tuple() == other._as_tuple()

# BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch
# Each subclass of BigFrameNode should use this property to implement __hash__
# The default dataclass-generated __hash__ method is not cached
@functools.cached_property
def _node_hash(self):
return hash(tuple(hash(getattr(self, field.name)) for field in fields(self)))
def _cached_hash(self):
return hash(self._as_tuple())

@property
def roots(self) -> typing.Set[BigFrameNode]:
Expand Down Expand Up @@ -226,7 +244,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return self.transform_children(lambda x: x.prune(used_cols))


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class UnaryNode(BigFrameNode):
child: BigFrameNode

Expand All @@ -252,7 +270,7 @@ def order_ambiguous(self) -> bool:
return self.child.order_ambiguous


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinNode(BigFrameNode):
left_child: BigFrameNode
right_child: BigFrameNode
Expand Down Expand Up @@ -285,9 +303,6 @@ def explicitly_ordered(self) -> bool:
# Do not consider user pre-join ordering intent - they need to re-order post-join in unordered mode.
return False

def __hash__(self):
return self._node_hash

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(itertools.chain(self.left_child.fields, self.right_child.fields))
Expand Down Expand Up @@ -320,7 +335,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return self.transform_children(lambda x: x.prune(new_used))


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ConcatNode(BigFrameNode):
# TODO: Explcitly map column ids from each child
children: Tuple[BigFrameNode, ...]
Expand All @@ -345,9 +360,6 @@ def explicitly_ordered(self) -> bool:
# Consider concat as an ordered operations (even though input frames may not be ordered)
return True

def __hash__(self):
return self._node_hash

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
# TODO: Output names should probably be aligned beforehand or be part of concat definition
Expand All @@ -371,16 +383,13 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return self


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class FromRangeNode(BigFrameNode):
# TODO: Enforce single-row, single column constraint
start: BigFrameNode
end: BigFrameNode
step: int

def __hash__(self):
return self._node_hash

@property
def roots(self) -> typing.Set[BigFrameNode]:
return {self}
Expand Down Expand Up @@ -419,7 +428,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
# Input Nodex
# TODO: Most leaf nodes produce fixed column names based on the datasource
# They should support renaming
@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class LeafNode(BigFrameNode):
@property
def roots(self) -> typing.Set[BigFrameNode]:
Expand Down Expand Up @@ -451,7 +460,7 @@ class ScanList:
items: typing.Tuple[ScanItem, ...]


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ReadLocalNode(LeafNode):
feather_bytes: bytes
data_schema: schemata.ArraySchema
Expand All @@ -460,9 +469,6 @@ class ReadLocalNode(LeafNode):
scan_list: ScanList
session: typing.Optional[bigframes.session.Session] = None

def __hash__(self):
return self._node_hash

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
Expand Down Expand Up @@ -545,7 +551,7 @@ class BigqueryDataSource:


## Put ordering in here or just add order_by node above?
@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ReadTableNode(LeafNode):
source: BigqueryDataSource
# Subset of physical schema column
Expand All @@ -568,9 +574,6 @@ def __post_init__(self):
def session(self):
return self.table_session

def __hash__(self):
return self._node_hash

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
Expand Down Expand Up @@ -614,15 +617,12 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return ReadTableNode(self.source, new_scan_list, self.table_session)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class CachedTableNode(ReadTableNode):
# The original BFET subtree that was cached
# note: this isn't a "child" node.
original_node: BigFrameNode = field()

def __hash__(self):
return self._node_hash

def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
new_scan_list = ScanList(
tuple(item for item in self.scan_list.items if item.id in used_cols)
Expand All @@ -633,13 +633,10 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:


# Unary nodes
@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class PromoteOffsetsNode(UnaryNode):
col_id: bigframes.core.identifiers.ColumnId

def __hash__(self):
return self._node_hash

@property
def non_local(self) -> bool:
return True
Expand All @@ -664,17 +661,14 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return self.transform_children(lambda x: x.prune(new_used))


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class FilterNode(UnaryNode):
predicate: ex.Expression

@property
def row_preserving(self) -> bool:
return False

def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 1
Expand All @@ -685,13 +679,10 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return FilterNode(pruned_child, self.predicate)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class OrderByNode(UnaryNode):
by: Tuple[OrderingExpression, ...]

def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0
Expand All @@ -714,14 +705,11 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return OrderByNode(pruned_child, self.by)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ReversedNode(UnaryNode):
# useless field to make sure has distinct hash
reversed: bool = True

def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0
Expand All @@ -732,15 +720,12 @@ def relation_ops_created(self) -> int:
return 0


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SelectionNode(UnaryNode):
input_output_pairs: typing.Tuple[
typing.Tuple[ex.DerefOp, bigframes.core.identifiers.ColumnId], ...
]

def __hash__(self):
return self._node_hash

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(
Expand Down Expand Up @@ -770,7 +755,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return SelectionNode(pruned_child, pruned_selections)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ProjectionNode(UnaryNode):
"""Assigns new variables (without modifying existing ones)"""

Expand All @@ -786,9 +771,6 @@ def __post_init__(self):
# Cannot assign to existing variables - append only!
assert all(name not in self.child.schema.names for _, name in self.assignments)

def __hash__(self):
return self._node_hash

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
input_types = self.child._dtype_lookup
Expand Down Expand Up @@ -817,7 +799,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:

# TODO: Merge RowCount into Aggregate Node?
# Row count can be compute from table metadata sometimes, so it is a bit special.
@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class RowCountNode(UnaryNode):
@property
def row_preserving(self) -> bool:
Expand All @@ -840,7 +822,7 @@ def defines_namespace(self) -> bool:
return True


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class AggregateNode(UnaryNode):
aggregations: typing.Tuple[
typing.Tuple[ex.Aggregation, bigframes.core.identifiers.ColumnId], ...
Expand All @@ -852,9 +834,6 @@ class AggregateNode(UnaryNode):
def row_preserving(self) -> bool:
return False

def __hash__(self):
return self._node_hash

@property
def non_local(self) -> bool:
return True
Expand Down Expand Up @@ -902,7 +881,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return AggregateNode(pruned_child, pruned_aggs, self.by_column_ids, self.dropna)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WindowOpNode(UnaryNode):
column_name: ex.DerefOp
op: agg_ops.UnaryWindowOp
Expand All @@ -911,9 +890,6 @@ class WindowOpNode(UnaryNode):
never_skip_nulls: bool = False
skip_reproject_unsafe: bool = False

def __hash__(self):
return self._node_hash

@property
def non_local(self) -> bool:
return True
Expand Down Expand Up @@ -943,11 +919,8 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:


# TODO: Remove this op
@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ReprojectOpNode(UnaryNode):
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0
Expand All @@ -958,7 +931,7 @@ def relation_ops_created(self) -> int:
return 0


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class RandomSampleNode(UnaryNode):
fraction: float

Expand All @@ -970,26 +943,20 @@ def deterministic(self) -> bool:
def row_preserving(self) -> bool:
return False

def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 1


# TODO: Explode should create a new column instead of overriding the existing one
@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ExplodeNode(UnaryNode):
column_ids: typing.Tuple[ex.DerefOp, ...]

@property
def row_preserving(self) -> bool:
return False

def __hash__(self):
return self._node_hash

@functools.cached_property
def fields(self) -> Tuple[Field, ...]:
return tuple(
Expand Down