# Demo: `LinearOperator`

`linear_operator` (https://github.com/cornellius-gp/linear_operator) is a library for structured linear algebra built on PyTorch.

Due to its history as the linear algebra backend for GPyTorch (), it assumes (with a few exceptions) that the involved matrices symmetric positive definite. This can and should be relaxed to more general structured matrices (indefinite, non-square) as we think about developing `linear_operator` into a more general library.


### Installation:

**Stable:**
`pip install linear_operator`

**Lastest main branch:**
`pip install git+https://github.com/cornellius-gp/linear_operator.git`

In [1]:
from typing import Tuple

import torch

from linear_operator.operators import DiagLinearOperator, BlockDiagLinearOperator, KroneckerProductLinearOperator

In [2]:
def make_pd(n: int, b: Tuple[int, ...]=()):
    """Helper for generating random positive definite matrices."""
    a = torch.rand(*b, n, n)
    return a @ a.transpose(-1, -2) + torch.diag_embed(0.1 + torch.rand(*b, n))

### Simple example: Diagonal matrices

Consider diagonal matrices of size $n \times n$

Matmul is $\mathcal O(n)$ using he underlying structure, but $\mathcal O(n^2)$ not using structure. The same is true for memory complexity.

In [3]:
# Note: Using n > 2500 would demonstrate the benefits even more, but
# there is a weird pytorch bug with eigh that results in failures with
# certain setups: https://github.com/pytorch/pytorch/issues/83818

diag1 = 0.1 + torch.rand(2500)
diag2 = 0.1 + torch.rand(2500)

Diag1 = diag1.diag() # 25M elements
Diag2 = diag2.diag() # 25M elements

Diag1_lo = DiagLinearOperator(diag1)  # 5K elements
Diag2_lo = DiagLinearOperator(diag2)  # 5K elements

#### Addition

Diagonality (ness?) is closed under addition. `LinearOperator` understands that (note that the result is again a `DiagLinearOperator` rather than a dense Tensor).

In [4]:
result = Diag1_lo + Diag2_lo
result

<linear_operator.operators.diag_linear_operator.DiagLinearOperator at 0x111590af0>

In [5]:
assert torch.equal(result.diagonal(), diag1 + diag2)

#### Matmul

Matrix-multiplying diagonal matrices just means creating a diagonal matrix with the element-wise product of the diagonals as its diagonal. Naive time and memory complexity is $\mathcal O(n^2)$, using structure it is $\mathcal O(n)$

In [6]:
matmul = (Diag1 @ Diag2).diag()
matmul_lo = (Diag1_lo @ Diag2_lo).diagonal()

assert torch.equal(matmul, matmul_lo)

In [7]:
t_d = %timeit -o (Diag1 @ Diag2).diag()

17.3 ms ± 245 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
t_lo = %timeit -o (Diag1_lo @ Diag2_lo).diagonal()

3.26 µs ± 46.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [9]:
t_d.average / t_lo.average

5324.419178392022

Improvements: 
- $2,500$ -fold reduction in memory
- more than 3 orders of magnitude faster

#### Eigendecomposition

This uses `__torch_function__` in order to dispatch `torch.symeig` to a custom implementation that essentially just returns the diagonal elements and the identity matrix (should sort the evals and permute the evecs to have the exact same behavior, that's an easy thing to do).

Time complexity goes from $\mathcal O(n^3)$ to $\mathcal O(1)$ (without sorting). Memory complexity goes from $\mathcal O(n^2)$ to $\mathcal O(n)$. 

Of course if the user was aware of the structure, they could do this manually. The point is that `LinearOperator` does these things automatically (for more complex examples see below). Think of it like operator fusing on steroids (with the steroids being exploiting linear algebra simplifications for structured operators - this is something that basic notions of sparsity cannot do achieve).

In [10]:
evals, evecs = torch.linalg.eigh(Diag1)
evals_lo, evecs_lo = torch.linalg.eigh(Diag1_lo)

assert torch.allclose(evals, torch.sort(evals_lo).values)

In [11]:
t_d = %timeit -o torch.linalg.eigh(Diag1)

915 ms ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
t_lo = %timeit -o torch.linalg.eigh(Diag1_lo)

4.98 µs ± 33.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [13]:
t_d.average / t_lo.average

183581.73564684932

Improvements: 
- $2,500$ -fold reduction in memory
- 5 orders of magnitude faster

### Simpl-ish example: Block-Diagonal matrices

Matmul is $\mathcal O(n)$ using structure, but $\mathcal O(n^2)$ not using structure. Same for memory complexity.

In [14]:
# create a 2000 x 2000 block-diagonal matrix with 200 random symmetric 10 x 10 matrices on the (block)diagonal
BDiag_lo = BlockDiagLinearOperator(make_pd(10, (200,)))

# instatiate the full matrix
BDiag = BDiag_lo.to_dense()

#### Matrix-vector Multiplication (MVM)

In [15]:
v = torch.rand(2000, 1)

mvm = BDiag @ v
mvm_lo = BDiag_lo @ v

assert torch.allclose(mvm, mvm_lo)

In [16]:
t_d = %timeit -o BDiag @ v

324 µs ± 2.07 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
t_lo = %timeit -o BDiag_lo @ v

48.1 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [18]:
t_d.average / t_lo.average

6.734157966151394

Improvements: 
- $2,000$ -fold reduction in memory
- $\approx 6$ times faster (dense matmuls are just really optimized so not a ton to gain...)

#### SVD

Can construct the SVD of a Kronecker product from the SVD of the constitutent matrices. This allows us to compute the SVDs of 200 10x10 matrices in batch under the hood rather than the SVD of a 2000 x 2000 matrix.

In math, since time complexity for computing SVDs is cubic, if there are $n_b$ blocks of size $n \times n$ in the matrix, then we reduce complexity from $\mathcal O (n_b^3 n^3)$ to $\mathcal O (n_b n^3)$

In [19]:
# TODO: This worked with torch.svd in he past, may need to patch `__torch_function__` for eigh

U, S, V = torch.linalg.svd(BDiag)
U_lo, S_lo, V_lo = torch.linalg.svd(BDiag_lo)

torch.allclose(S, torch.sort(S_lo, descending=True).values, atol=1e-5)

True

In [20]:
t_d = %timeit -o torch.linalg.svd(BDiag)

506 ms ± 30.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
t_lo = %timeit -o torch.linalg.svd(BDiag_lo)

7.33 µs ± 98.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [22]:
t_d.average / t_lo.average

69007.02781403827

Improvements: 
- more than 4 orders of magnitude faster

### More complex: Kronecker matrices

If $A$ is $n \times n$ and $B$ is $m \times m$, then $A\otimes B$ is $nm \times nm$

In [23]:
A = make_pd(20)
B = make_pd(500)

Kron_lo = KroneckerProductLinearOperator(A, B)
Kron = Kron_lo.to_dense()

assert torch.allclose(Kron, torch.kron(A, B))

#### MVM

Naively, MVM is $\mathcal O(n^2m^2)$ time. However, exploiting Kronecker structure, we get $\mathcal O (nm (n+m))$ time. Memory complexity goes from $\mathcal O(n^2m^2)$ to $\mathcal O(n^2 + m^2)$

In [24]:
v = torch.rand(Kron_lo.shape[-1], 1)

mvm = Kron @ v
mvm_lo = Kron_lo @ v

assert torch.allclose(mvm, mvm_lo)

In [25]:
t_d = %timeit -o Kron @ v

15.3 ms ± 240 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [26]:
t_lo = %timeit -o Kron_lo @ v

93 µs ± 1.46 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [27]:
t_d.average / t_lo.average

164.64672578765354

Improvements:
- $200,000$ -fold reduction in memory
- more than 2 order so magnitude faster

### Even more fun with Kronecker structure

In [28]:
from linear_operator.operators import KroneckerProductLinearOperator, ConstantDiagLinearOperator

In [29]:
Kron_lt = KroneckerProductLinearOperator(A, B)

Let's add some (constant) diagonal: 

In [30]:
Diag_lt = ConstantDiagLinearOperator(1 + torch.rand(1), Kron_lt.shape[-1])
Diag = Diag_lt.to_dense()

KaddD = Kron + Diag
KaddD_lt = Kron_lt + Diag_lt
KaddD_lt

<linear_operator.operators.kronecker_product_added_diag_linear_operator.KroneckerProductAddedDiagLinearOperator at 0x1077dac70>

In [31]:
assert torch.allclose(KaddD_lt.to_dense(), KaddD)

#### Solve

Solving $(A \otimes B + a I)x = v$ naively means solving a $nm \times nm$ linear system.

We can be smart by instead noting that computing the inverse of $A \otimes B + a I$ can be done cheaply:
1. We perform an eigendecomposition $A \otimes B = \sum_j e_j v_jv_j^T$. This can be done cheaply b/c the eigendecomposition of  $A \otimes B$ can be constructed from the (small and cheap-to-compute) eigendecompositions of $A$ and $B$, respectively.
2. The eigendecomposition of $A \otimes B + a I$ is just the eigendecomposition of $A \otimes B$ plus a spectral shift of the eigenvalues $e_j$ by $a$.
3. The inverse of $A \otimes B + a I$ is obtained by simply taking the reciprocals of the eigenvalues in its eigendecomposition.

At the end of the day, this means that we can go from $\mathcal O(n^3m^3)$ to $\mathcal O(n^3 + m^3)$ complexity for the solve. We don't have to do anything other than ensuring that we express $A \otimes B$ and the constant diagonal with the right operators, the rest we get for free (modulo registering this with `solve` via `__torch_function__`). Of course this is true for additional Kronecker factors, i.e. we go from $\mathcal O(\Pi_i n_i^3)$ to $\mathcal O(\sum_i n_i^3)$

One thing to be careful about are numerical robustness issues, which can be an issue when computing eigendecompositions. In general, hairy linear algebra should happen in double not float...

In [32]:
x = torch.linalg.solve(KaddD, v)
x_lt = torch.linalg.solve(KaddD_lt, v)

assert torch.allclose(x, x_lt, atol=1e-2, rtol=1e-4)  # hard solve, need to increase tolerance here

In [33]:
t_d = %timeit -o torch.linalg.solve(Kron + Diag, v)

1.15 s ± 35.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [34]:
t_lo = %timeit -o torch.linalg.solve(Kron_lt + Diag_lt, v)

20.1 ms ± 137 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [35]:
t_d.average / t_lo.average

57.28185047652546

Improvements:
- $200,000$ -fold reduction in memory
- $\approx 50$ times faster

### A non-PSD example: Triangular matrices 

We have `torch.triangular_solve` to get solve complexity down from $\mathcal O(n^3)$ to $\mathcal O(n^2)$, but the user has to fully trace through all of their code to understand when it's safe to call it. If we can retain the structural information, we can just dispatch to the right solve automatically, allowing us to write structure-agnostic code and get the linear algebra optimziations for free.

In [36]:
from linear_operator.operators import TriangularLinearOperator

In [37]:
tri = torch.eye(500) + torch.rand(500, 500, dtype=torch.double).tril()
tri_lo = TriangularLinearOperator(tri)

assert torch.equal(tri, tri_lo.to_dense())

In [38]:
tri_inv = torch.inverse(tri)
tri_lo_inv = tri_lo.inverse()  # TODO: Handle in torch.inverse by registering via __torch_function__

assert torch.allclose(tri_inv, tri_lo_inv.to_dense())

In [39]:
t_d = %timeit -o torch.inverse(tri)

3.34 ms ± 90 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [40]:
t_lo = %timeit -o tri_lo.inverse()

672 ns ± 3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [41]:
t_d.average / t_lo.average

4973.516163490839

Improvements:
- uses half of the memory
- $\approx 5,000$ times faster