/
optimize.py
348 lines (263 loc) · 10.2 KB
/
optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
from __future__ import annotations
import copy
import logging
import warnings
from collections.abc import Hashable, Mapping
from typing import TYPE_CHECKING, Any
import dask.config
from dask.blockwise import fuse_roots, optimize_blockwise
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph
from dask.local import get_sync
from dask_awkward.layers import AwkwardInputLayer
log = logging.getLogger(__name__)
if TYPE_CHECKING:
from awkward import Array as AwkwardArray
def all_optimizations(
dsk: Mapping,
keys: Hashable | list[Hashable] | set[Hashable],
**_: Any,
) -> Mapping:
"""Run all optimizations that benefit dask-awkward computations.
This function will run both dask-awkward specific and upstream
general optimizations from core dask.
"""
if not isinstance(keys, (list, set)):
keys = (keys,) # pragma: no cover
keys = tuple(flatten(keys))
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
else:
# Perform dask-awkward specific optimizations.
dsk = optimize(dsk, keys=keys)
# Perform Blockwise optimizations for HLG input
dsk = optimize_blockwise(dsk, keys=keys)
# fuse nearby layers
dsk = fuse_roots(dsk, keys=keys) # type: ignore
# cull unncessary tasks
dsk = dsk.cull(set(keys)) # type: ignore
return dsk
def optimize(
dsk: Mapping,
keys: Hashable | list[Hashable] | set[Hashable],
**_: Any,
) -> Mapping:
"""Run optimizations specific to dask-awkward.
This is currently limited to determining the necessary columns for
input layers.
"""
if dask.config.get("awkward.optimization.enabled", default=False):
dsk = optimize_columns(dsk) # type: ignore
return dsk
def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
"""Run column projection optimization.
This optimization determines which columns from an
``AwkwardInputLayer`` are necessary for a complete computation.
For example, if a parquet dataset is loaded with fields:
``["foo", "bar", "baz.x", "baz.y"]``
And the following task graph is made:
>>> ds = dak.from_parquet("/path/to/dataset")
>>> z = ds["foo"] - ds["baz"]["y"]
Upon calling z.compute() the AwkwardInputLayer created in the
from_parquet call will only read the parquet columns ``foo`` and
``baz.y``.
Parameters
----------
dsk : HighLevelGraph
Original high level dask graph
Returns
-------
HighLevelGraph
New dask graph with a modified ``AwkwardInputLayer``.
"""
layers = dsk.layers.copy() # type: ignore
deps = dsk.dependencies.copy() # type: ignore
layer_to_necessary_columns = _necessary_columns(dsk)
for name, neccols in layer_to_necessary_columns.items():
meta = layers[name]._meta
neccols = _prune_wildcards(neccols, meta)
layers[name] = layers[name].project_columns(neccols)
return HighLevelGraph(layers, deps)
def _projectable_input_layer_names(dsk: HighLevelGraph) -> list[str]:
"""Get list of column-projectable AwkwardInputLayer names.
Parameters
----------
dsk : HighLevelGraph
Task graph of interest
Returns
-------
list[str]
Names of the AwkwardInputLayers in the graph that are
column-projectable.
"""
return [
n
for n, v in dsk.layers.items()
if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns")
]
def _layers_with_annotation(dsk: HighLevelGraph, key: str) -> list[str]:
return [n for n, v in dsk.layers.items() if (v.annotations or {}).get(key)]
def _ak_output_layer_names(dsk: HighLevelGraph) -> list[str]:
"""Get a list output layer names.
Output layer names are annotated with 'ak_output'.
Parameters
----------
dsk : HighLevelGraph
Graph of interest.
Returns
-------
list[str]
Names of the output layers.
"""
return _layers_with_annotation(dsk, "ak_output")
def _opt_touch_all_layer_names(dsk: HighLevelGraph) -> list[str]:
return [n for n, v in dsk.layers.items() if hasattr(v, "_opt_touch_all")]
# return _layers_with_annotation(dsk, "ak_touch_all")
def _has_projectable_awkward_io_layer(dsk: HighLevelGraph) -> bool:
"""Check if a graph at least one AwkwardInputLayer that is project-able."""
for _, v in dsk.layers.items():
if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns"):
return True
return False
def _touch_all_data(*args, **kwargs):
"""Mock writing an ak.Array to disk by touching data buffers."""
import awkward as ak
for arg in args + tuple(kwargs.values()):
if isinstance(arg, ak.Array):
arg.layout._touch_data(recursive=True)
def _mock_output(layer):
"""Update a layer to run the _touch_all_data."""
assert len(layer.dsk) == 1
new_layer = copy.deepcopy(layer)
mp = new_layer.mapping.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_all_data,) + mp[k][1:]
new_layer.mapping = mp
return new_layer
def _touch_and_call_fn(fn, *args, **kwargs):
_touch_all_data(*args, **kwargs)
return fn(*args, **kwargs)
def _touch_and_call(layer):
assert len(layer.dsk) == 1
new_layer = copy.deepcopy(layer)
mp = new_layer.mapping.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_and_call_fn,) + mp[k]
new_layer.mapping = mp
return new_layer
def _get_column_reports(dsk: HighLevelGraph) -> dict[str, Any]:
"""Get the TypeTracerReport for each input layer in a task graph."""
if not _has_projectable_awkward_io_layer(dsk):
return {}
import awkward as ak
layers = dsk.layers.copy() # type: ignore
deps = dsk.dependencies.copy() # type: ignore
reports = {}
# make labelled report
for name in _projectable_input_layer_names(dsk):
layers[name], report = layers[name].mock()
reports[name] = report
for name in _ak_output_layer_names(dsk):
layers[name] = _mock_output(layers[name])
for name in _opt_touch_all_layer_names(dsk):
layers[name] = _touch_and_call(layers[name])
hlg = HighLevelGraph(layers, deps)
outlayer = hlg.layers[hlg._toposort_layers()[-1]]
try:
out = get_sync(hlg, list(outlayer.keys())[0])
except Exception as err:
on_fail = dask.config.get("awkward.optimization.on-fail")
# this is the default, throw a warning but skip the optimization.
if on_fail == "warn":
warnings.warn(f"Column projection optimization failed: {type(err)}, {err}")
return {}
# option "pass" means do not throw warning but skip the optimization.
elif on_fail == "pass":
log.debug("Column projection optimization failed; optimization skipped.")
return {}
# option "raise" to raise the exception here
elif on_fail == "raise":
raise
else:
raise ValueError(
f"Invalid awkward.optimization.on-fail option: {on_fail}.\n"
"Valid options are 'warn', 'pass', or 'raise'."
)
if isinstance(out, ak.Array):
out.layout._touch_data(recursive=True)
return reports
def _necessary_columns(dsk: HighLevelGraph) -> dict[str, list[str]]:
"""Pair layer names with lists of necessary columns."""
kv = {}
for name, report in _get_column_reports(dsk).items():
cols = {_ for _ in report.data_touched if _ is not None}
select = []
for col in sorted(cols):
if col == name:
continue
n, c = col.split(".", 1)
if n == name:
if c.endswith("__list__"):
cnew = c[:-9].rstrip(".")
if cnew not in select:
select.append(f"{cnew}.*")
else:
select.append(c)
kv[name] = select
return kv
def _prune_wildcards(columns: list[str], meta: AwkwardArray) -> list[str]:
"""Prune wildcard '.*' suffix from necessary columns results.
The _necessary_columns logic will provide some results of the
form:
"foo.bar.*"
This function will eliminate the wildcard in one of two ways
(continuing to use "foo.bar.*" as an example):
1. If "foo.bar" has leaves (subfields) "x", "y" and "z", and _any_
of those (so "foo.bar.x", for example) also appears in the
columns list, then essentially nothing will happen (except we
drop the wildcard string), because we can be sure that a leaf
of "foo.bar" will be read (in this case it's "foo.bar.x").
2. If "foo.bar" has multiple leaves but none of them appear in the
columns list, we will just pick the first one that we find
(that is, foo.bar.fields[0]).
Parameters
----------
columns : list[str]
The "raw" columns deemed necessary by the necessary columns
logic; can still contain the wildcard syntax we've adopted.
meta : ak.Array
The metadata (typetracer array) from the AwkwardInputLayer
that is getting optimized.
Returns
-------
list[str]
Columns with the wildcard syntax pruned and (also augmented
with a leaf node if necessary).
"""
good_columns: list[str] = []
wildcard_columns: list[str] = []
for col in columns:
if ".*" in col:
wildcard_columns.append(col)
else:
good_columns.append(col)
for col in wildcard_columns:
# each time we meet a wildcard column we need to start back
# with the original meta array.
imeta = meta
colsplit = col.split(".")[:-1]
parts = list(reversed(colsplit))
while parts:
part = parts.pop()
# for unnamed roots part may be an empty string, so we
# need this if statement.
if part:
imeta = imeta[part]
for field in imeta.fields:
wholecol = f"{col[:-2]}.{field}"
if wholecol in good_columns:
break
else:
if imeta.fields:
good_columns.append(f"{col[:-2]}.{imeta.fields[0]}")
return good_columns