# linalg

> Numerical linear algebra utilities

In [None]:
#| default_exp linalg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch

## 2×2 and 3×3 Hermitian eigenvalue solvers

Direct eigenvalue calculators for special matrices for which closed-form solutions exist (see [here](https://en.wikipedia.org/wiki/Eigenvalue_algorithm#Direct_calculation) for more details).

In [None]:
#| exporti
def _is_square(A: torch.Tensor) -> bool:
    _, i, j, *_ = A.shape
    assert i == j, "Matrix is not square"


def _is_hermitian(A: torch.Tensor) -> bool:
    return torch.testing.assert_close(
        A, A.transpose(1, 2).conj(), msg="Matrix is not Hermitian"
    )

In [None]:
#| export
def eigvalsh(A: torch.Tensor, check_valid: bool = True) -> torch.Tensor:
    """
    Compute the eigenvalues of a batched tensor with shape [B C C H W (D)]
    where C is 2 or 3, and the tensor is Hermitian in dimensions 1 and 2.

    Returns eigenvalues in a tensor with shape [1 2 H W] or [1 3 H W D],
    for 2D and 3D inputs, respectively, sorted in ascending order.
    """
    if check_valid:
        _is_square(A)
        _is_hermitian(A)
    if A.shape[1] == 2:
        return eigvalsh2(*A[:, *torch.triu_indices(2, 2)].split(1, dim=1))
    elif A.shape[1] == 3:
        return eigvalsh3(*A[:, *torch.triu_indices(3, 3)].split(1, dim=1))
    else:
        raise ValueError("Only supports 2×2 and 3×3 matrices")

In [None]:
#| export
def eigvalsh2(ii: torch.Tensor, ij: torch.Tensor, jj: torch.Tensor) -> torch.Tensor:
    """
    Compute the eigenvalues of a batched Hermitian 2×2 tensor
    where blocks have shape [1 1 H W].

    Returns eigenvalues in a tensor with shape [1 2 H W]
    sorted in ascending order.
    """
    tr = ii + jj
    det = ii * jj - ij.square()

    disc = (tr.square() - 4 * det).sqrt()
    disc = torch.concat([-disc, disc], dim=1)

    eigvals = (tr + disc) / 2
    return eigvals

In [None]:
#| export
def eigvalsh3(
    ii: torch.Tensor,
    ij: torch.Tensor,
    ik: torch.Tensor,
    jj: torch.Tensor,
    jk: torch.Tensor,
    kk: torch.Tensor,
    eps: float = 1e-8,
) -> torch.Tensor:
    """
    Compute the eigenvalues of a batched Hermitian 3×3 tensor
    where blocks have shape [1 1 H W D].

    Returns eigenvalues in a tensor with shape [1 3 H W D]
    sorted in ascending order.
    """
    diag = torch.concat([ii, jj, kk], dim=1)
    triu = torch.concat([ij, ik, jk], dim=1)
    
    q = diag.sum(dim=1, keepdim=True) / 3
    p1 = triu.square().sum(dim=1, keepdim=True)
    p2 = (diag - q).square().sum(dim=1, keepdim=True)
    p = ((2 * p1 + p2) / 6).sqrt()

    r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / (p.pow(3) + eps) / 2
    r = r.clamp(-1, 1)
    phi = r.arccos() / 3

    eig3 = q + 2 * p * phi.cos()
    eig1 = q + 2 * p * (phi + 2 * torch.pi / 3).cos()
    eig2 = 3 * q - eig1 - eig3
    return torch.concat([eig1, eig2, eig3], dim=1)

In [None]:
#| exporti
def deth3(ii, ij, ik, jj, jk, kk):
    return (
        ii * jj * kk
        + 2 * ij * ik * jk
        - ii * jk.square()
        - jj * ik.square()
        - kk * ij.square()
    )

### Testing

Our closed-form solvers are numerically equivalent to `torch.linalg.eigvalsh`.
Unsurprisingly, our implementation is also much faster than PyTorch's solver.

In [None]:
# Test the 2×2 implementation is equivalent to torch's eigvalsh
A = torch.randn(100, 2, 2, 30, 30)
A = A + A.transpose(1, 2)  # Make A Hermitian

torch.testing.assert_close(
    eigvalsh(A),
    torch.linalg.eigvalsh(A.permute(0, -2, -1, 1, 2)).permute(0, -1, 1, 2),
    rtol=1e-5,
    atol=1e-4,
)

In [None]:
# Time diptorch's implementation
%timeit eigvalsh(A)

1.61 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
# Time torch's implementation
B = A.permute(0, -2, -1, 1, 2)
%timeit torch.linalg.eigvalsh(B)

43.7 ms ± 40.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# Test the 3×3 implementation is equivalent to torch's eigh
A = torch.randn(100, 3, 3, 30, 30)
A = A + A.transpose(1, 2)  # Make A Hermitian

torch.testing.assert_close(
    eigvalsh(A),
    torch.linalg.eigvalsh(A.permute(0, -2, -1, 1, 2)).permute(0, -1, 1, 2),
    rtol=1e-5,
    atol=1e-4,
)

In [None]:
# Time diptorch's implementation
%timeit eigvalsh(A)

5.78 ms ± 11 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
# Time torch's implementation
B = A.permute(0, -2, -1, 1, 2)
%timeit torch.linalg.eigvalsh(B)

97.8 ms ± 766 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
#| hide
import nbdev

nbdev.nbdev_export()