In [22]:
import torch
import numpy as np

_ = torch.manual_seed(42)

In [41]:
d, k = 10, 10

# Forcefully creating a rank deficient matrix
rank = 2
W = torch.randn(d,rank) @ torch.randn(rank,k)
print(W)

tensor([[-0.2372,  1.5820,  1.9616,  2.2083,  1.6262,  1.6086,  0.5403,  1.8828,
          0.5303,  0.3491],
        [-0.4614,  2.0330,  3.3893,  4.6533,  2.7196,  2.9572,  1.2604,  2.7854,
          0.9951,  0.7921],
        [ 0.3682, -0.5472, -2.2663, -4.0833, -1.7140, -2.1839, -1.2218, -1.3201,
         -0.7569, -0.7488],
        [-0.0654, -0.3721,  0.2107,  0.8858,  0.1049,  0.3106,  0.3110, -0.1598,
          0.1180,  0.1837],
        [ 0.1002, -0.6852, -0.8353, -0.9266, -0.6940, -0.6821, -0.2247, -0.8095,
         -0.2245, -0.1456],
        [-0.0090, -0.0484,  0.0303,  0.1213,  0.0157,  0.0434,  0.0423, -0.0196,
          0.0164,  0.0250],
        [ 0.2231, -1.3954, -1.8071, -2.1085, -1.4901, -1.4976, -0.5267, -1.6932,
         -0.4955, -0.3384],
        [ 0.0945, -0.2484, -0.6256, -1.0108, -0.4857, -0.5781, -0.2919, -0.4294,
         -0.1980, -0.1804],
        [ 0.1641, -1.8107, -1.6495, -1.2820, -1.4292, -1.2304, -0.2302, -1.9041,
         -0.3918, -0.1640],
        [ 0.4905, -

In [42]:
U, S , V = torch.svd(W)

print(f"U : {U.shape}")
print(f"S : {S.shape}")
print(f"V : {V.shape}")

U : torch.Size([10, 10])
S : torch.Size([10])
V : torch.Size([10, 10])


In [43]:
# rank factorization
U_r = U[:,:rank]
S_r = torch.diag(S[:rank])
V_r = V[:,:rank].t()

print(f"U_r : {U_r.shape}")
print(f"S_r : {S_r.shape}")
print(f"V_r : {V_r.shape}")

U_r : torch.Size([10, 2])
S_r : torch.Size([2, 2])
V_r : torch.Size([2, 10])


In [44]:
B = U_r @ S_r
A = V_r

print(f"B : {B.shape}")
print(f"A : {A.shape}")

B : torch.Size([10, 2])
A : torch.Size([2, 10])


In [45]:
bias = torch.randn(d)
x = torch.randn(d)

In [46]:
y = W @ x + bias
y_prime = (B @ A) @ x + bias

In [47]:
print(f"Original Output : \n{y}")
print(f"Low Rank Output : \n{y_prime}")

Original Output : 
tensor([ 1.8099,  3.7031, -1.4406,  0.4295, -1.1809,  0.4950, -1.3792, -0.5206,
        -0.6639, -4.0201])
Low Rank Output : 
tensor([ 1.8099,  3.7031, -1.4406,  0.4295, -1.1809,  0.4950, -1.3792, -0.5206,
        -0.6639, -4.0201])


In [48]:
print(f"Total parameters in original : {W.nelement()}")
print(f"Total parameters in B and A : {B.nelement() + A.nelement()}")

Total parameters in original : 100
Total parameters in B and A : 40
