Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CatBoost integration #61

Merged
merged 8 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc

* [AllenNLP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#allennlp) ([example](https://github.com/optuna/optuna-examples/tree/main/allennlp))
* [Catalyst](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#catalyst) ([example](https://github.com/optuna/optuna-examples/blob/main/pytorch/catalyst_simple.py))
* [CatBoost](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#catboost) ([example](https://github.com/optuna/optuna-examples/blob/main/catboost/catboost_pruning.py))
* [Chainer](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainer) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainer_integration.py))
* [ChainerMN](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainermn) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainermn_simple.py))
* FastAI ([V1](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv1) ([example](https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv1_simple.py)), ([V2](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv2) ([example]https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv2_simple.py)))
Expand Down
14 changes: 11 additions & 3 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ API Reference for Optuna-Integration

The Optuna-Integration package contains classes used to integrate Optuna with external machine learning frameworks.

All of these classes can be imported in two ways. One is "`from optuna.integration import xxx`" like a module in Optuna,
and the other is "`from optuna_integration import xxx`" as an Optuna-Integration specific module.
All of these classes can be imported in two ways. One is "`from optuna.integration import xxx`" like a module in Optuna,
and the other is "`from optuna_integration import xxx`" as an Optuna-Integration specific module.
The former is provided for backward compatibility.

For most of the ML frameworks supported by Optuna, the corresponding Optuna integration class serves only to implement a callback object and functions, compliant with the framework's specific callback API, to be called with each intermediate step in the model training. The functionality implemented in these callbacks across the different ML frameworks includes:
Expand Down Expand Up @@ -36,6 +36,15 @@ Catalyst

optuna.integration.CatalystPruningCallback

CatBoost
--------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna.integration.CatBoostPruningCallback

Chainer
-------

Expand Down Expand Up @@ -110,4 +119,3 @@ TensorFlow
:nosignatures:

optuna.integration.TFKerasPruningCallback

112 changes: 112 additions & 0 deletions optuna_integration/catboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Any
from typing import Optional
y0z marked this conversation as resolved.
Show resolved Hide resolved

import optuna
from optuna._experimental import experimental_class
from packaging import version

from optuna_integration._imports import try_import


with try_import() as _imports:
import catboost as cb

if version.parse(cb.__version__) < version.parse("0.26"):
raise ImportError(f"You don't have CatBoost>=0.26! CatBoost version: {cb.__version__}")

Check warning on line 15 in optuna_integration/catboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/catboost.py#L15

Added line #L15 was not covered by tests


@experimental_class("3.0.0")
class CatBoostPruningCallback:
"""Callback for catboost to prune unpromising trials.

See `the example <https://github.com/optuna/optuna-examples/blob/main/
catboost/catboost_pruning.py>`__
if you want to add a pruning callback which observes validation accuracy of
a CatBoost model.

.. note::
:class:`optuna.TrialPruned` cannot be raised
in :meth:`~optuna.integration.CatBoostPruningCallback.after_iteration`
y0z marked this conversation as resolved.
Show resolved Hide resolved
that is called in CatBoost via ``CatBoostPruningCallback``.
You must call :meth:`~optuna.integration.CatBoostPruningCallback.check_pruned`
y0z marked this conversation as resolved.
Show resolved Hide resolved
after training manually unlike other pruning callbacks
to raise :class:`optuna.TrialPruned`.

.. note::
This callback cannot be used with CatBoost on GPUs because CatBoost doesn't support
a user-defined callback for GPU.
Please refer to `CatBoost issue <https://github.com/catboost/catboost/issues/1792>`_.

Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
metric:
An evaluation metric for pruning, e.g., ``Logloss`` and ``AUC``.
Please refer to
`CatBoost reference
<https://catboost.ai/docs/references/eval-metric__supported-metrics.html>`_
for further details.
eval_set_index:
The index of the target validation dataset.
If you set only one ``eval_set``, ``eval_set_index`` is None.
If you set multiple datasets as ``eval_set``, the index of ``eval_set`` must be
``eval_set_index``, e.g., ``0`` or ``1`` when ``eval_set`` contains two datasets.
"""

def __init__(
self, trial: optuna.trial.Trial, metric: str, eval_set_index: Optional[int] = None
y0z marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
default_valid_name = "validation"
self._trial = trial
self._metric = metric
if eval_set_index is None:
self._valid_name = default_valid_name
else:
self._valid_name = default_valid_name + "_" + str(eval_set_index)
self._pruned = False
self._message = ""

def after_iteration(self, info: Any) -> bool:
"""Report an evaluation metric value for Optuna pruning after each CatBoost's iteration.

This method is called by CatBoost.

Args:
info:
A ``SimpleNamespace`` containing iteraion, ``validation_name``, ``metric_name``
and history of losses.
For example ``SimpleNamespace(iteration=2, metrics={
'learn': {'Logloss': [0.6, 0.5]},
'validation': {'Logloss': [0.7, 0.6], 'AUC': [0.8, 0.9]}
})``.

Returns:
A boolean value. If :obj:`False`, CatBoost internally stops the optimization
with Optuna's pruning logic without raising :class:`optuna.TrialPruned`.
Otherwise, the optimization continues.
"""
step = info.iteration - 1
if self._valid_name not in info.metrics:
raise ValueError(

Check warning on line 91 in optuna_integration/catboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/catboost.py#L91

Added line #L91 was not covered by tests
'The entry associated with the validation name "{}" '
"is not found in the evaluation result list {}.".format(self._valid_name, info)
)
metrics = info.metrics[self._valid_name]
if self._metric not in metrics:
raise ValueError(

Check warning on line 97 in optuna_integration/catboost.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/catboost.py#L97

Added line #L97 was not covered by tests
'The entry associated with the metric name "{}" '
"is not found in the evaluation result list {}.".format(self._metric, info)
)
current_score = metrics[self._metric][-1]
self._trial.report(current_score, step=step)
if self._trial.should_prune():
self._message = "Trial was pruned at iteration {}.".format(step)
self._pruned = True
return False
return True

def check_pruned(self) -> None:
"""Raise :class:`optuna.TrialPruned` manually if the CatBoost optimization is pruned."""
if self._pruned:
raise optuna.TrialPruned(self._message)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ document = [
]
all = [
"catalyst",
"catboost>=0.26; sys_platform!='darwin'",
"catboost>=0.26,<1.2; sys_platform=='darwin'",
"fastai",
"mxnet",
"shap",
Expand Down
135 changes: 135 additions & 0 deletions tests/test_catboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import types

import numpy as np
import optuna
from optuna.testing.pruners import DeterministicPruner
import pytest

from optuna_integration._imports import try_import
from optuna_integration.catboost import CatBoostPruningCallback


with try_import():
import catboost as cb

pytestmark = pytest.mark.integration


def test_catboost_pruning_callback_call() -> None:
# The pruner is deactivated.
study = optuna.create_study(pruner=DeterministicPruner(False))
trial = study.ask()
pruning_callback = CatBoostPruningCallback(trial, "Logloss")
info = types.SimpleNamespace(
iteration=1, metrics={"learn": {"Logloss": [1.0]}, "validation": {"Logloss": [1.0]}}
)
assert pruning_callback.after_iteration(info)

# The pruner is activated.
study = optuna.create_study(pruner=DeterministicPruner(True))
trial = study.ask()
pruning_callback = CatBoostPruningCallback(trial, "Logloss")
info = types.SimpleNamespace(
iteration=1, metrics={"learn": {"Logloss": [1.0]}, "validation": {"Logloss": [1.0]}}
)
assert not pruning_callback.after_iteration(info)


METRICS = ["AUC", "Accuracy"]
EVAL_SET_INDEXES = [None, 0, 1]


@pytest.mark.parametrize("metric", METRICS)
@pytest.mark.parametrize("eval_set_index", EVAL_SET_INDEXES)
def test_catboost_pruning_callback_init_param(metric: str, eval_set_index: int) -> None:
def objective(trial: optuna.trial.Trial) -> float:
train_x = np.asarray([[1.0], [2.0]])
train_y = np.asarray([[1.0], [0.0]])
valid_x = np.asarray([[1.0], [2.0]])
valid_y = np.asarray([[1.0], [0.0]])

if eval_set_index is None:
eval_set = [(valid_x, valid_y)]
pruning_callback = CatBoostPruningCallback(trial, metric)
else:
eval_set = [(valid_x, valid_y), (valid_x, valid_y)]
pruning_callback = CatBoostPruningCallback(trial, metric, eval_set_index)

param = {
"objective": "Logloss",
"eval_metric": metric,
}

gbm = cb.CatBoostClassifier(**param)
gbm.fit(
train_x,
train_y,
eval_set=eval_set,
verbose=0,
callbacks=[pruning_callback],
)

# Invoke pruning manually.
pruning_callback.check_pruned()

return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0


# TODO(Hemmi): Remove the skip decorator after CatBoost's error handling is fixed.
# See https://github.com/optuna/optuna/pull/4190 for more details.
@pytest.mark.skip(reason="Temporally skip due to unknown CatBoost error.")
@pytest.mark.parametrize(
"metric, eval_set_index",
[
("foo_metric", None),
("AUC", 100),
],
)
def test_catboost_pruning_callback_errors(metric: str, eval_set_index: int) -> None:
# This test aims to cover the ValueError block in CatBoostPruningCallback.after_iteration().
# However, catboost currently terminates with a SystemError when python>=3.9 or pytest>=7.2.0,
# otherwise terminates with RecursionError. This is because after_iteration() is called in a
# Cython function in the catboost library, which is causing the unexpected error behavior.
# Note that the difference in error type is mainly because the _Py_CheckRecursionLimit
# variable used in limited C API was removed after python 3.9.

def objective(trial: optuna.trial.Trial) -> float:
train_x = np.asarray([[1.0], [2.0]])
train_y = np.asarray([[1.0], [0.0]])
valid_x = np.asarray([[1.0], [2.0]])
valid_y = np.asarray([[1.0], [0.0]])

pruning_callback = CatBoostPruningCallback(trial, metric, eval_set_index)
param = {
"objective": "Logloss",
"eval_metric": "AUC",
}

gbm = cb.CatBoostClassifier(**param)
gbm.fit(
train_x,
train_y,
eval_set=[(valid_x, valid_y)],
verbose=0,
callbacks=[pruning_callback],
)

# Invoke pruning manually.
pruning_callback.check_pruned()

return 1.0

# Unknown validation name or metric.
study = optuna.create_study(pruner=DeterministicPruner(False))

with pytest.raises(ValueError):
study.optimize(objective, n_trials=1)