Skip to content

Setting up a custom GPyTorch model for BoTorch #546

@r-ashwin

Description

@r-ashwin

If you are submitting a bug report or feature request, please use the respective
issue template.

Issue description

I am trying to use the MultiTaskGP model from GPyTorch with the BoTorch's qMaxValueEntropy. I get the UnsupportedError because the objective kwarg is not supported. See error below

`---------------------------------------------------------------------------

UnsupportedError                          Traceback (most recent call last)
<ipython-input-9-e910224785b8> in <module>
    223 candidate_set = torch.rand(size=[1000, 1]) # MES requires a candidate set
    224 from botorch.acquisition.objective import ScalarizedObjective
--> 225 qSMES = qScalarizedMES(model, candidate_set=candidate_set, weights=torch.tensor([1.,0.]))

<ipython-input-9-e910224785b8> in __init__(self, model, candidate_set, weights, num_fantasies, num_mv_samples, num_y_samples, use_gumbel, maximize, X_pending)
     65         """
     66         sampler = SobolQMCNormalSampler(num_y_samples)
---> 67         super().__init__(model=model, sampler=sampler)
     68 
     69         # Batch GP models (e.g. fantasized models) are not currently supported

~\Anaconda3\lib\site-packages\botorch\acquisition\monte_carlo.py in __init__(self, model, sampler, objective, X_pending)
     69             if model.num_outputs != 1:
     70                 raise UnsupportedError(
---> 71                     "Must specify an objective when using a multi-output model."
     72                 )
     73             objective = IdentityMCObjective()

UnsupportedError: Must specify an objective when using a multi-output model.`

## Code example
See code below to reproduce error

import torch
import gpytorch
import math
from matplotlib import cm
from matplotlib import pyplot as plt
import numpy as np
from botorch.models import MultiTaskGP
def test_1d(X):
    a = 16
    f = 1*X**2 + torch.sin(a*X)
    dfx = 1*2*X + a * torch.cos(a*X)
    return f, dfx
x = torch.linspace(0.15, .65, 5)
f, dfx = test_1d(x)
train_x = x.unsqueeze(-1)
train_y = torch.stack((f, dfx),dim=1)
print(train_x.size())
plt.plot(x.numpy(), f.numpy())
plt.plot(x.numpy(), dfx.numpy(), ls='--', c='gray')

from botorch.posteriors import GPyTorchPosterior
from gpytorch.distributions import MultitaskMultivariateNormal
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.likelihoods import MultitaskGaussianLikelihood

class GPModelWithDerivatives(gpytorch.models.ExactGP, GPyTorchModel):
    num_outputs = 2  # to inform GPyTorchModel API (only to interface with BoTorch)
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMeanGrad()
        self.base_kernel = gpytorch.kernels.RBFKernelGrad(ard_num_dims=1)
        self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
    
likelihood = MultitaskGaussianLikelihood(num_tasks=2)  # Value + x-derivative + y-derivative
model = GPModelWithDerivatives(train_x, train_y, likelihood)

# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iter = 2 if smoke_test else 500


# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam([
    {'params': model.parameters()},  # Includes GaussianLikelihood parameters
], lr=0.05)

# "Loss" for GPs - the marginal log likelihood
# likelihood.noise_covar.raw_noise_constraint.upper_bound = torch.tensor([1e-6, 1e-6])
likelihood.noise_covar.register_constraint("raw_noise", gpytorch.constraints.LessThan(1e-4) )
likelihood.noise_covar.register_constraint("raw_noise", gpytorch.constraints.GreaterThan(1e-8) )
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iter):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
#     print(loss.item())
    loss.backward()
#     print("Iter %d/%d - Loss: %.3f   lengthscales: %.3f noise: %.8f" % (
#         i + 1, training_iter, loss.item(),
#         model.covar_module.base_kernel.lengthscale.squeeze().item(),
#         model.likelihood.noise.squeeze().item()
#     ))
    optimizer.step()
print(model.likelihood.noise.squeeze())

from botorch.acquisition import MCAcquisitionFunction
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy
from botorch.acquisition.objective import ScalarizedObjective

# Scalarized MES
import math

from torch import Tensor
from typing import Optional

from botorch.acquisition import MCAcquisitionObjective
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.models.model import Model
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
# from botorch.utils import match_batch_shape, t_batch_mode_transform
from botorch.utils.transforms import match_batch_shape, t_batch_mode_transform

from botorch.models.utils import check_no_nans
from botorch.exceptions import UnsupportedError
CLAMP_LB = 1.0e-8

