# Singular Value Decomposition

In [3]:
import pandas as pd
import torch
import numpy as np

In [5]:
_ = torch.manual_seed(0)

Generate a rank-deficient matrix W

In [20]:
d, k = 10, 10

"""
This way we can generate a rank-deficient matrix. 

The key here is that the rank of a matrix is equal to the number of linearly independent rows in a matrix. 
This is calculated as the max number of non zeros rows when reduced to row echelon form.
The rank of a matrix tells us how many unique directions a matrix can hold in a vector space (higher rank = more independent information found in the matrix)
"""
W_rank = 2
W = torch.randn(d, W_rank) @ torch.randn(W_rank, k)
# print(W)

Evaluate the rank of the matrix W

In [21]:
W_rank_calculated = np.linalg.matrix_rank(W)
print(f"Rank of W: {W_rank_calculated}")

Rank of W: 2


Calculate the SVD Decomposition of the W matrix

In [27]:
# perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)

# For rank-r factorization, keep only the first r singular values (and corresponding columns from U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()  # Transpose V_r to get the right dimensions

# Compoute C = U_r * S_r and R = V_r
B = U_r @ S_r
A = V_r
print(f"Shape of B: {B.shape}")
print(f"Shape of A: {A.shape}")

Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


Given the same input, check the output using the original W matrix and the matricies resulting from the decomposition

In [31]:
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + b
y = W @ x + bias
# Compute y' = BAx + b
y_prime = (B @ A) @ x + bias

print(f"Original y using W: \n {y}")
print()
print(f"y' using BA: \n {y_prime}")

Original y using W: 
 tensor([ 2.6094,  0.0902, -3.1996,  0.8718,  3.6818,  2.6134,  3.1317, -0.5616,
         1.2925,  3.5579])

y' using BA: 
 tensor([ 2.6094,  0.0902, -3.1996,  0.8718,  3.6818,  2.6134,  3.1317, -0.5616,
         1.2925,  3.5579])


In [32]:
print("Total parameters of W: ", W.nelement())
print("Total parameters of B and A: ", B.nelement() + A.nelement())

Total parameters of W:  100
Total parameters of B and A:  40
