In [1]:
import torch
import numpy as np

### Generate a rank-deficient matrix W

In [2]:
d, k = 10, 10
W_rank = 2
W = torch.randn(d, W_rank) @ torch.randn(W_rank, k)
print(W.shape, W.dtype)
W

torch.Size([10, 10]) torch.float32


tensor([[-0.2690,  1.4512,  3.2885,  0.9916, -0.3192, -0.5900,  0.8768, -0.2148,
          0.8403, -1.7063],
        [-1.0573, -0.9110,  0.6831, -1.8105, -0.7579, -1.0126,  1.4853,  0.7139,
         -0.2056,  1.2645],
        [ 0.3463,  1.1518,  1.3556,  1.3295,  0.1842,  0.1631, -0.2336, -0.4349,
          0.5200, -1.4426],
        [ 1.5132,  3.8337,  3.7048,  4.7743,  0.8948,  0.9494, -1.3756, -1.6176,
          1.6362, -4.8585],
        [ 0.0119, -0.8964, -1.6858, -0.7620,  0.0766,  0.1905, -0.2855,  0.2055,
         -0.4786,  1.0784],
        [ 0.4635, -0.4313, -1.8368,  0.0770,  0.3946,  0.6080, -0.8974, -0.1173,
         -0.3504,  0.4466],
        [-0.8219, -2.0977, -2.0409, -2.6064, -0.4848, -0.5126,  0.7426,  0.8822,
         -0.8969,  2.6575],
        [-0.9006, -2.4675, -2.5487, -3.0019, -0.5186, -0.5284,  0.7637,  1.0065,
         -1.0723,  3.1155],
        [-2.0535, -1.7387,  1.3834, -3.4900, -1.4744, -1.9727,  2.8938,  1.3793,
         -0.3831,  2.4191],
        [ 0.0707,  

### Evaluate rank of the matrix W

In [3]:
W_rank = np.linalg.matrix_rank(W)
print(f"Rank of W: {W_rank}")

Rank of W: 2


### Calculate SVD decomposition of the W matrix

In [4]:
U, S, V = torch.svd(W)
print(f"U shape: {U.shape}, S shape: {S.shape}, V shape: {V.shape}")
print(f"U rank: {np.linalg.matrix_rank(U)}, S rank: {np.linalg.matrix_rank(S)}, V rank: {np.linalg.matrix_rank(V)}")

U shape: torch.Size([10, 10]), S shape: torch.Size([10]), V shape: torch.Size([10, 10])
U rank: 10, S rank: 1, V rank: 10


#### For rank-r factorization, keep only the first r singular values (corresponding columns of U and V)

In [5]:
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()
print(f"U_r shape: {U_r.shape}, S_r shape: {S_r.shape}, V_r shape: {V_r.shape}")
print(f"U_r rank: {np.linalg.matrix_rank(U_r)}, S_r rank: {np.linalg.matrix_rank(S_r)}, V_r rank: {np.linalg.matrix_rank(V_r)}")

U_r shape: torch.Size([10, 2]), S_r shape: torch.Size([2, 2]), V_r shape: torch.Size([2, 10])
U_r rank: 2, S_r rank: 2, V_r rank: 2


#### Creating the A and B matrices

In [6]:
B = U_r @ S_r
A = V_r
print(f"B shape: {B.shape}, A shape: {A.shape}")
print(f"B rank: {np.linalg.matrix_rank(B)}, A rank: {np.linalg.matrix_rank(A)}")

B shape: torch.Size([10, 2]), A shape: torch.Size([2, 10])
B rank: 2, A rank: 2


### For the same input, check whether the original matrix and matrices after decomposition are giving the same result

In [7]:
bias = torch.randn(d, 1)
x = torch.randn(d, 1)
print(f"bias shape: {bias.shape}, x shape: {x.shape}")

bias shape: torch.Size([10, 1]), x shape: torch.Size([10, 1])


#### Compute y = Wx + bias

In [8]:
y = W @ x + bias
print(f"Output from original W: {y}")

Output from original W: tensor([[ 1.5075],
        [ 1.4519],
        [-0.6265],
        [-3.7423],
        [-1.2784],
        [-0.2425],
        [ 1.5680],
        [ 2.8867],
        [ 2.1521],
        [ 0.1100]])


In [9]:
y_prime = (B @ A) @ x + bias
print(f"Output from decomposed W: {y_prime}")

Output from decomposed W: tensor([[ 1.5075],
        [ 1.4519],
        [-0.6265],
        [-3.7423],
        [-1.2784],
        [-0.2425],
        [ 1.5680],
        [ 2.8867],
        [ 2.1521],
        [ 0.1100]])


In [10]:
print(f"Total parameters in W: {W.nelement()}")
print(f"Total parameters in B and A: {B.nelement() + A.nelement()}")

Total parameters in W: 100
Total parameters in B and A: 40
