Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XGBoost integration #65

Merged
merged 16 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc
* [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py))
* [TensorBoard](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorboard) ([example](https://github.com/optuna/optuna-examples/tree/main/tensorboard/tensorboard_simple.py))
* [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py))
* [XGBoost](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#xgboost) ([example](https://github.com/optuna/optuna-examples/tree/main/xgboost/xgboost_integration.py))

## Installation

Expand Down
11 changes: 10 additions & 1 deletion docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,13 @@ TensorFlow
:toctree: generated/
:nosignatures:

optuna_integration.TFKerasPruningCallback
optuna.integration.TFKerasPruningCallback
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved

XGBoost
-------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna_integration.XGBoostPruningCallback
3 changes: 3 additions & 0 deletions optuna_integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"tensorboard": ["TensorBoardCallback"],
"tensorflow": ["TensorFlowPruningHook"],
"tfkeras": ["TFKerasPruningCallback"],
"xgboost": ["XGBoostPruningCallback"],
}


Expand All @@ -46,6 +47,7 @@
from optuna_integration.tensorboard import TensorBoardCallback
from optuna_integration.tensorflow import TensorFlowPruningHook
from optuna_integration.tfkeras import TFKerasPruningCallback
from optuna_integration.xgboost import XGBoostPruningCallback

Check warning on line 50 in optuna_integration/__init__.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/__init__.py#L50

Added line #L50 was not covered by tests
else:

class _IntegrationModule(ModuleType):
Expand Down Expand Up @@ -113,4 +115,5 @@
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
"XGBoostPruningCallback",
]
124 changes: 124 additions & 0 deletions optuna_integration/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import Any

import optuna
import optuna_integration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import Any
import optuna
import optuna_integration
from typing import Any
import optuna
import optuna_integration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and this change to the imports is done as well!



use_callback_cls = True

with optuna_integration._imports.try_import() as _imports:
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
import xgboost as xgb

xgboost_version = xgb.__version__.split(".")
xgboost_major_version = int(xgboost_version[0])
xgboost_minor_version = int(xgboost_version[1])
use_callback_cls = (
xgboost_major_version >= 1 and xgboost_minor_version >= 3
) or xgboost_major_version >= 2

_doc = """Callback for XGBoost to prune unpromising trials.

See `the example <https://github.com/optuna/optuna-examples/blob/main/
xgboost/xgboost_integration.py>`__
if you want to add a pruning callback which observes validation accuracy of
a XGBoost model.

Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
observation_key:
An evaluation metric for pruning, e.g., ``validation-error`` and
``validation-merror``. When using the Scikit-Learn API, the index number of
``eval_set`` must be included in the ``observation_key``, e.g.,
``validation_0-error`` and ``validation_0-merror``. Please refer to ``eval_metric``
in `XGBoost reference <https://xgboost.readthedocs.io/en/latest/parameter.html>`_
for further details.
"""

if _imports.is_successful() and use_callback_cls:

class XGBoostPruningCallback(xgb.callback.TrainingCallback):
__doc__ = _doc

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
self._trial = trial
self._observation_key = observation_key
self._is_cv = False

def before_training(self, model: Any) -> Any:
# The use of Any type is due to _PackedBooster is not yet being exposed
# to public interface as of xgboost 1.3.
if isinstance(model, xgb.Booster):
self._is_cv = False
else:
self._is_cv = True
return model

def after_iteration(self, model: Any, epoch: int, evals_log: dict) -> bool:
evaluation_results = {}
# Flatten the evaluation history to `{dataset-metric: score}` layout.
for dataset, metrics in evals_log.items():
for metric, scores in metrics.items():
assert isinstance(scores, list), scores
key = dataset + "-" + metric
if self._is_cv:
# Remove stddev of the metric across the cross-validation
# folds.
evaluation_results[key] = scores[-1][0]
else:
evaluation_results[key] = scores[-1]

current_score = evaluation_results[self._observation_key]
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(epoch)
raise optuna.TrialPruned(message)
# The training should not stop.
return False

elif _imports.is_successful():

Check warning on line 80 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L80

Added line #L80 was not covered by tests

def _get_callback_context(env: "xgb.core.CallbackEnv") -> str: # type: ignore

Check warning on line 82 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L82

Added line #L82 was not covered by tests
"""Return whether the current callback context is cv or train.

.. note::
`Reference
<https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py>`_.
"""

if env.model is None and env.cvfolds is not None:
context = "cv"

Check warning on line 91 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L90-L91

Added lines #L90 - L91 were not covered by tests
else:
context = "train"
return context

Check warning on line 94 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L93-L94

Added lines #L93 - L94 were not covered by tests

class XGBoostPruningCallback: # type: ignore
__doc__ = _doc

Check warning on line 97 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L96-L97

Added lines #L96 - L97 were not covered by tests

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
self._trial = trial
self._observation_key = observation_key

Check warning on line 101 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L99-L101

Added lines #L99 - L101 were not covered by tests

def __call__(self, env: "xgb.core.CallbackEnv") -> None: # type: ignore
context = _get_callback_context(env)
evaluation_result_list = env.evaluation_result_list
if context == "cv":

Check warning on line 106 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L103-L106

Added lines #L103 - L106 were not covered by tests
# Remove a third element: the stddev of the metric across the
# cross-validation folds.
evaluation_result_list = [

Check warning on line 109 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L109

Added line #L109 was not covered by tests
(key, metric) for key, metric, _ in evaluation_result_list
]
current_score = dict(evaluation_result_list)[self._observation_key]
self._trial.report(current_score, step=env.iteration)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(env.iteration)
raise optuna.TrialPruned(message)

Check warning on line 116 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L112-L116

Added lines #L112 - L116 were not covered by tests

else:

class XGBoostPruningCallback: # type: ignore
__doc__ = _doc

Check warning on line 121 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L120-L121

Added lines #L120 - L121 were not covered by tests

def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
_imports.check()

Check warning on line 124 in optuna_integration/xgboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/xgboost.py#L123-L124

Added lines #L123 - L124 were not covered by tests
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ all = [
"skorch",
"tensorboard",
"tensorflow",
"xgboost",
"torch",
]

Expand Down
79 changes: 79 additions & 0 deletions tests/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import pytest

import optuna
from optuna_integration._imports import try_import
from optuna_integration.xgboost import XGBoostPruningCallback
from optuna.testing.pruners import DeterministicPruner
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import numpy as np
import pytest
import optuna
from optuna_integration._imports import try_import
from optuna_integration.xgboost import XGBoostPruningCallback
from optuna.testing.pruners import DeterministicPruner
import numpy as np
import optuna
from optuna.testing.pruners import DeterministicPruner
import pytest
from optuna_integration._imports import try_import
from optuna_integration.xgboost import XGBoostPruningCallback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and this is done!



with try_import():
import xgboost as xgb

pytestmark = pytest.mark.integration


def test_xgboost_pruning_callback_call() -> None:
# The pruner is deactivated.
study = optuna.create_study(pruner=DeterministicPruner(False))
trial = study.ask()
pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
pruning_callback.after_iteration(
model=None, epoch=1, evals_log={"validation": {"logloss": [1.0]}}
)

# The pruner is activated.
study = optuna.create_study(pruner=DeterministicPruner(True))
trial = study.ask()
pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
with pytest.raises(optuna.TrialPruned):
pruning_callback.after_iteration(
model=None, epoch=1, evals_log={"validation": {"logloss": [1.0]}}
)


def test_xgboost_pruning_callback() -> None:
def objective(trial: optuna.trial.Trial) -> float:
dtrain = xgb.DMatrix(np.asarray([[1.0]]), label=[1.0])
dtest = xgb.DMatrix(np.asarray([[1.0]]), label=[1.0])

pruning_callback = XGBoostPruningCallback(trial, "validation-logloss")
xgb.train(
{"objective": "binary:logistic"},
dtrain,
1,
evals=[(dtest, "validation")],
verbose_eval=False,
callbacks=[pruning_callback],
)
return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0


def test_xgboost_pruning_callback_cv() -> None:
def objective(trial: optuna.trial.Trial) -> float:
dtrain = xgb.DMatrix(np.ones((2, 1)), label=[1.0, 1.0])
params = {
"objective": "binary:logistic",
}

pruning_callback = XGBoostPruningCallback(trial, "test-logloss")
xgb.cv(params, dtrain, callbacks=[pruning_callback], nfold=2)
return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0