<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Galore_Truncated_Randomized_CUR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

SVD

In [2]:
import torch

# Generate a random 3072x3072 tensor
A = torch.randn(3072, 3072)

# Perform Singular Value Decomposition (SVD)
U, S, Vh = torch.linalg.svd(A, full_matrices=False)

# Take the first 128 columns of U to form a 3072x128 orthogonal matrix
U_128 = U[:, :128]

# Verify orthogonality (U_128.T @ U_128 should be close to identity)
orthogonality_check = U_128.T @ U_128
identity_matrix = torch.eye(128, device=U_128.device)

print("Orthogonality check (should be close to identity):")
print(torch.allclose(orthogonality_check, identity_matrix, atol=1e-6))

print('orthogonal mat:', U_128.shape)


Orthogonality check (should be close to identity):
False
orthogonal mat: torch.Size([3072, 128])


Truncated SVD

In [7]:
import torch

# Generate a random 3072x3072 tensor
A = torch.randn(3072, 3072)

# Use PyTorch's svd_lowrank (efficient low-rank SVD)
U, S, Vh = torch.svd_lowrank(A, q=128)  # Compute only the first 128 singular values/vectors

# Truncated projection: Reduce A to 3072x128
A_128 = U @ torch.diag(S)

print("Reduced matrix shape:", A_128.shape)  # Expected: (3072, 128)

# Verify orthogonality
orthogonality_check = U.T @ U
identity_matrix = torch.eye(128, device=U.device)
print("Orthogonality check (should be close to identity):", torch.allclose(orthogonality_check, identity_matrix, atol=1e-6))



Reduced matrix shape: torch.Size([3072, 128])
Orthogonality check (should be close to identity): True


Randomized SVD

In [9]:
import torch

def randomized_svd(A, k=128, n_iter=5):
    """
    Compute the randomized SVD of matrix A.

    Args:
        A (torch.Tensor): Input matrix (m x n).
        k (int): Number of singular values/vectors to compute.
        n_iter (int): Number of power iterations (improves accuracy for structured matrices).

    Returns:
        U (torch.Tensor): Left singular vectors (m x k).
        S (torch.Tensor): Singular values (k).
        Vh (torch.Tensor): Right singular vectors (k x n).
    """
    m, n = A.shape

    # Step 1: Random Projection
    Q = torch.randn(n, k, device=A.device)  # Random Gaussian matrix
    Y = A @ Q  # Project A onto a lower-dimensional space

    # Step 2: Power Iteration (improves approximation for structured data)
    for _ in range(n_iter):
        Y = A @ (A.T @ Y)

    # Step 3: Orthonormalization (QR decomposition)
    Q, _ = torch.linalg.qr(Y)  # Q is m x k

    # Step 4: Compute SVD on the projected small matrix
    B = Q.T @ A  # k x n matrix
    U_hat, S, Vh = torch.linalg.svd(B, full_matrices=False)  # SVD of reduced matrix

    # Step 5: Compute final U
    U = Q @ U_hat  # Convert back to original space

    return U, S, Vh

# Generate a large random 3072x3072 matrix
A = torch.randn(3072, 3072, device="cuda")  # Using GPU if available

# Compute randomized SVD with 128 components
U, S, Vh = randomized_svd(A, k=128, n_iter=5)

# Verify results
print("final U shape:", U.shape)   # (3072, 128)
print("S shape:", S.shape)   # (128,)
print("Vh shape:", Vh.shape) # (128, 3072)

# Verify orthogonality (U.T @ U should be identity)
orthogonality_check = U.T @ U
identity_matrix = torch.eye(128, device=U.device)
print("Orthogonality check (should be close to identity):", torch.allclose(orthogonality_check, identity_matrix, atol=1e-6))


final U shape: torch.Size([3072, 128])
S shape: torch.Size([128])
Vh shape: torch.Size([128, 3072])
Orthogonality check (should be close to identity): False


CUR Decomposition

In [6]:
import torch

def cur_decomposition(A, k=128):
    """
    Compute the CUR decomposition of matrix A.

    Args:
        A (torch.Tensor): Input matrix (m x n).
        k (int): Number of rows and columns to select.

    Returns:
        C (torch.Tensor): Selected columns (m x k).
        U (torch.Tensor): Low-rank representation (k x k).
        R (torch.Tensor): Selected rows (k x n).
    """
    m, n = A.shape

    # Compute column selection probabilities (based on squared column norms)
    col_norms = torch.norm(A, dim=0) ** 2
    col_probs = col_norms / torch.sum(col_norms)

    # Select k columns based on probabilities
    col_indices = torch.multinomial(col_probs, k, replacement=False)
    C = A[:, col_indices]

    # Compute row selection probabilities (based on squared row norms)
    row_norms = torch.norm(A, dim=1) ** 2
    row_probs = row_norms / torch.sum(row_norms)

    # Select k rows based on probabilities
    row_indices = torch.multinomial(row_probs, k, replacement=False)
    R = A[row_indices, :]

    # Compute U (pseudo-inverse of intersection submatrix)
    W = A[row_indices[:, None], col_indices]  # Intersection submatrix
    U = torch.linalg.pinv(W)  # Compute pseudo-inverse

    return C, U, R

# Generate a random 3072x3072 matrix
A = torch.randn(3072, 3072, device="cuda")  # Using GPU if available

# Perform CUR decomposition with 128 selected columns/rows
C, U, R = cur_decomposition(A, k=128)

# Verify shapes
print("C final shape:", C.shape)   # Expected: (3072, 128)
print("U shape:", U.shape)   # Expected: (128, 128)
print("R shape:", R.shape)   # Expected: (128, 3072)

# Check reconstruction error
A_approx = C @ U @ R
reconstruction_error = torch.norm(A - A_approx) / torch.norm(A)
print("Relative reconstruction error:", reconstruction_error.item())


C shape: torch.Size([3072, 128])
U shape: torch.Size([128, 128])
R shape: torch.Size([128, 3072])
Relative reconstruction error: 9.713813781738281
orthogonal mat: torch.Size([3072, 128])
