Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UCB for optuna._gp #5224

Merged
merged 12 commits into from Feb 8, 2024
Merged

Add UCB for optuna._gp #5224

merged 12 commits into from Feb 8, 2024

Conversation

contramundum53
Copy link
Member

@contramundum53 contramundum53 commented Feb 2, 2024

Motivation

optuna.terminator uses UCB. We implement UCB in optuna._gp so that we can implement the backend of optuna.terminator with optuna._gp, essentially removing the dependency on BoTorch from Optuna.

Description of the changes

  • Implement UCB in optuna._gp
  • Add test that compares the values with BoTorch

#5185 must be merged first.

@github-actions github-actions bot added the optuna.samplers Related to the `optuna.samplers` submodule. This is automatically labeled by github-actions. label Feb 2, 2024
@contramundum53
Copy link
Member Author

Since we don't want our CI to depend on botorch, we paste the test code here and remove the test from this PR:

from typing import Any
from typing import Callable

from botorch.acquisition.analytic import LogExpectedImprovement
from botorch.acquisition.analytic import UpperConfidenceBound
from botorch.models import SingleTaskGP
from botorch.models.model import Model
from gpytorch.kernels import MaternKernel
from gpytorch.kernels import ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ZeroMean
import numpy as np
import pytest
import torch

from optuna._gp.acqf import AcquisitionFunctionParams
from optuna._gp.acqf import AcquisitionFunctionType
from optuna._gp.acqf import create_acqf_params
from optuna._gp.acqf import eval_acqf
from optuna._gp.gp import kernel
from optuna._gp.gp import KernelParamsTensor
from optuna._gp.gp import posterior
from optuna._gp.search_space import ScaleType
from optuna._gp.search_space import SearchSpace


@pytest.mark.parametrize(
    "acqf_type, beta, botorch_acqf_gen",
    [
        (
            AcquisitionFunctionType.LOG_EI,
            None,
            lambda model, acqf_params: LogExpectedImprovement(model, best_f=acqf_params.max_Y),
        ),
        (
            AcquisitionFunctionType.UCB,
            2.0,
            lambda model, acqf_params: UpperConfidenceBound(model, beta=acqf_params.beta),
        ),
    ],
)
@pytest.mark.parametrize(
    "x", [np.array([0.15, 0.12]), np.array([[0.15, 0.12], [0.0, 1.0]])]  # unbatched  # batched
)
def test_posterior_and_eval_acqf(
    acqf_type: AcquisitionFunctionType,
    beta: float | None,
    botorch_acqf_gen: Callable[[Model, AcquisitionFunctionParams], Any],
    x: np.ndarray,
) -> None:
    n_dims = 2
    X = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.1]])
    Y = np.array([1.0, 2.0, 3.0])
    kernel_params = KernelParamsTensor(
        inverse_squared_lengthscales=torch.tensor([2.0, 3.0], dtype=torch.float64),
        kernel_scale=torch.tensor(4.0, dtype=torch.float64),
        noise_var=torch.tensor(0.1, dtype=torch.float64),
    )
    search_space = SearchSpace(
        scale_types=np.full(n_dims, ScaleType.LINEAR),
        bounds=np.array([[0.0, 1.0] * n_dims]),
        steps=np.zeros(n_dims),
    )

    acqf_params = create_acqf_params(
        acqf_type=acqf_type,
        kernel_params=kernel_params,
        search_space=search_space,
        X=X,
        Y=Y,
        beta=beta,
        acqf_stabilizing_noise=0.0,
    )

    x_tensor = torch.from_numpy(x)
    x_tensor.requires_grad_(True)

    prior_cov_fX_fX = kernel(
        torch.zeros(n_dims, dtype=torch.bool),
        kernel_params,
        torch.from_numpy(X),
        torch.from_numpy(X),
    )
    posterior_mean_fx, posterior_var_fx = posterior(
        kernel_params,
        torch.from_numpy(X),
        torch.zeros(n_dims, dtype=torch.bool),
        torch.from_numpy(acqf_params.cov_Y_Y_inv),
        torch.from_numpy(acqf_params.cov_Y_Y_inv_Y),
        torch.from_numpy(x),
    )

    acqf_value = eval_acqf(acqf_params, x_tensor)
    acqf_value.sum().backward()  # type: ignore
    acqf_grad = x_tensor.grad
    assert acqf_grad is not None

    gpytorch_likelihood = GaussianLikelihood()
    gpytorch_likelihood.noise_covar.noise = kernel_params.noise_var
    matern_kernel = MaternKernel(nu=2.5, ard_num_dims=n_dims)
    matern_kernel.lengthscale = kernel_params.inverse_squared_lengthscales.rsqrt()
    covar_module = ScaleKernel(matern_kernel)
    covar_module.outputscale = kernel_params.kernel_scale

    botorch_model = SingleTaskGP(
        train_X=torch.from_numpy(X),
        train_Y=torch.from_numpy(Y)[:, None],
        likelihood=gpytorch_likelihood,
        covar_module=covar_module,
        mean_module=ZeroMean(),
    )
    botorch_prior_fX = botorch_model(torch.from_numpy(X))
    assert torch.allclose(botorch_prior_fX.covariance_matrix, prior_cov_fX_fX)

    botorch_model.eval()

    botorch_acqf = botorch_acqf_gen(botorch_model, acqf_params)

    x_tensor = torch.from_numpy(x)
    x_tensor.requires_grad_(True)
    botorch_posterior_fx = botorch_model.posterior(x_tensor[..., None, :])
    assert torch.allclose(posterior_mean_fx, botorch_posterior_fx.mean[..., 0, 0])
    assert torch.allclose(posterior_var_fx, botorch_posterior_fx.variance[..., 0, 0])

    botorch_acqf_value = botorch_acqf(x_tensor[..., None, :])
    botorch_acqf_value.sum().backward()  # type: ignore
    botorch_acqf_grad = x_tensor.grad
    assert botorch_acqf_grad is not None
    assert torch.allclose(acqf_value, botorch_acqf_value)
    assert torch.allclose(acqf_grad, botorch_acqf_grad)

