Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/optuna/optuna into update…
Browse files Browse the repository at this point in the history
…-black
  • Loading branch information
Alnusjaponica committed Feb 19, 2024
2 parents a85929d + 1bcda7a commit c3f9671
Show file tree
Hide file tree
Showing 23 changed files with 113 additions and 3,937 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ formats: all

# Optionally set the version of Python and requirements required to build your docs
python:
# `sphinx` requires either Python >= 3.8 or `typed-ast` to reflect type comments
# `sphinx` requires either Python >= 3.8 or `typed-ast` to reflect type comments
# in the documentation. See: https://github.com/sphinx-doc/sphinx/pull/6984
install:
- method: pip
Expand Down
21 changes: 0 additions & 21 deletions docs/source/reference/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,6 @@ For most of the ML frameworks supported by Optuna, the corresponding Optuna inte

For scikit-learn, an integrated :class:`~optuna.integration.OptunaSearchCV` estimator is available that combines scikit-learn BaseEstimator functionality with access to a class-level ``Study`` object.

LightGBM
--------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna.integration.LightGBMPruningCallback
optuna.integration.lightgbm.train
optuna.integration.lightgbm.LightGBMTuner
optuna.integration.lightgbm.LightGBMTunerCV

MLflow
------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna.integration.MLflowCallback

Dependencies of each integration
--------------------------------

Expand Down
50 changes: 42 additions & 8 deletions optuna/_gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import numpy as np

from optuna.logging import get_logger


if TYPE_CHECKING:
import scipy.optimize as so
Expand All @@ -18,6 +20,7 @@
so = _LazyImport("scipy.optimize")
torch = _LazyImport("torch")

logger = get_logger(__name__)

# This GP implementation uses the following notation:
# X[len(trials), len(params)]: observed parameter values.
Expand Down Expand Up @@ -141,21 +144,15 @@ def marginal_log_likelihood(
)


def fit_kernel_params(
def _fit_kernel_params(
X: np.ndarray, # [len(trials), len(params)]
Y: np.ndarray, # [len(trials)]
is_categorical: np.ndarray, # [len(params)]
log_prior: Callable[[KernelParamsTensor], torch.Tensor],
minimum_noise: float,
initial_kernel_params: KernelParamsTensor | None = None,
initial_kernel_params: KernelParamsTensor,
) -> KernelParamsTensor:
n_params = X.shape[1]
if initial_kernel_params is None:
initial_kernel_params = KernelParamsTensor(
inverse_squared_lengthscales=torch.ones(n_params, dtype=torch.float64),
kernel_scale=torch.tensor(1.0, dtype=torch.float64),
noise_var=torch.tensor(1.0, dtype=torch.float64),
)

# We apply log transform to enforce the positivity of the kernel parameters.
# Note that we cannot just use the constraint because of the numerical unstability
Expand Down Expand Up @@ -188,6 +185,8 @@ def loss_func(raw_params: np.ndarray) -> tuple[float, np.ndarray]:

# jac=True means loss_func returns the gradient for gradient descent.
res = so.minimize(loss_func, initial_raw_params, jac=True)
if not res.success:
raise RuntimeError(f"Optimization failed: {res.message}")

# TODO(contramundum53): Handle the case where the optimization fails.
raw_params_opt_tensor = torch.from_numpy(res.x)
Expand All @@ -197,3 +196,38 @@ def loss_func(raw_params: np.ndarray) -> tuple[float, np.ndarray]:
kernel_scale=torch.exp(raw_params_opt_tensor[n_params]),
noise_var=torch.exp(raw_params_opt_tensor[n_params + 1]) + minimum_noise,
)


def fit_kernel_params(
X: np.ndarray,
Y: np.ndarray,
is_categorical: np.ndarray,
log_prior: Callable[[KernelParamsTensor], torch.Tensor],
minimum_noise: float,
initial_kernel_params: KernelParamsTensor | None = None,
) -> KernelParamsTensor:
default_initial_kernel_params = KernelParamsTensor(
inverse_squared_lengthscales=torch.ones(X.shape[1], dtype=torch.float64),
kernel_scale=torch.tensor(1.0, dtype=torch.float64),
noise_var=torch.tensor(1.0, dtype=torch.float64),
)
if initial_kernel_params is None:
initial_kernel_params = default_initial_kernel_params

error = None
# First try optimizing the kernel params with the provided initial_kernel_params,
# but if it fails, rerun the optimization with the default initial_kernel_params.
# This increases the robustness of the optimization.
for init_kernel_params in [initial_kernel_params, default_initial_kernel_params]:
try:
return _fit_kernel_params(
X, Y, is_categorical, log_prior, minimum_noise, init_kernel_params
)
except RuntimeError as e:
error = e

logger.warn(
f"The optimization of kernel_params failed: \n{error}\n"
"The default initial kernel params will be used instead."
)
return default_initial_kernel_params
70 changes: 36 additions & 34 deletions optuna/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,41 @@
}


__all__ = [
"AllenNLPExecutor",
"AllenNLPPruningCallback",
"BoTorchSampler",
"CatalystPruningCallback",
"CatBoostPruningCallback",
"ChainerPruningExtension",
"ChainerMNStudy",
"CmaEsSampler",
"PyCmaSampler",
"DaskStorage",
"MLflowCallback",
"WeightsAndBiasesCallback",
"KerasPruningCallback",
"LightGBMPruningCallback",
"LightGBMTuner",
"LightGBMTunerCV",
"TorchDistributedTrial",
"PyTorchIgnitePruningHandler",
"PyTorchLightningPruningCallback",
"OptunaSearchCV",
"ShapleyImportanceEvaluator",
"SkorchPruningCallback",
"MXNetPruningCallback",
"SkoptSampler",
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
"XGBoostPruningCallback",
"FastAIV1PruningCallback",
"FastAIV2PruningCallback",
"FastAIPruningCallback",
]


if TYPE_CHECKING:
from optuna.integration.allennlp import AllenNLPExecutor
from optuna.integration.allennlp import AllenNLPPruningCallback
Expand Down Expand Up @@ -77,6 +112,7 @@ class _IntegrationModule(ModuleType):
imports all submodules and their dependencies (e.g., chainer, keras, lightgbm) all at once.
"""

__all__ = __all__
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]

Expand Down Expand Up @@ -113,37 +149,3 @@ def _get_module(self, module_name: str) -> ModuleType:
)

sys.modules[__name__] = _IntegrationModule(__name__)

__all__ = [
"AllenNLPExecutor",
"AllenNLPPruningCallback",
"BoTorchSampler",
"CatalystPruningCallback",
"CatBoostPruningCallback",
"ChainerPruningExtension",
"ChainerMNStudy",
"CmaEsSampler",
"PyCmaSampler",
"DaskStorage",
"MLflowCallback",
"WeightsAndBiasesCallback",
"KerasPruningCallback",
"LightGBMPruningCallback",
"LightGBMTuner",
"LightGBMTunerCV",
"TorchDistributedTrial",
"PyTorchIgnitePruningHandler",
"PyTorchLightningPruningCallback",
"OptunaSearchCV",
"ShapleyImportanceEvaluator",
"SkorchPruningCallback",
"MXNetPruningCallback",
"SkoptSampler",
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
"XGBoostPruningCallback",
"FastAIV1PruningCallback",
"FastAIV2PruningCallback",
"FastAIPruningCallback",
]
19 changes: 0 additions & 19 deletions optuna/integration/_lightgbm_tuner/__init__.py

This file was deleted.

138 changes: 0 additions & 138 deletions optuna/integration/_lightgbm_tuner/_train.py

This file was deleted.

0 comments on commit c3f9671

Please sign in to comment.