Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why do parameters still change if the linesearch fails? #9

Closed
KeAWang opened this issue Feb 5, 2019 · 2 comments
Closed

Why do parameters still change if the linesearch fails? #9

KeAWang opened this issue Feb 5, 2019 · 2 comments

Comments

@KeAWang
Copy link

KeAWang commented Feb 5, 2019

Using the example code for GP regression (updated to use the master branch of gpytorch) and manually setting max_ls=1 and lr=0.01 to force linesearch failure:

import math
import torch
import gpytorch
from matplotlib import pyplot as plt

sys.path.append('../../../PyTorch-LBFGS/functions/')
from LBFGS import FullBatchLBFGS

# Training data is 11 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 100)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use full-batch L-BFGS optimizer
optimizer = FullBatchLBFGS(model.parameters(), lr=0.1)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

# define closure
def closure():
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    return loss

loss = closure()
loss.backward()

training_iter = 10
for i in range(training_iter):

    # perform step and update curvature
    options = {'closure': closure, 'current_loss': loss, 'max_ls': 1}
    loss, _, lr, _, F_eval, G_eval, _, _ = optimizer.step(options)

    print('Iter %d/%d - Loss: %.16f - LR: %.3f - Func Evals: %0.0f - Grad Evals: %0.0f - Raw-Lengthscale: %.16f - Raw_Noise: %.16f' % (
        i + 1, training_iter, loss.item(), lr, F_eval, G_eval,
        model.covar_module.base_kernel.raw_lengthscale.item(),
        model.likelihood.raw_noise.item()
        ))

I get the following output

Iter 1/10 - Loss: 0.9470608234405518 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000004039440 - Raw_Noise: 0.0000000014626897 - Fail: True
Iter 2/10 - Loss: 0.9470605254173279 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000004354150 - Raw_Noise: 0.0000000029325762 - Fail: True
Iter 3/10 - Loss: 0.9470604658126831 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000005869125 - Raw_Noise: 0.0000000007250947 - Fail: True
Iter 4/10 - Loss: 0.9470607042312622 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000006170258 - Raw_Noise: 0.0000000007257976 - Fail: True
Iter 5/10 - Loss: 0.9470604062080383 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000004235302 - Raw_Noise: 0.0000000022039979 - Fail: True
Iter 6/10 - Loss: 0.9470605254173279 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000002037677 - Raw_Noise: 0.0000000036952374 - Fail: True
Iter 7/10 - Loss: 0.9470604062080383 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000005822218 - Raw_Noise: 0.0000000052020050 - Fail: True
Iter 8/10 - Loss: 0.9470604062080383 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000008010645 - Raw_Noise: 0.0000000066766215 - Fail: True
Iter 9/10 - Loss: 0.9470604658126831 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000005968881 - Raw_Noise: 0.0000000059220691 - Fail: True
Iter 10/10 - Loss: 0.9470604658126831 - LR: 0.000 - Func Evals: 3 - Grad Evals: 2 - Raw-Lengthscale: -0.0000000002212988 - Raw_Noise: 0.0000000066797297 - Fail: True

Clearly the linesearch is failing and the lr is set to 0. But why are the parameters still changing?

@hjmshi
Copy link
Owner

hjmshi commented Feb 5, 2019

Hi Alex,

Is the inplace option set to True? (It is on by default.)

If the inplace option is set to True, then the algorithm only tracks the search direction and current and previous steplengths, so it attempts to recover the original iterate by updating the parameters by -alpha_k p_k. This may introduce numerical error which will lead to a slight perturbation of the original iterate. (This was designed with neural networks in mind where storing and constantly reloading the original set of parameters may not be ideal.)

If you set inplace to be False, then the algorithm will store the original parameters (the current iterate) and reload those parameters at every line search iteration.

Let me know if this resolves your issue!

@KeAWang
Copy link
Author

KeAWang commented Feb 6, 2019

Ah I see. Thank you so much for your clarification! Indeed the parameters no longer change if inplace is False.

@hjmshi hjmshi closed this as completed Feb 6, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants