Skip to content

Commit

Permalink
Add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Dec 11, 2020
1 parent 56ada8f commit 171241e
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/python-gpu/test_gpu_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@

import numpy as np
import xgboost as xgb
from xgboost.compat import PANDAS_INSTALLED

from hypothesis import given, strategies, assume, settings, note

if PANDAS_INSTALLED:
from hypothesis.extra.pandas import column, data_frames, range_indexes
else:
def noop(*args, **kwargs):
pass
column, data_frames, range_indexes = noop, noop, noop

sys.path.append("tests/python")
import testing as tm
from test_predict import run_threaded_predict # noqa
Expand Down Expand Up @@ -259,3 +268,29 @@ def test_predict_leaf_dart(self, param, dataset):
param['booster'] = 'dart'
param['tree_method'] = 'gpu_hist'
self.run_predict_leaf_booster(param, 10, dataset)

@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.skipif(**tm.no_pandas())
@given(df=data_frames([column('x0', elements=strategies.integers(min_value=0, max_value=3)),
column('x1', elements=strategies.integers(min_value=0, max_value=5))],
index=range_indexes(min_size=20, max_size=50)))
@settings(deadline=None)
def test_predict_categorical_split(self, df):
from sklearn.metrics import mean_squared_error

df = df.astype('category')
x0, x1 = df['x0'].to_numpy(), df['x1'].to_numpy()
y = (x0 * 10 - 20) + (x1 - 2)
dtrain = xgb.DMatrix(df, label=y, enable_categorical=True)

params = {'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor',
'enable_experimental_json_serialization': True,
'max_depth': 3, 'learning_rate': 1.0, 'base_score': 0.0, 'eval_metric': 'rmse'}

eval_history = {}
bst = xgb.train(params, dtrain, num_boost_round=5, evals=[(dtrain, 'train')],
verbose_eval=False, evals_result=eval_history)

pred = bst.predict(dtrain)
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)

0 comments on commit 171241e

Please sign in to comment.