In [2]:
import torch
import numpy as np

_ = torch.manual_seed(0)

In [3]:
d, k = 10, 10
w_rank = 2
w = torch.randn(d, w_rank) @ torch.randn(w_rank, k)

In [4]:
w_rank = np.linalg.matrix_rank(w)
print(f'Rank of w: {w_rank}')

Rank of w: 2


In [5]:
# Perfom SVD on W ( W = UxSxV^T)
U, S, V = torch.svd(w)

U_r = U[:, :w_rank]
S_r = torch.diag(S[:w_rank])
V_r = V[:, :w_rank].t()

B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


In [6]:
bias = torch.randn(d)
x = torch.randn(d)

y = w @ x + bias

y_prime = (B @ A) @ x + bias
print("Original y using W:\n", y)
print("")
print("y' computed using BA:\n", y_prime)

Original y using W:
 tensor([ 7.2684e+00,  2.3162e+00,  7.7151e+00, -1.0446e+01, -8.1639e-03,
        -3.7270e+00, -1.1146e+01,  2.0207e+00, -9.6258e+00, -4.1163e+00])

y' computed using BA:
 tensor([ 7.2684e+00,  2.3162e+00,  7.7151e+00, -1.0446e+01, -8.1638e-03,
        -3.7270e+00, -1.1146e+01,  2.0207e+00, -9.6258e+00, -4.1163e+00])


In [8]:
print("Total parameters of W: ", w.nelement())
print("Total parameters of B and A: ", B.nelement() + A.nelement())

Total parameters of W:  100
Total parameters of B and A:  40
