1. Implement the four operations

In [None]:
import torch

# 1) Vectorize (lower triangular, incl diagonal)
def vectorize(M: torch.Tensor) -> torch.Tensor:
    n = M.shape[-1]
    idx = torch.tril_indices(n, n, device=M.device)
    return M[idx[0], idx[1]]

# 2) Devectorize 
def devectorize(v: torch.Tensor, n: int, device=None, dtype=None) -> torch.Tensor:
    device = device if device is not None else v.device
    dtype = dtype if dtype is not None else v.dtype

    L = torch.zeros((n, n), device=device, dtype=dtype)
    idx = torch.tril_indices(n, n, device=device)
    L[idx[0], idx[1]] = v
    return L + L.T - torch.diag(torch.diag(L))

# 3) Matrix square root for SPD
def matrix_sqrt_spd(M: torch.Tensor) -> torch.Tensor:
    eigvals, eigvecs = torch.linalg.eigh(M)
    D_sqrt = torch.diag(torch.sqrt(eigvals))
    return eigvecs @ D_sqrt @ eigvecs.T

# 4) Matrix logarithm for SPD 
def matrix_log_spd(M: torch.Tensor) -> torch.Tensor:
    eigvals, eigvecs = torch.linalg.eigh(M)
    D_log = torch.diag(torch.log(eigvals))
    return eigvecs @ D_log @ eigvecs.T