Copy link

codecov bot commented Feb 2, 2024

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (f831ee6) 89.49% compared to head (e0ab667) 89.00%.
Report is 13 commits behind head on master.

Files Patch % Lines
optuna/_gp/acqf.py 94.11% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5224      +/-   ##
==========================================
- Coverage   89.49%   89.00%   -0.50%     
==========================================
  Files         213      213              
  Lines       15478    14532     -946     
==========================================
- Hits        13852    12934     -918     
+ Misses       1626     1598      -28     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@y0z y0z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave a suggestion.

optuna/_gp/acqf.py Outdated Show resolved Hide resolved
Co-authored-by: Yoshihiko Ozaki <30489874+y0z@users.noreply.github.com>
LOG_EI = 0
UCB = 1


@dataclass(frozen=True)
class AcquisitionFunctionParams:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do something like this?

@dataclass(frozen=True)
class BaseAcquisitionFunction:
    kernel_params: KernelParamsTensor
    X: np.ndarray
    search_space: SearchSpace
    cov_Y_Y_inv: np.ndarray
    cov_Y_Y_inv_Y: np.ndarray
    acqf_stabilizing_noise: float

    def _posterior(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        is_categorical = torch.from_numpy(self.search_space.scale_types == ScaleType.CATEGORICAL)
        cov_fx_fX = torch.from_numpy(kernel(is_categorical, self.kernel_params, x[..., None, :], self.X)[..., 0, :])
        cov_fx_fx = torch.from_numpy(kernel_at_zero_distance(self.kernel_params))
        cov_Y_Y_inv = torch.from_numpy(self.cov_Y_Y_inv)
        cov_Y_Y_inv_Y = torch.from_numpy(self.cov_Y_Y_inv_Y)       

        # mean = cov_fx_fX @ inv(cov_fX_fX + noise * I) @ Y
        # var = cov_fx_fx - cov_fx_fX @ inv(cov_fX_fX + noise * I) @ cov_fx_fX.T

        mean = cov_fx_fX @ cov_Y_Y_inv_Y  # [batch]
        var = cov_fx_fx - (cov_fx_fX * (cov_fx_fX @ cov_Y_Y_inv)).sum(dim=-1)  # [batch]
        # We need to clamp the variance to avoid negative values due to numerical errors.
        return (mean, torch.clamp(var, min=0.0))

    def compute_with_no_grad(self, x: np.ndarray) -> np.ndarray:
        with torch.no_grad():
            return self.compute(torch.from_numpy(x)).detach().numpy()

    @abstractmethod
    def compute(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


@dataclass(frozen=True)
class UCB(BaseAcquisitionFunction):
    beta: float

    def compute(self, x: torch.Tensor) -> torch.Tensor:
        mean, var = self._posterior(x)
        return mean + torch.sqrt(self.beta * var)

@dataclass(frozen=True)
class LogEI(BaseAcquisitionFunction):
    max_Y: float

    def compute(self, x: torch.Tensor) -> torch.Tensor:
        # Return E_{y ~ N(mean, var)}[max(0, y-max_Y]
        mean, var = self._posterior(x)
        sigma = torch.sqrt(var)
        st_val = standard_logei((mean - self.max_Y) / sigma)
        return torch.log(sigma) + st_val

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numba cannot support polymorphism like this.
If I wouldn't be using numba, I would make it even simpler, with everything captured in a callable function.

Copy link
Member

@y0z y0z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@y0z y0z removed their assignment Feb 7, 2024
Copy link
Collaborator

@nabenabe0928 nabenabe0928 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove the default value here?

minimum_noise: float = 0.0,

optuna/_gp/acqf.py Show resolved Hide resolved
optuna/_gp/acqf.py Outdated Show resolved Hide resolved
contramundum53 and others added 3 commits February 8, 2024 13:23
Co-authored-by: Shuhei Watanabe <47781922+nabenabe0928@users.noreply.github.com>
Co-authored-by: Shuhei Watanabe <47781922+nabenabe0928@users.noreply.github.com>
@nabenabe0928 nabenabe0928 removed their assignment Feb 8, 2024
Copy link
Collaborator

@nabenabe0928 nabenabe0928 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the changes, LGTM!

@nabenabe0928 nabenabe0928 merged commit 01a2df1 into optuna:master Feb 8, 2024
22 checks passed
@nabenabe0928 nabenabe0928 added this to the v3.6.0 milestone Feb 8, 2024
@contramundum53 contramundum53 added the enhancement Change that does not break compatibility and not affect public interfaces, but improves performance. label Feb 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Change that does not break compatibility and not affect public interfaces, but improves performance. optuna.samplers Related to the `optuna.samplers` submodule. This is automatically labeled by github-actions.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants