Skip to content

Commit

Permalink
Merge pull request dask#569 from mrocklin/fuse-getitem
Browse files Browse the repository at this point in the history
Fuse getitem optimization
  • Loading branch information
mrocklin committed Aug 11, 2015
2 parents b4cf70d + 3f10b9c commit d0da36c
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 34 deletions.
31 changes: 11 additions & 20 deletions dask/dataframe/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,15 @@
a, b, c, d, e = '~a', '~b', '~c', '~d', '~e'
from ..rewrite import RuleSet, RewriteRule
from .io import dataframe_from_ctable
from ..optimize import cull, fuse, inline_functions
from ..optimize import cull, fuse, inline_functions, fuse_getitem
from ..core import istask
from ..utils import ignoring
from .. import core
from toolz import valmap
from operator import getitem
import operator


rules = [
# Merge column access into bcolz loading
RewriteRule((getitem, (dataframe_from_ctable, a, b, c, d), e),
(dataframe_from_ctable, a, b, e, d),
(a, b, c, d, e)),
]
with ignoring(ImportError):
from castra import Castra
rules.append(
RewriteRule((getitem, (Castra.load_partition, '~c', '~part', '~cols1'),
'~cols2'),
(Castra.load_partition, '~c', '~part', '~cols2'),
('~c', '~part', '~cols1', '~cols2')))

rewrite_rules = RuleSet(*rules)

fast_functions = [getattr(operator, attr) for attr in dir(operator)
if not attr.startswith('_')]

Expand All @@ -37,6 +22,12 @@ def optimize(dsk, keys, **kwargs):
else:
dsk2 = cull(dsk, [keys])
dsk3 = inline_functions(dsk2, fast_functions)
dsk4 = fuse(dsk3)
dsk5 = valmap(rewrite_rules.rewrite, dsk4)
return dsk5
try:
from castra import Castra
dsk4 = fuse_getitem(dsk3, Castra.load_partition, 3)
except ImportError:
dsk4 = dsk3
dsk5 = fuse_getitem(dsk4, dataframe_from_ctable, 3)
dsk6 = fuse(dsk5)
dsk7 = cull(dsk6, keys)
return dsk7
24 changes: 13 additions & 11 deletions dask/dataframe/tests/test_optimize_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from operator import getitem
from toolz import valmap
from dask.dataframe.optimize import rewrite_rules, dataframe_from_ctable
from toolz import valmap, merge
from dask.dataframe.optimize import dataframe_from_ctable
import dask.dataframe as dd
import pandas as pd

Expand All @@ -22,18 +22,20 @@ def test_column_optimizations_with_bcolz_and_rewrite():
bc = bcolz.ctable([[1, 2, 3], [10, 20, 30]], names=['a', 'b'])
func = lambda x: x
for cols in [None, 'abc', ['abc']]:
dsk2 = dict((('x', i),
(func,
(getitem,
(dataframe_from_ctable, bc, slice(0, 2), cols, {}),
(list, ['a', 'b']))))
for i in [1, 2, 3])
dsk2 = merge(dict((('x', i),
(dataframe_from_ctable, bc, slice(0, 2), cols, {}))
for i in [1, 2, 3]),
dict((('y', i),
(getitem, ('x', i), (list, ['a', 'b'])))
for i in [1, 2, 3]),
dict((('z', i), (func, ('y', i)))
for i in [1, 2, 3]))

expected = dict((('x', i), (func, (dataframe_from_ctable,
expected = dict((('z', i), (func, (dataframe_from_ctable,
bc, slice(0, 2), (list, ['a', 'b']), {})))
for i in [1, 2, 3])
result = valmap(rewrite_rules.rewrite, dsk2)
for i in [1, 2, 3])

result = dd.optimize(dsk2, [('z', i) for i in [1, 2, 3]])
assert result == expected


Expand Down
42 changes: 41 additions & 1 deletion dask/optimize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from itertools import count
from operator import getitem

from toolz import identity

from .compatibility import zip_longest
from .core import (istask, get_dependencies, subs, toposort, flatten,
reverse_dict, add, inc, ishashable, preorder_traversal)
from .rewrite import END
from toolz import identity


def cull(dsk, keys):
Expand Down Expand Up @@ -454,3 +457,40 @@ def merge_sync(dsk1, dsk2):

# store the name iterator in the function
merge_sync.names = ('merge_%d' % i for i in count(1))


def fuse_getitem(dsk, func, place):
""" Fuse getitem with lower operation
Parameters
----------
dsk: dict
dask graph
func: function
A function in a task to merge
place: int
Location in task to insert the getitem key
>>> def load(store, partition, columns):
... pass
>>> dsk = {'x': (load, 'store', 'part', ['a', 'b']),
... 'y': (getitem, 'x', 'a')}
>>> dsk2 = fuse_getitem(dsk, load, 3) # columns in arg place 3
>>> cull(dsk2, 'y')
{'y': (<function load at ...>, 'store', 'part', 'a')}
"""
dsk2 = dict()
seen = set()
for k, v in dsk.items():
try:
if (istask(v) and v[0] == getitem and v[1] in dsk and
istask(dsk[v[1]]) and dsk[v[1]][0] == func):
vv = list(dsk[v[1]])
vv[place] = v[2]
dsk2[k] = tuple(vv)
else:
dsk2[k] = v
except TypeError:
dsk2[k] = v
return dsk2
14 changes: 12 additions & 2 deletions dask/tests/test_optimize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from itertools import count
from operator import add, mul
from operator import add, mul, getitem
from toolz import partial, identity
from dask.utils import raises
from dask.optimize import (cull, fuse, inline, inline_functions, functions_of,
dealias, equivalent, sync_keys, merge_sync)
dealias, equivalent, sync_keys, merge_sync, fuse_getitem)


def inc(x):
Expand Down Expand Up @@ -312,3 +312,13 @@ def test_merge_sync():
'g2': (add, 'conflict', 3),
'h2': (add, 'merge_1', 3)}
assert key_map == {'h1': 'g1', 'conflict': 'merge_1', 'h2': 'h2'}


def test_fuse_getitem():
def load(*args):
pass
dsk = {'x': (load, 'store', 'part', ['a', 'b']),
'y': (getitem, 'x', 'a')}
dsk2 = fuse_getitem(dsk, load, 3)
dsk2 = cull(dsk2, 'y')
assert dsk2 == {'y': (load, 'store', 'part', 'a')}

0 comments on commit d0da36c

Please sign in to comment.