Skip to content

[Question] Proper use of Normalize for inputs (possible bug?) #874

@lukasfro

Description

@lukasfro

Issue description

I'm struggling to understand how to properly use the input normalization for the SingleTaskGP (presumably the same issue arises for different models). As far as I know, the default kernel hyper-parameters are optimized for normalized inputs and standardized targets. Hence, I would expect to obtain reasonable posterior predictions when the inputs are normalized but somewhat unreasonable length scales for unnormalized inputs. However, I observed the exact opposite when running the small code example below. I'm unsure if I'm just using Normalize in a wrong manner, if the observed behaviour is intended, or if I'm just completely off track. Any help is much appreciated!

Code example

import torch
import matplotlib.pyplot as plt

from botorch.models.transforms import Normalize, Standardize
from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_scipy

import math


def f(x):
    return torch.sin(math.pi * x)


def train_gp(x, y, input_transform):
    gp = SingleTaskGP(
        train_X=x,
        train_Y=y,
        outcome_transform=Standardize(m=1),
        input_transform=input_transform,
    )
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    mll.train()
    fit_gpytorch_scipy(mll)
    mll.eval()

    return gp


def plot_gp(gp, x_plot, color, label):
    p = gp.posterior(x_plot)
    m = p.mean.squeeze().detach()
    s = p.variance.sqrt().squeeze().detach()

    plt.plot(x_plot, m, color, label=label)
    plt.fill_between(x_plot.squeeze(), m + s, m - s, color=color, alpha=0.3)


# Problem set up
d = 1
bounds = torch.tensor([[-2.0], [2.0]])
train_X = bounds[0] + (bounds[1] - bounds[0]) * torch.rand((10, 1))
train_Y = f(train_X)
train_Y += 0.1 * torch.randn_like(train_Y)
eval_X = torch.linspace(bounds[0].item(), bounds[1].item(), 500).unsqueeze(-1)

# Train 3 models with different input normalization
gp1 = train_gp(train_X, train_Y, Normalize(d=d, bounds=bounds))
gp2 = train_gp(train_X, train_Y, Normalize(d=d, bounds=bounds, transform_on_eval=False))
gp3 = train_gp(train_X, train_Y, None)

# Visualize
plt.figure()
plt.plot(eval_X, f(eval_X), "k")
plt.plot(train_X, train_Y, "ro")
plot_gp(gp1, eval_X, "C0", label="Normalize")
plot_gp(gp2, eval_X, "C1", label="Normalize(transform_on_eval=False)")
plot_gp(gp3, eval_X, "C2", label="No Normalization")
plt.legend()
plt.show()

Exemplary Output

example_output

System Info

Please provide information about your setup, including

  • BoTorch Version: 0.5.0
  • GPyTorch Version: 1.5.0
  • PyTorch Version: 1.8.1
  • Computer OS: macOS Big Sur (11.3.1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions