Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Input Transformations #1652

Open
Balandat opened this issue Jun 11, 2021 · 5 comments
Open

[Feature Request] Input Transformations #1652

Balandat opened this issue Jun 11, 2021 · 5 comments

Comments

@Balandat
Copy link
Collaborator

🚀 Feature Request

Essentially, upstream BoTorch's InputTransformation from https://github.com/pytorch/botorch/blob/master/botorch/models/transforms/input.py to GPyTorch. This allows to add either fixed or learnable transformations that are automatically applied when training models.

Motivation

This allows to do things like normalize inputs but also to combine the GP with a learnable transformation. This simplifies model setup. We currently have this in BoTorch and essentially apply the transform in the forward methods.

Additional context

We recently worked on having input transformations that can change the shape of the input pytorch/botorch#819, which caused some headaches for how to best set this up without a ton of boilerplate code. We were hoping to do this in the __call__ rather than forward method, but this collides with some of GPyTorch's assumptions. Moving this functionality upstream into gpytorch would allow us to solve these challenges more organically.

Describe alternatives you've considered
You could do this as we do it right now, but one would have to add boilerplate transformation code to every implementation of forward.

@Balandat
Copy link
Collaborator Author

cc @sdaulton @saitcakmak

@saitcakmak
Copy link
Collaborator

saitcakmak commented Jun 11, 2021

I'll just add concrete examples of what we tried and where it fails.

Apply one-to-many transforms at model.forward(): For a batch x q x d-dim input these return batch x new_q x m-dim output. This fails at the reshape operation here:

train_labels = train_labels.reshape(*train_labels.shape[: -len(self.train_shape)], self._train_shape.numel())

Define a wrapper around ExactGP.__call__ and apply the transforms before calling ExactGP.__call__: Leads to

raise RuntimeError("You must train on the training inputs!")

We could maybe get around this by wrapping the call with with debug(False), but that also breaks tests, so probably not a good idea.

Current proposal at pytorch/botorch#819 is to apply the transforms at model.forward at the training time and at posterior call in eval mode to get around these issues (and do this at each models forward and each posterior). This is admittedly not a good design, so upstreaming the input transforms would make it much cleaner.
Edit: Just realized that this actually breaks things. So, we don't have a proper way of applying one-to-many transforms currently.
Edit 2: With some changes to input transforms (storing train inputs as transformed), applying them in both model.forward and posterior now works.

@gpleiss
Copy link
Member

gpleiss commented Jun 14, 2021

@Balandat and @saitcakmak is this the interface you have in mind?

class MyGP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        # ...
        self.transform = InputTransform(...)

    def forward(self, x):
        x = self.transform(x)
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

@saitcakmak
Copy link
Collaborator

So, that is the current (pre pytorch/botorch#819) interface on botorch end, and there are some issues with that. It does not play nicely with transforms that change the shape of x. We also want certain transforms to apply to all inputs, including train_inputs, and others to apply only to test inputs. This interface doesn't work well there if you only wanted the transform to apply to test inputs since the forward is called after concatenating them with train_inputs. Because of these issues, I think forward is not the right place to apply the transforms.

I think ExactGP.__call__ (similarly ApproximateGP.__call__) is a better place to apply the transforms. I am thinking something along the lines of

    def __call__(self, *args, **kwargs):
        train_inputs = list(self.transform_inputs(self.train_inputs, train=True)) if self.train_inputs is not None else []
        inputs = [self.transform_inputs(i.unsqueeze(-1) if i.ndimension() == 1 else i, train=False) for i in args]
        ...

where we can handle the train argument along with transform_on_train attribute of the input transform. self.transform_inputs(...) here is a simple wrapper that calls the input transform if it exists, similar to
https://github.com/pytorch/botorch/blob/f930c01c7a15ea4db951b5346f3331422fcd04e4/botorch/models/model.py#L151-L168

@wjmaddox
Copy link
Collaborator

wjmaddox commented Aug 10, 2021

After some discussion with @saitcakmak, it looks like placing the input transforms in ApproximateGP.__call__ for variational GPs might be the only feasible option.

If it's done in the forwards pass, then the inducing points would be estimated on the raw, untransformed scale because of the call here.

I have a version of current botorch transforms + variational GPs in pytorch/botorch#895 but it only works for fixed, un-learnable transforms (e.g. not learnable input warping) because it forcibly sets the "training inputs" and thus the model inputs to be on the transformed scale, at least when using model fitting utilities such as botorch...fit_gpytorch_torch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants