-
Notifications
You must be signed in to change notification settings - Fork 450
Closed
Description
Issue description
I'm struggling to understand how to properly use the input normalization for the SingleTaskGP (presumably the same issue arises for different models). As far as I know, the default kernel hyper-parameters are optimized for normalized inputs and standardized targets. Hence, I would expect to obtain reasonable posterior predictions when the inputs are normalized but somewhat unreasonable length scales for unnormalized inputs. However, I observed the exact opposite when running the small code example below. I'm unsure if I'm just using Normalize in a wrong manner, if the observed behaviour is intended, or if I'm just completely off track. Any help is much appreciated!
Code example
import torch
import matplotlib.pyplot as plt
from botorch.models.transforms import Normalize, Standardize
from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_scipy
import math
def f(x):
return torch.sin(math.pi * x)
def train_gp(x, y, input_transform):
gp = SingleTaskGP(
train_X=x,
train_Y=y,
outcome_transform=Standardize(m=1),
input_transform=input_transform,
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
mll.train()
fit_gpytorch_scipy(mll)
mll.eval()
return gp
def plot_gp(gp, x_plot, color, label):
p = gp.posterior(x_plot)
m = p.mean.squeeze().detach()
s = p.variance.sqrt().squeeze().detach()
plt.plot(x_plot, m, color, label=label)
plt.fill_between(x_plot.squeeze(), m + s, m - s, color=color, alpha=0.3)
# Problem set up
d = 1
bounds = torch.tensor([[-2.0], [2.0]])
train_X = bounds[0] + (bounds[1] - bounds[0]) * torch.rand((10, 1))
train_Y = f(train_X)
train_Y += 0.1 * torch.randn_like(train_Y)
eval_X = torch.linspace(bounds[0].item(), bounds[1].item(), 500).unsqueeze(-1)
# Train 3 models with different input normalization
gp1 = train_gp(train_X, train_Y, Normalize(d=d, bounds=bounds))
gp2 = train_gp(train_X, train_Y, Normalize(d=d, bounds=bounds, transform_on_eval=False))
gp3 = train_gp(train_X, train_Y, None)
# Visualize
plt.figure()
plt.plot(eval_X, f(eval_X), "k")
plt.plot(train_X, train_Y, "ro")
plot_gp(gp1, eval_X, "C0", label="Normalize")
plot_gp(gp2, eval_X, "C1", label="Normalize(transform_on_eval=False)")
plot_gp(gp3, eval_X, "C2", label="No Normalization")
plt.legend()
plt.show()
Exemplary Output
System Info
Please provide information about your setup, including
- BoTorch Version: 0.5.0
- GPyTorch Version: 1.5.0
- PyTorch Version: 1.8.1
- Computer OS: macOS Big Sur (11.3.1)
Metadata
Metadata
Assignees
Labels
No labels
