Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XGBoost integration #65

Merged
merged 16 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc
* [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py))
* [TensorBoard](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorboard) ([example](https://github.com/optuna/optuna-examples/tree/main/tensorboard/tensorboard_simple.py))
* [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py))
* [XGBoost](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#xgboost) ([example](https://github.com/optuna/optuna-examples/tree/main/xgboost/xgboost_integration.py))

## Installation

Expand Down
11 changes: 10 additions & 1 deletion docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,13 @@ TensorFlow
:toctree: generated/
:nosignatures:

optuna_integration.TFKerasPruningCallback
optuna.integration.TFKerasPruningCallback
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved

XGBoost
-------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna_integration.XGBoostPruningCallback
3 changes: 3 additions & 0 deletions optuna_integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"tensorboard": ["TensorBoardCallback"],
"tensorflow": ["TensorFlowPruningHook"],
"tfkeras": ["TFKerasPruningCallback"],
"xgboost": ["XGBoostPruningCallback"],
}


Expand All @@ -46,6 +47,7 @@
from optuna_integration.tensorboard import TensorBoardCallback
from optuna_integration.tensorflow import TensorFlowPruningHook
from optuna_integration.tfkeras import TFKerasPruningCallback
from optuna_integration.xgboost import XGBoostPruningCallback

Check warning on line 50 in optuna_integration/__init__.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/__init__.py#L50

Added line #L50 was not covered by tests
else:

class _IntegrationModule(ModuleType):
Expand Down Expand Up @@ -113,4 +115,5 @@
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
"XGBoostPruningCallback",
]
125 changes: 125 additions & 0 deletions optuna_integration/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Any

import optuna

import optuna_integration


use_callback_cls = True

with optuna_integration._imports.try_import() as _imports:
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
import xgboost as xgb

xgboost_version = xgb.__version__.split(".")
xgboost_major_version = int(xgboost_version[0])
xgboost_minor_version = int(xgboost_version[1])
use_callback_cls = (
xgboost_major_version >= 1 and xgboost_minor_version >= 3
) or xgboost_major_version >= 2

_doc = """Callback for XGBoost to prune unpromising trials.

See `the example <https://github.com/optuna/optuna-examples/blob/main/
xgboost/xgboost_integration.py>`__
if you want to add a pruning callback which observes validation accuracy of
a XGBoost model.

Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
observation_key:
An evaluation metric for pruning, e.g., ``validation-error`` and
``validation-merror``. When using the Scikit-Learn API, the index number of
``eval_set`` must be included in the ``observation_key``, e.g.,
``validation_0-error`` and ``validation_0-merror``. Please refer to ``eval_metric``
in `XGBoost reference <https://xgboost.readthedocs.io/en/latest/parameter.html>`_
for further details.
"""

if _imports.is_successful() and use_callback_cls:

class XGBoostPruningCallback(xgb.callback.TrainingCallback):
__doc__ = _doc

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
self._trial = trial
self._observation_key = observation_key
self._is_cv = False

def before_training(self, model: Any) -> Any:
# The use of Any type is due to _PackedBooster is not yet being exposed
# to public interface as of xgboost 1.3.
if isinstance(model, xgb.Booster):
self._is_cv = False
else:
self._is_cv = True
return model

def after_iteration(self, model: Any, epoch: int, evals_log: dict) -> bool:
evaluation_results = {}
# Flatten the evaluation history to `{dataset-metric: score}` layout.
for dataset, metrics in evals_log.items():
for metric, scores in metrics.items():
assert isinstance(scores, list), scores
key = dataset + "-" + metric
if self._is_cv:
# Remove stddev of the metric across the cross-validation
# folds.
evaluation_results[key] = scores[-1][0]
else:
evaluation_results[key] = scores[-1]

current_score = evaluation_results[self._observation_key]
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(epoch)
raise optuna.TrialPruned(message)
# The training should not stop.
return False

elif _imports.is_successful():

Check warning on line 81 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L81

Added line #L81 was not covered by tests

def _get_callback_context(env: "xgb.core.CallbackEnv") -> str: # type: ignore

Check warning on line 83 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L83

Added line #L83 was not covered by tests
"""Return whether the current callback context is cv or train.

.. note::
`Reference
<https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py>`_.
"""

if env.model is None and env.cvfolds is not None:
context = "cv"

Check warning on line 92 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L91-L92

Added lines #L91 - L92 were not covered by tests
else:
context = "train"
return context

Check warning on line 95 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L94-L95

Added lines #L94 - L95 were not covered by tests

class XGBoostPruningCallback: # type: ignore
__doc__ = _doc

Check warning on line 98 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L97-L98

Added lines #L97 - L98 were not covered by tests

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
self._trial = trial
self._observation_key = observation_key

Check warning on line 102 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L100-L102

Added lines #L100 - L102 were not covered by tests

def __call__(self, env: "xgb.core.CallbackEnv") -> None: # type: ignore
context = _get_callback_context(env)
evaluation_result_list = env.evaluation_result_list
if context == "cv":

Check warning on line 107 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L104-L107

Added lines #L104 - L107 were not covered by tests
# Remove a third element: the stddev of the metric across the
# cross-validation folds.
evaluation_result_list = [

Check warning on line 110 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L110

Added line #L110 was not covered by tests
(key, metric) for key, metric, _ in evaluation_result_list
]
current_score = dict(evaluation_result_list)[self._observation_key]
self._trial.report(current_score, step=env.iteration)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(env.iteration)
raise optuna.TrialPruned(message)

Check warning on line 117 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L113-L117

Added lines #L113 - L117 were not covered by tests

else:

class XGBoostPruningCallback: # type: ignore
__doc__ = _doc

Check warning on line 122 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L121-L122

Added lines #L121 - L122 were not covered by tests

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
_imports.check()

Check warning on line 125 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L124-L125

Added lines #L124 - L125 were not covered by tests
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ all = [
"skorch",
"tensorboard",
"tensorflow",
"xgboost",
"torch",
]

Expand Down
79 changes: 79 additions & 0 deletions tests/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import optuna
from optuna.testing.pruners import DeterministicPruner
import pytest

from optuna_integration._imports import try_import
from optuna_integration.xgboost import XGBoostPruningCallback


with try_import():
import xgboost as xgb

pytestmark = pytest.mark.integration


def test_xgboost_pruning_callback_call() -> None:
# The pruner is deactivated.
study = optuna.create_study(pruner=DeterministicPruner(False))
trial = study.ask()
pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
pruning_callback.after_iteration(
model=None, epoch=1, evals_log={"validation": {"logloss": [1.0]}}
)

# The pruner is activated.
study = optuna.create_study(pruner=DeterministicPruner(True))
trial = study.ask()
pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
with pytest.raises(optuna.TrialPruned):
pruning_callback.after_iteration(
model=None, epoch=1, evals_log={"validation": {"logloss": [1.0]}}
)


def test_xgboost_pruning_callback() -> None:
def objective(trial: optuna.trial.Trial) -> float:
dtrain = xgb.DMatrix(np.asarray([[1.0]]), label=[1.0])
dtest = xgb.DMatrix(np.asarray([[1.0]]), label=[1.0])

pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
xgb.train(
{"objective": "binary:logistic"},
dtrain,
1,
evals=[(dtest, "validation")],
verbose_eval=False,
callbacks=[pruning_callback],
)
return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0


def test_xgboost_pruning_callback_cv() -> None:
def objective(trial: optuna.trial.Trial) -> float:
dtrain = xgb.DMatrix(np.ones((2, 1)), label=[1.0, 1.0])
params = {
"objective": "binary:logistic",
}

pruning_callback = XGBoostPruningCallback(trial, "test-logloss")
xgb.cv(params, dtrain, callbacks=[pruning_callback], nfold=2)
return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0