Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 45 additions & 91 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from gpytorch import settings
from gpytorch.constraints import GreaterThan, Positive
from gpytorch.constraints import Positive
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels.rbf_kernel import RBFKernel
from gpytorch.lazy.added_diag_lazy_tensor import AddedDiagLazyTensor
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.lazy.lazy_tensor import LazyTensor
from gpytorch.likelihoods.noise_models import HomoskedasticNoise
from gpytorch.means.constant_mean import ConstantMean
from gpytorch.mlls import MarginalLogLikelihood
from gpytorch.models.gp import GP
Expand All @@ -53,29 +52,31 @@ class PairwiseGP(Model, GP):

Implementation is based on [Chu2005preference]_.
Also see [Brochu2010tutorial]_ for additional reference.

Note that in [Chu2005preference]_ the likelihood of a pairwise comparison
is :math:`\left(\frac{f(x_1) - f(x_2)}{\sqrt{2}\sigma}\right)`, i.e. a scale is
used in the denominator. To maintain consistency with usage of kernels
elsewhere in botorch, we instead do not include :math:`\sigma` in the code
(implicitly setting it to 1) and use ScaleKernel to scale the function.
"""

def __init__(
self,
datapoints: Tensor,
comparisons: Tensor,
covar_module: Optional[Module] = None,
noise_module: Optional[HomoskedasticNoise] = None,
**kwargs,
) -> None:
super().__init__()
r"""A probit-likelihood GP with Laplace approximation model.

A probit-likelihood GP with Laplace approximation model that learns via
pairwise comparison data. By default it uses a scaled-RBF kernel.
r"""A probit-likelihood GP with Laplace approximation model that learns via
pairwise comparison data. By default it uses a scaled RBF kernel.

Args:
datapoints: A `batch_shape x n x d` tensor of training features.
comparisons: A `batch_shape x m x 2` training comparisons;
comparisons[i] is a noisy indicator suggesting the utility value
of comparisons[i, 0]-th is greater than comparisons[i, 1]-th.
covar_module: Covariance module
noise_module: Noise module
"""

