Skip to content

Commit

Permalink
Don't validate feature when number of rows is 0.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 5, 2020
1 parent d6386e4 commit 56c42d3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
3 changes: 3 additions & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,9 @@ def _validate_features(self, data):
Validate Booster and data's feature_names are identical.
Set feature_names and feature_types from DMatrix
"""
if data.num_row() == 0:
return

if self.feature_names is None:
self.feature_names = data.feature_names
self.feature_types = data.feature_types
Expand Down
11 changes: 6 additions & 5 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,8 @@ async def _direct_predict_impl(client, data, predict_fn):


# pylint: disable=too-many-statements
async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
async def _predict_async(client, model, data, missing, validate_features,
**kwargs):
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
Expand All @@ -778,7 +779,7 @@ def mapped_predict(partition, is_df):
with config.config_context(**_global_config):
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, validate_features=False, **kwargs)
predt = booster.predict(m, validate_features=validate_features, **kwargs)
if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
import cudf # pylint: disable=import-error
Expand Down Expand Up @@ -820,7 +821,7 @@ def dispatched_predict(worker_id, list_of_orders, list_of_parts):
)
predt = booster.predict(
data=local_part,
validate_features=local_part.num_row() != 0,
validate_features=validate_features,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((dask.delayed(predt), columns), order)
Expand Down Expand Up @@ -877,7 +878,7 @@ async def map_function(func):
return predictions


def predict(client, model, data, missing=numpy.nan, **kwargs):
def predict(client, model, data, missing=numpy.nan, validate_features=True, **kwargs):
'''Run prediction with a trained booster.
.. note::
Expand Down Expand Up @@ -908,7 +909,7 @@ def predict(client, model, data, missing=numpy.nan, **kwargs):
_assert_dask_support()
client = _xgb_get_client(client)
return client.sync(_predict_async, client, model, data,
missing=missing, **kwargs)
missing=missing, validate_features=validate_features, **kwargs)


async def _inplace_predict_async(client, model, data,
Expand Down
37 changes: 37 additions & 0 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,30 @@ def test_sklearn_grid_search():
assert len(means) == len(set(means))


def test_empty_dmatrix_training_continuation(client):
kRows, kCols = 1, 97
X = dd.from_array(np.random.randn(kRows, kCols))
y = dd.from_array(np.random.rand(kRows))
X.columns = ['X' + str(i) for i in range(0, 97)]
dtrain = xgb.dask.DaskDMatrix(client, X, y)

kRows += 1000
X = dd.from_array(np.random.randn(kRows, kCols), chunksize=10)
X.columns = ['X' + str(i) for i in range(0, 97)]
y = dd.from_array(np.random.rand(kRows), chunksize=10)
valid = xgb.dask.DaskDMatrix(client, X, y)

out = xgb.dask.train(client, {'tree_method': 'hist'},
dtrain=dtrain, num_boost_round=2,
evals=[(valid, 'validation')])

out = xgb.dask.train(client, {'tree_method': 'hist'},
dtrain=dtrain, xgb_model=out['booster'],
num_boost_round=2,
evals=[(valid, 'validation')])
assert xgb.dask.predict(client, out, dtrain).compute().shape[0] == 1


def run_empty_dmatrix_reg(client, parameters):
def _check_outputs(out, predictions):
assert isinstance(out['booster'], xgb.dask.Booster)
Expand All @@ -371,6 +395,19 @@ def _check_outputs(out, predictions):
data=dtrain).compute()
_check_outputs(out, predictions)

# valid has more rows than train
kRows += 1
X = dd.from_array(np.random.randn(kRows, kCols))
y = dd.from_array(np.random.rand(kRows))
valid = xgb.dask.DaskDMatrix(client, X, y)
out = xgb.dask.train(client, parameters,
dtrain=dtrain,
evals=[(valid, 'validation')],
num_boost_round=2)
predictions = xgb.dask.predict(client=client, model=out,
data=dtrain).compute()
_check_outputs(out, predictions)

# train has more rows than evals
valid = dtrain
kRows += 1
Expand Down

0 comments on commit 56c42d3

Please sign in to comment.