Skip to content

Commit

Permalink
Correct argument order for SGPR added loss term (#1670)
Browse files Browse the repository at this point in the history
[Fixes #1657]
  • Loading branch information
gpleiss authored Jun 23, 2021
1 parent aaba5fb commit cb712b2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gpytorch/mlls/inducing_point_kernel_added_loss_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@


class InducingPointKernelAddedLossTerm(AddedLossTerm):
def __init__(self, variational_dist, prior_dist, likelihood):
def __init__(self, prior_dist, variational_dist, likelihood):
self.prior_dist = prior_dist
self.variational_dist = variational_dist
self.likelihood = likelihood

def loss(self, *params):
prior_covar = self.prior_dist.lazy_covariance_matrix
variational_covar = self.variational_dist.lazy_covariance_matrix
diag = prior_covar.diag() - variational_covar.diag()
diag = variational_covar.diag() - prior_covar.diag()
shape = prior_covar.shape[:-1]
noise_diag = self.likelihood._shaped_noise_covar(shape, *params).diag()
return 0.5 * (diag / noise_diag).sum()

0 comments on commit cb712b2

Please sign in to comment.