Skip to content

Commit

Permalink
Start objective.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 5, 2020
1 parent 4f29f18 commit 4d57fee
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .core import _deprecate_positional_args
from .training import train as worker_train
from .tracker import RabitTracker
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
from .sklearn import xgboost_model_doc


Expand Down Expand Up @@ -1152,6 +1152,10 @@ async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
sample_weight_eval_set,
self.missing)

if callable(self.objective):
obj = _objective_decorator(self.objective)
else:
obj = None
metric = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
Expand All @@ -1165,6 +1169,7 @@ async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
feval=metric,
obj=obj,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks)
Expand Down Expand Up @@ -1256,6 +1261,10 @@ async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
sample_weight_eval_set,
self.missing)

if callable(self.objective):
obj = _objective_decorator(self.objective)
else:
obj = None
metric = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
Expand All @@ -1268,6 +1277,7 @@ async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
Expand Down
37 changes: 37 additions & 0 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,43 @@ def test_feature_weights(self, client):
assert poly_increasing[0] > 0.08
assert poly_decreasing[0] < -0.08

@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_sklearn())
def test_custom_objective(self, client):
from sklearn.datasets import load_boston
X, y = load_boston(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
rounds = 20

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'log')

def sqr(labels, predts):
with open(path, 'a') as fd:
print('Running sqr', file=fd)
grad = predts - labels
hess = np.ones(shape=labels.shape[0])
return grad, hess

reg = xgb.dask.DaskXGBRegressor(n_estimators=rounds, objective=sqr,
tree_method='hist')
reg.fit(X, y, eval_set=[(X, y)])

# Check the obj is ran for rounds.
with open(path, 'r') as fd:
out = fd.readlines()
assert len(out) == rounds

results_custom = reg.evals_result()

reg = xgb.dask.DaskXGBRegressor(n_estimators=rounds, tree_method='hist')
reg.fit(X, y, eval_set=[(X, y)])
results_native = reg.evals_result()

np.testing.assert_allclose(results_custom['validation_0']['rmse'],
results_native['validation_0']['rmse'])
tm.non_increasing(results_native['validation_0']['rmse'])

def test_data_initialization(self):
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
generate unnecessary copies of data.
Expand Down

0 comments on commit 4d57fee

Please sign in to comment.