# Compatibility variables with fit_gpytorch_*: Dummy likelihood
Expand Down Expand Up @@ -132,26 +133,24 @@ def __init__(
param.requires_grad = False

# set covariance module
if noise_module is None:
noise_module = HomoskedasticNoise(
noise_prior=SmoothedBoxPrior(-5, 5, 0.5, transform=torch.log),
noise_constraint=GreaterThan(1e-4), # if None, 1e-4 by default
batch_shape=self._input_batch_shape,
)
self.noise_module = noise_module

# set covariance module
# the default outputscale here is only a rule of thumb, meant to keep
# estimates away from scale value that would make Phi(f(x)) saturate
# at 0 or 1
if covar_module is None:
ls_prior = GammaPrior(1.2, 0.5)
ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate
covar_module = RBFKernel(
batch_shape=self._input_batch_shape,
ard_num_dims=self.dim,
lengthscale_prior=ls_prior,
lengthscale_constraint=Positive(
transform=None, initial_value=ls_prior_mode
covar_module = ScaleKernel(
RBFKernel(
batch_shape=self._input_batch_shape,
ard_num_dims=self.dim,
lengthscale_prior=ls_prior,
lengthscale_constraint=Positive(
transform=None, initial_value=ls_prior_mode
),
),
outputscale_prior=SmoothedBoxPrior(a=1, b=4),
)

self.covar_module = covar_module

self._x0 = None # will store temporary results for warm-starting
Expand Down Expand Up @@ -191,14 +190,6 @@ def __deepcopy__(self, memo) -> PairwiseGP:
self.__deepcopy__ = dcp
return new_model

@property
def std_noise(self) -> Tensor:
return self.noise_module.noise

@std_noise.setter
def std_noise(self, value: Tensor) -> None:
self.noise_module.initialize(noise=value)

@property
def num_outputs(self) -> int:
r"""The number of outputs of the model."""
Expand All @@ -212,16 +203,10 @@ def _has_no_data(self):
or self.comparisons is None
)

def _calc_covar(
self, X1: Tensor, X2: Tensor, observation_noise: bool = False
) -> Union[Tensor, LazyTensor]:
def _calc_covar(self, X1: Tensor, X2: Tensor) -> Union[Tensor, LazyTensor]:
r"""Calculate the covariance matrix given two sets of datapoints"""
X1, X2 = X1, X2
covar = self.covar_module(X1, X2)
if observation_noise:
noise_shape = self._input_batch_shape + self.covar.shape[-1:]
noise = self.noise_module(shape=noise_shape)
covar = AddedDiagLazyTensor(covar, noise)
return covar.evaluate()

def _batch_chol_inv(self, mat_chol: Tensor) -> Tensor:
Expand Down Expand Up @@ -302,7 +287,7 @@ def _add_jitter(self, X: Tensor) -> Tensor:
return X

def _calc_z(
self, utility: Tensor, D: Tensor, std_noise: Tensor
self, utility: Tensor, D: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""Calculate z score.

Expand All @@ -311,7 +296,6 @@ def _calc_z(
Args:
utility: A Tensor of shape `batch_size x n`, the utility at MAP point
D: as in self.D
std_noise: comparison noise (as a Tensor since it's a pytorch param)

Returns:
z: z score calculated as in [Chu2005preference]_.
Expand All @@ -320,7 +304,7 @@ def _calc_z(
hazard: hazard function defined as pdf(z)/cdf(z)
"""

scaled_util = (utility / (math.sqrt(2) * std_noise)).unsqueeze(-1).to(D)
scaled_util = (utility / math.sqrt(2)).unsqueeze(-1).to(D)
z = (D @ scaled_util).squeeze(-1)
std_norm = torch.distributions.normal.Normal(
torch.zeros(1, dtype=z.dtype, device=z.device),
Expand All @@ -336,9 +320,7 @@ def _calc_z(
hazard = torch.exp(z_logpdf - z_logcdf)
return z, z_logpdf, z_logcdf, hazard

def _grad_likelihood_f_sum(
self, utility: Tensor, D: Tensor, std_noise: Tensor
) -> Tensor:
def _grad_likelihood_f_sum(self, utility: Tensor, D: Tensor) -> Tensor:
r"""Compute the sum over of grad. of negative Log-LH wrt utility f.
Original grad should be of dimension m x n, as in (6) from [Chu2005preference]_.
Sum over the first dimension and return a tensor of shape n
Expand All @@ -347,20 +329,17 @@ def _grad_likelihood_f_sum(
Args:
utility: A Tensor of shape `batch_size x n`
D: A Tensor of shape `batch_size x m x n` as in self.D
std_noise: A Tensor of shape `batch_size x 1`, as in self.std_noise

Returns:
The sum over the first dimension of grad. of negative Log-LH wrt utility f
"""
_, _, _, h = self._calc_z(utility, D, std_noise)
h_factor = (h / (math.sqrt(2) * std_noise)).unsqueeze(-2)
_, _, _, h = self._calc_z(utility, D)
h_factor = (h / math.sqrt(2)).unsqueeze(-2)
grad = (h_factor @ (-D)).squeeze(-2)

return grad

def _hess_likelihood_f_sum(
self, utility: Tensor, D: Tensor, DT: Tensor, std_noise: Tensor
) -> Tensor:
def _hess_likelihood_f_sum(self, utility: Tensor, D: Tensor, DT: Tensor) -> Tensor:
r"""Compute the sum over of hessian of neg. Log-LH wrt utility f.

Original hess should be of dimension m x n x n, as in (7) from
Expand All @@ -371,13 +350,12 @@ def _hess_likelihood_f_sum(
utility: A Tensor of shape `batch_size x n`
D: A Tensor of shape `batch_size x m x n` as in self.D
DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT
std_noise: A Tensor of shape `batch_size x 1`, as in self.std_noise

Returns:
The sum over the first dimension of hess. of negative Log-LH wrt utility f
"""
z, _, _, h = self._calc_z(utility, D, std_noise)
mul_factor = h * (h + z) / (2 * (std_noise ** 2))
z, _, _, h = self._calc_z(utility, D)
mul_factor = h * (h + z) / 2
weighted_DT = DT * mul_factor.unsqueeze(-2).expand(*DT.size())
hess = weighted_DT @ D

Expand All @@ -389,7 +367,6 @@ def _grad_posterior_f(
datapoints: Tensor,
D: Tensor,
DT: Tensor,
std_noise: Tensor,
covar_chol: Tensor,
covar_inv: Tensor,
ret_np: bool = False,
Expand All @@ -405,7 +382,6 @@ def _grad_posterior_f(
datapoints: A Tensor of shape `batch_size x n x d` as in self.datapoints
D: A Tensor of shape `batch_size x m x n` as in self.D
DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT
std_noise: A Tensor of shape `batch_size x 1`, as in self.std_noise
covar_chol: A Tensor of shape `batch_size x n x n`, as in self.covar_chol
covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv
ret_np: return a numpy array if true, otherwise a Tensor
Expand All @@ -416,7 +392,7 @@ def _grad_posterior_f(
utility = torch.tensor(utility, dtype=self.datapoints.dtype)
prior_mean = prior_mean.cpu()

b = self._grad_likelihood_f_sum(utility, D, std_noise)
b = self._grad_likelihood_f_sum(utility, D)

# g_ = covar_inv x (utility - pred_prior)
p = (utility - prior_mean).unsqueeze(-1).to(covar_chol)
Expand All @@ -434,7 +410,6 @@ def _hess_posterior_f(
datapoints: Tensor,
D: Tensor,
DT: Tensor,
std_noise: Tensor,
covar_chol: Tensor,
covar_inv: Tensor,
ret_np: bool = False,
Expand All @@ -450,15 +425,14 @@ def _hess_posterior_f(
datapoints: A Tensor of shape `batch_size x n x d` as in self.datapoints
D: A Tensor of shape `batch_size x m x n` as in self.D
DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT
std_noise: A Tensor of shape `batch_size x 1`, as in self.std_noise
covar_chol: A Tensor of shape `batch_size x n x n`, as in self.covar_chol
covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv
ret_np: return a numpy array if true, otherwise a Tensor
"""
if ret_np:
utility = torch.tensor(utility, dtype=self.datapoints.dtype)

hl = self._hess_likelihood_f_sum(utility, D, DT, std_noise)
hl = self._hess_likelihood_f_sum(utility, D, DT)
hess = hl + covar_inv
return hess.numpy() if ret_np else hess

Expand All @@ -470,7 +444,7 @@ def _posterior_f(self, utility: Union[Tensor, np.ndarray]) -> Tensor:
Args:
utility: A Tensor of shape `batch_size x n`
"""
_, _, z_logcdf, _ = self._calc_z(utility, self.D, self.std_noise)
_, _, z_logcdf, _ = self._calc_z(utility, self.D)
loss1 = -(torch.sum(z_logcdf, dim=-1))
inv_prod = torch.cholesky_solve(utility.unsqueeze(-1), self.covar_chol)
loss2 = 0.5 * (utility.unsqueeze(-2) @ inv_prod).squeeze(-1).squeeze(-1)
Expand Down Expand Up @@ -536,23 +510,11 @@ def _update(self, **kwargs) -> None:
dp_v = self.datapoints.view(-1, self.n, self.dim).cpu()
D_v = self.D.view(-1, self.m, self.n).cpu()
DT_v = self.DT.view(-1, self.n, self.m).cpu()
# Use `expand` here since we need to expand std_noise along
# the batch shape dimensions if we start off as non-batch model,
# but later conditioned on batched new data
sn_v = self.std_noise.expand(*init_x0_size[:-1], 1).reshape(-1).cpu()
ch_v = self.covar_chol.view(-1, self.n, self.n).cpu()
ci_v = self.covar_inv.view(-1, self.n, self.n).cpu()
x = np.empty(x0.shape)
for i in range(x0.shape[0]):
fsolve_args = (
dp_v[i],
D_v[i],
DT_v[i],
sn_v[i],
ch_v[i],
ci_v[i],
True,
)
fsolve_args = (dp_v[i], D_v[i], DT_v[i], ch_v[i], ci_v[i], True)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
x[i] = optimize.fsolve(
Expand All @@ -571,7 +533,6 @@ def _update(self, **kwargs) -> None:
self.datapoints.cpu(),
self.D.cpu(),
self.DT.cpu(),
self.std_noise.cpu(),
self.covar_chol.cpu(),
self.covar_inv.cpu(),
True,
Expand All @@ -597,9 +558,7 @@ def _update(self, **kwargs) -> None:
# when calling forward() in order to obtain correct gradients
# self.likelihood_hess is updated here is for the rare case where we
# do not want to call forward()
self.likelihood_hess = self._hess_likelihood_f_sum(
f, self.D, self.DT, self.std_noise
)
self.likelihood_hess = self._hess_likelihood_f_sum(f, self.D, self.DT)

# Lazy update hlcov_eye, which is used in calculating posterior during training
self.pred_cov_fac_need_update = True
Expand Down Expand Up @@ -655,11 +614,10 @@ def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor:
finishing `max_iter` updates
"""
xtol = float("-Inf") if xtol is None else xtol
dp, D, DT, sn, ch, ci = (
dp, D, DT, ch, ci = (
self.datapoints,
self.D,
self.DT,
self.std_noise,
self.covar_chol,
self.covar_inv,
)
Expand All @@ -669,7 +627,7 @@ def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor:
x = x0
eye = None
while i < max_iter and diff > xtol:
hl = self._hess_likelihood_f_sum(x, D, DT, sn)
hl = self._hess_likelihood_f_sum(x, D, DT)
cov_hl = covar @ hl
if eye is None:
eye = torch.eye(
Expand All @@ -678,7 +636,7 @@ def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor:
device=self.datapoints.device,
).expand(cov_hl.shape)
cov_hl = cov_hl + eye # add 1 to cov_hl
g = self._grad_posterior_f(x, dp, D, DT, sn, ch, ci)
g = self._grad_posterior_f(x, dp, D, DT, ch, ci)
cov_g = covar @ g.unsqueeze(-1)
x_update = torch.solve(cov_g, cov_hl).solution.squeeze(-1)
x_next = x - x_update
Expand Down Expand Up @@ -789,7 +747,7 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
self.utility = self._util_newton_updates(self.utility, max_iter=1)

hl = self.likelihood_hess = self._hess_likelihood_f_sum(
self.utility, self.D, self.DT, self.std_noise
self.utility, self.D, self.DT
)
covar = self.covar
# Apply matrix inversion lemma on eq. in page 27 of [Brochu2010tutorial]_
Expand Down Expand Up @@ -883,11 +841,12 @@ def posterior(
X: A `batch_shape x q x d`-dim Tensor, where `d` is the dimension
of the feature space and `q` is the number of points considered jointly.
output_indices: As defined in parent Model class, not used for this model.
observation_noise: If True, add observation noise to the posterior.
observation_noise: Ignored (since noise is not identifiable from scale
in probit models).

Returns:
A `Posterior` object, representing joint
distributions over `q` points. Includes observation noise if specified.
distributions over `q` points.
"""
self.eval() # make sure model is in eval mode

Expand All @@ -899,10 +858,6 @@ def posterior(

post = self(X)

if observation_noise:
noise_module = self.noise_module(shape=post.mean.shape).evaluate()
post = MultivariateNormal(post.mean, post.covariance_matrix + noise_module)

return GPyTorchPosterior(post)

def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
Expand All @@ -918,8 +873,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode

Returns:
A (deepcopied) `Model` object of the same type, representing the
original model conditioned on the new observations `(X, Y)` (and
possibly noise observations passed in via kwargs).
original model conditioned on the new observations `(X, Y)`.
"""
new_model = deepcopy(self)
if self._has_no_data():
Expand Down
Loading