Skip to content

Commit

Permalink
More fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Apr 19, 2019
1 parent a3e78d1 commit 42395b6
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
4 changes: 3 additions & 1 deletion dask/dataframe/core.py
Expand Up @@ -3789,7 +3789,9 @@ def map_partitions(func, *args, **kwargs):
meta_index = getattr(make_meta(dfs[0]), 'index', None) if dfs else None

if meta is no_default:
meta = _emulate(func, *args, udf=True, **kwargs2)
# Use non-normalized kwargs here, as we want the real values (not
# delayed values)
meta = _emulate(func, *args, udf=True, **kwargs)
else:
meta = make_meta(meta, index=meta_index)

Expand Down
34 changes: 24 additions & 10 deletions dask/dataframe/multi.py
Expand Up @@ -70,7 +70,6 @@
from . import methods
from .shuffle import shuffle, rearrange_by_divisions
from .utils import strip_unknown_categories
from ..utils import M


def align_partitions(*dfs):
Expand Down Expand Up @@ -203,6 +202,17 @@ def require(divisions, parts, required=None):
required = {'left': [0], 'right': [1], 'inner': [0, 1], 'outer': []}


def merge_chunk(lhs, *args, **kwargs):
empty_index_dtype = kwargs.pop('empty_index_dtype', None)
out = lhs.merge(*args, **kwargs)
# Workaround pandas bug where if the output result of a merge operation is
# an empty dataframe, the output index is `int64` in all cases, regardless
# of input dtypes.
if len(out) == 0 and empty_index_dtype is not None:
out.index = out.index.astype(empty_index_dtype)
return out


def merge_indexed_dataframes(lhs, rhs, left_index=True, right_index=True, **kwargs):
""" Join two partitioned dataframes along their index """
how = kwargs.get('how', 'left')
Expand All @@ -214,11 +224,12 @@ def merge_indexed_dataframes(lhs, rhs, left_index=True, right_index=True, **kwar

name = 'join-indexed-' + tokenize(lhs, rhs, **kwargs)

meta = lhs._meta_nonempty.merge(rhs._meta_nonempty, **kwargs)
kwargs['empty_index_dtype'] = meta.index.dtype

dsk = dict()
for i, (a, b) in enumerate(parts):
dsk[(name, i)] = (apply, M.merge, [a, b], kwargs)

meta = lhs._meta_nonempty.merge(rhs._meta_nonempty, **kwargs)
dsk[(name, i)] = (apply, merge_chunk, [a, b], kwargs)

graph = HighLevelGraph.from_collections(name, dsk, dependencies=[lhs, rhs])
return new_dd_object(graph, name, meta, divisions)
Expand Down Expand Up @@ -270,7 +281,8 @@ def hash_join(lhs, left_on, rhs, right_on, how='inner',
token = tokenize(lhs2, rhs2, npartitions, shuffle, **kwargs)
name = 'hash-join-' + token

dsk = {(name, i): (apply, M.merge, [(lhs2._name, i), (rhs2._name, i)], kwargs)
kwargs['empty_index_dtype'] = meta.index.dtype
dsk = {(name, i): (apply, merge_chunk, [(lhs2._name, i), (rhs2._name, i)], kwargs)
for i in range(npartitions)}

divisions = [None] * (npartitions + 1)
Expand All @@ -283,10 +295,11 @@ def single_partition_join(left, right, **kwargs):
# new index will not necessarily correspond the current divisions

meta = left._meta_nonempty.merge(right._meta_nonempty, **kwargs)
kwargs['empty_index_dtype'] = meta.index.dtype
name = 'merge-' + tokenize(left, right, **kwargs)
if left.npartitions == 1 and kwargs['how'] in ('inner', 'right'):
left_key = first(left.__dask_keys__())
dsk = {(name, i): (apply, M.merge, [left_key, right_key], kwargs)
dsk = {(name, i): (apply, merge_chunk, [left_key, right_key], kwargs)
for i, right_key in enumerate(right.__dask_keys__())}

if kwargs.get('right_index') or right._contains_index_name(
Expand All @@ -297,7 +310,7 @@ def single_partition_join(left, right, **kwargs):

elif right.npartitions == 1 and kwargs['how'] in ('inner', 'left'):
right_key = first(right.__dask_keys__())
dsk = {(name, i): (apply, M.merge, [left_key, right_key], kwargs)
dsk = {(name, i): (apply, merge_chunk, [left_key, right_key], kwargs)
for i, left_key in enumerate(left.__dask_keys__())}

if kwargs.get('left_index') or left._contains_index_name(
Expand Down Expand Up @@ -393,10 +406,11 @@ def merge(left, right, how='inner', on=None, left_on=None, right_on=None,
left = rearrange_by_divisions(left, left_on, right.divisions,
max_branch, shuffle=shuffle)
right = right.clear_divisions()
return map_partitions(M.merge, left, right, meta=meta, how=how, on=on,
left_on=left_on, right_on=right_on,
return map_partitions(merge_chunk, left, right, meta=meta, how=how,
on=on, left_on=left_on, right_on=right_on,
left_index=left_index, right_index=right_index,
suffixes=suffixes, indicator=indicator)
suffixes=suffixes, indicator=indicator,
empty_index_dtype=meta.index.dtype)
# Catch all hash join
else:
return hash_join(left, left.index if left_index else left_on,
Expand Down
16 changes: 12 additions & 4 deletions dask/dataframe/tests/test_multi.py
Expand Up @@ -436,6 +436,14 @@ def test_merge_by_index_patterns(how, shuffle):
'd': [5, 4, 3, 2]},
index=list('fghi'))

def pd_merge(left, right, **kwargs):
# Workaround pandas bug where output dtype of empty index will be int64
# even if input was object.
out = pd.merge(left, right, **kwargs)
if len(out) == 0:
return out.set_index(out.index.astype(left.index.dtype))
return out

for pdl, pdr in [(pdf1l, pdf1r), (pdf2l, pdf2r), (pdf3l, pdf3r),
(pdf4l, pdf4r), (pdf5l, pdf5r), (pdf6l, pdf6r),
(pdf7l, pdf7r)]:
Expand All @@ -449,22 +457,22 @@ def test_merge_by_index_patterns(how, shuffle):

assert_eq(dd.merge(ddl, ddr, how=how, left_index=True,
right_index=True, shuffle=shuffle),
pd.merge(pdl, pdr, how=how, left_index=True,
pd_merge(pdl, pdr, how=how, left_index=True,
right_index=True))
assert_eq(dd.merge(ddr, ddl, how=how, left_index=True,
right_index=True, shuffle=shuffle),
pd.merge(pdr, pdl, how=how, left_index=True,
pd_merge(pdr, pdl, how=how, left_index=True,
right_index=True))

assert_eq(dd.merge(ddl, ddr, how=how, left_index=True,
right_index=True, shuffle=shuffle,
indicator=True),
pd.merge(pdl, pdr, how=how, left_index=True,
pd_merge(pdl, pdr, how=how, left_index=True,
right_index=True, indicator=True))
assert_eq(dd.merge(ddr, ddl, how=how, left_index=True,
right_index=True, shuffle=shuffle,
indicator=True),
pd.merge(pdr, pdl, how=how, left_index=True,
pd_merge(pdr, pdl, how=how, left_index=True,
right_index=True, indicator=True))

assert_eq(ddr.merge(ddl, how=how, left_index=True,
Expand Down

0 comments on commit 42395b6

Please sign in to comment.