diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index de4dab3b..2d4152e3 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -757,6 +757,7 @@ def shuffle( shuffle_method, options, index_shuffle=on_index, + _branch_id=expr.BranchId(0), ) ) @@ -4780,6 +4781,7 @@ def merge( shuffle_method=shuffle_method, _npartitions=npartitions, broadcast=broadcast, + _branch_id=expr.BranchId(0), ) ) @@ -4866,7 +4868,7 @@ def merge_asof( from dask_expr._merge_asof import MergeAsof - return new_collection(MergeAsof(left, right, **kwargs)) + return new_collection(MergeAsof(left, right, **kwargs, _branch_id=expr.BranchId(0))) def from_map( diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 5b4a1e21..7080caed 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -5,7 +5,7 @@ import weakref from collections import defaultdict from collections.abc import Generator -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, NamedTuple import dask import pandas as pd @@ -29,6 +29,10 @@ ] +class BranchId(NamedTuple): + branch_id: int + + def _unpack_collections(o): if isinstance(o, Expr): return o @@ -43,9 +47,17 @@ class Expr: _parameters = [] _defaults = {} _instances = weakref.WeakValueDictionary() + _branch_id_required = False + _reuse_consumer = False - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, _branch_id=None, **kwargs): + cls._check_branch_id_given(args, _branch_id) operands = list(args) + if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): + _branch_id = operands.pop(-1) + elif _branch_id is None: + _branch_id = BranchId(0) + for parameter in cls._parameters[len(operands) :]: try: operands.append(kwargs.pop(parameter)) @@ -54,6 +66,7 @@ def __new__(cls, *args, **kwargs): assert not kwargs, kwargs inst = object.__new__(cls) inst.operands = [_unpack_collections(o) for o in operands] + inst._branch_id = _branch_id _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -61,6 +74,15 @@ def __new__(cls, *args, **kwargs): Expr._instances[_name] = inst return inst + @classmethod + def _check_branch_id_given(cls, args, _branch_id): + if not cls._branch_id_required: + return + operands = list(args) + if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): + _branch_id = operands.pop(-1) + assert _branch_id is not None, "BranchId not found" + def _tune_down(self): return None @@ -116,7 +138,10 @@ def _tree_repr_lines(self, indent=0, recursive=True): elif is_arraylike(op): op = "" header = self._tree_repr_argument_construction(i, op, header) - + if self._branch_id.branch_id != 0: + header = self._tree_repr_argument_construction( + i + 1, f" branch_id={self._branch_id.branch_id}", header + ) lines = [header] + lines lines = [" " * indent + line for line in lines] @@ -218,7 +243,7 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} - def rewrite(self, kind: str): + def rewrite(self, kind: str, cache): """Rewrite an expression This leverages the ``._{kind}_down`` and ``._{kind}_up`` @@ -231,6 +256,9 @@ def rewrite(self, kind: str): changed: whether or not any change occured """ + if self._name in cache: + return cache[self._name] + expr = self down_name = f"_{kind}_down" up_name = f"_{kind}_up" @@ -267,7 +295,8 @@ def rewrite(self, kind: str): changed = False for operand in expr.operands: if isinstance(operand, Expr): - new = operand.rewrite(kind=kind) + new = operand.rewrite(kind=kind, cache=cache) + cache[operand._name] = new if new._name != operand._name: changed = True else: @@ -275,13 +304,37 @@ def rewrite(self, kind: str): new_operands.append(new) if changed: - expr = type(expr)(*new_operands) + expr = type(expr)(*new_operands, _branch_id=expr._branch_id) continue else: break return expr + def _reuse_up(self, parent): + return + + def _reuse_down(self): + if not self.dependencies(): + return + return self._bubble_branch_id_down() + + def _bubble_branch_id_down(self): + b_id = self._branch_id + if b_id.branch_id <= 0: + return + if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): + ops = [ + op._substitute_branch_id(b_id) if isinstance(op, Expr) else op + for op in self.operands + ] + return type(self)(*ops) + + def _substitute_branch_id(self, branch_id): + if self._branch_id.branch_id != 0: + return self + return type(self)(*self.operands, branch_id) + def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression @@ -346,7 +399,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): new_operands.append(new) if changed: - expr = type(expr)(*new_operands) + expr = type(expr)(*new_operands, _branch_id=expr._branch_id) break @@ -391,7 +444,7 @@ def lower_once(self): new_operands.append(new) if changed: - out = type(out)(*new_operands) + out = type(out)(*new_operands, _branch_id=out._branch_id) return out @@ -426,6 +479,23 @@ def _lower(self): @functools.cached_property def _name(self): + return ( + funcname(type(self)).lower() + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) + + @functools.cached_property + def _dep_name(self): + # The name identifies every expression uniquely. The dependents name + # is used during optimization to capture the dependents of any given + # expression. A reuse consumer will have the same dependents independently + # of the branch_id parameter, since we want to reuse everything that comes + # before us and split branches up everything that is processed after + # us. So we have to ignore the branch_id from tokenization for those + # nodes. + if not self._reuse_consumer: + return self._name return ( funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) ) @@ -554,7 +624,7 @@ def _substitute(self, old, new, _seen): new_exprs.append(operand) if update: # Only recreate if something changed - return type(self)(*new_exprs) + return type(self)(*new_exprs, _branch_id=self._branch_id) else: _seen.add(self._name) return self @@ -580,7 +650,7 @@ def substitute_parameters(self, substitutions: dict) -> Expr: else: new_operands.append(operand) if changed: - return type(self)(*new_operands) + return type(self)(*new_operands, _branch_id=self._branch_id) return self def _node_label_args(self): @@ -741,5 +811,5 @@ def collect_dependents(expr) -> defaultdict: for dep in node.dependencies(): stack.append(dep) - dependents[dep._name].append(weakref.ref(node)) + dependents[dep._dep_name].append(weakref.ref(node)) return dependents diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 83682f19..f975dfd5 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -53,6 +53,7 @@ from tlz import merge_sorted, partition, unique from dask_expr import _core as core +from dask_expr._core import BranchId from dask_expr._util import ( _calc_maybe_new_divisions, _convert_to_list, @@ -421,7 +422,9 @@ def dtypes(self): def _filter_simplification(self, parent, predicate=None): if predicate is None: predicate = parent.predicate.substitute(self, self.frame) - return type(self)(self.frame[predicate], *self.operands[1:]) + return type(self)( + self.frame[predicate], *self.operands[1:], _branch_id=self._branch_id + ) class Literal(Expr): @@ -502,7 +505,7 @@ def _name(self): head = funcname(self.operation) else: head = funcname(type(self)).lower() - return head + "-" + _tokenize_deterministic(*self.operands) + return head + "-" + _tokenize_deterministic(*self.operands, self._branch_id) def _blockwise_arg(self, arg, i): """Return a Blockwise-task argument""" @@ -1821,7 +1824,7 @@ class Filter(Blockwise): def _simplify_up(self, parent, dependents): if isinstance(self.predicate, Or): result = rewrite_filters(self.predicate) - if result._name != self.predicate._name: + if result._dep_name != self.predicate._dep_name: return type(parent)( type(self)(self.frame, result), *parent.operands[1:] ) @@ -2087,7 +2090,7 @@ def _simplify_up(self, parent, dependents): ): parents = [ p().columns - for p in dependents[self._name] + for p in dependents[self._dep_name] if p() is not None and not isinstance(p(), Filter) ] predicate = None @@ -2118,7 +2121,7 @@ def _simplify_up(self, parent, dependents): return if all( isinstance(d(), Projection) and d().operand("columns") == col - for d in dependents[self._name] + for d in dependents[self._dep_name] ): return type(self)(self.frame, True, self.name) return @@ -2728,8 +2731,11 @@ class _DelayedExpr(Expr): # TODO _parameters = ["obj"] - def __init__(self, obj): + def __init__(self, obj, _branch_id=None): self.obj = obj + if _branch_id is None: + _branch_id = BranchId(0) + self._branch_id = _branch_id self.operands = [obj] def __str__(self): @@ -2758,18 +2764,29 @@ def normalize_expression(expr): return expr._name -def optimize_until(expr: Expr, stage: core.OptimizerStage) -> Expr: +def optimize_until( + expr: Expr, stage: core.OptimizerStage, common_subplan_elimination: bool = False +) -> Expr: result = expr if stage == "logical": return result - # Simplify - expr = result.simplify() + while True: + if not common_subplan_elimination: + out = result.rewrite("reuse", cache={}) + else: + out = result + out = out.simplify() + if out._name == result._name or common_subplan_elimination: + break + result = out + + expr = out if stage == "simplified-logical": return expr # Manipulate Expression to make it more efficient - expr = expr.rewrite(kind="tune") + expr = expr.rewrite(kind="tune", cache={}) if stage == "tuned-logical": return expr @@ -2791,7 +2808,9 @@ def optimize_until(expr: Expr, stage: core.OptimizerStage) -> Expr: raise ValueError(f"Stage {stage!r} not supported.") -def optimize(expr: Expr, fuse: bool = True) -> Expr: +def optimize( + expr: Expr, fuse: bool = True, common_subplan_elimination: bool = False +) -> Expr: """High level query optimization This leverages three optimization passes: @@ -2805,6 +2824,10 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr: Input expression to optimize fuse: whether or not to turn on blockwise fusion + common_subplan_elimination : bool, default False + whether we want to reuse common subplans that are found in the graph and + are used in self-joins or similar which require all data be held in memory + at some point. Only set this to true if your dataset fits into memory. See Also -------- @@ -2813,7 +2836,7 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr: """ stage: core.OptimizerStage = "fused" if fuse else "simplified-physical" - return optimize_until(expr, stage) + return optimize_until(expr, stage, common_subplan_elimination) def is_broadcastable(dfs, s): @@ -3195,7 +3218,13 @@ def _lower(self): from dask_expr._shuffle import RearrangeByColumn args = [ - RearrangeByColumn(df, None, npartitions, index_shuffle=True) + RearrangeByColumn( + df, + None, + npartitions, + index_shuffle=True, + _branch_id=self._branch_id, + ) if isinstance(df, Expr) else df for df in self.operands @@ -3462,7 +3491,7 @@ def __str__(self): @functools.cached_property def _name(self): - return f"{str(self)}-{_tokenize_deterministic(self.exprs)}" + return f"{str(self)}-{_tokenize_deterministic(self.exprs, self._branch_id)}" def _divisions(self): return self.exprs[0]._divisions() @@ -3513,13 +3542,13 @@ def determine_column_projection(expr, parent, dependents, additional_columns=Non column_union = [] else: column_union = parent.columns.copy() - parents = [x() for x in dependents[expr._name] if x() is not None] + parents = [x() for x in dependents[expr._dep_name] if x() is not None] seen = set() for p in parents: - if p._name in seen: + if p._dep_name in seen: continue - seen.add(p._name) + seen.add(p._dep_name) column_union.extend(p._projection_columns) @@ -3576,8 +3605,8 @@ def plain_column_projection(expr, parent, dependents, additional_columns=None): def is_filter_pushdown_available(expr, parent, dependents, allow_reduction=True): - parents = [x() for x in dependents[expr._name] if x() is not None] - filters = {e._name for e in parents if isinstance(e, Filter)} + parents = [x() for x in dependents[expr._dep_name] if x() is not None] + filters = {e._dep_name for e in parents if isinstance(e, Filter)} if len(filters) != 1: # Don't push down if not exactly one Filter return False @@ -3585,7 +3614,7 @@ def is_filter_pushdown_available(expr, parent, dependents, allow_reduction=True) return True # We have to see if the non-filter ops are all exclusively part of the predicates - others = {e._name for e in parents if not isinstance(e, Filter)} + others = {e._dep_name for e in parents if not isinstance(e, Filter)} return _check_dependents_are_predicates( expr, others, parent, dependents, allow_reduction ) @@ -3620,7 +3649,7 @@ def _get_predicate_components(predicate, components, type_=Or): def _convert_mapping(components): - return dict(zip([e._name for e in components], components)) + return dict(zip([e._dep_name for e in components], components)) def _replace_common_or_components(expr, or_components): @@ -3671,19 +3700,21 @@ def _check_dependents_are_predicates( # Walk down the predicate side from the filter to see if we can arrive at # other_names without hitting an expression that has other dependents that # are not part of the predicate, see test_filter_pushdown_unavailable - allowed_expressions = {parent._name} + allowed_expressions = {parent._dep_name} stack = parent.dependencies() seen = set() while stack: e = stack.pop() - if expr._name == e._name: + if expr._dep_name == e._dep_name: continue - if e._name in seen: + if e._dep_name in seen: continue - seen.add(e._name) + seen.add(e._dep_name) - e_dependents = {x()._name for x in dependents[e._name] if x() is not None} + e_dependents = { + x()._dep_name for x in dependents[e._dep_name] if x() is not None + } if not allow_reduction: if isinstance(e, (ApplyConcatApply, TreeReduce, ShuffleReduce)): @@ -3694,7 +3725,7 @@ def _check_dependents_are_predicates( continue return False - allowed_expressions.add(e._name) + allowed_expressions.add(e._dep_name) stack.extend(e.dependencies()) return other_names.issubset(allowed_expressions) diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 11c61818..5661b603 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -52,6 +52,7 @@ from dask.utils import M, apply, derived_from, is_index_like from dask_expr._collection import FrameBase, Index, Series, new_collection +from dask_expr._core import BranchId from dask_expr._expr import ( Assign, Blockwise, @@ -867,6 +868,8 @@ class GroupByApply(Expr, GroupByBase): "group_keys": True, "shuffle_method": None, } + _branch_id_required = True + _reuse_consumer = True @functools.cached_property def grp_func(self): @@ -878,6 +881,9 @@ def _meta(self): return make_meta(self.operand("meta"), parent_meta=self.frame._meta) return _meta_apply_transform(self, self.grp_func) + def _reuse_down(self): + return + def _divisions(self): if self.need_to_shuffle: return (None,) * (self.frame.npartitions + 1) @@ -925,6 +931,7 @@ def get_map_columns(df): [map_columns.get(c, c) for c in cols], df.npartitions, method=self.shuffle_method, + _branch_id=self._branch_id, ) if unmap_columns: @@ -951,6 +958,7 @@ def get_map_columns(df): map_columns.get(self.by[0], self.by[0]), self.npartitions, method=self.shuffle_method, + _branch_id=self._branch_id, ) if unmap_columns: @@ -1252,7 +1260,9 @@ def groupby_projection(expr, parent, dependents): if columns == expr.frame.columns: return return type(parent)( - type(expr)(expr.frame[columns], *expr.operands[1:]), + type(expr)( + expr.frame[columns], *expr.operands[1:], _branch_id=expr._branch_id + ), *parent.operands[1:], ) return @@ -1945,6 +1955,7 @@ def apply(self, func, *args, meta=no_default, shuffle_method=None, **kwargs): kwargs, shuffle_method, *self.by, + BranchId(0), ) ) @@ -1964,6 +1975,7 @@ def _transform_like_op( kwargs, shuffle_method, *self.by, + BranchId(0), ) ) @@ -2060,6 +2072,7 @@ def median( shuffle_method, split_every, *self.by, + BranchId(0), ) ) if split_out is not True: diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 3e47e104..39ec8d40 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -82,6 +82,8 @@ class Merge(Expr): "_npartitions": None, "broadcast": None, } + _branch_id_required = True + _reuse_consumer = True @property def _filter_passthrough(self): @@ -133,19 +135,19 @@ def _get_original_predicate_columns(self, predicate): seen = set() while stack: e = stack.pop() - if self._name == e._name: + if self._dep_name == e._dep_name: continue - if e._name in seen: + if e._dep_name in seen: continue - seen.add(e._name) + seen.add(e._dep_name) if isinstance(e, _DelayedExpr): continue dependencies = e.dependencies() stack.extend(dependencies) - if any(d._name == self._name for d in dependencies): + if any(d._dep_name == self._dep_name for d in dependencies): predicate_columns.update(e.columns) return predicate_columns @@ -309,6 +311,9 @@ def merge_indexed_right(self): self.right_index or _contains_index_name(self.right, self.right_on) ) and self.right.known_divisions + def _reuse_down(self): + return + def _lower(self): # Lower from an abstract expression left = self.left @@ -366,12 +371,14 @@ def _lower(self): left, shuffle_left_on, npartitions_out=left.npartitions, + _branch_id=self._branch_id, ) else: right = RearrangeByColumn( right, shuffle_right_on, npartitions_out=right.npartitions, + _branch_id=self._branch_id, ) return BroadcastJoin( @@ -404,6 +411,7 @@ def _lower(self): shuffle_left_on=shuffle_left_on, shuffle_right_on=shuffle_right_on, _npartitions=self.operand("_npartitions"), + _branch_id=self._branch_id, ) if shuffle_left_on: @@ -414,6 +422,7 @@ def _lower(self): npartitions_out=self._npartitions, method=shuffle_method, index_shuffle=left_index, + _branch_id=self._branch_id, ) if shuffle_right_on: @@ -424,6 +433,7 @@ def _lower(self): npartitions_out=self._npartitions, method=shuffle_method, index_shuffle=right_index, + _branch_id=self._branch_id, ) # Blockwise merge @@ -466,7 +476,9 @@ def _simplify_up(self, parent, dependents): if new_right is self.right and new_left is self.left: # don't drop the filter return - return type(self)(new_left, new_right, *self.operands[2:]) + return type(self)( + new_left, new_right, *self.operands[2:], _branch_id=self._branch_id + ) if isinstance(parent, (Projection, Index)): # Reorder the column projection to # occur before the Merge @@ -518,7 +530,10 @@ def _simplify_up(self, parent, dependents): right.columns ): result = type(self)( - left[project_left], right[project_right], *self.operands[2:] + left[project_left], + right[project_right], + *self.operands[2:], + _branch_id=self._branch_id, ) if parent_columns is None: return type(parent)(result) @@ -569,7 +584,7 @@ def _layer(self) -> dict: # Include self._name to ensure that shuffle IDs are unique for individual # merge operations. Reusing shuffles between merges is dangerous because of # required coordination and complexity introduced through dynamic clusters. - self._name, + self._dep_name, self.left._name, self.shuffle_left_on, self.left_index, @@ -578,7 +593,7 @@ def _layer(self) -> dict: # Include self._name to ensure that shuffle IDs are unique for individual # merge operations. Reusing shuffles between merges is dangerous because of # required coordination and complexity introduced through dynamic clusters. - self._name, + self._dep_name, self.right._name, self.shuffle_right_on, self.right_index, @@ -672,6 +687,7 @@ class BroadcastJoin(Merge, PartitionsFiltered): "indicator": False, "_partitions": None, } + _branch_id_required = False def _divisions(self): if self.broadcast_side == "left": @@ -806,6 +822,7 @@ class BlockwiseMerge(Merge, Blockwise): """ is_broadcast_join = False + _branch_id_required = False def _divisions(self): if self.left.npartitions == self.right.npartitions: @@ -863,6 +880,7 @@ def _lower(self): how=self.how, left_index=True, right_index=True, + _branch_id=self._branch_id, ) return self._recursive_join(self.frames) @@ -878,6 +896,7 @@ def _recursive_join(self, frames): how="outer", left_index=True, right_index=True, + _branch_id=self._branch_id, ) midx = len(frames) // 2 diff --git a/dask_expr/_merge_asof.py b/dask_expr/_merge_asof.py index 7efcc13a..2820ef26 100644 --- a/dask_expr/_merge_asof.py +++ b/dask_expr/_merge_asof.py @@ -40,6 +40,7 @@ class MergeAsof(Merge): "allow_exact_matches": True, "direction": "backward", } + _branch_id_required = False @functools.cached_property def _kwargs(self): diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 27f75d19..b41e624a 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -26,6 +26,7 @@ from dask.utils import M, apply, funcname from dask_expr._concat import Concat +from dask_expr._core import BranchId from dask_expr._expr import ( Blockwise, Expr, @@ -130,6 +131,9 @@ class ShuffleReduce(Expr): ApplyConcatApply """ + _branch_id_required = True + _reuse_consumer = True + _parameters = [ "frame", "kind", @@ -224,6 +228,7 @@ def _lower(self): ignore_index=ignore_index, index_shuffle=not split_by_index and self.shuffle_by_index, method=self.shuffle_method, + _branch_id=self._branch_id, ) # Unmap column names if necessary @@ -300,7 +305,7 @@ def _name(self): name = funcname(self.combine.__self__).lower() + "-tree" else: name = funcname(self.combine) - return name + "-" + _tokenize_deterministic(*self.operands) + return name + "-" + _tokenize_deterministic(*self.operands, self._branch_id) def __dask_postcompute__(self): return toolz.first, () @@ -411,6 +416,10 @@ def split_out(self): else: return 1 + @functools.cached_property + def _reuse_consumer(self): + return self.should_shuffle + def _layer(self): # This is an abstract expression raise NotImplementedError() @@ -505,8 +514,60 @@ def _lower(self): shuffle_by_index=getattr(self, "shuffle_by_index", None), shuffle_method=getattr(self, "shuffle_method", None), ignore_index=getattr(self, "ignore_index", True), + _branch_id=self._branch_id, ) + def _substitute_branch_id(self, branch_id): + if self._reuse_consumer: + # We are lowering into a Shuffle, so we are a consumer ourselves and + # we have to consume the branch_id of our parents + return super()._substitute_branch_id(branch_id) + return self + + def _reuse_down(self): + if self._branch_id.branch_id != 0: + return + + if self._reuse_consumer: + # We are lowering into a Shuffle, so we are a consumer ourselves + return + + from dask_expr.io import IO + + seen = set() + stack = self.dependencies() + counter, found_consumer = 1, False + + while stack: + node = stack.pop() + + if node._dep_name in seen: + continue + seen.add(node._dep_name) + + if isinstance(node, IO) or node._reuse_consumer: + found_consumer = True + continue + + if isinstance(node, IO): + found_consumer = True + continue + + if isinstance(node, ApplyConcatApply): + counter += 1 + continue + + stack.extend(node.dependencies()) + + if not found_consumer: + return + b_id = BranchId(counter) + result = type(self)(*self.operands, b_id) + out = result._bubble_branch_id_down() + if out is None: + return result + return type(out)(*out.operands, _branch_id=b_id) + class Unique(ApplyConcatApply): _parameters = ["frame", "split_every", "split_out", "shuffle_method"] @@ -591,7 +652,9 @@ def _simplify_up(self, parent, dependents): columns = [col for col in self.frame.columns if col in columns] return type(parent)( - type(self)(self.frame[columns], *self.operands[1:]), + type(self)( + self.frame[columns], *self.operands[1:], _branch_id=self._branch_id + ), *parent.operands[1:], ) diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index 710adc3f..07cfd7e1 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -84,6 +84,8 @@ class ShuffleBase(Expr): } _is_length_preserving = True _filter_passthrough = True + _branch_id_required = True + _reuse_consumer = True def __str__(self): return f"Shuffle({self._name[-7:]})" @@ -112,9 +114,11 @@ def _simplify_up(self, parent, dependents): if (col in partitioning_index or col in projection) ] if set(new_projection) < set(target.columns): - return type(self)(target[new_projection], *self.operands[1:])[ - parent.operand("columns") - ] + return type(self)( + target[new_projection], + *self.operands[1:], + _branch_id=self._branch_id, + )[parent.operand("columns")] if isinstance( parent, @@ -139,7 +143,8 @@ def _simplify_up(self, parent, dependents): MemoryUsage, ), ): - return type(parent)(self.frame, *parent.operands[1:]) + branch_id = None if not parent.should_shuffle else parent._branch_id + return type(parent)(self.frame, *parent.operands[1:], _branch_id=branch_id) def _layer(self): raise NotImplementedError( @@ -157,6 +162,10 @@ def _meta(self): def _divisions(self): return (None,) * (self.npartitions_out + 1) + def _reuse_down(self): + # TODO: What to do with task based shuffle? + return + class Shuffle(ShuffleBase): """Abstract shuffle class @@ -196,6 +205,7 @@ def _lower(self): self.npartitions_out, self.ignore_index, self.options, + self._branch_id, ] if method == "p2p": return P2PShuffle(frame, *ops) @@ -293,6 +303,7 @@ def _lower(self): ignore_index, self.method, options, + _branch_id=self._branch_id, ) if frame.ndim == 1: # Reduce back to series @@ -503,6 +514,11 @@ def _layer(self): class DiskShuffle(SimpleShuffle): """Disk-based shuffle implementation""" + @functools.cached_property + def _name(self): + # This is only used locally anyway, so don't bother with pipeline breakers + return self._dep_name + @staticmethod def _shuffle_group(df, col, _filter, p): with ensure_cleanup_on_exception(p): @@ -555,7 +571,8 @@ def _layer(self): ) dsk = {} - token = self._name.split("-")[-1] + # Ensure that shuffles with different branch_ids have the same barrier + token = self._dep_name.split("-")[-1] _barrier_key = barrier_key(ShuffleId(token)) name = "shuffle-transfer-" + token transfer_keys = list() @@ -1026,6 +1043,7 @@ def _lower(self): ignore_index=self.ignore_index, method=self.shuffle_method, options=self.options, + _branch_id=self._branch_id, ) shuffled = Projection(shuffled, self.frame.columns) return SortValuesBlockwise( @@ -1113,6 +1131,7 @@ def _lower(self): ignore_index=True, method=self.shuffle_method, options=self.options, + _branch_id=self._branch_id, ) shuffled = Projection( shuffled, [c for c in assigned.columns if c != "_partitions"] diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 1b6b34fe..e06682e6 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -48,7 +48,9 @@ def _divisions(self): @functools.cached_property def _name(self): return ( - self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands) + self.operand("name_prefix") + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) ) def _layer(self): @@ -103,7 +105,7 @@ def _name(self): return ( funcname(type(self.operand("_expr"))).lower() + "-fused-" - + _tokenize_deterministic(*self.operands) + + _tokenize_deterministic(*self.operands, self._expr._branch_id) ) @functools.cached_property @@ -173,10 +175,14 @@ def _name(self): return ( funcname(self.func).lower() + "-" - + _tokenize_deterministic(*self.operands) + + _tokenize_deterministic(*self.operands, self._branch_id) ) else: - return self.label + "-" + _tokenize_deterministic(*self.operands) + return ( + self.label + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) @functools.cached_property def _meta(self): @@ -448,7 +454,11 @@ class FromPandasDivisions(FromPandas): @functools.cached_property def _name(self): - return "from_pd_divs" + "-" + _tokenize_deterministic(*self.operands) + return ( + "from_pd_divs" + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) @property def _divisions_and_locations(self): diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 8cd9e8c3..a1905331 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -501,7 +501,7 @@ def _name(self): return ( funcname(type(self)).lower() + "-" - + _tokenize_deterministic(self.checksum, *self.operands) + + _tokenize_deterministic(self.checksum, *self.operands, self._branch_id) ) @property diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index 1f24bfda..b04a93af 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -5,6 +5,8 @@ from dask import config from dask.dataframe.utils import assert_eq as dd_assert_eq +from dask_expr.io import IO + def _backend_name() -> str: return config.get("dataframe.backend", "pandas") @@ -39,3 +41,12 @@ def assert_eq(a, b, *args, serialize_graph=True, **kwargs): # Use `dask.dataframe.assert_eq` return dd_assert_eq(a, b, *args, **kwargs) + + +def _check_consumer_node(expr, expected, consumer_node=IO, branch_id_counter=None): + if branch_id_counter is None: + branch_id_counter = expected + expr = expr.optimize(fuse=False) + io_nodes = list(expr.find_operations(consumer_node)) + assert len(io_nodes) == expected + assert len({node._branch_id.branch_id for node in io_nodes}) == branch_id_counter diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index c113432a..82178660 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -507,7 +507,7 @@ def test_diff(pdf, df, axis, periods): if axis in ("columns", 1): assert actual._name == actual.simplify()._name else: - assert actual.simplify()._name == expected.simplify()._name + assert actual.optimize()._name == expected.optimize()._name @pytest.mark.parametrize( @@ -942,7 +942,7 @@ def test_repr(df): s = (df["x"] + 1).sum(skipna=False).expr assert '["x"]' in str(s) or "['x']" in str(s) assert "+ 1" in str(s) - assert "sum(skipna=False)" in str(s) + assert "sum(skipna=False" in str(s) @xfail_gpu("combine_first not supported by cudf") @@ -1163,8 +1163,8 @@ def test_tail_repartition(df): def test_projection_stacking(df): result = df[["x", "y"]]["x"] - optimized = result.simplify() - expected = df["x"].simplify() + optimized = result.optimize() + expected = df["x"].optimize() assert optimized._name == expected._name @@ -1885,8 +1885,8 @@ def test_assign_simplify(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df[["x", "new"]].simplify() - expected = df2[["x"]].assign(new=df2.x > 1).simplify() + result = df[["x", "new"]].optimize() + expected = df2[["x"]].assign(new=df2.x > 1).optimize() assert result._name == expected._name pdf["new"] = pdf.x > 1 @@ -1897,8 +1897,8 @@ def test_assign_simplify_new_column_not_needed(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df[["x"]].simplify() - expected = df2[["x"]].simplify() + result = df[["x"]].optimize() + expected = df2[["x"]].optimize() assert result._name == expected._name pdf["new"] = pdf.x > 1 @@ -1909,8 +1909,8 @@ def test_assign_simplify_series(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df.new.simplify() - expected = df2[[]].assign(new=df2.x > 1).new.simplify() + result = df.new.optimize() + expected = df2[[]].assign(new=df2.x > 1).new.optimize() assert result._name == expected._name @@ -1928,7 +1928,16 @@ def test_assign_squash_together(df, pdf): df["a"] = 1 df["b"] = 2 result = df.simplify() - assert len([x for x in list(result.expr.walk()) if isinstance(x, expr.Assign)]) == 1 + assert ( + len( + [ + x + for x in list(df.optimize(fuse=False).expr.walk()) + if isinstance(x, expr.Assign) + ] + ) + == 1 + ) pdf["a"] = 1 pdf["b"] = 2 assert_eq(df, pdf) @@ -1973,10 +1982,10 @@ def test_astype_categories(df): assert_eq(result.y._meta.cat.categories, pd.Index([UNKNOWN_CATEGORIES])) -def test_drop_simplify(df): +def test_drop_optimize(df): q = df.drop(columns=["x"])[["y"]] - result = q.simplify() - expected = df[["y"]].simplify() + result = q.optimize() + expected = df[["y"]].optimize() assert result._name == expected._name @@ -2064,6 +2073,7 @@ def test_filter_pushdown_unavailable(df): result = df[df.x > 5] + df.x.sum() result = result[["x"]] expected = df[["x"]][df.x > 5] + df.x.sum() + assert result.optimize()._name == expected.optimize()._name assert result.simplify()._name == expected.simplify()._name @@ -2076,6 +2086,7 @@ def test_filter_pushdown(df, pdf): df = df.rename_axis(index="hello") result = df[df.x > 5].simplify() assert result._name == expected._name + assert result.optimize()._name == expected.optimize()._name pdf["z"] = 1 df = from_pandas(pdf, npartitions=10) @@ -2084,6 +2095,7 @@ def test_filter_pushdown(df, pdf): df_opt = df[["x", "y"]] expected = df_opt[df_opt.x > 5].rename_axis(index="hello").simplify() assert result._name == expected._name + assert result.optimize()._name == expected.optimize()._name def test_shape(df, pdf): @@ -2433,13 +2445,13 @@ def test_reset_index_filter_pushdown(df): result = q[q > 5] expected = df["x"] expected = expected[expected > 5].reset_index(drop=True) - assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name q = df.x.reset_index() result = q[q.x > 5] expected = df["x"] expected = expected[expected > 5].reset_index() - assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name def test_astype_filter_pushdown(df, pdf): diff --git a/dask_expr/tests/test_distributed.py b/dask_expr/tests/test_distributed.py index 2cf9aa97..0ab6b00d 100644 --- a/dask_expr/tests/test_distributed.py +++ b/dask_expr/tests/test_distributed.py @@ -3,8 +3,9 @@ import pytest from dask_expr import from_pandas, map_partitions, merge -from dask_expr._merge import BroadcastJoin -from dask_expr.tests._util import _backend_library +from dask_expr._merge import BroadcastJoin, HashJoinP2P +from dask_expr._shuffle import P2PShuffle +from dask_expr.tests._util import _backend_library, _check_consumer_node distributed = pytest.importorskip("distributed") @@ -354,3 +355,186 @@ def test_func(n): result = await result expected = pd.DataFrame({"a": [4951, 4952, 4953, 4954]}) pd.testing.assert_frame_equal(result, expected) + + +@gen_cluster(client=True) +async def test_p2p_shuffle_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + q = df.shuffle("a") + q = q.fillna(100) + q = q.a + q.a.sum() + # Only one IO node since shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 2, consumer_node=P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.fillna(100) + expected = expected.a + expected.a.sum() + pd.testing.assert_series_equal(x.sort_index(), expected) + + # Check that we have 1 shuffle barrier but 20 p2pshuffle tasks for the output + dsk = q.optimize(fuse=False).dask + keys = list(dsk.keys()) + assert ( + len( + list( + key for key in keys if isinstance(key, str) and "shuffle-barrier" in key + ) + ) + == 1 + ) + assert ( + len( + list( + key for key in keys if isinstance(key, tuple) and "p2pshuffle" in key[0] + ) + ) + == 20 + ) + + +@gen_cluster(client=True) +async def test_groupby_apply_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + q = df.groupby("a").apply(lambda x: x) + q = q.fillna(100) + q = q.a + q.a.sum() + # Only one IO node since shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 2, consumer_node=P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.groupby("a").apply(lambda x: x) + expected = expected.fillna(100) + expected = expected.a + expected.a.sum() + pd.testing.assert_series_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_groupby_sum_reuse_split_out(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + q = df.groupby("a").sum(split_out=True) + q = df + q.b.sum() + # Only one IO node since groupby-shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 1, consumer_node=P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.groupby("a").sum() + expected = pdf + expected.b.sum() + pd.testing.assert_frame_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_groupby_sum_no_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + # no split_out, so we can't reuse the groupby operation + q = df.groupby("a").sum() + q = df + q.b.sum() + # 2 IO Nodes, one for the groupby branch and one for the main branch + _check_consumer_node(q, 2) + x = c.compute(q) + x = await x + + expected = pdf.groupby("a").sum() + expected = pdf + expected.b.sum() + pd.testing.assert_frame_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_drop_duplicates_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + # no split_out, so we can't reuse the groupby operation + q = df.drop_duplicates(subset="a") + q = df + q.b.sum() + # Only one IO node since drop duplicates-shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 1, P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.drop_duplicates(subset="a") + expected = pdf + expected.b.sum() + pd.testing.assert_frame_equal(x.sort_index(), expected) + + q = df.drop_duplicates(subset="a", split_out=1) + q = df + q.b.sum() + # 2 IO nodes since reducer can't consume + _check_consumer_node(q, 2) + x = c.compute(q) + x = await x + + expected = pdf.drop_duplicates(subset="a") + expected = pdf + expected.b.sum() + pd.testing.assert_frame_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_groupby_ffill_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + q = df.groupby("a").ffill() + q = q.fillna(100) + q = q.b + q.b.sum() + # Only one IO node since shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 2, consumer_node=P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.groupby("a").ffill() + expected = expected.fillna(100) + expected = expected.b + expected.b.sum() + pd.testing.assert_series_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_merge_reuse(c, s, a, b): + pdf1 = pd.DataFrame({"a": [1, 2, 3, 4, 1, 2, 3, 4], "b": 1, "c": 1}) + pdf2 = pd.DataFrame({"a": [1, 2, 3, 4, 1, 2, 3, 4], "e": 1, "f": 1}) + + df1 = from_pandas(pdf1, npartitions=3) + df2 = from_pandas(pdf2, npartitions=3) + q = df1.merge(df2) + q = q.fillna(100) + q = q.b + q.b.sum() + _check_consumer_node(q, 2, HashJoinP2P) + # One on either side + _check_consumer_node(q, 2, branch_id_counter=1) + x = c.compute(q) + x = await x + expected = pdf1.merge(pdf2) + expected = expected.fillna(100) + expected = expected.b + expected.b.sum() + pd.testing.assert_series_equal(x.reset_index(drop=True), expected) + + # Check that we have 2 shuffle barriers (one for either side) for both merges but 6 + # hashjoinp2p tasks for the output + dsk = q.optimize(fuse=False).dask + keys = list(dsk.keys()) + assert ( + len( + list( + key for key in keys if isinstance(key, str) and "shuffle-barrier" in key + ) + ) + == 2 + ) + assert ( + len( + list( + key + for key in keys + if isinstance(key, tuple) and "hashjoinp2p" in key[0] + ) + ) + == 6 + ) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py new file mode 100644 index 00000000..c92ccea1 --- /dev/null +++ b/dask_expr/tests/test_reuse.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import pytest + +from dask_expr import from_pandas +from dask_expr._merge import BlockwiseMerge +from dask_expr._shuffle import DiskShuffle +from dask_expr.tests._util import _backend_library, _check_consumer_node, assert_eq + +# Set DataFrame backend for this module +pd = _backend_library() + + +@pytest.fixture +def pdf(): + pdf = pd.DataFrame({"x": range(100), "a": 1, "b": 1, "c": 1}) + pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions + yield pdf + + +@pytest.fixture +def df(pdf): + yield from_pandas(pdf, npartitions=10) + + +def test_reuse_everything_scalar_and_series(df, pdf): + df["new"] = 1 + df["new2"] = df["x"] + 1 + df["new3"] = df.x[df.x > 1] + df.x[df.x > 2] + + pdf["new"] = 1 + pdf["new2"] = pdf["x"] + 1 + pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] + assert_eq(df, pdf) + _check_consumer_node(df, 1) + + +def test_dont_reuse_reducer(df, pdf): + result = df.replace(1, 5) + result["new"] = result.x + result.y.sum() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + assert_eq(result, expected) + _check_consumer_node(result, 2) + + result = df + df.sum() + expected = pdf + pdf.sum() + assert_eq(result, expected, check_names=False) # pandas 2.2 bug + _check_consumer_node(result, 2) + + result = df.replace(1, 5) + rhs_1 = result.x + result.y.sum() + rhs_2 = result.b + result.a.sum() + result["new"] = rhs_1 + result["new2"] = rhs_2 + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + expected["new2"] = expected.b + expected.a.sum() + assert_eq(result, expected) + _check_consumer_node(result, 2) + + result = df.replace(1, 5) + result["new"] = result.x + result.y.sum() + result["new2"] = result.b + result.a.sum() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + expected["new2"] = expected.b + expected.a.sum() + assert_eq(result, expected) + _check_consumer_node(result, 3) + + result = df.replace(1, 5) + result["new"] = result.x + result.sum().dropna().prod() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.sum().dropna().prod() + assert_eq(result, expected) + _check_consumer_node(result, 2) + + +def test_disk_shuffle(df, pdf): + q = df.shuffle("a") + q = q.fillna(100) + q = q.a + q.a.sum() + q.optimize(fuse=False).pprint() + # Disk shuffle is not utilizing pipeline breakers + _check_consumer_node(q, 1, consumer_node=DiskShuffle) + _check_consumer_node(q, 1) + expected = pdf.fillna(100) + expected = expected.a + expected.a.sum() + assert_eq(q, expected) + + +def test_groupb_apply_disk_shuffle_reuse(df, pdf): + q = df.groupby("a").apply(lambda x: x) + q = q.fillna(100) + q = q.a + q.a.sum() + # Disk shuffle is not utilizing pipeline breakers + _check_consumer_node(q, 1, consumer_node=DiskShuffle) + _check_consumer_node(q, 1) + expected = pdf.groupby("a").apply(lambda x: x) + expected = expected.fillna(100) + expected = expected.a + expected.a.sum() + assert_eq(q, expected) + + +def test_groupb_ffill_disk_shuffle_reuse(df, pdf): + q = df.groupby("a").ffill() + q = q.fillna(100) + q = q.b + q.b.sum() + # Disk shuffle is not utilizing pipeline breakers + _check_consumer_node(q, 1, consumer_node=DiskShuffle) + _check_consumer_node(q, 1) + expected = pdf.groupby("a").ffill() + expected = expected.fillna(100) + expected = expected.b + expected.b.sum() + assert_eq(q, expected) + + +def test_merge_reuse(): + pdf1 = pd.DataFrame({"a": [1, 2, 3, 4, 1, 2, 3, 4], "b": 1, "c": 1}) + pdf2 = pd.DataFrame({"a": [1, 2, 3, 4, 1, 2, 3, 4], "e": 1, "f": 1}) + + df1 = from_pandas(pdf1, npartitions=3) + df2 = from_pandas(pdf2, npartitions=3) + q = df1.merge(df2) + q = q.fillna(100) + q = q.b + q.b.sum() + _check_consumer_node(q, 1, BlockwiseMerge) + # One on either side + _check_consumer_node(q, 2, DiskShuffle, branch_id_counter=1) + _check_consumer_node(q, 2, branch_id_counter=1) + + expected = pdf1.merge(pdf2) + expected = expected.fillna(100) + expected = expected.b + expected.b.sum() + assert_eq(q, expected, check_index=False) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 36277622..40592e54 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -137,7 +137,7 @@ def test_shuffle_column_projection(df): def test_shuffle_reductions(df): - assert df.shuffle("x").sum().simplify()._name == df.sum()._name + assert df.shuffle("x").sum().optimize()._name == df.sum().optimize()._name @pytest.mark.xfail(reason="Shuffle can't see the reduction through the Projection") @@ -264,7 +264,7 @@ def test_set_index_repartition(df, pdf): assert_eq(result, pdf.set_index("x")) -def test_set_index_simplify(df, pdf): +def test_set_index_optimize(df, pdf): q = df.set_index("x")["y"].optimize(fuse=False) expected = df[["x", "y"]].set_index("x")["y"].optimize(fuse=False) assert q._name == expected._name @@ -697,18 +697,21 @@ def test_shuffle_filter_pushdown(pdf, meth): result = result[result.x > 5.0] expected = getattr(df[df.x > 5.0], meth)("x") assert result.simplify()._name == expected._name + assert result.optimize()._name == expected.optimize()._name result = getattr(df, meth)("x") result = result[result.x > 5.0][["x", "y"]] expected = df[["x", "y"]] expected = getattr(expected[expected.x > 5.0], meth)("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name result = getattr(df, meth)("x")[["x", "y"]] result = result[result.x > 5.0] expected = df[["x", "y"]] expected = getattr(expected[expected.x > 5.0], meth)("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name @pytest.mark.parametrize("meth", ["set_index", "sort_values"]) @@ -716,7 +719,7 @@ def test_sort_values_avoid_overeager_filter_pushdown(meth): pdf1 = pd.DataFrame({"a": [4, 2, 3], "b": [1, 2, 3]}) df = from_pandas(pdf1, npartitions=2) df = getattr(df, meth)("a") - df = df[df.b > 2] + df.b.sum() + df = df[df.b > 2] + df[df.b > 1] result = df.simplify() assert isinstance(result.expr.left, Filter) assert isinstance(result.expr.left.frame, BaseSetIndexSortValues) @@ -729,18 +732,21 @@ def test_set_index_filter_pushdown(): result = result[result.y == 1] expected = df[df.y == 1].set_index("x") assert result.simplify()._name == expected._name + assert result.optimize()._name == expected.optimize()._name result = df.set_index("x") result = result[result.y == 1][["y"]] expected = df[["x", "y"]] expected = expected[expected.y == 1].set_index("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name result = df.set_index("x")[["y"]] result = result[result.y == 1] expected = df[["x", "y"]] expected = expected[expected.y == 1].set_index("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name def test_shuffle_index_shuffle(df):