Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance GPSampler performance (other than introducing local search) #5279

Merged
merged 20 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 37 additions & 7 deletions optuna/_gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def _fit_kernel_params(
is_categorical: np.ndarray, # [len(params)]
log_prior: Callable[[KernelParamsTensor], torch.Tensor],
minimum_noise: float,
deterministic_objective: bool,
initial_kernel_params: KernelParamsTensor,
gtol: float,
) -> KernelParamsTensor:
n_params = X.shape[1]

Expand All @@ -164,7 +166,8 @@ def _fit_kernel_params(
np.log(initial_kernel_params.inverse_squared_lengthscales.detach().numpy()),
[
np.log(initial_kernel_params.kernel_scale.item()),
np.log(initial_kernel_params.noise_var.item() - minimum_noise),
# We add 0.01 * minimum_noise to initial noise_var to avoid instability.
np.log(initial_kernel_params.noise_var.item() - 0.99 * minimum_noise),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the comment above (i.e., "add 0.01 * minimum_noise to initial noise_var"), I feel this line should be like below.

np.log(initial_kernel_params.noise_var.item() + 0.01 * minimum_noise)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is trying to transform the variable so that the search space is gonna be $(-\infty, \infty)$.
So probably, what @y0z wanted to say was actually this?

transformed_noise = initial_kernel_params.noise_var.item() - minimum_noise
...=np.log(transformed_noise + 0.01 * minimum_noise)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memo: The intention is below.

(initial_noise_var + 0.01 * minimum_noise) - minimum_noise = initial_var - 0.99 * minimum_noise.

],
]
)
Expand All @@ -175,27 +178,45 @@ def loss_func(raw_params: np.ndarray) -> tuple[float, np.ndarray]:
params = KernelParamsTensor(
inverse_squared_lengthscales=torch.exp(raw_params_tensor[:n_params]),
kernel_scale=torch.exp(raw_params_tensor[n_params]),
noise_var=torch.exp(raw_params_tensor[n_params + 1]) + minimum_noise,
noise_var=(
torch.tensor(minimum_noise, dtype=torch.float64)
if deterministic_objective
else torch.exp(raw_params_tensor[n_params + 1]) + minimum_noise
),
)
loss = -marginal_log_likelihood(
torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(is_categorical), params
) - log_prior(params)
loss.backward() # type: ignore
# scipy.minimize requires all the gradients to be zero for termination.
raw_noise_var_grad = raw_params_tensor.grad[n_params + 1] # type: ignore
assert not deterministic_objective or raw_noise_var_grad == 0
return loss.item(), raw_params_tensor.grad.detach().numpy() # type: ignore
contramundum53 marked this conversation as resolved.
Show resolved Hide resolved

# jac=True means loss_func returns the gradient for gradient descent.
res = so.minimize(loss_func, initial_raw_params, jac=True)
res = so.minimize(
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
# We need a high gtol value because the loss_func can have high numerically errors.
contramundum53 marked this conversation as resolved.
Show resolved Hide resolved
loss_func,
initial_raw_params,
jac=True,
method="l-bfgs-b",
options={"gtol": gtol},
)
if not res.success:
raise RuntimeError(f"Optimization failed: {res.message}")

# TODO(contramundum53): Handle the case where the optimization fails.
raw_params_opt_tensor = torch.from_numpy(res.x)

