Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Commit

Permalink
Add support for scipy sparse, fixing sparse tests (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomlaube authored and mrocklin committed Jul 26, 2018
1 parent dd4b5f5 commit 6cd5f13
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
conda config --add channels conda-forge
conda create -q -n test-environment python=${PYTHON}
source activate test-environment
conda install -q coverage flake8 pytest pytest-cov pytest-xdist numpy pandas xgboost dask distributed scikit-learn sparse
conda install -q coverage flake8 pytest pytest-cov pytest-xdist numpy pandas xgboost dask distributed scikit-learn sparse scipy
pip install -e .
conda list test-environment
- run:
Expand Down
6 changes: 5 additions & 1 deletion dask_xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

try:
import sparse
import scipy.sparse as ss
except ImportError:
sparse = False
ss = False

from dask import delayed
from dask.distributed import wait, default_client
Expand Down Expand Up @@ -51,8 +53,10 @@ def concat(L):
return np.concatenate(L, axis=0)
elif isinstance(L[0], (pd.DataFrame, pd.Series)):
return pd.concat(L, axis=0)
elif ss and isinstance(L[0], ss.spmatrix):
return ss.vstack(L, format='csr')
elif sparse and isinstance(L[0], sparse.SparseArray):
return sparse.concatenate(L[0], axis=0)
return sparse.concatenate(L, axis=0)
else:
raise TypeError("Data must be either numpy arrays or pandas dataframes"
". Got %s" % type(L[0]))
Expand Down
58 changes: 33 additions & 25 deletions dask_xgboost/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import xgboost as xgb
import sparse
import scipy.sparse

import pytest

Expand Down Expand Up @@ -98,6 +99,21 @@ def test_dmatrix_kwargs(c, s, a, b):
assert np.abs(result - dresult_incompat).sum() > 0.02


def _test_container(dbst, predictions, X_type):
dtrain = xgb.DMatrix(X_type(X), label=y)
bst = xgb.train(param, dtrain)

result = bst.predict(dtrain)
dresult = dbst.predict(dtrain)

correct = (result > 0.5) == y
dcorrect = (dresult > 0.5) == y

assert dcorrect.sum() >= correct.sum()
assert isinstance(predictions, np.ndarray)
assert ((predictions > 0.5) != labels).sum() < 2


@gen_cluster(client=True, timeout=None)
def test_numpy(c, s, a, b):
xgb.rabit.init() # workaround for "Doing rabit call after Finalize"
Expand All @@ -106,48 +122,40 @@ def test_numpy(c, s, a, b):
dbst = yield dxgb.train(c, param, dX, dy)
dbst = yield dxgb.train(c, param, dX, dy) # we can do this twice

dtrain = xgb.DMatrix(X, label=y)
bst = xgb.train(param, dtrain)
predictions = dxgb.predict(c, dbst, dX)
assert isinstance(predictions, da.Array)
predictions = yield c.compute(predictions)
_test_container(dbst, predictions, np.array)

result = bst.predict(dtrain)
dresult = dbst.predict(dtrain)

correct = (result > 0.5) == y
dcorrect = (dresult > 0.5) == y
assert dcorrect.sum() >= correct.sum()
@gen_cluster(client=True, timeout=None)
def test_scipy_sparse(c, s, a, b):
xgb.rabit.init() # workaround for "Doing rabit call after Finalize"
dX = da.from_array(X, chunks=(2, 2)).map_blocks(scipy.sparse.csr_matrix)
dy = da.from_array(y, chunks=(2,))
dbst = yield dxgb.train(c, param, dX, dy)
dbst = yield dxgb.train(c, param, dX, dy) # we can do this twice

predictions = dxgb.predict(c, dbst, dX)
assert isinstance(predictions, da.Array)
predictions = yield c.compute(predictions)._result()
assert isinstance(predictions, np.ndarray)

assert ((predictions > 0.5) != labels).sum() < 2
predictions_result = yield c.compute(predictions)
_test_container(dbst, predictions_result, scipy.sparse.csr_matrix)


@gen_cluster(client=True, timeout=None)
def test_sparse(c, s, a, b):
xgb.rabit.init() # workaround for "Doing rabit call after Finalize"
dX = da.from_array(sparse.COO.from_numpy(X), chunks=(2, 2))
dy = da.from_array(sparse.COO.from_numpy(y), chunks=(2,))
dX = da.from_array(X, chunks=(2, 2)).map_blocks(sparse.COO)
dy = da.from_array(y, chunks=(2,))
dbst = yield dxgb.train(c, param, dX, dy)
dbst = yield dxgb.train(c, param, dX, dy) # we can do this twice

dtrain = xgb.DMatrix(X, label=y)
bst = xgb.train(param, dtrain)

result = bst.predict(dtrain)
dresult = dbst.predict(dtrain)

correct = (result > 0.5) == y
dcorrect = (dresult > 0.5) == y
assert dcorrect.sum() >= correct.sum()

predictions = dxgb.predict(c, dbst, dX)
assert isinstance(predictions, da.Array)
predictions = yield c.compute(predictions)._result()
assert isinstance(predictions, np.ndarray)

assert ((predictions > 0.5) != labels).sum() < 2
predictions_result = yield c.compute(predictions)
_test_container(dbst, predictions_result, sparse.COO)


def test_synchronous_api(loop): # noqa
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

requires = open('requirements.txt').read().strip().split('\n')
install_requires = []
extras_require = {'sparse': ['sparse']}
extras_require = {'sparse': ['sparse', 'scipy']}
for r in requires:
if ';' in r:
# requirements.txt conditional dependencies need to be reformatted for
Expand Down

0 comments on commit 6cd5f13

Please sign in to comment.