Skip to content

Commit

Permalink
Add has_parallel_type
Browse files Browse the repository at this point in the history
This is a bit nicer than `isinstance(obj, parallel_types())` because it
works also with lazily registered types.
  • Loading branch information
mrocklin committed Jan 16, 2019
1 parent 0dbe40d commit 7cf4587
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
12 changes: 9 additions & 3 deletions dask/dataframe/core.py
Expand Up @@ -60,7 +60,7 @@ def _concat(args):
return args
if isinstance(first(core.flatten(args)), np.ndarray):
return da.core.concatenate3(args)
if not isinstance(args[0], parallel_types()):
if not has_parallel_type(args[0]):
try:
return pd.Series(args)
except Exception:
Expand Down Expand Up @@ -3651,7 +3651,7 @@ def map_partitions(func, *args, **kwargs):
(apply, func, (tuple, [(arg._name, 0) for arg in args]), kwargs)}
graph = HighLevelGraph.from_collections(name, layer, dependencies=args)
return Scalar(graph, name, meta)
elif not (isinstance(meta, parallel_types()) or is_arraylike(meta)):
elif not (has_parallel_type(meta) or is_arraylike(meta)):
# If `meta` is not a pandas object, the concatenated results will be a
# different type
meta = _concat([meta])
Expand Down Expand Up @@ -4500,12 +4500,18 @@ def parallel_types():
if v is not get_parallel_type_object)


def has_parallel_type(x):
""" Does this object have a dask dataframe equivalent? """
get_parallel_type(x) # trigger lazy registration
return isinstance(x, parallel_types())


def new_dd_object(dsk, name, meta, divisions):
"""Generic constructor for dask.dataframe objects.
Decides the appropriate output class based on the type of `meta` provided.
"""
if isinstance(meta, parallel_types()):
if has_parallel_type(meta):
return get_parallel_type(meta)(dsk, name, meta, divisions)
elif is_arraylike(meta):
import dask.array as da
Expand Down
4 changes: 2 additions & 2 deletions dask/dataframe/io/io.py
Expand Up @@ -14,7 +14,7 @@
from ... import array as da
from ...delayed import delayed

from ..core import DataFrame, Series, Index, new_dd_object, parallel_types
from ..core import DataFrame, Series, Index, new_dd_object, has_parallel_type
from ..shuffle import set_partition
from ..utils import insert_meta_param_description, check_meta, make_meta

Expand Down Expand Up @@ -165,7 +165,7 @@ def from_pandas(data, npartitions=None, chunksize=None, sort=True, name=None):
if isinstance(getattr(data, 'index', None), pd.MultiIndex):
raise NotImplementedError("Dask does not support MultiIndex Dataframes.")

if not isinstance(data, parallel_types()):
if not has_parallel_type(data):
raise TypeError("Input must be a pandas DataFrame or Series")

if ((npartitions is None) == (chunksize is None)):
Expand Down
9 changes: 8 additions & 1 deletion dask/dataframe/tests/test_dataframe.py
Expand Up @@ -16,7 +16,8 @@
from dask.compatibility import PY2
from dask.utils import put_lines, M

from dask.dataframe.core import repartition_divisions, aca, _concat, Scalar
from dask.dataframe.core import (repartition_divisions, aca, _concat, Scalar,
has_parallel_type)
from dask.dataframe import methods
from dask.dataframe.utils import (assert_eq, make_meta, assert_max_deps,
PANDAS_VERSION)
Expand Down Expand Up @@ -3364,3 +3365,9 @@ def test_scalar_with_array():

da.utils.assert_eq(df.x.values + df.x.mean(),
ddf.x.values + ddf.x.mean())


def test_has_parallel_type():
assert has_parallel_type(pd.DataFrame())
assert has_parallel_type(pd.Series())
assert not has_parallel_type(123)

0 comments on commit 7cf4587

Please sign in to comment.