# linalg

> Numerical linear algebra utilities

In [None]:
#| default_exp linalg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch

## 2×2 and 3×3 Hermitian eigenvalue solvers

Direct eigenvalue calculators for special matrices for which closed-form solutions exist (see [here](https://en.wikipedia.org/wiki/Eigenvalue_algorithm#Direct_calculation) for more details).

In [None]:
#| exporti
def _is_square(A: torch.Tensor) -> bool:
    _, i, j, *_ = A.shape
    assert i == j, "Matrix is not square"


def _is_hermitian(A: torch.Tensor) -> bool:
    return torch.testing.assert_close(
        A, A.transpose(1, 2).conj(), msg="Matrix is not Hermitian"
    )

In [None]:
#| export
def eigvalsh(A: torch.Tensor, check_valid: bool = True) -> torch.Tensor:
    if check_valid:
        _is_square(A)
        _is_hermitian(A)
    if A.shape[1] == 2:
        return eigvalsh2(*A[:, *torch.triu_indices(2, 2)].split(1, dim=1))
    elif A.shape[1] == 3:
        return eigvalsh3(*A[:, *torch.triu_indices(3, 3)].split(1, dim=1))
    else:
        raise ValueError("Only supports 2×2 and 3×3 matrices")

In [None]:
#| export
def eigvalsh2(ii, ij, jj):
    tr = ii + jj
    det = ii * jj - ij.square()

    disc = (tr.square() - 4 * det).sqrt()
    disc = torch.concat([-disc, disc], dim=1)

    eigvals = (tr + disc) / 2
    return eigvals

In [None]:
#| export
def eigvalsh3(ii, ij, ik, jj, jk, kk):
    q = (ii + jj + kk) / 3
    p1 = torch.concat([ij, ik, jk], dim=-1).square().sum(-1, keepdim=True)
    p2 = (torch.concat([ii, jj, kk], dim=-1) - q).square().sum(
        dim=-1, keepdim=True
    ) + 2 * p1
    p = (p2 / 6).sqrt()

    r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / p.pow(3) / 2
    r = r.clamp(-1, 1)
    phi = r.arccos() / 3

    eig3 = q + 2 * p * phi.cos()
    eig1 = q + 2 * p * (phi + 2 * torch.pi / 3).cos()
    eig2 = 3 * q - eig1 - eig3
    return torch.concat([eig1, eig2, eig3], dim=-1)

In [None]:
#| exporti
def deth3(ii, ij, ik, jj, jk, kk):
    return (
        ii * jj * kk
        + 2 * ij * ik * jk
        - ii * jk.square()
        - jj * ik.square()
        - kk * ij.square()
    )

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()