Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XGBoost integration #65

Merged
merged 16 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc
* [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py))
* [TensorBoard](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorboard) ([example](https://github.com/optuna/optuna-examples/tree/main/tensorboard/tensorboard_simple.py))
* [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py))
* [XGBoost](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#xgboost) ([example](https://github.com/optuna/optuna-examples/tree/main/xgboost/xgboost_integration.py))

## Installation

Expand Down
11 changes: 10 additions & 1 deletion docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,13 @@ TensorFlow
:toctree: generated/
:nosignatures:

optuna_integration.TFKerasPruningCallback
optuna.integration.TFKerasPruningCallback
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved

XGBoost
----------
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved

.. autosummary::
:toctree: generated/
:nosignatures:

optuna.integration.XGBoostPruningCallback
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
123 changes: 123 additions & 0 deletions optuna_integration/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Any

import optuna


use_callback_cls = True

with optuna_integration._imports.try_import() as _imports:
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
import xgboost as xgb

xgboost_version = xgb.__version__.split(".")
xgboost_major_version = int(xgboost_version[0])
xgboost_minor_version = int(xgboost_version[1])
use_callback_cls = (
xgboost_major_version >= 1 and xgboost_minor_version >= 3
) or xgboost_major_version >= 2

_doc = """Callback for XGBoost to prune unpromising trials.

See `the example <https://github.com/optuna/optuna-examples/blob/main/
xgboost/xgboost_integration.py>`__
if you want to add a pruning callback which observes validation accuracy of
a XGBoost model.

Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
observation_key:
An evaluation metric for pruning, e.g., ``validation-error`` and
``validation-merror``. When using the Scikit-Learn API, the index number of
``eval_set`` must be included in the ``observation_key``, e.g.,
``validation_0-error`` and ``validation_0-merror``. Please refer to ``eval_metric``
in `XGBoost reference <https://xgboost.readthedocs.io/en/latest/parameter.html>`_
for further details.
"""

if _imports.is_successful() and use_callback_cls:

class XGBoostPruningCallback(xgb.callback.TrainingCallback):
__doc__ = _doc

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
self._trial = trial
self._observation_key = observation_key
self._is_cv = False

def before_training(self, model: Any) -> Any:
# The use of Any type is due to _PackedBooster is not yet being exposed
# to public interface as of xgboost 1.3.
if isinstance(model, xgb.Booster):
self._is_cv = False
else:
self._is_cv = True
return model

def after_iteration(self, model: Any, epoch: int, evals_log: dict) -> bool:
evaluation_results = {}
# Flatten the evaluation history to `{dataset-metric: score}` layout.
for dataset, metrics in evals_log.items():
for metric, scores in metrics.items():
assert isinstance(scores, list), scores
key = dataset + "-" + metric
if self._is_cv:
# Remove stddev of the metric across the cross-validation
# folds.
evaluation_results[key] = scores[-1][0]
else:
evaluation_results[key] = scores[-1]

current_score = evaluation_results[self._observation_key]
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(epoch)
raise optuna.TrialPruned(message)
# The training should not stop.
return False

elif _imports.is_successful():

def _get_callback_context(env: "xgb.core.CallbackEnv") -> str: # type: ignore
"""Return whether the current callback context is cv or train.

.. note::
`Reference
<https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py>`_.
"""

if env.model is None and env.cvfolds is not None:
context = "cv"
else:
context = "train"
return context

class XGBoostPruningCallback: # type: ignore
__doc__ = _doc

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
self._trial = trial
self._observation_key = observation_key

def __call__(self, env: "xgb.core.CallbackEnv") -> None: # type: ignore
context = _get_callback_context(env)
evaluation_result_list = env.evaluation_result_list
if context == "cv":
# Remove a third element: the stddev of the metric across the
# cross-validation folds.
evaluation_result_list = [
(key, metric) for key, metric, _ in evaluation_result_list
]
current_score = dict(evaluation_result_list)[self._observation_key]
self._trial.report(current_score, step=env.iteration)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(env.iteration)
raise optuna.TrialPruned(message)

else:

class XGBoostPruningCallback: # type: ignore
__doc__ = _doc

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
_imports.check()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ all = [
"skorch",
"tensorboard",
"tensorflow",
"xgboost",
]

[tool.setuptools.packages.find]
Expand Down
79 changes: 79 additions & 0 deletions tests/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import pytest

import optuna
from optuna_integration._imports import try_import
from optuna_integration.xgboost import XGBoostPruningCallback
from optuna.testing.pruners import DeterministicPruner
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import numpy as np
import pytest
import optuna
from optuna_integration._imports import try_import
from optuna_integration.xgboost import XGBoostPruningCallback
from optuna.testing.pruners import DeterministicPruner
import numpy as np
import optuna
from optuna.testing.pruners import DeterministicPruner
import pytest
from optuna_integration._imports import try_import
from optuna_integration.xgboost import XGBoostPruningCallback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and this is done!



with try_import():
import xgboost as xgb

pytestmark = pytest.mark.integration


def test_xgboost_pruning_callback_call() -> None:
# The pruner is deactivated.
study = optuna.create_study(pruner=DeterministicPruner(False))
trial = study.ask()
pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
pruning_callback.after_iteration(
model=None, epoch=1, evals_log={"validation": {"logloss": [1.0]}}
)

# The pruner is activated.
study = optuna.create_study(pruner=DeterministicPruner(True))
trial = study.ask()
pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
with pytest.raises(optuna.TrialPruned):
pruning_callback.after_iteration(
model=None, epoch=1, evals_log={"validation": {"logloss": [1.0]}}
)


def test_xgboost_pruning_callback() -> None:
def objective(trial: optuna.trial.Trial) -> float:
dtrain = xgb.DMatrix(np.asarray([[1.0]]), label=[1.0])
dtest = xgb.DMatrix(np.asarray([[1.0]]), label=[1.0])

pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
xgb.train(
{"objective": "binary:logistic"},
dtrain,
1,
evals=[(dtest, "validation")],
verbose_eval=False,
callbacks=[pruning_callback],
)
return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0


def test_xgboost_pruning_callback_cv() -> None:
def objective(trial: optuna.trial.Trial) -> float:
dtrain = xgb.DMatrix(np.ones((2, 1)), label=[1.0, 1.0])
params = {
"objective": "binary:logistic",
}

pruning_callback = optuna.integration.XGBoostPruningCallback(trial, "test-logloss")
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
xgb.cv(params, dtrain, callbacks=[pruning_callback], nfold=2)
return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0