Skip to content

Commit c3c01ed

Browse files
committed
get rid of hard coded keys
1 parent 4f8e0e3 commit c3c01ed

File tree

13 files changed

+115
-95
lines changed

13 files changed

+115
-95
lines changed

dask_expr/_concat.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class StackPartition(Concat):
271271
_defaults = {"join": "outer", "ignore_order": False, "_kwargs": {}, "axis": 0}
272272

273273
def _layer(self):
274-
dsk, i = {}, 0
274+
dsk = {}
275275
kwargs = self._kwargs.copy()
276276
kwargs["ignore_order"] = self.ignore_order
277277
ctr = 0
@@ -282,16 +282,15 @@ def _layer(self):
282282
match = True
283283
except (ValueError, TypeError):
284284
match = False
285-
286-
for i in range(df.npartitions):
285+
for key in df.__dask_keys__():
287286
if match:
288-
dsk[(self._name, ctr)] = df._name, i
287+
dsk[(self._name, ctr)] = key
289288
else:
290289
dsk[(self._name, ctr)] = (
291290
apply,
292291
methods.concat,
293292
[
294-
[meta, (df._name, i)],
293+
[meta, key],
295294
self.axis,
296295
self.join,
297296
False,
@@ -321,7 +320,7 @@ def _layer(self):
321320
apply,
322321
methods.concat,
323322
[
324-
[(df._name, i) for df in dfs],
323+
[df.__dask_keys__()[i] for df in dfs],
325324
self.axis,
326325
self.join,
327326
False,

dask_expr/_cumulative.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,25 @@ def _meta(self):
7676
def _layer(self) -> dict:
7777
dsk = {}
7878
frame, previous_partitions = self.frame, self.previous_partitions
79-
dsk[(self._name, 0)] = (frame._name, 0)
79+
frame_keys = frame.__dask_keys__()
80+
previous_partitions_keys = previous_partitions.__dask_keys__()
81+
dsk[(self._name, 0)] = frame_keys[0]
8082

8183
intermediate_name = self._name + "-intermediate"
8284
for i in range(1, self.frame.npartitions):
8385
if i == 1:
84-
dsk[(intermediate_name, i)] = (previous_partitions._name, i - 1)
86+
dsk[(intermediate_name, i)] = previous_partitions_keys[i - 1]
8587
else:
8688
# aggregate with previous cumulation results
8789
dsk[(intermediate_name, i)] = (
8890
methods._cum_aggregate_apply,
8991
self.aggregator,
9092
(intermediate_name, i - 1),
91-
(previous_partitions._name, i - 1),
93+
previous_partitions_keys[i - 1],
9294
)
9395
dsk[(self._name, i)] = (
9496
self.aggregator,
95-
(self.frame._name, i),
97+
frame_keys[i],
9698
(intermediate_name, i),
9799
)
98100
return dsk

dask_expr/_expr.py

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,9 @@ def _blockwise_arg(self, arg, i):
517517
if isinstance(arg, Expr):
518518
# Make key for Expr-based argument
519519
if self._broadcast_dep(arg):
520-
return (arg._name, 0)
520+
return arg.__dask_keys__()[0]
521521
else:
522-
return (arg._name, i)
522+
return arg.__dask_keys__()[i]
523523

524524
else:
525525
return arg
@@ -636,11 +636,15 @@ def _get_meta_ufunc(dfs, args, func):
636636
raise NotImplementedError(msg)
637637
# For broadcastable series, use no rows.
638638
parts = [
639-
d._meta
640-
if d.ndim == 0
641-
else np.empty((), dtype=d.dtype)
642-
if isinstance(d, Array)
643-
else meta_nonempty(d._meta)
639+
(
640+
d._meta
641+
if d.ndim == 0
642+
else (
643+
np.empty((), dtype=d.dtype)
644+
if isinstance(d, Array)
645+
else meta_nonempty(d._meta)
646+
)
647+
)
644648
for d in dasks
645649
]
646650

@@ -895,12 +899,13 @@ def _layer(self) -> dict:
895899
dsk, prevs, nexts = {}, [], []
896900

897901
name_prepend = "overlap-prepend" + self.frame._name
902+
frame_keys = self.frame.__dask_keys__()
898903
if self.before:
899904
prevs.append(None)
900905
if isinstance(self.before, numbers.Integral):
901906
before = self.before
902907
for i in range(self.frame.npartitions - 1):
903-
dsk[(name_prepend, i)] = (M.tail, (self.frame._name, i), before)
908+
dsk[(name_prepend, i)] = (M.tail, frame_keys[i], before)
904909
prevs.append((name_prepend, i))
905910
elif isinstance(self.before, datetime.timedelta):
906911
# Assumes monotonic (increasing?) index
@@ -928,17 +933,17 @@ def _layer(self) -> dict:
928933

929934
dsk[(name_prepend, i)] = (
930935
_tail_timedelta,
931-
(self.frame._name, i + 1),
932-
[(self.frame._name, k) for k in range(j, i + 1)],
936+
frame_keys[i + 1],
937+
[frame_keys[k] for k in range(j, i + 1)],
933938
self.before,
934939
)
935940
prevs.append((name_prepend, i))
936941
else:
937942
for i in range(self.frame.npartitions - 1):
938943
dsk[(name_prepend, i)] = (
939944
_tail_timedelta,
940-
(self.frame._name, i + 1),
941-
[(self.frame._name, i)],
945+
frame_keys[i + 1],
946+
[frame_keys[i]],
942947
self.before,
943948
)
944949
prevs.append((name_prepend, i))
@@ -950,7 +955,7 @@ def _layer(self) -> dict:
950955
if isinstance(self.after, numbers.Integral):
951956
after = self.after
952957
for i in range(1, self.frame.npartitions):
953-
dsk[(name_append, i)] = (M.head, (self.frame._name, i), after)
958+
dsk[(name_append, i)] = (M.head, frame_keys[i], after)
954959
nexts.append((name_append, i))
955960
else:
956961
# We don't want to look at the divisions, so take twice the step and
@@ -959,8 +964,8 @@ def _layer(self) -> dict:
959964
for i in range(1, self.frame.npartitions):
960965
dsk[(name_append, i)] = (
961966
_head_timedelta,
962-
(self.frame._name, i - 1),
963-
(self.frame._name, i),
967+
frame_keys[i - 1],
968+
frame_keys[i],
964969
after,
965970
)
966971
nexts.append((name_append, i))
@@ -974,7 +979,7 @@ def _layer(self) -> dict:
974979
dsk[(self._name, i)] = (
975980
_combined_parts,
976981
prev,
977-
(self.frame._name, i),
982+
frame_keys[i],
978983
next,
979984
self.before,
980985
self.after,
@@ -1605,7 +1610,7 @@ def _task(self, index: int):
16051610
apply,
16061611
M.apply,
16071612
[
1608-
(self.frame._name, index),
1613+
self.frame.__dask_keys__()[index],
16091614
self.function,
16101615
]
16111616
+ list(self.args),
@@ -1825,9 +1830,11 @@ def _meta(self):
18251830
caselist = [
18261831
(
18271832
meta_nonempty(c[i]._meta) if isinstance(c[i], Expr) else c[i],
1828-
meta_nonempty(c[i + 1]._meta)
1829-
if isinstance(c[i + 1], Expr)
1830-
else c[i + 1],
1833+
(
1834+
meta_nonempty(c[i + 1]._meta)
1835+
if isinstance(c[i + 1], Expr)
1836+
else c[i + 1]
1837+
),
18311838
)
18321839
for i in range(0, len(c), 2)
18331840
]
@@ -1939,7 +1946,7 @@ def _projection_columns(self):
19391946
def _task(self, index: int):
19401947
return (
19411948
getattr,
1942-
(self.frame._name, index),
1949+
self.frame.__dask_keys__()[index],
19431950
"index",
19441951
)
19451952

@@ -1981,15 +1988,13 @@ def _divisions(self):
19811988
def _layer(self):
19821989
non_empties = [i for i, length in enumerate(self.lens) if length != 0]
19831990
# If all empty, collapse into one partition
1991+
frame_keys = self.frame.__dask_keys__()
19841992
if len(non_empties) == 0:
1985-
return {(self._name, 0): (self.frame._name, 0)}
1993+
return {(self._name, 0): frame_keys[0]}
19861994

19871995
# drop empty partitions by mapping each partition in a new graph to a particular
19881996
# partition on the old graph.
1989-
dsk = {
1990-
(self._name, i): (self.frame._name, div)
1991-
for i, div in enumerate(non_empties)
1992-
}
1997+
dsk = {(self._name, i): frame_keys[div] for i, div in enumerate(non_empties)}
19931998
ddf_keys = list(dsk.values())
19941999

19952000
overlap = [
@@ -2044,9 +2049,9 @@ def _simplify_down(self):
20442049

20452050
def _layer(self):
20462051
name = "part-" + self._name
2052+
20472053
dsk = {
2048-
(name, i): (len, (self.frame._name, i))
2049-
for i in range(self.frame.npartitions)
2054+
(name, i): (len, key) for i, key in enumerate(self.frame.__dask_keys__())
20502055
}
20512056
dsk[(self._name, 0)] = (tuple, list(dsk.keys()))
20522057
return dsk
@@ -2226,12 +2231,12 @@ def _task(self, index: int):
22262231
op = safe_head
22272232
else:
22282233
op = M.head
2229-
return (op, (self.frame._name, index), self.n)
2234+
return (op, self.frame.__dask_keys__()[index], self.n)
22302235

22312236

22322237
class BlockwiseHeadIndex(BlockwiseHead):
22332238
def _task(self, index: int):
2234-
return (operator.getitem, (self.frame._name, index), slice(0, self.n))
2239+
return (operator.getitem, self.frame.__dask_keys__()[index], slice(0, self.n))
22352240

22362241

22372242
class Tail(Expr):
@@ -2289,12 +2294,16 @@ def _divisions(self):
22892294
return self.frame.divisions
22902295

22912296
def _task(self, index: int):
2292-
return (M.tail, (self.frame._name, index), self.n)
2297+
return (M.tail, self.frame.__dask_keys__()[index], self.n)
22932298

22942299

22952300
class BlockwiseTailIndex(BlockwiseTail):
22962301
def _task(self, index: int):
2297-
return (operator.getitem, (self.frame._name, index), slice(-self.n, None))
2302+
return (
2303+
operator.getitem,
2304+
self.frame.__dask_keys__()[index],
2305+
slice(-self.n, None),
2306+
)
22982307

22992308

23002309
class Binop(Elemwise):
@@ -2582,16 +2591,18 @@ def _divisions(self):
25822591
return tuple(divisions)
25832592

25842593
def _task(self, index: int):
2585-
return (self.frame._name, self.partitions[index])
2594+
return self.frame.__dask_keys__()[self.partitions[index]]
25862595

25872596
def _simplify_down(self):
25882597
if isinstance(self.frame, Blockwise) and not isinstance(
25892598
self.frame, (BlockwiseIO, Fused)
25902599
):
25912600
operands = [
2592-
Partitions(op, self.partitions)
2593-
if (isinstance(op, Expr) and not self.frame._broadcast_dep(op))
2594-
else op
2601+
(
2602+
Partitions(op, self.partitions)
2603+
if (isinstance(op, Expr) and not self.frame._broadcast_dep(op))
2604+
else op
2605+
)
25952606
for op in self.frame.operands
25962607
]
25972608
return type(self.frame)(*operands)
@@ -3130,18 +3141,19 @@ def _broadcast_dep(self, dep: Expr):
31303141
return dep.npartitions == 1
31313142

31323143
def _task(self, index):
3133-
graph = {self._name: (self.exprs[0]._name, index)}
3144+
graph = {self._name: self.exprs[0].__dask_keys__()[index]}
31343145
for _expr in self.exprs:
3146+
frame_keys = _expr.__dask_keys__()
31353147
if isinstance(_expr, Fused):
31363148
subgraph, name = _expr._task(index)[1:3]
31373149
graph.update(subgraph)
31383150
graph[(name, index)] = name
31393151
elif self._broadcast_dep(_expr):
31403152
# When _expr is being broadcasted, we only
31413153
# want to define a fused task for index 0
3142-
graph[(_expr._name, 0)] = _expr._task(0)
3154+
graph[frame_keys[0]] = _expr._task(0)
31433155
else:
3144-
graph[(_expr._name, index)] = _expr._task(index)
3156+
graph[frame_keys[index]] = _expr._task(index)
31453157

31463158
for i, dep in enumerate(self.dependencies()):
31473159
graph[self._blockwise_arg(dep, index)] = "_" + str(i)
@@ -3247,9 +3259,11 @@ def maybe_align_partitions(*exprs, divisions):
32473259
from dask_expr._repartition import Repartition
32483260

32493261
return [
3250-
Repartition(df, new_divisions=divisions, force=True)
3251-
if isinstance(df, Expr) and df.ndim > 0
3252-
else df
3262+
(
3263+
Repartition(df, new_divisions=divisions, force=True)
3264+
if isinstance(df, Expr) and df.ndim > 0
3265+
else df
3266+
)
32533267
for df in exprs
32543268
]
32553269

dask_expr/_groupby.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,25 +1316,27 @@ def _divisions(self):
13161316
return self.frame.divisions
13171317

13181318
def _layer(self) -> dict:
1319-
dsk = {(self._name, 0): (self.cum_raw._name, 0)}
1319+
dsk = {(self._name, 0): self.cum_raw.__dask_keys__()[0]}
13201320
name_cum = "cum-last" + self._name
1321-
1322-
for i in range(1, self.frame.npartitions):
1321+
cum_last_keys = self.cum_last.__dask_keys__()
1322+
for i, frame_key in enumerate(self.frame.__dask_keys__()):
1323+
if i == 0:
1324+
continue
13231325
# store each cumulative step to graph to reduce computation
13241326
if i == 1:
1325-
dsk[(name_cum, i)] = (self.cum_last._name, i - 1)
1327+
dsk[(name_cum, i)] = cum_last_keys[i - 1]
13261328
else:
13271329
# aggregate with previous cumulation results
13281330
dsk[(name_cum, i)] = (
13291331
_cum_agg_filled,
13301332
(name_cum, i - 1),
1331-
(self.cum_last._name, i - 1),
1333+
cum_last_keys[i - 1],
13321334
self.aggregate,
13331335
self.initial,
13341336
)
13351337
dsk[(self._name, i)] = (
13361338
_cum_agg_aligned,
1337-
(self.frame._name, i),
1339+
frame_key,
13381340
(name_cum, i),
13391341
self.by,
13401342
self.operand("columns"),

0 commit comments

Comments
 (0)