# linalg

> Numerical linear algebra implementations

In [None]:
#| default_exp linalg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch

## 2×2 and 3×3 Hermitian eigenvalue solvers

Direct eigenvalue calculators for special matrices for which closed-form solutions exist (see [here](https://en.wikipedia.org/wiki/Eigenvalue_algorithm#Direct_calculation) for more details).

In [None]:
#| exporti
def _is_square(A: torch.Tensor) -> bool:
    *_, i, j = A.shape
    assert i == j, "Matrix is not square"


def _is_hermitian(A: torch.Tensor) -> bool:
    return torch.testing.assert_close(A, A.mH, msg="Matrix is not Hermitian")

In [None]:
#| export
def eigvalsh(A: torch.Tensor, check_valid: bool = True) -> torch.Tensor:
    if check_valid:
        _is_square(A)
        _is_hermitian(A)
    if A.shape[-1] == 2:
        return eigvalsh2(*A[..., *torch.triu_indices(2, 2)].split(1, dim=-1))
    elif A.shape[-1] == 3:
        return eigvalsh3(*A[..., *torch.triu_indices(3, 3)].split(1, dim=-1))
    else:
        return ValueError("Only supports 2×2 and 3×3 matrices")

In [None]:
#| export
def eigvalsh2(ii, ij, jj):
    tr = ii + jj
    det = ii * jj - ij.square()

    disc = (tr.square() - 4 * det).sqrt()
    disc = torch.concat([-disc, disc], dim=-1)

    eigvals = (tr + disc) / 2
    return eigvals

In [None]:
#| export
def eigvalsh3(ii, ij, ik, jj, jk, kk):
    pass

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()