In [1]:
import torch
import linear_operator

# Linear Operator Kronecker Example

We'll construct 2 (symmetric PSD) matrices: $\mathbf A \in \mathbb R^{25 \times 25}$ and $\mathbf B \in \mathbb R^{100 \times 100}$.

The kronecker product $\mathbf A \otimes \mathbf B$ is a $2500 \times 2500$ matrix. Performing linear operations on $\mathbf A \otimes \mathbf B$ can be fast if we take into account the structure afforded by the Kronecker product.

By wrapping $\mathbf A$ and $\mathbf B$ in a `KroneckerProductLinearOperator`, we can perform algebraic operations on the Kronecker product in a structure-aware way.

In [2]:
A = torch.randn(25, 25)
A = A @ A.T  # A 25 x 25 PSD matrix

B = torch.randn(100, 100)
B = B @ B.T  # A 100 x 100 PSD matrix

## Naively Computing $\mathbf A \otimes \mathbf B$ Eigenvalues (Without LinearOperator)

In [3]:
kron = torch.kron(A, B)
%timeit torch.linalg.eigh(kron)

277 ms ± 805 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Computing $\mathbf A \otimes \mathbf B$ Eigenvalues With LinearOperator

We wil begin by wrapping `A` and `B` with the `to_linear_operator` function. After this, all math operations will take place with the linear operator abstraction, which will take into account the Kronecker product structure.

If we are aware that $\mathbf A \otimes \mathbf B$ has Kronecker structure, then we can compute eigenvalues efficiently. The linear_operator package keeps track of this structure. Calling `torch.kron` on LinearOperators returns a `KroneckerProductLinearOperator`, which codifies the Kronecker structure.

In [4]:
A_lo = linear_operator.to_linear_operator(A)
B_lo = linear_operator.to_linear_operator(B)
kron = torch.kron(A_lo, B_lo)
print(kron.__class__)  # It's not a torch.Tensor!

<class 'linear_operator.operators.kronecker_product_linear_operator.KroneckerProductLinearOperator'>


In [5]:
%timeit torch.linalg.eigh(kron)  # It's much faster!!!

3.05 µs ± 22.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


## The Lazy Evaluation of LinearOperators

The linear operator package avoids explicitly instantiating any `LinearOperator` as a matrix. Any composition or decoration operation on `LinearOperators` returns another `LinearOperator` which specifies the structure of the operature through a tree-like object.

For example: adding together two linear operators returns a `SumLinearOperator`.

In [6]:
D_lo = linear_operator.operators.DiagLinearOperator(torch.randn(2500).abs())
# This is a 2500 x 2500 diagonal matrix, represented as a LinearOperator

In [7]:
print((kron + D_lo).__class__)

<class 'linear_operator.operators.sum_linear_operator.SumLinearOperator'>


The only methods that return `torch.Tensors` are those that perform some sort of reduction to a `LinearOperator`, such as a `matmul`, eigendecomposition, etc. The linear operator package attempts to perform all reductions in the most efficient way, given the structure of the operator.

For example, note that `(kron + D_lo)` is the summation of a Kronecker product and a diagonal matrix. `matmul`s distribute across summations, and Kronecker products and diagonal matrices both have very efficient `matmul` implementations.

In [8]:
vec = torch.randn(2500)

# With LinearOperator - exploiting structure
%timeit (kron + D_lo).matmul(vec)

76.6 µs ± 101 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [9]:
# Using dense torch.Tensor - ignoring structure
%timeit (kron + D_lo).to_dense().matmul(vec)  # Much slower!

25.9 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Converting LinearOperators to torch.Tensors

If - at any point - we want to explicitly instantiate the dense matrix represented by a `LinearOperator`, we can call the `to_dense()` method.

This is generally not recommended, since many `LinearOperators` are efficient (data-sparse) representations of large matrices. Calling `to_dense()` might easily create an object that eats up all available memory!

In [10]:
dense = (kron + D_lo).to_dense()
print(dense.__class__)
print(dense.shape)  # A big matrix!

<class 'torch.Tensor'>
torch.Size([2500, 2500])
