Skip to content

Commit

Permalink
Add explicit backends.py files in dask.array/dataframe
Browse files Browse the repository at this point in the history
This collects lazily registered functions for cupy, sparse, scipy, and
cudf.
  • Loading branch information
mrocklin committed Jan 17, 2019
1 parent 64366b5 commit 7566d58
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 68 deletions.
2 changes: 1 addition & 1 deletion dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .reductions import nanprod, nancumprod, nancumsum
with ignoring(ImportError):
from . import ma
from . import random, linalg, overlap, fft
from . import random, linalg, overlap, fft, backends
from .overlap import map_overlap
from .wrap import ones, zeros, empty, full
from .creation import ones_like, zeros_like, empty_like, full_like
Expand Down
33 changes: 33 additions & 0 deletions dask/array/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .core import tensordot_lookup, concatenate_lookup


@tensordot_lookup.register_lazy('cupy')
@concatenate_lookup.register_lazy('cupy')
def register_cupy():
import cupy
concatenate_lookup.register(cupy.ndarray, cupy.concatenate)
tensordot_lookup.register(cupy.ndarray, cupy.tensordot)


@tensordot_lookup.register_lazy('sparse')
@concatenate_lookup.register_lazy('sparse')
def register_sparse():
import sparse
concatenate_lookup.register(sparse.COO, sparse.concatenate)
tensordot_lookup.register(sparse.COO, sparse.tensordot)


@concatenate_lookup.register_lazy('scipy')
def register_scipy_sparse():
import scipy.sparse

def _concatenate(L, axis=0):
if axis == 0:
return scipy.sparse.vstack(L)
elif axis == 1:
return scipy.sparse.hstack(L)
else:
msg = ("Can only concatenate scipy sparse matrices for axis in "
"{0, 1}. Got %s" % axis)
raise ValueError(msg)
concatenate_lookup.register(scipy.sparse.spmatrix, _concatenate)
32 changes: 0 additions & 32 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,38 +61,6 @@
tensordot_lookup.register((object, np.ndarray), np.tensordot)


@tensordot_lookup.register_lazy('cupy')
@concatenate_lookup.register_lazy('cupy')
def register_cupy():
import cupy
concatenate_lookup.register(cupy.ndarray, cupy.concatenate)
tensordot_lookup.register(cupy.ndarray, cupy.tensordot)


@tensordot_lookup.register_lazy('sparse')
@concatenate_lookup.register_lazy('sparse')
def register_sparse():
import sparse
concatenate_lookup.register(sparse.COO, sparse.concatenate)
tensordot_lookup.register(sparse.COO, sparse.tensordot)


@concatenate_lookup.register_lazy('scipy')
def register_scipy_sparse():
import scipy.sparse

def _concatenate(L, axis=0):
if axis == 0:
return scipy.sparse.vstack(L)
elif axis == 1:
return scipy.sparse.hstack(L)
else:
msg = ("Can only concatenate scipy sparse matrices for axis in "
"{0, 1}. Got %s" % axis)
raise ValueError(msg)
concatenate_lookup.register(scipy.sparse.spmatrix, _concatenate)


class PerformanceWarning(Warning):
""" A warning given when bad chunking may cause poor performance """

Expand Down
2 changes: 1 addition & 1 deletion dask/array/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def normalize_to_array(x):
if 'cupy' in str(type(x)):
if 'cupy' in str(type(x)): # TODO: avoid explicit reference to cupy
return x.get()
else:
return x
Expand Down
2 changes: 1 addition & 1 deletion dask/dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
read_fwf)
from .optimize import optimize
from .multi import merge, concat
from . import rolling
from . import rolling, backends
from ..base import compute
from .reshape import get_dummies, pivot_table, melt
from .utils import assert_eq
Expand Down
37 changes: 37 additions & 0 deletions dask/dataframe/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from .methods import concat_dispatch
from .core import get_parallel_type, meta_nonempty, make_meta


######################################
# cuDF: Pandas Dataframes on the GPU #
######################################


@concat_dispatch.register_lazy('cudf')
@get_parallel_type.register_lazy('cudf')
@meta_nonempty.register_lazy('cudf')
@make_meta.register_lazy('cudf')
def _register_cudf():
import cudf
import dask_cudf
get_parallel_type.register(cudf.DataFrame, lambda _: dask_cudf.DataFrame)
get_parallel_type.register(cudf.Series, lambda _: dask_cudf.Series)
get_parallel_type.register(cudf.Index, lambda _: dask_cudf.Index)

@meta_nonempty.register((cudf.DataFrame, cudf.Series, cudf.Index))
def _(x):
y = meta_nonempty(x.to_pandas()) # TODO: add iloc[:5]
return cudf.from_pandas(y)

@make_meta.register((cudf.Series, cudf.DataFrame))
def _(x):
return x.head(0)

@make_meta.register(cudf.Index)
def _(x):
return x[:0]

concat_dispatch.register(
(cudf.DataFrame, cudf.Series, cudf.Index),
cudf.concat
)
24 changes: 0 additions & 24 deletions dask/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4495,30 +4495,6 @@ def get_parallel_type_frame(o):
return get_parallel_type(o._meta)


@get_parallel_type.register_lazy('cudf')
@meta_nonempty.register_lazy('cudf')
@make_meta.register_lazy('cudf')
def _register_cudf():
import cudf
import dask_cudf
get_parallel_type.register(cudf.DataFrame, lambda _: dask_cudf.DataFrame)
get_parallel_type.register(cudf.Series, lambda _: dask_cudf.Series)
get_parallel_type.register(cudf.Index, lambda _: dask_cudf.Index)

@meta_nonempty.register((cudf.DataFrame, cudf.Series, cudf.Index))
def _(x):
y = meta_nonempty(x.to_pandas()) # TODO: add iloc[:5]
return cudf.from_pandas(y)

@make_meta.register((cudf.Series, cudf.DataFrame))
def _(x):
return x.head(0)

@make_meta.register(cudf.Index)
def _(x):
return x[:0]


def parallel_types():
return tuple(k for k, v in get_parallel_type._lookup.items()
if v is not get_parallel_type_object)
Expand Down
9 changes: 0 additions & 9 deletions dask/dataframe/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,6 @@ def _get_level_values(x, n):
concat_dispatch = Dispatch('concat')


@concat_dispatch.register_lazy('cudf')
def register_cudf():
import cudf
concat_dispatch.register(
(cudf.DataFrame, cudf.Series, cudf.Index),
cudf.concat
)


def concat(dfs, axis=0, join='outer', uniform=False, filter_warning=True):
"""Concatenate, handling some edge cases:
Expand Down

0 comments on commit 7566d58

Please sign in to comment.