Skip to content

Commit

Permalink
Merge pull request #5224 from contramundum53/gp-ucb
Browse files Browse the repository at this point in the history
Add UCB for `optuna._gp`
  • Loading branch information
nabenabe0928 committed Feb 8, 2024
2 parents a8aa5c1 + e0ab667 commit 01a2df1
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 38 deletions.
65 changes: 32 additions & 33 deletions optuna/_gp/acqf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import IntEnum
import math
from typing import TYPE_CHECKING

import numpy as np

from optuna._gp.gp import kernel
from optuna._gp.gp import kernel_at_zero_distance
from optuna._gp.gp import KernelParamsTensor
from optuna._gp.gp import posterior
from optuna._gp.search_space import ScaleType
Expand Down Expand Up @@ -45,52 +45,45 @@ def standard_logei(z: torch.Tensor) -> torch.Tensor:
return vals


def logei(mean: torch.Tensor, var: torch.Tensor, f0: torch.Tensor) -> torch.Tensor:
def logei(mean: torch.Tensor, var: torch.Tensor, f0: float) -> torch.Tensor:
# Return E_{y ~ N(mean, var)}[max(0, y-f0)]
sigma = torch.sqrt(var)
st_val = standard_logei((mean - f0) / sigma)
val = torch.log(sigma) + st_val
return val


def eval_logei(
kernel_params: KernelParamsTensor,
X: torch.Tensor,
is_categorical: torch.Tensor,
cov_Y_Y_inv: torch.Tensor,
cov_Y_Y_inv_Y: torch.Tensor,
max_Y: torch.Tensor,
x: torch.Tensor,
# Additional noise to prevent numerical instability.
# Usually this is set to a very small value.
stabilizing_noise: float,
) -> torch.Tensor:
cov_fx_fX = kernel(is_categorical, kernel_params, x[..., None, :], X)[..., 0, :]
cov_fx_fx = kernel_at_zero_distance(kernel_params)
(mean, var) = posterior(cov_Y_Y_inv, cov_Y_Y_inv_Y, cov_fx_fX, cov_fx_fx)
val = logei(mean, var + stabilizing_noise, max_Y)

return val
def ucb(mean: torch.Tensor, var: torch.Tensor, beta: float) -> torch.Tensor:
return mean + torch.sqrt(beta * var)


# TODO(contramundum53): consider abstraction for acquisition functions.
# NOTE: Acquisition function is not class on purpose to integrate numba in the future.
class AcquisitionFunctionType(IntEnum):
LOG_EI = 0
UCB = 1


@dataclass(frozen=True)
class AcquisitionFunctionParams:
# Currently only logEI is supported.
acqf_type: AcquisitionFunctionType
kernel_params: KernelParamsTensor
X: np.ndarray
search_space: SearchSpace
cov_Y_Y_inv: np.ndarray
cov_Y_Y_inv_Y: np.ndarray
max_Y: np.ndarray
max_Y: float
beta: float | None
acqf_stabilizing_noise: float


def create_acqf_params(
acqf_type: AcquisitionFunctionType,
kernel_params: KernelParamsTensor,
search_space: SearchSpace,
X: np.ndarray,
Y: np.ndarray,
beta: float | None = None,
acqf_stabilizing_noise: float = 1e-12,
) -> AcquisitionFunctionParams:
X_tensor = torch.from_numpy(X)
Expand All @@ -102,30 +95,36 @@ def create_acqf_params(
cov_Y_Y_inv = np.linalg.inv(cov_Y_Y)

return AcquisitionFunctionParams(
acqf_type=acqf_type,
kernel_params=kernel_params,
X=X,
search_space=search_space,
cov_Y_Y_inv=cov_Y_Y_inv,
cov_Y_Y_inv_Y=cov_Y_Y_inv @ Y,
max_Y=np.max(Y),
beta=beta,
acqf_stabilizing_noise=acqf_stabilizing_noise,
)


def eval_acqf(acqf_params: AcquisitionFunctionParams, x: torch.Tensor) -> torch.Tensor:
return eval_logei(
kernel_params=acqf_params.kernel_params,
X=torch.from_numpy(acqf_params.X),
is_categorical=torch.from_numpy(
acqf_params.search_space.scale_types == ScaleType.CATEGORICAL
),
cov_Y_Y_inv=torch.from_numpy(acqf_params.cov_Y_Y_inv),
cov_Y_Y_inv_Y=torch.from_numpy(acqf_params.cov_Y_Y_inv_Y),
max_Y=torch.tensor(acqf_params.max_Y, dtype=torch.float64),
x=x,
stabilizing_noise=acqf_params.acqf_stabilizing_noise,
mean, var = posterior(
acqf_params.kernel_params,
torch.from_numpy(acqf_params.X),
torch.from_numpy(acqf_params.search_space.scale_types == ScaleType.CATEGORICAL),
torch.from_numpy(acqf_params.cov_Y_Y_inv),
torch.from_numpy(acqf_params.cov_Y_Y_inv_Y),
x,
)

if acqf_params.acqf_type == AcquisitionFunctionType.LOG_EI:
return logei(mean=mean, var=var + acqf_params.acqf_stabilizing_noise, f0=acqf_params.max_Y)
elif acqf_params.acqf_type == AcquisitionFunctionType.UCB:
assert acqf_params.beta is not None, "beta must be given to UCB."
return ucb(mean=mean, var=var, beta=acqf_params.beta)
else:
assert False, "Unknown acquisition function type."


def eval_acqf_no_grad(acqf_params: AcquisitionFunctionParams, x: np.ndarray) -> np.ndarray:
with torch.no_grad():
Expand Down
14 changes: 9 additions & 5 deletions optuna/_gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,18 @@ def kernel_at_zero_distance(


def posterior(
kernel_params: KernelParamsTensor,
X: torch.Tensor, # [len(trials), len(params)]
is_categorical: torch.Tensor, # bool[len(params)]
cov_Y_Y_inv: torch.Tensor, # [len(trials), len(trials)]
cov_Y_Y_inv_Y: torch.Tensor, # [len(trials)]
cov_fx_fX: torch.Tensor, # [(batch,) len(trials)]
cov_fx_fx: torch.Tensor, # Scalar or [(batch,)]
) -> tuple[torch.Tensor, torch.Tensor]: # [(batch,)], [(batch,)]
x: torch.Tensor, # [(batch,) len(params)]
) -> tuple[torch.Tensor, torch.Tensor]: # (mean: [(batch,)], var: [(batch,)])
cov_fx_fX = kernel(is_categorical, kernel_params, x[..., None, :], X)[..., 0, :]
cov_fx_fx = kernel_at_zero_distance(kernel_params)

# mean = cov_fx_fX @ inv(cov_fX_fX + noise * I) @ Y
# var = cov_fx_fx - cov_fx_fX @ inv(cov_fX_fX + noise * I) @ cov_fx_fX.T

mean = cov_fx_fX @ cov_Y_Y_inv_Y # [batch]
var = cov_fx_fx - (cov_fx_fX * (cov_fx_fX @ cov_Y_Y_inv)).sum(dim=-1) # [batch]
# We need to clamp the variance to avoid negative values due to numerical errors.
Expand Down Expand Up @@ -142,7 +146,7 @@ def fit_kernel_params(
Y: np.ndarray, # [len(trials)]
is_categorical: np.ndarray, # [len(params)]
log_prior: Callable[[KernelParamsTensor], torch.Tensor],
minimum_noise: float = 0.0,
minimum_noise: float,
initial_kernel_params: KernelParamsTensor | None = None,
) -> KernelParamsTensor:
n_params = X.shape[1]
Expand Down
1 change: 1 addition & 0 deletions optuna/samplers/_gp/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def sample_relative(
self._kernel_params_cache = kernel_params

acqf_params = acqf.create_acqf_params(
acqf_type=acqf.AcquisitionFunctionType.LOG_EI,
kernel_params=kernel_params,
search_space=internal_search_space,
X=normalized_params,
Expand Down
73 changes: 73 additions & 0 deletions tests/gp_tests/test_acqf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

import sys

import numpy as np
import pytest


# TODO(contramundum53): Remove this block after torch supports Python 3.12.
if sys.version_info >= (3, 12):
pytest.skip("PyTorch does not support python 3.12.", allow_module_level=True)

import torch

from optuna._gp.acqf import AcquisitionFunctionType
from optuna._gp.acqf import create_acqf_params
from optuna._gp.acqf import eval_acqf
from optuna._gp.gp import KernelParamsTensor
from optuna._gp.search_space import ScaleType
from optuna._gp.search_space import SearchSpace


@pytest.mark.parametrize(
"acqf_type, beta",
[
(AcquisitionFunctionType.LOG_EI, None),
(AcquisitionFunctionType.UCB, 2.0),
],
)
@pytest.mark.parametrize(
"x", [np.array([0.15, 0.12]), np.array([[0.15, 0.12], [0.0, 1.0]])] # unbatched # batched
)
def test_eval_acqf(
acqf_type: AcquisitionFunctionType,
beta: float | None,
x: np.ndarray,
) -> None:
n_dims = 2
X = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.1]])
Y = np.array([1.0, 2.0, 3.0])
kernel_params = KernelParamsTensor(
inverse_squared_lengthscales=torch.tensor([2.0, 3.0], dtype=torch.float64),
kernel_scale=torch.tensor(4.0, dtype=torch.float64),
noise_var=torch.tensor(0.1, dtype=torch.float64),
)
search_space = SearchSpace(
scale_types=np.full(n_dims, ScaleType.LINEAR),
bounds=np.array([[0.0, 1.0] * n_dims]),
steps=np.zeros(n_dims),
)

acqf_params = create_acqf_params(
acqf_type=acqf_type,
kernel_params=kernel_params,
search_space=search_space,
X=X,
Y=Y,
beta=beta,
acqf_stabilizing_noise=0.0,
)

x_tensor = torch.from_numpy(x)
x_tensor.requires_grad_(True)

acqf_value = eval_acqf(acqf_params, x_tensor)
acqf_value.sum().backward() # type: ignore
acqf_grad = x_tensor.grad
assert acqf_grad is not None

assert acqf_value.shape == x.shape[:-1]

assert torch.all(torch.isfinite(acqf_value))
assert torch.all(torch.isfinite(acqf_grad))

0 comments on commit 01a2df1

Please sign in to comment.