# Singular Value Decomposition

In [1]:
import torch
import numpy as np
_ = torch.manual_seed(0)

Generate a rank-deficient matrix W

In [5]:
d, k = 10, 10

# This way we can generate a rank-deficient matrix
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)
print(W)

tensor([[ 0.2355,  0.1049,  0.0277,  0.2717, -0.1069,  0.3124,  0.2668,  0.1211,
         -0.9132,  0.2465],
        [ 0.8507,  0.5331,  0.0902,  0.5400, -0.4245,  1.0022,  1.0993,  0.4476,
         -3.0183,  1.0352],
        [-0.2607, -0.4673, -0.0085,  0.7045,  0.2058, -0.0584, -0.6039, -0.1573,
          0.3727, -0.6024],
        [-0.3829, -0.9508,  0.0042,  1.7915,  0.3681,  0.1304, -1.1195, -0.2485,
          0.0672, -1.1328],
        [ 0.1658,  1.1941, -0.0511, -3.0148, -0.3542, -0.6964,  1.1723,  0.1595,
          1.3917,  1.2245],
        [ 0.6174,  0.9482,  0.0301, -1.2144, -0.4478,  0.2682,  1.2910,  0.3620,
         -1.1710,  1.2778],
        [-0.8711, -0.3998, -0.1016, -0.9711,  0.3983, -1.1457, -0.9972, -0.4486,
          3.3560, -0.9229],
        [-0.7299, -0.7823, -0.0569,  0.4664,  0.4451, -0.5942, -1.2286, -0.4055,
          1.9997, -1.1929],
        [-0.5895, -0.9108, -0.0284,  1.1752,  0.4290, -0.2516, -1.2375, -0.3460,
          1.1083, -1.2252],
        [ 0.7297,  

Evaluate the rank of the matrix W

In [6]:
W_rank = np.linalg.matrix_rank(W)
print(f'Rank of W: {W_rank}')

Rank of W: 2


Calculate the SVD decomposition of the W matrix.

In [8]:
# Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)

# For rank-r factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()  # Transpose V_r to get the right dimensions

# Compute B = U_r * S_r and A = V_r
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


Given the same input, check the output using the original W matrix and the matrices resulting from the decomposition.

In [9]:
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + bias
y = W @ x + bias
# Compute y' = (B*A)x + bias
y_prime = (B @ A) @ x + bias

print("Original y using W:\n", y)
print("")
print("y' computed using BA:\n", y_prime)

Original y using W:
 tensor([ 0.1179, -3.5418,  2.4757,  5.2276, -5.0896, -6.7406,  5.7523,  5.8828,
         5.4308, -5.6584])

y' computed using BA:
 tensor([ 0.1179, -3.5418,  2.4757,  5.2276, -5.0896, -6.7406,  5.7523,  5.8828,
         5.4308, -5.6584])


In [10]:
print("Total parameters of W: ", W.nelement())
print("Total parameters of B and A: ", B.nelement() + A.nelement())

Total parameters of W:  100
Total parameters of B and A:  40