return KernelParamsTensor(
res = KernelParamsTensor(
inverse_squared_lengthscales=torch.exp(raw_params_opt_tensor[:n_params]),
kernel_scale=torch.exp(raw_params_opt_tensor[n_params]),
noise_var=torch.exp(raw_params_opt_tensor[n_params + 1]) + minimum_noise,
noise_var=(
torch.tensor(minimum_noise, dtype=torch.float64)
if deterministic_objective
else minimum_noise + torch.exp(raw_params_opt_tensor[n_params + 1])
),
)
return res


def fit_kernel_params(
Expand All @@ -204,7 +225,9 @@ def fit_kernel_params(
is_categorical: np.ndarray,
log_prior: Callable[[KernelParamsTensor], torch.Tensor],
minimum_noise: float,
deterministic_objective: bool,
initial_kernel_params: KernelParamsTensor | None = None,
gtol: float = 1e-2,
) -> KernelParamsTensor:
default_initial_kernel_params = KernelParamsTensor(
inverse_squared_lengthscales=torch.ones(X.shape[1], dtype=torch.float64),
Expand All @@ -221,7 +244,14 @@ def fit_kernel_params(
for init_kernel_params in [initial_kernel_params, default_initial_kernel_params]:
try:
return _fit_kernel_params(
X, Y, is_categorical, log_prior, minimum_noise, init_kernel_params
X=X,
Y=Y,
is_categorical=is_categorical,
log_prior=log_prior,
minimum_noise=minimum_noise,
initial_kernel_params=init_kernel_params,
deterministic_objective=deterministic_objective,
gtol=gtol,
)
except RuntimeError as e:
error = e
Expand Down
8 changes: 6 additions & 2 deletions optuna/_gp/optim.py → optuna/_gp/optim_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@


def optimize_acqf_sample(
acqf_params: acqf.AcquisitionFunctionParams, n_samples: int = 2048, seed: int | None = None
acqf_params: acqf.AcquisitionFunctionParams,
*,
n_samples: int = 2048,
rng: np.random.RandomState | None = None,
) -> tuple[np.ndarray, float]:
# Normalized parameter values are sampled.
xs = sample_normalized_params(n_samples, acqf_params.search_space, seed=seed)
xs = sample_normalized_params(n_samples, acqf_params.search_space, rng=rng)
res = acqf.eval_acqf_no_grad(acqf_params, xs)

best_i = np.argmax(res)
return xs[best_i, :], res[best_i]
12 changes: 8 additions & 4 deletions optuna/_gp/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ def default_log_prior(kernel_params: "gp.KernelParamsTensor") -> "torch.Tensor":
# Log of prior distribution of kernel parameters.

def gamma_log_prior(x: "torch.Tensor", concentration: float, rate: float) -> "torch.Tensor":
# We omit the constant factor `rate ** concentration / factorial(concentration)`.
# We omit the constant factor `rate ** concentration / Gamma(concentration)`.
return (concentration - 1) * torch.log(x) - rate * x

# NOTE(contramundum53): The parameters below were picked qualitatively.
# NOTE(contramundum53): The priors below (params and function
# shape for inverse_squared_lengthscales) were picked by heuristics.
# TODO(contramundum53): Check whether these priors are appropriate.
return (
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
gamma_log_prior(kernel_params.inverse_squared_lengthscales, 2, 0.5).sum()
-(
0.1 / kernel_params.inverse_squared_lengthscales
+ 0.1 * kernel_params.inverse_squared_lengthscales
).sum()
contramundum53 marked this conversation as resolved.
Show resolved Hide resolved
+ gamma_log_prior(kernel_params.kernel_scale, 2, 1)
+ gamma_log_prior(kernel_params.noise_var, 1.1, 20)
+ gamma_log_prior(kernel_params.noise_var, 1.1, 30)
)
7 changes: 5 additions & 2 deletions optuna/_gp/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ def round_one_normalized_param(
return param_value


def sample_normalized_params(n: int, search_space: SearchSpace, seed: int | None) -> np.ndarray:
def sample_normalized_params(
n: int, search_space: SearchSpace, rng: np.random.RandomState | None
) -> np.ndarray:
rng = rng or np.random.RandomState()
dim = search_space.scale_types.shape[0]
scale_types = search_space.scale_types
bounds = search_space.bounds
steps = search_space.steps
qmc_engine = qmc.Sobol(dim, scramble=True, seed=seed)
qmc_engine = qmc.Sobol(dim, scramble=True, seed=rng.randint(np.iinfo(np.int32).max))
param_values = qmc_engine.random(n)
for i in range(dim):
if scale_types[i] == ScaleType.CATEGORICAL:
Expand Down
30 changes: 23 additions & 7 deletions optuna/samplers/_gp/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import optuna._gp.acqf as acqf
import optuna._gp.gp as gp
import optuna._gp.optim as optim
import optuna._gp.optim_sample as optim_sample

Check warning on line 28 in optuna/samplers/_gp/sampler.py

View check run for this annotation

Codecov / codecov/patch

optuna/samplers/_gp/sampler.py#L28

Added line #L28 was not covered by tests
import optuna._gp.prior as prior
import optuna._gp.search_space as gp_search_space
else:
Expand All @@ -34,7 +34,7 @@
torch = _LazyImport("torch")
gp_search_space = _LazyImport("optuna._gp.search_space")
gp = _LazyImport("optuna._gp.gp")
optim = _LazyImport("optuna._gp.optim")
optim_sample = _LazyImport("optuna._gp.optim_sample")
acqf = _LazyImport("optuna._gp.acqf")
prior = _LazyImport("optuna._gp.prior")

Expand Down Expand Up @@ -69,6 +69,12 @@

n_startup_trials:
Number of initial trials. Defaults to 10.

deterministic_objective:
Whether the objective function is deterministic or not.
If `True`, the sampler will fix the noise variance of the surrogate model to
the minimum value (slightly above 0 to ensure numerical stability).
Defaults to `False`.
"""

def __init__(
Expand All @@ -77,6 +83,7 @@
seed: int | None = None,
independent_sampler: BaseSampler | None = None,
n_startup_trials: int = 10,
deterministic_objective: bool = False,
) -> None:
self._rng = LazyRandomState(seed)
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
Expand All @@ -89,6 +96,7 @@
# We cache the kernel parameters for initial values of fitting the next time.
self._kernel_params_cache: "gp.KernelParamsTensor | None" = None
self._optimize_n_samples: int = 2048
self._deterministic = deterministic_objective

def reseed_rng(self) -> None:
self._rng.rng.seed()
Expand All @@ -105,6 +113,17 @@

return search_space

def _optimize_acqf(
self,
acqf_params: "acqf.AcquisitionFunctionParams",
) -> np.ndarray:
normalized_params, _ = optim_sample.optimize_acqf_sample(
contramundum53 marked this conversation as resolved.
Show resolved Hide resolved
acqf_params,
n_samples=self._optimize_n_samples,
rng=self._rng.rng,
)
return normalized_params

def sample_relative(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> dict[str, Any]:
Expand Down Expand Up @@ -156,6 +175,7 @@
log_prior=self._log_prior,
minimum_noise=self._minimum_noise,
initial_kernel_params=self._kernel_params_cache,
deterministic_objective=self._deterministic,
)
self._kernel_params_cache = kernel_params

Expand All @@ -167,11 +187,7 @@
Y=standarized_score_vals,
)

normalized_param, _ = optim.optimize_acqf_sample(
acqf_params,
n_samples=self._optimize_n_samples,
seed=self._rng.rng.randint(np.iinfo(np.int32).max),
)
normalized_param = self._optimize_acqf(acqf_params)
return gp_search_space.get_unnormalized_param(search_space, normalized_param)

def sample_independent(
Expand Down
11 changes: 7 additions & 4 deletions optuna/terminator/improvement/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

from optuna._gp import acqf
from optuna._gp import gp
from optuna._gp import optim
from optuna._gp import optim_sample

Check warning on line 24 in optuna/terminator/improvement/evaluator.py

View check run for this annotation

Codecov / codecov/patch

optuna/terminator/improvement/evaluator.py#L24

Added line #L24 was not covered by tests
from optuna._gp import prior
from optuna._gp import search_space
else:
from optuna._imports import _LazyImport

torch = _LazyImport("torch")
gp = _LazyImport("optuna._gp.gp")
optim = _LazyImport("optuna._gp.optim")
optim_sample = _LazyImport("optuna._gp.optim_sample")
acqf = _LazyImport("optuna._gp.acqf")
prior = _LazyImport("optuna._gp.prior")
search_space = _LazyImport("optuna._gp.search_space")
Expand Down Expand Up @@ -141,11 +141,12 @@
Y=standarized_top_n_values,
beta=beta,
)
seed = self._rng.rng.randint(np.iinfo(np.int32).max)
# UCB over the search space. (Original: LCB over the search space. See Change 1 above.)
standardized_ucb_value = max(
acqf.eval_acqf_no_grad(ucb_acqf_params, normalized_top_n_params).max(),
optim.optimize_acqf_sample(ucb_acqf_params, self._optimize_n_samples, seed)[1],
optim_sample.optimize_acqf_sample(
ucb_acqf_params, n_samples=self._optimize_n_samples, rng=self._rng.rng
)[1],
)

# calculate min_lcb
Expand Down Expand Up @@ -192,6 +193,8 @@
is_categorical=(gp_search_space.scale_types == search_space.ScaleType.CATEGORICAL),
log_prior=self._log_prior,
minimum_noise=self._minimum_noise,
# TODO(contramundum53): Add option to specify this.
deterministic_objective=False,
# TODO(y0z): Add `kernel_params_cache` to speedup.
initial_kernel_params=None,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/gp_tests/test_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def test_sample_normalized_params() -> None:
bounds=np.array([(0.0, 10.0), (1.0, 10.0), (10.0, 100.0), (10.0, 100.0), (0.0, 5.0)]),
steps=np.array([0.0, 1.0, 0.0, 1.0, 1.0]),
)
samples = sample_normalized_params(n=128, search_space=search_space, seed=0)
samples = sample_normalized_params(
n=128, search_space=search_space, rng=np.random.RandomState(0)
)
assert samples.shape == (128, 5)
assert np.all((samples[:, :4] >= 0.0) & (samples[:, :4] <= 1.0))

Expand Down
10 changes: 8 additions & 2 deletions tests/samplers_tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@


def get_gp_sampler(
*, n_startup_trials: int = 0, seed: int | None = None
*, n_startup_trials: int = 0, deterministic_objective: bool = False, seed: int | None = None
) -> optuna.samplers.GPSampler:
return optuna.samplers.GPSampler(n_startup_trials=n_startup_trials, seed=seed)
return optuna.samplers.GPSampler(
n_startup_trials=n_startup_trials,
seed=seed,
deterministic_objective=deterministic_objective,
)


parametrize_sampler = pytest.mark.parametrize(
Expand All @@ -49,6 +53,7 @@ def get_gp_sampler(
optuna.samplers.NSGAIIISampler,
optuna.samplers.QMCSampler,
lambda: get_gp_sampler(n_startup_trials=0),
lambda: get_gp_sampler(n_startup_trials=0, deterministic_objective=True),
],
)
parametrize_relative_sampler = pytest.mark.parametrize(
Expand All @@ -58,6 +63,7 @@ def get_gp_sampler(
lambda: optuna.samplers.CmaEsSampler(n_startup_trials=0),
lambda: optuna.samplers.CmaEsSampler(n_startup_trials=0, use_separable_cma=True),
lambda: get_gp_sampler(n_startup_trials=0),
lambda: get_gp_sampler(n_startup_trials=0, deterministic_objective=True),
],
)
parametrize_multi_objective_sampler = pytest.mark.parametrize(
Expand Down