Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Move input transforms to GPyTorch #2114

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gpytorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,5 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
def __call__(self, inputs, prior=False, **kwargs):
if inputs.dim() == 1:
inputs = inputs.unsqueeze(-1)
inputs = self.apply_input_transforms(X=inputs, is_training_input=self.training)
return self.variational_strategy(inputs, prior=prior, **kwargs)
21 changes: 17 additions & 4 deletions gpytorch/models/exact_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
except KeyError:
fantasy_kwargs = {}

full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
# Prediction strategy should have transformed train inputs.
prediction_strategy_inputs = [
self.apply_input_transforms(X=t_input, is_training_input=True) for t_input in full_inputs
]
full_output = super(ExactGP, self).__call__(*prediction_strategy_inputs, **kwargs)

# Copy model without copying training data or prediction strategy (since we'll overwrite those)
old_pred_strat = self.prediction_strategy
Expand All @@ -229,7 +233,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):

new_model.likelihood = old_likelihood.get_fantasy_likelihood(**fantasy_kwargs)
new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
inputs, targets, full_inputs, full_targets, full_output, **fantasy_kwargs
inputs, targets, prediction_strategy_inputs, full_targets, full_output, **fantasy_kwargs
)

# if the fantasies are at the same points, we need to expand the inputs for the new model
Expand All @@ -242,8 +246,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
return new_model

def __call__(self, *args, **kwargs):
train_inputs = list(self.train_inputs) if self.train_inputs is not None else []
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args]
train_inputs = (
[self.apply_input_transforms(X=t_input, is_training_input=True) for t_input in self.train_inputs]
if self.train_inputs is not None
else []
)
inputs = [
self.apply_input_transforms(
X=i.unsqueeze(-1) if i.ndimension() == 1 else i, is_training_input=self.training
)
for i in args
]

# Training mode: optimizing
if self.training:
Expand Down
9 changes: 8 additions & 1 deletion gpytorch/models/gp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
#!/usr/bin/env python3

from torch import Tensor

from ..module import Module


class GP(Module):
pass
def apply_input_transforms(self, X: Tensor, is_training_input: bool) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not name this arg is_training rather than is_training_input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just being verbose and differentiating the model being in training and the inputs being the train_inputs or otherwise being treated as such.

input_transform = getattr(self, "input_transform", None)
if input_transform is not None:
return input_transform(X=X, is_training_input=is_training_input)
else:
return X