In [None]:
# Don't run this if you already have an environment set up with torch.

!pip install torch

In [1]:
import torch

In [16]:
def data_aware_low_rank(
    A: torch.Tensor,
    X: torch.Tensor,
    k: int,
    tol: float = 1e-7
    ):
  """
  Compute the data-aware low-rank decomposition of a matrix by solving min_{L,R}||X(LR - A)||_F^2
  A: Input matrix to be decomposed into low-rank factors
  X: Calibration data matrix (each row is a datapoint)
  k: Target rank
  tol: For not inverting very small singular values
  """

  m = X.shape[0]        # No. of datapoints
  d = X.shape[1]        # Dimension of each datapoint


  assert A.shape[0] == d, "Dimension mismatch between X and A!"
  assert m >= k, "The number of datapoints should be larger than the target rank."

  n = A.shape[1]

  Y = X @ A
  U, Sigmat, Vh = torch.linalg.svd(X, full_matrices=True)
  V = Vh.T

  if m <= d:

    Ub, Sigmab, Vbh = torch.linalg.svd(U.T @ Y, full_matrices=True)

    print(f"Sigmab.shape: {Sigmab.shape}")
    print(f"Vbh.shape: {Vbh.shape}")

    inv_sing_vals = torch.where(Sigmat >= tol, 1.0 / Sigmat, torch.tensor(0.0))       # Pseudo-inverse
    Sigma_inv = torch.diag(inv_sing_vals)

    L = V[:,:m] @ Sigma_inv @ Ub[:,:k]

    pad = torch.zeros((m, n-m))
    Sigmab = torch.cat((torch.diag(Sigmab), pad), dim=1)
    print(f"Sigmab.shape = {Sigmab.shape}")
    R = Sigmab[:k,:] @ Vbh

  else:

    Ub, Sigmab, Vbh = torch.linalg.svd(U[:,:d].T @ Y, full_matrices=True)

    inv_sing_vals = torch.where(Sigmat >= tol, 1.0 / Sigmat, torch.tensor(0.0))       # Pseudo-inverse
    Sigma_inv = torch.diag(inv_sing_vals)

    L = V @ Sigma_inv @ Ub[:, :k]
    R = torch.diag(Sigmab)[:k,:] @ Vbh

  return {"L": L, "R": R}

In [15]:
from collections import namedtuple

def data_aware_low_rank_regH(
    A: torch.Tensor,
    H: torch.Tensor,
    k: int,
    sigma_reg: float = 1e-5
    ):
  """
  Compute the data-aware low-rank decomposition with regularized Hessian by solving min_{L,R}||(A - LR)H^{1/2}||_F^2
  A: Input matrix to be decomposed into low-rank factors
  H: Input Hessian H = X'X
  k: Target rank
  sigma_reg: Regularization tolerance for not inverting very small eigenvalues
  """


  assert H.shape[0] == H.shape[1] and torch.allclose(H, H.T), "Hessian is not symmetric."
  assert A.shape[1] == H.shape[0], "Dimension mismatch between A and H."

  # Named tuple to save eigenvalues and eigenvectors
  EigTuple = namedtuple("EigTuple", ["eigenvalues", "eigenvectors"])

  # Compute the eigenvalue decomposition of the Hessian and regularize to make it positive definite
  eigH = torch.linalg.eigh(H)
  eigvals = eigH.eigenvalues
  if eigvals.min() < sigma_reg:
      H = H + (sigma_reg - eigvals.min()) * torch.eye(H.shape[0])
      eigvals += sigma_reg - eigvals.min()
      eigH = EigTuple(eigvals, eigH.eigenvectors)

  # Symmetric square root of Hessian
  H_sqrt = (eigH.eigenvectors @ torch.diag(torch.sqrt(eigvals)) @ eigH.eigenvectors.T)

  # Compute low-rank factors
  Y = A @ H_sqrt @ eigH.eigenvectors
  U, Sigma, Vh = torch.linalg.svd(Y, full_matrices=False)

  L = U[:, :k]
  R = torch.diag(Sigma[:k]) @ Vh[:k, :] @ torch.diag(1 / eigH.eigenvalues.sqrt()) @ eigH.eigenvectors.T

  return {"L": L, "R": R}


In [20]:
torch.manual_seed(42)

d = 8
n = 8
m = 10
k = 5

A = torch.randn(d,n)
X = torch.randn(m,d)
H = X.T @ X

In [None]:
result = data_aware_low_rank_regH(A, H, k)

L = result["L"]
R = result["R"]

print(f"H.shape: {H.shape}")
print(f"A.shape: {A.shape}")

print(f"L.shape: {L.shape}")
print(f"R.shape: {R.shape}")