# Low-rank approximation

Let $A$ be a $D$-dimensional positive definite matrix with singular value decomposition $A = U\,\Sigma\,V^\intercal$.
The best rank $d$ approximation of $A$ (in a Frobenius norm sense) is
$$
    A_{(d)} = \arg\min_{B:{\rm rank}(B) = d}\|A - B\|_F = U\,\Sigma_{(d)}\,V^\intercal
$$
where $\|A\|_F = \sqrt{{\rm Tr}(A\,A^\intercal)}$ is the Frobenius norm and
$\Sigma_{(d)} = {\rm diag}(\sigma_1, \ldots, \sigma_d, 0, \ldots, 0)$
is a diagonal matrix that keeps the top $d$ singular values.

In [1]:
import jax
import jax.numpy as jnp

In [2]:
%config InlineBackend.figure_format = "retina"

In [3]:
key = jax.random.PRNGKey(314)

In [4]:
D, d = 10, 3
W, L = jax.random.normal(key, (2, d, D))

In [5]:
jnp.set_printoptions(precision=5, linewidth=200, suppress=True)

## Recovering a low-rank decomposition

In [6]:
F = W.T @ W + L.T @ L
F

Array([[ 3.43371,  0.52476, -0.34925, -2.74982,  2.87597,  2.03692, -1.51675, -0.6612 , -1.81309, -4.35611],
       [ 0.52476, 13.78217, -5.23019,  6.06569, -0.97464,  2.33045,  2.44654, -1.14212,  2.84491,  2.45289],
       [-0.34925, -5.23019,  4.08506, -1.85839,  0.99625, -0.93045, -1.50251,  2.7351 , -2.19246,  0.21692],
       [-2.74982,  6.06569, -1.85839, 11.78426, -3.83138,  3.53937, -2.26391, -0.98846,  4.79112,  5.58263],
       [ 2.87597, -0.97464,  0.99625, -3.83138,  4.42151, -0.03921, -0.81   ,  0.72523, -2.93381, -6.41965],
       [ 2.03692,  2.33045, -0.93045,  3.53937, -0.03921,  4.9531 , -3.54693, -1.9295 ,  1.35893, -0.54921],
       [-1.51675,  2.44654, -1.50251, -2.26391, -0.81   , -3.54693,  5.01821, -0.21129,  0.85256,  0.73781],
       [-0.6612 , -1.14212,  2.7351 , -0.98846,  0.72523, -1.9295 , -0.21129,  4.58182, -3.25161,  2.97801],
       [-1.81309,  2.84491, -2.19246,  4.79112, -2.93381,  1.35893,  0.85256, -3.25161,  4.67505,  1.57064],
       [-4.35611,  

In [7]:
Z = jnp.r_[W, L]
Z.T @ Z

Array([[ 3.43371,  0.52476, -0.34925, -2.74982,  2.87597,  2.03692, -1.51675, -0.6612 , -1.81309, -4.35611],
       [ 0.52476, 13.78217, -5.23019,  6.06569, -0.97464,  2.33045,  2.44654, -1.14212,  2.84491,  2.45289],
       [-0.34925, -5.23019,  4.08506, -1.85839,  0.99625, -0.93045, -1.50251,  2.7351 , -2.19246,  0.21692],
       [-2.74982,  6.06569, -1.85839, 11.78426, -3.83138,  3.53937, -2.26391, -0.98846,  4.79112,  5.58263],
       [ 2.87597, -0.97464,  0.99625, -3.83138,  4.42151, -0.03921, -0.81   ,  0.72523, -2.93381, -6.41965],
       [ 2.03692,  2.33045, -0.93045,  3.53937, -0.03921,  4.9531 , -3.54693, -1.9295 ,  1.35893, -0.54921],
       [-1.51675,  2.44654, -1.50251, -2.26391, -0.81   , -3.54693,  5.01821, -0.21129,  0.85256,  0.73781],
       [-0.6612 , -1.14212,  2.7351 , -0.98846,  0.72523, -1.9295 , -0.21129,  4.58182, -3.25161,  2.97801],
       [-1.81309,  2.84491, -2.19246,  4.79112, -2.93381,  1.35893,  0.85256, -3.25161,  4.67505,  1.57064],
       [-4.35611,  

## Explicit low-rank approximation by SVD — three slices, same results

In [8]:
F = W.T @ W + L.T @ L

### Slicing left-singular vectos

In [9]:
u, s, vh = jnp.linalg.svd(F, full_matrices=False)

u[:, :d] @ jnp.diag(s)[:d, ] @ vh

Array([[ 2.3883 , -0.3611 , -0.35461, -2.02488,  2.6544 ,  1.16671, -0.92652, -1.37972, -0.73123, -5.39314],
       [-0.3611 , 12.1414 , -5.95995,  5.95702, -2.04106,  1.80041,  3.17812, -3.10316,  4.50432,  1.77073],
       [-0.35461, -5.95995,  3.0545 , -2.2618 ,  0.3656 , -0.97916, -1.55652,  1.80161, -1.99234,  0.38266],
       [-2.02488,  5.95701, -2.2618 , 10.61903, -4.42717,  4.39949, -2.38358, -1.65962,  4.814  ,  6.45695],
       [ 2.6544 , -2.04105,  0.3656 , -4.42717,  3.4902 ,  0.02296, -0.40818, -0.79563, -1.89074, -6.4475 ],
       [ 1.1667 ,  1.80041, -0.97916,  4.39949,  0.02296,  4.077  , -3.22864, -2.12922,  1.88739, -1.45299],
       [-0.92652,  3.17812, -1.55652, -2.38358, -0.40818, -3.22864,  4.4876 ,  0.64199, -0.18068,  1.27827],
       [-1.37972, -3.10316,  1.80161, -1.65962, -0.79563, -2.12922,  0.64199,  1.96664, -1.21934,  2.55965],
       [-0.73123,  4.50432, -1.99234,  4.814  , -1.89074,  1.88739, -0.18068, -1.21934,  2.55251,  2.48996],
       [-5.39314,  

### Slicing right-singular vectos

In [10]:
u, s, vh = jnp.linalg.svd(F, full_matrices=False)

u @ jnp.diag(s)[:, :d] @ vh[:d]

Array([[ 2.3883 , -0.3611 , -0.35461, -2.02488,  2.6544 ,  1.16671, -0.92652, -1.37972, -0.73123, -5.39314],
       [-0.3611 , 12.1414 , -5.95995,  5.95702, -2.04106,  1.80041,  3.17812, -3.10316,  4.50432,  1.77073],
       [-0.35461, -5.95995,  3.0545 , -2.2618 ,  0.3656 , -0.97916, -1.55652,  1.80161, -1.99234,  0.38266],
       [-2.02488,  5.95701, -2.2618 , 10.61903, -4.42717,  4.39949, -2.38358, -1.65962,  4.814  ,  6.45695],
       [ 2.6544 , -2.04105,  0.3656 , -4.42717,  3.4902 ,  0.02296, -0.40818, -0.79563, -1.89074, -6.4475 ],
       [ 1.1667 ,  1.80041, -0.97916,  4.39949,  0.02296,  4.077  , -3.22864, -2.12922,  1.88739, -1.45299],
       [-0.92652,  3.17812, -1.55652, -2.38358, -0.40818, -3.22864,  4.4876 ,  0.64199, -0.18068,  1.27827],
       [-1.37972, -3.10316,  1.80161, -1.65962, -0.79563, -2.12922,  0.64199,  1.96664, -1.21934,  2.55965],
       [-0.73123,  4.50432, -1.99234,  4.814  , -1.89074,  1.88739, -0.18068, -1.21934,  2.55251,  2.48996],
       [-5.39314,  

### Zeroing diagonal
(This is really all we need for a low-rank approximation)

In [11]:
s_diag = s.copy()
s_diag = jnp.diag(s_diag.at[d:].set(0.0))
u @ s_diag @ vh

Array([[ 2.3883 , -0.3611 , -0.35461, -2.02488,  2.6544 ,  1.16671, -0.92652, -1.37972, -0.73123, -5.39314],
       [-0.3611 , 12.1414 , -5.95995,  5.95702, -2.04106,  1.80041,  3.17812, -3.10316,  4.50432,  1.77073],
       [-0.35461, -5.95995,  3.0545 , -2.2618 ,  0.3656 , -0.97916, -1.55652,  1.80161, -1.99234,  0.38266],
       [-2.02488,  5.95701, -2.2618 , 10.61903, -4.42717,  4.39949, -2.38358, -1.65962,  4.814  ,  6.45695],
       [ 2.6544 , -2.04105,  0.3656 , -4.42717,  3.4902 ,  0.02296, -0.40818, -0.79563, -1.89074, -6.4475 ],
       [ 1.1667 ,  1.80041, -0.97916,  4.39949,  0.02296,  4.077  , -3.22864, -2.12922,  1.88739, -1.45299],
       [-0.92652,  3.17812, -1.55652, -2.38358, -0.40818, -3.22864,  4.4876 ,  0.64199, -0.18068,  1.27827],
       [-1.37972, -3.10316,  1.80161, -1.65962, -0.79563, -2.12922,  0.64199,  1.96664, -1.21934,  2.55965],
       [-0.73123,  4.50432, -1.99234,  4.814  , -1.89074,  1.88739, -0.18068, -1.21934,  2.55251,  2.48996],
       [-5.39314,  

## SVD without explicitly computing $F$

### Option 1 — explicit SVD

In [12]:
Z = jnp.r_[W, L]
Z.shape # (2d X D)
Z_svd = jnp.linalg.svd(Z, full_matrices=False)

S = Z_svd.S

R = jnp.diag(Z_svd.S)[:d] @ Z_svd.Vh
R

Array([[-0.93406,  2.50154, -0.99415,  2.93843, -1.59326,  0.67353,  0.19101, -0.32054,  1.5005 ,  2.53536],
       [ 1.2001 ,  1.97021, -1.23696,  0.28369,  0.97543,  1.1487 , -0.14118, -1.34327,  0.51564, -2.47216],
       [ 0.2749 , -1.41492,  0.73218,  1.37993, -0.01608,  1.51784, -2.10504, -0.24397,  0.1874 , -0.21139]], dtype=float32)

In [13]:
R.T @ R

Array([[ 2.3883 , -0.3611 , -0.35461, -2.02488,  2.6544 ,  1.16671, -0.92652, -1.37973, -0.73123, -5.39314],
       [-0.3611 , 12.14141, -5.95996,  5.95703, -2.04106,  1.80042,  3.17812, -3.10317,  4.50433,  1.77073],
       [-0.35461, -5.95996,  3.0545 , -2.2618 ,  0.3656 , -0.97916, -1.55653,  1.80161, -1.99235,  0.38266],
       [-2.02488,  5.95703, -2.2618 , 10.61904, -4.42717,  4.3995 , -2.38358, -1.65962,  4.814  ,  6.45695],
       [ 2.6544 , -2.04106,  0.3656 , -4.42717,  3.4902 ,  0.02296, -0.40818, -0.79563, -1.89074, -6.4475 ],
       [ 1.16671,  1.80042, -0.97916,  4.3995 ,  0.02296,  4.07701, -3.22864, -2.12922,  1.88739, -1.45299],
       [-0.92652,  3.17812, -1.55653, -2.38358, -0.40818, -3.22864,  4.4876 ,  0.64199, -0.18067,  1.27827],
       [-1.37973, -3.10317,  1.80161, -1.65962, -0.79563, -2.12922,  0.64199,  1.96664, -1.21934,  2.55965],
       [-0.73123,  4.50433, -1.99235,  4.814  , -1.89074,  1.88739, -0.18067, -1.21934,  2.55252,  2.48996],
       [-5.39314,  

### Option 3 — right singular vectors

In [14]:
Z = jnp.r_[W, L]
singular_vectors, singular_values, _ = jnp.linalg.svd(Z @ Z.T, hermitian=True, full_matrices=False)
singular_values = jnp.sqrt(singular_values) # square root of eigenvalues

lr_new = jnp.diag(1/singular_values) @ singular_vectors.T @ Z # solving for right singular vectors
lr_new = jnp.diag(singular_values)[:d] @ lr_new # Keeping top-d singular values
lr_new

Array([[ 0.93407, -2.50154,  0.99415, -2.93843,  1.59326, -0.67353, -0.19101,  0.32054, -1.5005 , -2.53536],
       [ 1.2001 ,  1.97021, -1.23696,  0.28369,  0.97543,  1.1487 , -0.14118, -1.34327,  0.51564, -2.47215],
       [ 0.2749 , -1.41491,  0.73218,  1.37993, -0.01608,  1.51784, -2.10504, -0.24397,  0.1874 , -0.21138]], dtype=float32)

In [15]:
lr_new.T @ lr_new

Array([[ 2.3883 , -0.3611 , -0.35461, -2.02488,  2.6544 ,  1.16671, -0.92652, -1.37972, -0.73123, -5.39314],
       [-0.3611 , 12.14139, -5.95995,  5.95702, -2.04106,  1.80042,  3.17812, -3.10317,  4.50433,  1.77073],
       [-0.35461, -5.95995,  3.0545 , -2.2618 ,  0.3656 , -0.97916, -1.55652,  1.80161, -1.99235,  0.38266],
       [-2.02488,  5.95702, -2.2618 , 10.61903, -4.42717,  4.39949, -2.38358, -1.65962,  4.814  ,  6.45695],
       [ 2.6544 , -2.04106,  0.3656 , -4.42717,  3.49021,  0.02296, -0.40818, -0.79563, -1.89074, -6.4475 ],
       [ 1.16671,  1.80042, -0.97916,  4.39949,  0.02296,  4.077  , -3.22864, -2.12922,  1.88739, -1.45299],
       [-0.92652,  3.17812, -1.55652, -2.38358, -0.40818, -3.22864,  4.4876 ,  0.64199, -0.18067,  1.27827],
       [-1.37972, -3.10317,  1.80161, -1.65962, -0.79563, -2.12922,  0.64199,  1.96664, -1.21934,  2.55965],
       [-0.73123,  4.50433, -1.99235,  4.814  , -1.89074,  1.88739, -0.18067, -1.21934,  2.55252,  2.48996],
       [-5.39314,  

### Option 3.1 — right singular vectors (with early slicing)

In [16]:
Z = jnp.r_[W, L]
singular_vectors, singular_values, _ = jnp.linalg.svd(Z @ Z.T, hermitian=True, full_matrices=False)
singular_values = jnp.sqrt(singular_values) # square root of eigenvalues

lr_new = jnp.diag(1/singular_values) @ singular_vectors.T @ Z # solving for right singular vectors
lr_new = jnp.diag(singular_values[:d]) @ lr_new[:d] # keep top-d singular values and top-d singular vectors
lr_new

Array([[ 0.93407, -2.50154,  0.99415, -2.93843,  1.59326, -0.67353, -0.19101,  0.32054, -1.5005 , -2.53536],
       [ 1.2001 ,  1.97021, -1.23696,  0.28369,  0.97543,  1.1487 , -0.14118, -1.34327,  0.51564, -2.47215],
       [ 0.2749 , -1.41491,  0.73218,  1.37993, -0.01608,  1.51784, -2.10504, -0.24397,  0.1874 , -0.21138]], dtype=float32)

In [17]:
lr_new.T @ lr_new

Array([[ 2.3883 , -0.3611 , -0.35461, -2.02488,  2.6544 ,  1.16671, -0.92652, -1.37972, -0.73123, -5.39314],
       [-0.3611 , 12.14139, -5.95995,  5.95702, -2.04106,  1.80042,  3.17812, -3.10317,  4.50433,  1.77073],
       [-0.35461, -5.95995,  3.0545 , -2.2618 ,  0.3656 , -0.97916, -1.55652,  1.80161, -1.99235,  0.38266],
       [-2.02488,  5.95702, -2.2618 , 10.61903, -4.42717,  4.39949, -2.38358, -1.65962,  4.814  ,  6.45695],
       [ 2.6544 , -2.04106,  0.3656 , -4.42717,  3.49021,  0.02296, -0.40818, -0.79563, -1.89074, -6.4475 ],
       [ 1.16671,  1.80042, -0.97916,  4.39949,  0.02296,  4.077  , -3.22864, -2.12922,  1.88739, -1.45299],
       [-0.92652,  3.17812, -1.55652, -2.38358, -0.40818, -3.22864,  4.4876 ,  0.64199, -0.18067,  1.27827],
       [-1.37972, -3.10317,  1.80161, -1.65962, -0.79563, -2.12922,  0.64199,  1.96664, -1.21934,  2.55965],
       [-0.73123,  4.50433, -1.99235,  4.814  , -1.89074,  1.88739, -0.18067, -1.21934,  2.55252,  2.48996],
       [-5.39314,  