[RFC] Move input transforms to GPyTorch #2114
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This diff presents a minimal implementation of input transforms in GPyTorch, as requested in #1652. This should be viewed together with pytorch/botorch#1372. The input transforms themselves are currently implemented in https://github.com/pytorch/botorch/blob/cdd668d18b2a7e35bed09b7a2b2fca40e5fd2067/botorch/models/transforms/input.py
What this does:
transform_inputs
from BoTorchModel
to GPyTorchGP
class, with some modifications to explicitly identify whether given inputs are train or test inputs.InputTransform.forward
call to useis_training_input
argument instead ofself.training
check to apply the transforms that havetransform_on_train=True
.preprocess_transform
method since this is no-longer needed.ExactGP
models, it transforms both train and test inputs in__call__
. Fortrain_inputs
it always usesis_training_input=True
. For genericinputs
, it usesis_training_input=self.training
which signals that these are training inputs when the model is intrain
mode, and that these are test inputs when the model is ineval
mode.ApproximateGP
models, it applies the transform toinputs
in__call__
usingis_training_input=self.training
. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transformsinducing_points
, thus fixes the previous bug withinducing_points
getting transformed intrain
but not getting transformed ineval
. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).SingleTaskVariationalGP
, it moves theinput_transform
attribute down to_SingleTaskVariationalGP
, which is the actualApproximateGP
instance. This makes the transform accessible from GPyTorch.What this doesn't do:
DeterministicModel
s. Those will still need to deal with their own transforms, which is not implemented here. If we makeModel
inherit fromGP
, we can keep the existing setup with very minimal changes.self.transform_inputs
. This is just made into a no-op and the clean-up is left for later.InputTransform
classes to GPyTorch. That'll be done if we decide to go forward with this design.PairwiseGP
.PairwiseGP
has some non-standard use of input transforms, so it needs an audit to make sure things still work fine.ApproximateGP.fantasize
. This may need some changes similar toExactGP.get_fantasy_model
.PyroGP
andDeepGP
.