Skip to content

Commit

Permalink
Merge pull request #65 from buruzaemon/feature/add-xgboost-integration
Browse files Browse the repository at this point in the history
Add XGBoost integration
  • Loading branch information
nabenabe0928 committed Feb 8, 2024
2 parents 661aa9c + 3470d9b commit dec6d6c
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc
* [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py))
* [TensorBoard](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorboard) ([example](https://github.com/optuna/optuna-examples/tree/main/tensorboard/tensorboard_simple.py))
* [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py))
* [XGBoost](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#xgboost) ([example](https://github.com/optuna/optuna-examples/tree/main/xgboost/xgboost_integration.py))

## Installation

Expand Down
11 changes: 10 additions & 1 deletion docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,13 @@ TensorFlow
:toctree: generated/
:nosignatures:

optuna_integration.TFKerasPruningCallback
optuna.integration.TFKerasPruningCallback

XGBoost
-------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna_integration.XGBoostPruningCallback
3 changes: 3 additions & 0 deletions optuna_integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"tensorboard": ["TensorBoardCallback"],
"tensorflow": ["TensorFlowPruningHook"],
"tfkeras": ["TFKerasPruningCallback"],
"xgboost": ["XGBoostPruningCallback"],
}


Expand All @@ -46,6 +47,7 @@
from optuna_integration.tensorboard import TensorBoardCallback
from optuna_integration.tensorflow import TensorFlowPruningHook
from optuna_integration.tfkeras import TFKerasPruningCallback
from optuna_integration.xgboost import XGBoostPruningCallback
else:

class _IntegrationModule(ModuleType):
Expand Down Expand Up @@ -113,4 +115,5 @@ def _get_module(self, module_name: str) -> ModuleType:
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
"XGBoostPruningCallback",
]
125 changes: 125 additions & 0 deletions optuna_integration/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Any

import optuna

import optuna_integration


use_callback_cls = True

with optuna_integration._imports.try_import() as _imports:
import xgboost as xgb

xgboost_version = xgb.__version__.split(".")
xgboost_major_version = int(xgboost_version[0])
xgboost_minor_version = int(xgboost_version[1])
use_callback_cls = (
xgboost_major_version >= 1 and xgboost_minor_version >= 3
) or xgboost_major_version >= 2

_doc = """Callback for XGBoost to prune unpromising trials.
See `the example <https://github.com/optuna/optuna-examples/blob/main/
xgboost/xgboost_integration.py>`__
if you want to add a pruning callback which observes validation accuracy of
a XGBoost model.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
observation_key:
An evaluation metric for pruning, e.g., ``validation-error`` and
``validation-merror``. When using the Scikit-Learn API, the index number of
``eval_set`` must be included in the ``observation_key``, e.g.,
``validation_0-error`` and ``validation_0-merror``. Please refer to ``eval_metric``
in `XGBoost reference <https://xgboost.readthedocs.io/en/latest/parameter.html>`_
for further details.
"""

if _imports.is_successful() and use_callback_cls:

class XGBoostPruningCallback(xgb.callback.TrainingCallback):
__doc__ = _doc

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
self._trial = trial
self._observation_key = observation_key
self._is_cv = False

def before_training(self, model: Any) -> Any:
# The use of Any type is due to _PackedBooster is not yet being exposed
# to public interface as of xgboost 1.3.
if isinstance(model, xgb.Booster):
self._is_cv = False
else:
self._is_cv = True
return model

def after_iteration(self, model: Any, epoch: int, evals_log: dict) -> bool:
evaluation_results = {}
# Flatten the evaluation history to `{dataset-metric: score}` layout.
for dataset, metrics in evals_log.items():
for metric, scores in metrics.items():
assert isinstance(scores, list), scores
key = dataset + "-" + metric
if self._is_cv:
# Remove stddev of the metric across the cross-validation
# folds.
evaluation_results[key] = scores[-1][0]
else:
evaluation_results[key] = scores[-1]

current_score = evaluation_results[self._observation_key]
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(epoch)
raise optuna.TrialPruned(message)
# The training should not stop.
return False

elif _imports.is_successful():

def _get_callback_context(env: "xgb.core.CallbackEnv") -> str: # type: ignore
"""Return whether the current callback context is cv or train.
.. note::
`Reference
<https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py>`_.
"""

if env.model is None and env.cvfolds is not None:
context = "cv"
else:
context = "train"
return context

class XGBoostPruningCallback: # type: ignore
__doc__ = _doc

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
self._trial = trial
self._observation_key = observation_key

def __call__(self, env: "xgb.core.CallbackEnv") -> None: # type: ignore
context = _get_callback_context(env)
evaluation_result_list = env.evaluation_result_list
if context == "cv":
# Remove a third element: the stddev of the metric across the
# cross-validation folds.
evaluation_result_list = [
(key, metric) for key, metric, _ in evaluation_result_list
]
current_score = dict(evaluation_result_list)[self._observation_key]
self._trial.report(current_score, step=env.iteration)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(env.iteration)
raise optuna.TrialPruned(message)

else:

class XGBoostPruningCallback: # type: ignore
__doc__ = _doc

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
_imports.check()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ all = [
"skorch",
"tensorboard",
"tensorflow",
"xgboost",
"torch",
]

Expand Down
79 changes: 79 additions & 0 deletions tests/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import optuna
from optuna.testing.pruners import DeterministicPruner
import pytest

from optuna_integration._imports import try_import
from optuna_integration.xgboost import XGBoostPruningCallback


with try_import():
import xgboost as xgb

pytestmark = pytest.mark.integration


def test_xgboost_pruning_callback_call() -> None:
# The pruner is deactivated.
study = optuna.create_study(pruner=DeterministicPruner(False))
trial = study.ask()
pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
pruning_callback.after_iteration(
model=None, epoch=1, evals_log={"validation": {"logloss": [1.0]}}
)

# The pruner is activated.
study = optuna.create_study(pruner=DeterministicPruner(True))
trial = study.ask()
pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
with pytest.raises(optuna.TrialPruned):
pruning_callback.after_iteration(
model=None, epoch=1, evals_log={"validation": {"logloss": [1.0]}}
)


def test_xgboost_pruning_callback() -> None:
def objective(trial: optuna.trial.Trial) -> float:
dtrain = xgb.DMatrix(np.asarray([[1.0]]), label=[1.0])
dtest = xgb.DMatrix(np.asarray([[1.0]]), label=[1.0])

pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
xgb.train(
{"objective": "binary:logistic"},
dtrain,
1,
evals=[(dtest, "validation")],
verbose_eval=False,
callbacks=[pruning_callback],
)
return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0


def test_xgboost_pruning_callback_cv() -> None:
def objective(trial: optuna.trial.Trial) -> float:
dtrain = xgb.DMatrix(np.ones((2, 1)), label=[1.0, 1.0])
params = {
"objective": "binary:logistic",
}

pruning_callback = XGBoostPruningCallback(trial, "test-logloss")
xgb.cv(params, dtrain, callbacks=[pruning_callback], nfold=2)
return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0

0 comments on commit dec6d6c

Please sign in to comment.