class qScalarizedMES(MCAcquisitionFunction):
    r"""The acquisition function for Max-value Entropy Search.

    This acquisition function computes the mutual information of
    max values and a candidate point X. See [Wang2018mves]_ for
    a detailed discussion.

    The model must be single-outcome.
    q > 1 is supported through cyclic optimization and fantasies.

    Example:
        >>> model = SingleTaskGP(train_X, train_Y)
        >>> candidate_set = torch.rand(1000, bounds.size(1))
        >>> candidate_set = bounds[0] + (bounds[1] - bounds[0]) * candidate_set
        >>> MES = qMaxValueEntropy(model, candidate_set)
        >>> mes = MES(test_X)
    """

    def __init__(
        self,
        model: Model,
        candidate_set: Tensor,
        weights: Tensor,
        num_fantasies: int = 16,
        num_mv_samples: int = 10,
        num_y_samples: int = 128,
        use_gumbel: bool = True,
        maximize: bool = True,
        X_pending: Optional[Tensor] = None,
    ) -> None:
        r"""Single-outcome max-value entropy search acquisition function.

        Args:
            model: A fitted single-outcome model.
            candidate_set: A `n x d` Tensor including `n` candidate points to
                discretize the design space. Max values are sampled from the
                (joint) model posterior over these points.
            num_fantasies: Number of fantasies to generate. The higher this
                number the more accurate the model (at the expense of model
                complexity, wall time and memory). Ignored if `X_pending` is `None`.
            num_mv_samples: Number of max value samples.
            num_y_samples: Number of posterior samples at specific design point `X`.
            use_gumbel: If True, use Gumbel approximation to sample the max values.
            X_pending: A `m x d`-dim Tensor of `m` design points that have been
                submitted for function evaluation but have not yet been evaluated.
            maximize: If True, consider the problem a maximization problem.
        """
        sampler = SobolQMCNormalSampler(num_y_samples)
        super().__init__(model=model, sampler=sampler)

        # Batch GP models (e.g. fantasized models) are not currently supported
        if self.model.train_inputs[0].ndim > 2:
            raise NotImplementedError(
                "Batch GP models (e.g. fantasized models) "
                "are not yet supported by qMaxValueEntropy"
            )

        self._init_model = model  # only used for the `fantasize()` in `set_X_pending()`
        train_inputs = match_batch_shape(model.train_inputs[0], candidate_set)
        self.candidate_set = torch.cat([candidate_set, train_inputs], dim=0)
        self.fantasies_sampler = SobolQMCNormalSampler(num_fantasies)
        self.num_fantasies = num_fantasies
        self.use_gumbel = use_gumbel
        self.num_mv_samples = num_mv_samples
        self.maximize = maximize
        self.weight = 1.0 if maximize else -1.0
        
        self.register_buffer("weights", torch.as_tensor(weights))

    @t_batch_mode_transform(expected_q=1)
    def forward(self, X: Tensor) -> Tensor:
        r"""Compute max-value entropy at the design points `X`.

        Args:
            X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
                with `1` `d`-dim design points each.

        Returns:
            A `batch_shape`-dim Tensor of MVE values at the given design points `X`.
        """
        # Compute the posterior, posterior mean, variance and std
        posterior = self.model.posterior(X.unsqueeze(-3), observation_noise=False)
        mean = self.weight * posterior.mean.squeeze(-1).squeeze(-1)
        # batch_shape x num_fantasies
        variance = posterior.variance.clamp_min(CLAMP_LB).view_as(mean)
        check_no_nans(mean)
        check_no_nans(variance)
        
        posterior = self.model.posterior(X)
        samples = self.sampler(posterior)  # n x b x q x o
        scalarized_samples = samples.matmul(self.weights)  # n x b x q
#         mean = posterior.mean  # b x q x o
        scalarized_mean = mean.matmul(self.weights)  # b x q
            
        ig = self._compute_information_gain(
            X=X, mean_M=scalarized_mean, variance_M=variance, covar_mM=variance.unsqueeze(-1)
        )

        return ig.mean(dim=0)  # average over the fantasies
    
    def _compute_information_gain(
        self, X: Tensor, mean_M: Tensor, variance_M: Tensor, covar_mM: Tensor
    ) -> Tensor:
        r"""Computes the information gain at the design points `X`.

        Approximately computes the information gain at the design points `X`,
        for both MES with noisy observations and multi-fidelity MES with noisy
        observation and trace observations.

        The implementation is inspired from the paper on multi-fidelity MES by
        Takeno et. al. [Takeno2019mfmves]_. The notations in the comments in this
        function follows the Appendix A in the paper.

        Args:
            X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
                with `1` `d`-dim design point each.
            mean_M, variance_M: `batch_shape x num_fantasies`-dim Tensors of
                `batch_shape` t-batches with `num_fantasies` fantasies.
                `num_fantasies = 1` for non-fantasized models.
                All are obtained without noise.
            covar_mM: `batch_shape x num_fantasies x (1 + num_trace_observations)`
                -dim Tensor. `num_fantasies = 1` for non-fantasized models.
                All are obtained without noise.

        Returns:
            A `num_fantasies x batch_shape`-dim Tensor of information gains at the
            given design points `X`.
        """

        # compute the std_m, variance_m with noisy observation
        posterior_m = self.model.posterior(X.unsqueeze(-3), observation_noise=True)
        mean_m = self.weight * posterior_m.mean.squeeze(-1)
        # batch_shape x num_fantasies x (1 + num_trace_observations)
        variance_m = posterior_m.mvn.covariance_matrix
        # batch_shape x num_fantasies x (1 + num_trace_observations)^2
        check_no_nans(variance_m)

        # compute mean and std for fM|ym, x, Dt ~ N(u, s^2)
        samples_m = self.weight * self.sampler(posterior_m).squeeze(-1)
        # s_m x batch_shape x num_fantasies x (1 + num_trace_observations)
        L = torch.cholesky(variance_m)
        temp_term = torch.cholesky_solve(covar_mM.unsqueeze(-1), L).transpose(-2, -1)
        # equivalent to torch.matmul(covar_mM.unsqueeze(-2), torch.inverse(variance_m))
        # batch_shape x num_fantasies x 1 x (1 + num_trace_observations)

        mean_pt1 = torch.matmul(temp_term, (samples_m - mean_m).unsqueeze(-1))
        mean_new = mean_pt1.squeeze(-1).squeeze(-1) + mean_M
        # s_m x batch_shape x num_fantasies
        variance_pt1 = torch.matmul(temp_term, covar_mM.unsqueeze(-1))
        variance_new = variance_M - variance_pt1.squeeze(-1).squeeze(-1)
        # batch_shape x num_fantasies
        stdv_new = variance_new.clamp_min(CLAMP_LB).sqrt()
        # batch_shape x num_fantasies

        # define normal distribution to compute cdf and pdf
        normal = torch.distributions.Normal(
            torch.zeros(1, device=X.device, dtype=X.dtype),
            torch.ones(1, device=X.device, dtype=X.dtype),
        )

        # Compute p(fM <= f* | ym, x, Dt)
        view_shape = (
            [self.num_mv_samples] + [1] * (len(X.shape) - 2) + [self.num_fantasies]
        )  # s_M x batch_shape x num_fantasies
        if self.X_pending is None:
            view_shape[-1] = 1
        max_vals = self.posterior_max_values.view(view_shape).unsqueeze(1)
        # s_M x 1 x batch_shape x num_fantasies
        normalized_mvs_new = (max_vals - mean_new) / stdv_new
        # s_M x s_m x batch_shape x num_fantasies =
        # s_M x 1 x batch_shape x num_fantasies - s_m x batch_shape x num_fantasies
        cdf_mvs_new = normal.cdf(normalized_mvs_new).clamp_min(CLAMP_LB)

        # Compute p(fM <= f* | x, Dt)
        stdv_M = variance_M.sqrt()
        normalized_mvs = (max_vals - mean_M) / stdv_M
        # s_M x 1 x batch_shape x num_fantasies =
        # s_M x 1 x 1 x num_fantasies - batch_shape x num_fantasies
        cdf_mvs = normal.cdf(normalized_mvs).clamp_min(CLAMP_LB)
        # s_M x 1 x batch_shape x num_fantasies

        # Compute log(p(ym | x, Dt))
        log_pdf_fm = posterior_m.mvn.log_prob(self.weight * samples_m).unsqueeze(0)
        # 1 x s_m x batch_shape x num_fantasies

        # H0 = H(ym | x, Dt)
        H0 = posterior_m.mvn.entropy()  # batch_shape x num_fantasies

        # regression adjusted H1 estimation, H1_hat = H1_bar - beta * (H0_bar - H0)
        # H1 = E_{f*|x, Dt}[H(ym|f*, x, Dt)]
        Z = cdf_mvs_new / cdf_mvs  # s_M x s_m x batch_shape x num_fantasies
        h1 = -Z * Z.log() - Z * log_pdf_fm  # s_M x s_m x batch_shape x num_fantasies
        check_no_nans(h1)
        dim = [0, 1]  # dimension of fm samples, fM samples
        H1_bar = h1.mean(dim=dim)
        h0 = -log_pdf_fm
        H0_bar = h0.mean(dim=dim)
        cov = ((h1 - H1_bar) * (h0 - H0_bar)).mean(dim=dim)
        beta = cov / (h0.var(dim=dim) * h1.var(dim=dim)).sqrt()
        H1_hat = H1_bar - beta * (H0_bar - H0)
        ig = H0 - H1_hat  # batch_shape x num_fantasies
        ig = ig.permute(-1, *range(ig.dim() - 1))  # num_fantasies x batch_shape
        return ig
    
candidate_set = torch.rand(size=[1000, 1]) # MES requires a candidate set
from botorch.acquisition.objective import ScalarizedObjective
qSMES = qScalarizedMES(model, candidate_set=candidate_set, weights=torch.tensor([1.,0.]))

`

System Info

Please provide information about your setup, including

  • BoTorch Version 0.2.5
  • GPyTorch Version 1.1.1
  • PyTorch Version 1.5.0+cpu
  • Computer OS windows

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions