Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNM: Full branch_id implementation #896

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
fae5c6e
Implement branch_id to limit reuse
phofl Feb 15, 2024
d598734
Update
phofl Feb 15, 2024
2345cd4
Merge remote-tracking branch 'upstream/main' into test
phofl Feb 16, 2024
d88270d
Fix delayed
phofl Feb 16, 2024
948cd83
Update
phofl Feb 16, 2024
045bbef
Update
phofl Feb 16, 2024
93e0d28
Add cache
phofl Feb 16, 2024
7ddda99
Enhance tests
phofl Feb 16, 2024
8c2d977
Add tests
phofl Feb 16, 2024
7184bcf
Update
phofl Feb 16, 2024
5ac9394
Update
phofl Feb 16, 2024
fb2aa9f
Update
phofl Feb 16, 2024
061de6f
Update
phofl Feb 16, 2024
e486590
Update
phofl Feb 16, 2024
d28f906
Update _core.py
phofl Feb 19, 2024
366415a
Update test_groupby.py
phofl Feb 19, 2024
ee523ea
Update
phofl Feb 19, 2024
369c142
Update
phofl Feb 19, 2024
5ee43dd
Merge branch 'main' into branch_id_implementation
phofl Feb 19, 2024
7379a01
Remove argument_operands
phofl Feb 20, 2024
68e048c
Update
phofl Feb 20, 2024
f79155a
Update
phofl Feb 20, 2024
4801a93
Implement shuffles as consumer
phofl Feb 20, 2024
9fcc246
Tighten test
phofl Feb 20, 2024
391d8f6
Update
phofl Feb 21, 2024
4326a25
Update
phofl Feb 21, 2024
1b6b090
Merge remote-tracking branch 'upstream/main' into branch_id_implement…
phofl Feb 21, 2024
c9e0384
Update
phofl Feb 21, 2024
5531985
Merge remote-tracking branch 'origin/branch_id_implementation' into b…
phofl Feb 24, 2024
6891478
Merge remote-tracking branch 'upstream/main' into branch_id_implement…
phofl Feb 24, 2024
d70ba0f
Implement shuffle methods as consumers for branch_id
phofl Feb 24, 2024
cc120ee
Update
phofl Feb 24, 2024
8ba433c
Merge branch 'branch_id_implementation' into branch_id_implementation…
phofl Feb 24, 2024
9257b72
Merge remote-tracking branch 'upstream/main' into branch_id_implement…
phofl Feb 24, 2024
b665ce1
Merge branch 'branch_id_implementation' into branch_id_implementation…
phofl Feb 24, 2024
2bbeb2e
Remove unnecessary changes
phofl Feb 24, 2024
451dca0
Simplify variable
phofl Feb 24, 2024
150f99c
Make reuse step easier
phofl Feb 24, 2024
12432da
Make reuse step easier
phofl Feb 24, 2024
9cbcec3
Merge branch 'branch_id_implementation' into branch_id_implementation…
phofl Feb 24, 2024
8bc45b9
Update
phofl Feb 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ def shuffle(
shuffle_method,
options,
index_shuffle=on_index,
_branch_id=expr.BranchId(0),
)
)

Expand Down Expand Up @@ -4780,6 +4781,7 @@ def merge(
shuffle_method=shuffle_method,
_npartitions=npartitions,
broadcast=broadcast,
_branch_id=expr.BranchId(0),
)
)

Expand Down Expand Up @@ -4866,7 +4868,7 @@ def merge_asof(

from dask_expr._merge_asof import MergeAsof

return new_collection(MergeAsof(left, right, **kwargs))
return new_collection(MergeAsof(left, right, **kwargs, _branch_id=expr.BranchId(0)))


def from_map(
Expand Down
92 changes: 81 additions & 11 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import weakref
from collections import defaultdict
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, NamedTuple

import dask
import pandas as pd
Expand All @@ -29,6 +29,10 @@
]


class BranchId(NamedTuple):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you prefer a NamedTuple over a NewType('BranchId', int) here?

branch_id: int


def _unpack_collections(o):
if isinstance(o, Expr):
return o
Expand All @@ -43,9 +47,17 @@ class Expr:
_parameters = []
_defaults = {}
_instances = weakref.WeakValueDictionary()
_branch_id_required = False
_reuse_consumer = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is ambiguous. Can we come up with something more descriptive?


def __new__(cls, *args, **kwargs):
def __new__(cls, *args, _branch_id=None, **kwargs):
cls._check_branch_id_given(args, _branch_id)
operands = list(args)
if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId):
_branch_id = operands.pop(-1)
elif _branch_id is None:
_branch_id = BranchId(0)

for parameter in cls._parameters[len(operands) :]:
try:
operands.append(kwargs.pop(parameter))
Expand All @@ -54,13 +66,23 @@ def __new__(cls, *args, **kwargs):
assert not kwargs, kwargs
inst = object.__new__(cls)
inst.operands = [_unpack_collections(o) for o in operands]
inst._branch_id = _branch_id
_name = inst._name
if _name in Expr._instances:
return Expr._instances[_name]

Expr._instances[_name] = inst
return inst

@classmethod
def _check_branch_id_given(cls, args, _branch_id):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def _check_branch_id_given(cls, args, _branch_id):
def _maybe_check_branch_id_given(cls, args, _branch_id):

if not cls._branch_id_required:
return
operands = list(args)
if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId):
_branch_id = operands.pop(-1)
assert _branch_id is not None, "BranchId not found"

def _tune_down(self):
return None

Expand Down Expand Up @@ -116,7 +138,10 @@ def _tree_repr_lines(self, indent=0, recursive=True):
elif is_arraylike(op):
op = "<array>"
header = self._tree_repr_argument_construction(i, op, header)

if self._branch_id.branch_id != 0:
header = self._tree_repr_argument_construction(
i + 1, f" branch_id={self._branch_id.branch_id}", header
)
lines = [header] + lines
lines = [" " * indent + line for line in lines]

Expand Down Expand Up @@ -218,7 +243,7 @@ def _layer(self) -> dict:

return {(self._name, i): self._task(i) for i in range(self.npartitions)}

def rewrite(self, kind: str):
def rewrite(self, kind: str, cache):
"""Rewrite an expression

This leverages the ``._{kind}_down`` and ``._{kind}_up``
Expand All @@ -231,6 +256,9 @@ def rewrite(self, kind: str):
changed:
whether or not any change occured
"""
if self._name in cache:
return cache[self._name]

expr = self
down_name = f"_{kind}_down"
up_name = f"_{kind}_up"
Expand Down Expand Up @@ -267,21 +295,46 @@ def rewrite(self, kind: str):
changed = False
for operand in expr.operands:
if isinstance(operand, Expr):
new = operand.rewrite(kind=kind)
new = operand.rewrite(kind=kind, cache=cache)
cache[operand._name] = new
if new._name != operand._name:
changed = True
else:
new = operand
new_operands.append(new)

if changed:
expr = type(expr)(*new_operands)
expr = type(expr)(*new_operands, _branch_id=expr._branch_id)
continue
else:
break

return expr

def _reuse_up(self, parent):
return

def _reuse_down(self):
if not self.dependencies():
return
return self._bubble_branch_id_down()

def _bubble_branch_id_down(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def _bubble_branch_id_down(self):
def _propagate_branch_id_down(self):

b_id = self._branch_id
if b_id.branch_id <= 0:
return
if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()):
ops = [
op._substitute_branch_id(b_id) if isinstance(op, Expr) else op
for op in self.operands
]
return type(self)(*ops)

def _substitute_branch_id(self, branch_id):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def _substitute_branch_id(self, branch_id):
def _maybe_substitute_branch_id(self, branch_id):

or something else that highlights the conditionality.

if self._branch_id.branch_id != 0:
return self
return type(self)(*self.operands, branch_id)

def simplify_once(self, dependents: defaultdict, simplified: dict):
"""Simplify an expression

Expand Down Expand Up @@ -346,7 +399,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict):
new_operands.append(new)

if changed:
expr = type(expr)(*new_operands)
expr = type(expr)(*new_operands, _branch_id=expr._branch_id)

break

Expand Down Expand Up @@ -391,7 +444,7 @@ def lower_once(self):
new_operands.append(new)

if changed:
out = type(out)(*new_operands)
out = type(out)(*new_operands, _branch_id=out._branch_id)

return out

Expand Down Expand Up @@ -426,6 +479,23 @@ def _lower(self):

@functools.cached_property
def _name(self):
return (
funcname(type(self)).lower()
+ "-"
+ _tokenize_deterministic(*self.operands, self._branch_id)
)

@functools.cached_property
def _dep_name(self):
# The name identifies every expression uniquely. The dependents name
# is used during optimization to capture the dependents of any given
# expression. A reuse consumer will have the same dependents independently
# of the branch_id parameter, since we want to reuse everything that comes
# before us and split branches up everything that is processed after
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a word missing, maybe for?

Suggested change
# before us and split branches up everything that is processed after
# before us and split branches up for everything that is processed after

# us. So we have to ignore the branch_id from tokenization for those
# nodes.
if not self._reuse_consumer:
return self._name
return (
funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands)
)
Comment on lines +482 to 501
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels prone to errors/inconsistencies when subclassing. Would it make sense to define a property _dep_name_tokens that could be overriden and a property _name_tokens that just always adds branch_id to the _dep_name_tokens? This could then feed into a common function used to generate the name using the tokens as input.

For example, FromGraph already implements a new _name but not a new _dep_name.

Expand Down Expand Up @@ -554,7 +624,7 @@ def _substitute(self, old, new, _seen):
new_exprs.append(operand)

if update: # Only recreate if something changed
return type(self)(*new_exprs)
return type(self)(*new_exprs, _branch_id=self._branch_id)
else:
_seen.add(self._name)
return self
Expand All @@ -580,7 +650,7 @@ def substitute_parameters(self, substitutions: dict) -> Expr:
else:
new_operands.append(operand)
if changed:
return type(self)(*new_operands)
return type(self)(*new_operands, _branch_id=self._branch_id)
return self

def _node_label_args(self):
Expand Down Expand Up @@ -741,5 +811,5 @@ def collect_dependents(expr) -> defaultdict:

for dep in node.dependencies():
stack.append(dep)
dependents[dep._name].append(weakref.ref(node))
dependents[dep._dep_name].append(weakref.ref(node))
return dependents
Loading
Loading