Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KernelLinearOperator, deprecate KeOpsLinearOperator #62

Merged
merged 25 commits into from
Jun 2, 2023
Merged

Conversation

gpleiss
Copy link
Member

@gpleiss gpleiss commented May 5, 2023

KeOpsLinearOperator does not correctly backpropagate gradients if the covar_func closes over parameters.

KernelLinearOperator corrects for this, and is set up to replace LazyEvaluatedKernelTensor in GPyTorch down the line.

[Addresses issues in #2296]

@gpleiss gpleiss requested a review from Balandat May 5, 2023 21:12
@gpleiss
Copy link
Member Author

gpleiss commented May 10, 2023

@jacobrgardner @Balandat can one of you take a look at this?

@m-julian
Copy link

This implementation works perfectly when using one KeOps kernel, but there are a few problems when combining kernels. This is not specifically a problem with KernelLinearOperator, but the way the things are implemented in gpytorch/linear_operator currently. Hopefully some of these comments are useful for structuring the code.

I tried to multiply two keops kernels (which make a ProductKernel). These are the problems I found:

  1. ProductKernel calls to_dense if x1 and x2 are not identical here and returns a DenseLinearOperator since MulLinearOperator is only used for square symmetric matrices currently. Since KeOps can be used for large matrices where x1 is not equal to x2, calling to_dense could give memory errors. Instead, this should give back some object that represents the multiplication of the two symbolic matrices without actually computing them.
  2. If x1 is equal to x2, then you end up with a MulLinearOpeartor which calls root_decomposition on both the KerneLinearOperators. Then there are two issues:
    2.1. You might run out of memory if the matrix is very large when doing the root decomposition (because of KeOps).
    2.2. If the matrix fits into memory, it will be approximated as Lanczos is the default algorithm. This is fixed by doing gpytorch.settings.fast_computations(covar_root_decomposition=False), but it took me a while to figure out why I was getting different results when using KeOps kernels. If using a very large KeOps symbolic matrix, then I don't think root decomposition should be called, but instead the symbolic matrix should be used until some reduction operation is performed. (So should MulLinearOperator be used with KernelLinearOperator instances?)

I was thinking could KernelLinearOperator subclass from DenseLinearOperator since the KernelLinearOperator class is used to represent a full covariance matrix (just computed on the fly when a reduction operation is applied)? That way all the checks for isinstance(self, DenseLinearOperator) (for example this one) also work for KernelLinearOperator. Then this goes around the problem with MulLinearOperator but might cause other problems, so not sure if it is reasonable.

@gpleiss
Copy link
Member Author

gpleiss commented May 24, 2023

@m-julian there is unfortunately no way to do a symbolic element-wise multiplication of two kernels. KeOps (and LinearOperator) can keep things symbolic by using matrix multiplication based algorithms (CG/Lanczos) for solves and log determinants. Unfortunately, matrix multiplication does not distribute over element wise product. The current Lanczos solution comes out of the Product Kernel Interpolation paper (it was the best solution we could come up with).

Therefore, I don't know if there's a better way to handle the product of two matrices than what we currently do in code.

I was thinking could KernelLinearOperator subclass from DenseLinearOperator since the KernelLinearOperator class is used to represent a full covariance matrix (just computed on the fly when a reduction operation is applied)?

This would probably be a can of worms. DenseLinearOperator assumes that the matrix is represented by a tensor.
And most of our LinearOperator classes represent full covariance matrices that are computed on the fly when reduction operations are called.

@Turakar
Copy link
Contributor

Turakar commented May 25, 2023

Regarding the product discussion: Regarding a general implementation that would be part of GPyTorch, I also do not know of a better approach than what @gpleiss pointed out. However, from a user's perspective, you, @m-julian, could of course write a custom KeOps kernel. Most likely, this would even be faster than two separate kernels, as you only need one trip from global memory to processor registers on the GPU.

Copy link
Collaborator

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't say I did check all the indexing logic in exhaustive detail, but hopefully we have some test coverage for that?

linear_operator/operators/_linear_operator.py Show resolved Hide resolved
linear_operator/operators/_linear_operator.py Outdated Show resolved Hide resolved
linear_operator/operators/_linear_operator.py Outdated Show resolved Hide resolved
linear_operator/operators/_linear_operator.py Outdated Show resolved Hide resolved
linear_operator/operators/_linear_operator.py Outdated Show resolved Hide resolved
linear_operator/operators/kernel_linear_operator.py Outdated Show resolved Hide resolved
linear_operator/operators/kernel_linear_operator.py Outdated Show resolved Hide resolved
linear_operator/operators/kernel_linear_operator.py Outdated Show resolved Hide resolved
linear_operator/operators/kernel_linear_operator.py Outdated Show resolved Hide resolved
gpleiss and others added 20 commits June 2, 2023 18:58
KeOpsLinearOperator does not correctly backpropagate gradients if the
covar_func closes over parameters.

KernelLinearOperator corrects for this, and is set up to replace
LazyEvaluatedKernelTensor in GPyTorch down the line.
Previously, only positional args were added to the LinearOperator
representation, and so only positional args would receive gradients from
_bilinear_derivative.

This commit also adds Tensor/LinearOperator kwargs to the
representation, and so kwarg Tensor/LinearOperators will also receive
gradients.
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
Co-authored-by: Max Balandat <Balandat@users.noreply.github.com>
gpleiss and others added 4 commits June 2, 2023 18:58
@gpleiss gpleiss merged commit 7affaf3 into main Jun 2, 2023
@gpleiss gpleiss deleted the linops_keops branch June 2, 2023 19:07
Balandat added a commit to Balandat/linear_operator that referenced this pull request Jun 3, 2023
cornellius-gp#62 introduced an inconsistency of the `linear_ops` property of `KroneckerProductLinearOperator` (by making it a `list` rather than a `tuple` in some cases). This broke some downstream usage of this that relied on this being a tuple.
"are incompatible for a Kronecker product."
)

if len(batch_broadcast_shape): # Otherwise all linear_ops are non-batch, and we don't need to expand
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduced an inconsistency in the type of linear_ops (a list rather than a tuple), which resulted in some downstream breakages in botorch. Fixed in #66.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants