Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement optuna.terminator using optuna._gp #5241

Merged
merged 40 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7c4a4a3
Implement optuna.terminator using optuna/_gp.
y0z Feb 9, 2024
beb801a
Apply linter
y0z Feb 9, 2024
7cc188b
Apply mypy
y0z Feb 9, 2024
f501e44
Fix LazyImport
y0z Feb 9, 2024
e333aa2
Skip tests for python 3.12
y0z Feb 9, 2024
63073c9
Add URL for the paper
y0z Feb 9, 2024
32bd4d1
Fix for flake8.
y0z Feb 9, 2024
fc0afd0
Fix test for terminator
y0z Feb 9, 2024
b49223b
Use lazy import
y0z Feb 9, 2024
fdd7b5a
Update pyproject.toml
y0z Feb 9, 2024
cad8a96
Merge branch 'master' into feature/terminator-gp
contramundum53 Feb 9, 2024
e5276e3
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
565ef4c
Refactoring terminator.
y0z Feb 13, 2024
12005c4
Apply black.
y0z Feb 13, 2024
6abea75
Refactoring evaluator.
y0z Feb 13, 2024
f01cf16
Fix test_acqf.
y0z Feb 13, 2024
2185372
Fix test_acqf.
y0z Feb 13, 2024
3a1fd10
Use eval_acqf_no_grad.
y0z Feb 13, 2024
a681a28
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
aa35607
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
f58a50e
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
5c8bf70
Fix error.
y0z Feb 13, 2024
46b443d
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
273ea71
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
5e9a43d
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
8de5cf4
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
24a4c4e
Fix skipping method for terminator test.
y0z Feb 13, 2024
64362dc
Fix errors.
y0z Feb 13, 2024
2de1dfe
Fix mypy error.
y0z Feb 13, 2024
e004e68
Fix tutorial.
y0z Feb 13, 2024
29a39b6
Revert "Fix tutorial."
y0z Feb 13, 2024
0ef8897
Apply black.
y0z Feb 13, 2024
f1ef6e4
Update optuna/terminator/improvement/evaluator.py
y0z Feb 13, 2024
9df6742
Fix variable name.
y0z Feb 13, 2024
819e00d
Apply black.
y0z Feb 13, 2024
437c337
Update tests/terminator_tests/test_callback.py
y0z Feb 13, 2024
06c9eab
Update tests/terminator_tests/test_erroreval.py
y0z Feb 13, 2024
db1e1ef
Update tests/terminator_tests/test_terminator.py
y0z Feb 13, 2024
1b9acf8
Update tests/visualization_tests/test_terminator_improvement.py
y0z Feb 13, 2024
5b52d8b
Update optuna/_gp/acqf.py
y0z Feb 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ jobs:
fi

if [ "${{ matrix.python-version }}" = "3.12" ] ; then
# TODO(not522): Remove ignores when BoTorch supports Python 3.12
# TODO(not522): Remove ignores when BoTorch/Torch supports Python 3.12
ignore_option="--ignore tests/terminator_tests/ \
--ignore tests/terminator_tests/improvement_tests/ \
--ignore tests/visualization_tests/test_terminator_improvement.py"
else
ignore_option=""
Expand Down
8 changes: 8 additions & 0 deletions optuna/_gp/acqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,16 @@
return mean + torch.sqrt(beta * var)


def lcb(mean: torch.Tensor, var: torch.Tensor, beta: float) -> torch.Tensor:
return mean - torch.sqrt(beta * var)

Check warning on line 61 in optuna/_gp/acqf.py

View check run for this annotation

Codecov / codecov/patch

optuna/_gp/acqf.py#L61

Added line #L61 was not covered by tests


# TODO(contramundum53): consider abstraction for acquisition functions.
# NOTE: Acquisition function is not class on purpose to integrate numba in the future.
class AcquisitionFunctionType(IntEnum):
LOG_EI = 0
UCB = 1
LCB = 2


@dataclass(frozen=True)
Expand Down Expand Up @@ -122,6 +127,9 @@
elif acqf_params.acqf_type == AcquisitionFunctionType.UCB:
assert acqf_params.beta is not None, "beta must be given to UCB."
return ucb(mean=mean, var=var, beta=acqf_params.beta)
elif acqf_params.acqf_type == AcquisitionFunctionType.LCB:
assert acqf_params.beta is not None, "beta must be given to LCB."
return ucb(mean=mean, var=var, beta=acqf_params.beta)
y0z marked this conversation as resolved.
Show resolved Hide resolved
else:
assert False, "Unknown acquisition function type."

Expand Down
32 changes: 32 additions & 0 deletions optuna/_gp/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from optuna._gp import gp


if TYPE_CHECKING:
import torch

Check warning on line 9 in optuna/_gp/prior.py

View check run for this annotation

Codecov / codecov/patch

optuna/_gp/prior.py#L9

Added line #L9 was not covered by tests
else:
from optuna._imports import _LazyImport

torch = _LazyImport("torch")


DEFAULT_MINIMUM_NOISE_VAR = 1e-6


def default_log_prior(kernel_params: "gp.KernelParamsTensor") -> "torch.Tensor":
# Log of prior distribution of kernel parameters.

def gamma_log_prior(x: "torch.Tensor", concentration: float, rate: float) -> "torch.Tensor":
# We omit the constant factor `rate ** concentration / factorial(concentration)`.
return (concentration - 1) * torch.log(x) - rate * x

# NOTE(contramundum53): The parameters below were picked qualitatively.
# TODO(contramundum53): Check whether these priors are appropriate.
return (
gamma_log_prior(kernel_params.inverse_squared_lengthscales, 2, 0.5).sum()
+ gamma_log_prior(kernel_params.kernel_scale, 2, 1)
+ gamma_log_prior(kernel_params.noise_var, 1.1, 20)
)
24 changes: 6 additions & 18 deletions optuna/samplers/_gp/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import optuna._gp.acqf as acqf
import optuna._gp.gp as gp
import optuna._gp.optim as optim
import optuna._gp.prior as prior

Check warning on line 29 in optuna/samplers/_gp/sampler.py

View check run for this annotation

Codecov / codecov/patch

optuna/samplers/_gp/sampler.py#L29

Added line #L29 was not covered by tests
import optuna._gp.search_space as gp_search_space
else:
from optuna._imports import _LazyImport
Expand All @@ -35,22 +36,7 @@
gp = _LazyImport("optuna._gp.gp")
optim = _LazyImport("optuna._gp.optim")
acqf = _LazyImport("optuna._gp.acqf")


def log_prior(kernel_params: "gp.KernelParamsTensor") -> "torch.Tensor":
# Log of prior distribution of kernel parameters.

def gamma_log_prior(x: "torch.Tensor", concentration: float, rate: float) -> "torch.Tensor":
# We omit the constant factor `rate ** concentration / factorial(concentration)`.
return (concentration - 1) * torch.log(x) - rate * x

# NOTE(contramundum53): The parameters below were picked qualitatively.
# TODO(contramundum53): Check whether these priors are appropriate.
return (
gamma_log_prior(kernel_params.inverse_squared_lengthscales, 2, 0.5).sum()
+ gamma_log_prior(kernel_params.kernel_scale, 2, 1)
+ gamma_log_prior(kernel_params.noise_var, 1.1, 20)
)
prior = _LazyImport("optuna._gp.prior")


@experimental_class("3.6.0")
Expand Down Expand Up @@ -96,8 +82,10 @@
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
self._intersection_search_space = optuna.search_space.IntersectionSearchSpace()
self._n_startup_trials = n_startup_trials
self._log_prior: "Callable[[gp.KernelParamsTensor], torch.Tensor]" = log_prior
self._minimum_noise: float = 1e-6
self._log_prior: "Callable[[gp.KernelParamsTensor], torch.Tensor]" = (
prior.default_log_prior
)
self._minimum_noise: float = prior.DEFAULT_MINIMUM_NOISE_VAR
# We cache the kernel parameters for initial values of fitting the next time.
self._kernel_params_cache: "gp.KernelParamsTensor | None" = None
self._optimize_n_samples: int = 2048
Expand Down
230 changes: 0 additions & 230 deletions optuna/terminator/improvement/_preprocessing.py

This file was deleted.