<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Galore_Truncated_Randomized_CUR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

SVD

In [None]:
import torch
import time

# Generate a random 3072x3072 tensor
A = torch.randn(3072, 3072)

st = time.time()
# Perform Singular Value Decomposition (SVD)
U, S, Vh = torch.linalg.svd(A, full_matrices=False)

# Take the first 128 columns of U to form a 3072x128 orthogonal matrix
U_128 = U[:, :128]
en = time.time()

print('time:', en-st)
# Verify orthogonality (U_128.T @ U_128 should be close to identity)
orthogonality_check = U_128.T @ U_128
identity_matrix = torch.eye(128, device=U_128.device)

print("Orthogonality check (should be close to identity):")
print(torch.allclose(orthogonality_check, identity_matrix, atol=1e-6))

print('orthogonal mat:', U_128.shape)


time: 10.747417211532593
Orthogonality check (should be close to identity):
False
orthogonal mat: torch.Size([3072, 128])


Truncated SVD

In [None]:
import torch

# Generate a random 3072x3072 tensor
A = torch.randn(3072, 3072)

st = time.time()
# Use PyTorch's svd_lowrank (efficient low-rank SVD)
U, S, Vh = torch.svd_lowrank(A, q=128)  # Compute only the first 128 singular values/vectors

# Truncated projection: Reduce A to 3072x128
A_128 = U @ torch.diag(S)
en = time.time()

print('time:', en-st)
print("Reduced matrix shape:", A_128.shape)  # Expected: (3072, 128)

# Verify orthogonality
orthogonality_check = U.T @ U
identity_matrix = torch.eye(128, device=U.device)
print("Orthogonality check (should be close to identity):", torch.allclose(orthogonality_check, identity_matrix, atol=1e-6))



time: 0.26743578910827637
Reduced matrix shape: torch.Size([3072, 128])
Orthogonality check (should be close to identity): True


Randomized SVD

In [None]:
import torch

def randomized_svd(A, k=128, n_iter=5):
    """
    Compute the randomized SVD of matrix A.

    Args:
        A (torch.Tensor): Input matrix (m x n).
        k (int): Number of singular values/vectors to compute.
        n_iter (int): Number of power iterations (improves accuracy for structured matrices).

    Returns:
        U (torch.Tensor): Left singular vectors (m x k).
        S (torch.Tensor): Singular values (k).
        Vh (torch.Tensor): Right singular vectors (k x n).
    """
    m, n = A.shape

    # Step 1: Random Projection
    Q = torch.randn(n, k, device=A.device)  # Random Gaussian matrix
    Y = A @ Q  # Project A onto a lower-dimensional space

    # Step 2: Power Iteration (improves approximation for structured data)
    for _ in range(n_iter):
        Y = A @ (A.T @ Y)

    # Step 3: Orthonormalization (QR decomposition)
    Q, _ = torch.linalg.qr(Y)  # Q is m x k

    # Step 4: Compute SVD on the projected small matrix
    B = Q.T @ A  # k x n matrix
    U_hat, S, Vh = torch.linalg.svd(B, full_matrices=False)  # SVD of reduced matrix

    # Step 5: Compute final U
    U = Q @ U_hat  # Convert back to original space

    return U, S, Vh

# Generate a large random 3072x3072 matrix
A = torch.randn(3072, 3072, device="cuda")  # Using GPU if available

# Compute randomized SVD with 128 components
U, S, Vh = randomized_svd(A, k=128, n_iter=5)

# Verify results
print("final U shape:", U.shape)   # (3072, 128)
print("S shape:", S.shape)   # (128,)
print("Vh shape:", Vh.shape) # (128, 3072)

# Verify orthogonality (U.T @ U should be identity)
orthogonality_check = U.T @ U
identity_matrix = torch.eye(128, device=U.device)
print("Orthogonality check (should be close to identity):", torch.allclose(orthogonality_check, identity_matrix, atol=1e-6))


final U shape: torch.Size([3072, 128])
S shape: torch.Size([128])
Vh shape: torch.Size([128, 3072])
Orthogonality check (should be close to identity): False


CUR Decomposition

In [None]:
import torch

def cur_decomposition(A, k=128):
    """
    Compute the CUR decomposition of matrix A.

    Args:
        A (torch.Tensor): Input matrix (m x n).
        k (int): Number of rows and columns to select.

    Returns:
        C (torch.Tensor): Selected columns (m x k).
        U (torch.Tensor): Low-rank representation (k x k).
        R (torch.Tensor): Selected rows (k x n).
    """
    m, n = A.shape

    # Compute column selection probabilities (based on squared column norms)
    col_norms = torch.norm(A, dim=0) ** 2
    col_probs = col_norms / torch.sum(col_norms)

    # Select k columns based on probabilities
    col_indices = torch.multinomial(col_probs, k, replacement=False)
    C = A[:, col_indices]

    # Compute row selection probabilities (based on squared row norms)
    row_norms = torch.norm(A, dim=1) ** 2
    row_probs = row_norms / torch.sum(row_norms)

    # Select k rows based on probabilities
    row_indices = torch.multinomial(row_probs, k, replacement=False)
    R = A[row_indices, :]

    # Compute U (pseudo-inverse of intersection submatrix)
    W = A[row_indices[:, None], col_indices]  # Intersection submatrix
    U = torch.linalg.pinv(W)  # Compute pseudo-inverse

    return C, U, R

# Generate a random 3072x3072 matrix
A = torch.randn(3072, 3072, device="cuda")  # Using GPU if available

# Perform CUR decomposition with 128 selected columns/rows
C, U, R = cur_decomposition(A, k=128)

# Verify shapes
print("C final shape:", C.shape)   # Expected: (3072, 128)
print("U shape:", U.shape)   # Expected: (128, 128)
print("R shape:", R.shape)   # Expected: (128, 3072)

# Check reconstruction error
A_approx = C @ U @ R
reconstruction_error = torch.norm(A - A_approx) / torch.norm(A)
print("Relative reconstruction error:", reconstruction_error.item())


C shape: torch.Size([3072, 128])
U shape: torch.Size([128, 128])
R shape: torch.Size([128, 3072])
Relative reconstruction error: 9.713813781738281
orthogonal mat: torch.Size([3072, 128])


Kerbel SVD

In [None]:
import torch

def rbf_kernel_torch(A, gamma=None):
    """
    Compute the RBF Kernel using PyTorch.

    Args:
        A (torch.Tensor): Input matrix (m x n).
        gamma (float): Kernel coefficient (default: 1 / feature_dim).

    Returns:
        K (torch.Tensor): Kernel matrix (m x m).
    """
    m, n = A.shape
    if gamma is None:
        gamma = 1.0 / n  # Default gamma = 1/n_features

    # Compute squared Euclidean distance using torch.cdist
    squared_distances = torch.cdist(A, A, p=2) ** 2

    # Compute RBF Kernel
    K = torch.exp(-gamma * squared_distances)
    return K

def kernel_svd(A, k=128):
    """
    Compute Kernel SVD (K-SVD) using the RBF kernel.

    Args:
        A (torch.Tensor): Input matrix (m x n).
        k (int): Number of singular components to keep.

    Returns:
        U_k (torch.Tensor): Kernelized left singular vectors (m x k).
        S_k (torch.Tensor): Singular values (k).
        V_k (torch.Tensor): Kernelized right singular vectors (k x n).
    """
    device = A.device

    # Step 1: Compute Kernel Matrix using RBF
    K = rbf_kernel_torch(A).to(device)  # Kernel transformation

    # Step 2: Perform SVD on Kernel Matrix
    U, S, Vh = torch.linalg.svd(K, full_matrices=False)

    # Step 3: Select top-k components
    U_k = U[:, :k]  # Left singular vectors
    S_k = S[:k]  # Singular values
    V_k = Vh[:k, :]  # Right singular vectors

    return U_k, S_k, V_k

# Generate a random 3072x3072 matrix (GPU-accelerated)
A = torch.randn(3072, 3072, device="cuda")  # Use GPU if available

# Compute Kernel SVD with 128 components
U_k, S_k, V_k = kernel_svd(A, k=128)

# Verify output shapes
print("U_k shape:", U_k.shape)  # Expected: (3072, 128)
print("S_k shape:", S_k.shape)  # Expected: (128,)
print("V_k shape:", V_k.shape)  # Expected: (128, 3072)

# Compute reconstruction error (optional)
K_approx = U_k @ torch.diag(S_k) @ V_k
error = torch.norm(K_approx - rbf_kernel_torch(A)) / torch.norm(rbf_kernel_torch(A))
print("Relative reconstruction error:", error.item())


U_k shape: torch.Size([3072, 128])
S_k shape: torch.Size([128])
V_k shape: torch.Size([128, 3072])
Relative reconstruction error: 0.11192391067743301


V1: FFT-Based Projection (Fast Fourier Transform)

In [5]:
import torch

def fft_svd_projection(A, k=128):
    """
    Approximate SVD using Fast Fourier Transform (FFT),
    projecting a 3072x3072 matrix to a lower-rank 3072x128 representation.

    Args:
        A (torch.Tensor): Input matrix (m x n).
        k (int): Number of frequency components to keep (columns in the output).

    Returns:
        A_k (torch.Tensor): 3072x128 low-rank approximation using FFT.
    """
    device = A.device
    m, n = A.shape

    # Compute 2D Fourier Transform
    A_fft = torch.fft.fft2(A)

    # Keep only the top-k frequencies along columns
    A_fft_k = torch.zeros_like(A_fft)
    A_fft_k[:, :k] = A_fft[:, :k]  # Retain low-frequency components in columns

    # Inverse FFT to reconstruct low-rank matrix
    A_k = torch.fft.ifft2(A_fft_k).real  # Take the real part after inverse FFT

    return A_k[:, :k]  # Return only the first 128 columns

# Fix the seed for reproducibility
torch.manual_seed(42)
# Generate a random 3072x3072 matrix
A = torch.randn(3072, 3072, device="cuda")  # Use GPU if available
print(A.sum())

# Compute FFT-based low-rank projection (output: 3072x128)
A_k = fft_svd_projection(A, k=128)
print(A_k.sum())

# Verify output shape
print("Low-rank projected matrix shape:", A_k.shape)  # Expected: (3072, 128)

# # Check approximation error
# reconstruction_error = torch.norm(A[:, :128] - A_k) / torch.norm(A[:, :128])
# print("Relative reconstruction error:", reconstruction_error.item())


tensor(-3362.1399, device='cuda:0')
tensor(-186.6012, device='cuda:0')
Low-rank projected matrix shape: torch.Size([3072, 128])


V2: FFT-Based Projection (Fast Fourier Transform)

In [7]:
import torch

def fft_low_rank_projection_torch(matrix: torch.Tensor, rank_k: int = 128) -> torch.Tensor:
    """
    Perform FFT-based low-rank projection from [N x N] to [N x rank_k] using PyTorch.

    Args:
        matrix (torch.Tensor): Input tensor of shape [N, N].
        rank_k (int): Rank of the low-dimensional projection (number of frequency components to keep).

    Returns:
        torch.Tensor: Projected tensor of shape [N, rank_k].
    """
    assert matrix.shape[0] == matrix.shape[1], "Input must be a square matrix"
    N = matrix.shape[0]
    assert rank_k <= N, "rank_k must be <= input dimension"

    # Step 1: FFT along the columns (dim=1)
    matrix_fft = torch.fft.fft(matrix, dim=1)

    # Step 2: Keep only the first rank_k frequency components
    matrix_fft_reduced = matrix_fft[:, :rank_k]

    # Step 3: Take real part (or use both real and imag if needed)
    matrix_proj = matrix_fft_reduced.real

    return matrix_proj

# Example usage
# Fix the seed for reproducibility
torch.manual_seed(42)
matrix = torch.randn(3072, 3072, device="cuda")  # Use GPU if available
print(matrix.sum())
matrix_projected = fft_low_rank_projection_torch(matrix, rank_k=128)
print(matrix_projected.sum())
print("Projected shape:", matrix_projected.shape)  # torch.Size([3072, 128])


tensor(-3362.1399, device='cuda:0')
tensor(-8680.4062, device='cuda:0')
Projected shape: torch.Size([3072, 128])


V3: FFT-Based Projection (Fast Fourier Transform)

In [9]:
import torch

def fft_low_rank_projection_torch(matrix: torch.Tensor, rank_k: int = 128, use_complex: bool = True) -> torch.Tensor:
    """
    Perform FFT-based low-rank projection from [N x N] to [N x rank_k] (or [N x 2*rank_k] if use_complex) using PyTorch.

    Args:
        matrix (torch.Tensor): Input tensor of shape [N, N].
        rank_k (int): Rank of the low-dimensional projection (number of frequency components to keep).
        use_complex (bool): If True, concatenate both real and imag parts (output shape: [N, 2*rank_k]).

    Returns:
        torch.Tensor: Projected tensor of shape [N, rank_k] or [N, 2*rank_k].
    """
    assert matrix.shape[0] == matrix.shape[1], "Input must be a square matrix"
    N = matrix.shape[0]
    assert rank_k <= N, "rank_k must be <= input dimension"

    # FFT along columns
    matrix_fft = torch.fft.fft(matrix, dim=1)

    # Truncate to low-rank components
    matrix_fft_reduced = matrix_fft[:, :rank_k]

    if use_complex:
        # Concatenate real and imag parts → shape: [N, 2 * rank_k]
        matrix_proj = torch.cat([matrix_fft_reduced.real, matrix_fft_reduced.imag], dim=1)
    else:
        # Use only real part → shape: [N, rank_k]
        matrix_proj = matrix_fft_reduced.real

    return matrix_proj

# Fix the seed for reproducibility
torch.manual_seed(42)
matrix = torch.randn(3072, 3072, device="cuda")  # Use GPU if available
print(matrix.sum())
matrix_projected = fft_low_rank_projection_torch(matrix, rank_k=128)
print(matrix_projected.sum())
print("Projected shape:", matrix_projected.shape)  # torch.Size([3072, 128])

tensor(-3362.1399, device='cuda:0')
tensor(-25897.6562, device='cuda:0')
Projected shape: torch.Size([3072, 256])


V4: FFT-Based Projection (Fast Fourier Transform)

In [11]:
import torch
import torch.fft as fft

def fft_low_rank_projection(matrix: torch.Tensor, k: int = 128) -> torch.Tensor:
    """
    Project a square matrix to rank-k approximation using FFT frequency component selection.

    Args:
        matrix: Input square tensor [N x N]
        k: Target rank of the output tensor [N x k]

    Returns:
        Projected tensor of size [N x k] containing the most significant frequency components
    """
    assert matrix.dim() == 2, "Input must be a 2D tensor"
    assert matrix.size(0) == matrix.size(1), "Input matrix must be square"
    assert k <= matrix.size(1), f"Target rank k={k} must be ≤ original dimension {matrix.size(1)}"

    n = matrix.size(0)
    device = matrix.device

    # 1. Compute 2D FFT
    fft_matrix = fft.fft2(matrix)

    # 2. Create frequency mask for top-k components
    center = n // 2
    half_k = k // 2

    # Calculate column ranges to keep (centered around DC component)
    start_col = center - half_k
    end_col = center + half_k + (k % 2)  # Handle odd k

    # Create and apply mask
    mask = torch.zeros_like(fft_matrix, dtype=torch.bool)
    mask[:, start_col:end_col] = True
    truncated_fft = fft_matrix * mask

    # 3. Inverse FFT and column selection
    low_rank_approx = fft.ifft2(truncated_fft).real
    return low_rank_approx[:, start_col:end_col]


# Example usage
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create test matrix
    torch.manual_seed(42)
    A = torch.randn(3072, 3072, device=device)
    print(A.sum())
    # Project to rank-k approximation
    k = 128
    A_proj = fft_low_rank_projection(A, k=k)
    print(A_proj.sum())
    print(f"Original: {A.shape} | Projected (k={k}): {A_proj.shape}")
    print(f"Norm preservation: {A_proj.norm() / A.norm():.3f}")

Using device: cuda
tensor(-3362.1399, device='cuda:0')
tensor(-0.7844, device='cuda:0')
Original: torch.Size([3072, 3072]) | Projected (k=128): torch.Size([3072, 128])
Norm preservation: 0.041


V5: Randomized FFT Low-Rank Projection

In [15]:
import torch

def randomized_fft_low_rank_projection(matrix: torch.Tensor, rank_k: int = 128) -> torch.Tensor:
    """
    Apply Randomized FFT-based projection to approximate low-rank version of a square matrix.

    Args:
        matrix (torch.Tensor): Input [n x n] matrix.
        rank_k (int): Target low-rank dimension.

    Returns:
        torch.Tensor: Projected [n x 2*rank_k] matrix (real + imag parts).
    """
    assert matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1], "matrix must be square"

    n = matrix.shape[0]
    assert rank_k <= n, "rank_k must be less than or equal to matrix dimension"

    # Step 1: FFT along columns
    matrix_fft = torch.fft.fft(matrix, dim=1)  # shape: [n, n], dtype: complex

    # Step 2: Random Gaussian projection, converted to complex type
    Omega = torch.randn(n, rank_k, device=matrix.device).to(dtype=matrix_fft.dtype)
    Y = matrix_fft @ Omega  # shape: [n, rank_k], dtype: complex

    # Step 3: Concatenate real and imaginary parts
    Y_proj = torch.cat([Y.real, Y.imag], dim=1)  # shape: [n, 2 * rank_k], dtype: float

    return Y_proj

# Example usage
torch.manual_seed(42)
matrix = torch.randn(3072, 3072, device=device)
print(matrix.sum())
matrix_projected = randomized_fft_low_rank_projection(matrix, rank_k=128)
print(matrix_projected.sum())
print("Projected shape:", matrix_projected.shape)  # torch.Size([3072, 256])


tensor(-3362.1399, device='cuda:0')
tensor(886795.2500, device='cuda:0')
Projected shape: torch.Size([3072, 256])


V6: Randomized FFT Low-Rank Projection

In [17]:
import torch
import torch.fft as fft
from math import ceil

def randomized_fft_low_rank(matrix: torch.Tensor, k: int = 128, oversampling: int = 10, power_iter: int = 0) -> torch.Tensor:
    """
    Randomized FFT for low-rank approximation using the subspace iteration method.

    Args:
        matrix: Input matrix [M x N]
        k: Target rank
        oversampling: Additional samples for numerical stability (typically 5-10)
        power_iter: Number of power iterations to improve accuracy (0-2)

    Returns:
        Low-rank approximation matrix [M x k]
    """
    m, n = matrix.shape
    device = matrix.device
    l = k + oversampling

    # Step 1: Generate random test matrix
    omega = torch.randn(n, l, device=device)

    # Step 2: Form sketch matrix Y = A @ Ω
    Y = matrix @ omega

    # Step 3: Power iterations (optional)
    for _ in range(power_iter):
        Y = matrix @ (matrix.T @ Y)

    # Step 4: Compute QR decomposition of Y to get basis Q
    Q, _ = torch.linalg.qr(Y)

    # Step 5: Project A onto the basis
    B = Q.T @ matrix

    # Step 6: Compute FFT of the small matrix B
    B_fft = fft.fft(B, dim=1)

    # Step 7: Keep only top-k frequency components
    center = n // 2
    half_k = k // 2
    start_col = center - half_k
    end_col = center + half_k + (k % 2)

    mask = torch.zeros(n, dtype=torch.bool, device=device)
    mask[start_col:end_col] = True
    B_fft_trunc = B_fft[:, mask]

    # Step 8: Inverse FFT and reconstruct
    B_trunc = fft.ifft(B_fft_trunc, dim=1).real
    return Q @ B_trunc

# Example usage
torch.manual_seed(42)
matrix = torch.randn(3072, 3072, device=device)
print(matrix.sum())
matrix_projected = randomized_fft_low_rank(matrix, k=128)
print(matrix_projected.sum())
print("Projected shape:", matrix_projected.shape)  # torch.Size([3072, 256])


tensor(-3362.1399, device='cuda:0')
tensor(-109.6517, device='cuda:0')
Projected shape: torch.Size([3072, 128])


V7: Frequency Domain Low-Rank Matrix Completion

In [19]:
import torch

def frequency_domain_low_rank_projection(matrix: torch.Tensor, rank_k: int = 128, mask_ratio: float = 0.5) -> torch.Tensor:
    """
    Frequency-Domain Low-Rank Projection with random frequency masking.

    Args:
        matrix (torch.Tensor): Input [n x n] matrix.
        rank_k (int): Desired output rank (e.g., 128).
        mask_ratio (float): Fraction of frequency components to retain.

    Returns:
        torch.Tensor: Projected matrix of shape [n x rank_k].
    """
    n = matrix.shape[0]
    assert matrix.shape[0] == matrix.shape[1], "Only square matrices supported"
    assert 0 < mask_ratio <= 1.0, "mask_ratio must be in (0, 1]"

    # Step 1: FFT
    matrix_fft = torch.fft.fft2(matrix)

    # Step 2: Apply random frequency mask
    mask = (torch.rand_like(matrix_fft.real) < mask_ratio).to(matrix_fft.dtype)
    matrix_fft_masked = matrix_fft * mask

    # Step 3: Inverse FFT back to spatial domain
    matrix_reconstructed = torch.fft.ifft2(matrix_fft_masked).real

    # Step 4: Low-rank projection via SVD
    U, S, Vh = torch.linalg.svd(matrix_reconstructed, full_matrices=False)

    # Step 5: Project to rank_k (compressed form)
    X_proj = U[:, :rank_k] * S[:rank_k]  # shape: [n, rank_k]

    return X_proj

# Example usage
matrix = torch.randn(3072, 3072)
matrix_projected = frequency_domain_low_rank_projection(matrix, rank_k=128, mask_ratio=0.4)
print("Projected shape:", matrix_projected.shape)  # torch.Size([3072, 128])


Projected shape: torch.Size([3072, 128])


V8: Frequency Domain Low-Rank Matrix Completion Projection

In [20]:
import torch
import torch.nn.functional as F

def multiresolution_fft_low_rank_projection(matrix: torch.Tensor, rank_k: int = 128, coarse_scale: float = 0.25) -> torch.Tensor:
    """
    Perform multiresolution FFT-based low-rank projection.

    Args:
        matrix (torch.Tensor): Input square matrix [n x n].
        rank_k (int): Target projection dimension.
        coarse_scale (float): Scale factor for coarse FFT (e.g., 0.25 = 1/4 resolution).

    Returns:
        torch.Tensor: Projected matrix of shape [n, rank_k].
    """
    n = matrix.shape[0]
    assert matrix.shape[0] == matrix.shape[1], "Only square matrices are supported"
    assert 0 < coarse_scale < 1.0, "coarse_scale must be between 0 and 1"

    # Step 1: High-resolution FFT
    fine_fft = torch.fft.fft2(matrix)  # shape: [n, n], complex

    # Step 2: Downsample and compute coarse FFT
    down_n = int(n * coarse_scale)
    matrix_downsampled = F.interpolate(matrix.unsqueeze(0).unsqueeze(0), size=(down_n, down_n), mode='bilinear', align_corners=False).squeeze()
    coarse_fft = torch.fft.fft2(matrix_downsampled, s=(n, n))  # upsample back to [n, n]

    # Step 3: Combine coarse and fine FFTs (e.g., average or weighted sum)
    combined_fft = (fine_fft + coarse_fft) / 2  # shape: [n, n], complex

    # Step 4: Inverse FFT to get multiresolution spatial signal
    multires_spatial = torch.fft.ifft2(combined_fft).real  # shape: [n, n], real

    # Step 5: Low-rank projection (SVD to rank_k)
    U, S, Vh = torch.linalg.svd(multires_spatial, full_matrices=False)
    X_proj = U[:, :rank_k] * S[:rank_k]  # shape: [n, rank_k]

    return X_proj

# Example usage
matrix = torch.randn(3072, 3072)
matrix_projected = multiresolution_fft_low_rank_projection(matrix, rank_k=128, coarse_scale=0.25)
print("Projected shape:", matrix_projected.shape)  # torch.Size([3072, 128])


Projected shape: torch.Size([3072, 128])


v9: Multiresolution FFT (Fast Fourier Transform) for Low-Rank Approximation Projection

In [21]:
import numpy as np
from numpy.fft import fft, ifft

def multires_fft_projection(X, output_dim=128, resolution_levels=3, seed=42):
    """
    Perform a multiresolution FFT-based projection from [N x D] to [N x output_dim].

    Args:
        X: Input matrix of shape [N x D].
        output_dim: Target dimension after projection.
        resolution_levels: Number of multiresolution splits (controls frequency mixing).
        seed: Random seed for reproducibility.

    Returns:
        Projected matrix of shape [N x output_dim].
    """
    np.random.seed(seed)
    N, D = X.shape
    assert output_dim <= D, "Output dim must be <= input dim"
    assert D % output_dim == 0, "Input dim must be divisible by output dim"

    # Step 1: Apply FFT across each row (convert to frequency domain)
    X_fft = fft(X, axis=1)

    # Step 2: Multiresolution frequency mixing
    mixed_freq = np.zeros((N, output_dim), dtype=np.complex128)
    step = D // output_dim

    for i in range(output_dim):
        # Aggregate frequency bands at multiple resolutions
        indices = [i * step + (j % step) for j in range(resolution_levels)]
        for idx in indices:
            weight = np.random.randn() + 1j * np.random.randn()  # complex weight
            mixed_freq[:, i] += weight * X_fft[:, idx]

    # Step 3: Convert back to time domain and take real part
    projected = np.real(ifft(mixed_freq, axis=1))

    return projected

# Example usage
if __name__ == "__main__":
    D = 3072
    output_dim = 128
    X = np.random.randn(D, D)  # Simulate input matrix [3072 x 3072]
    X_proj = multires_fft_projection(X, output_dim=output_dim)

    print("Projected shape:", X_proj.shape)  # Should be (3072, 128)


Projected shape: (3072, 128)


V10: Thresholding for Sparsity and Low-Rank Approximation Projection

In [22]:
import numpy as np
from numpy.fft import fft, ifft

def fft_threshold_projection(W, output_dim=128, threshold=1e-2, axis=1):
    """
    FFT-based low-rank and sparse projection from [D, D] to [D, output_dim].

    Args:
        W: Input matrix of shape [D, D] (square matrix).
        output_dim: Desired output dimension (e.g., 128).
        threshold: Threshold for sparsity (small coefficients are set to 0).
        axis: Axis along which to apply FFT (1 = columns, 0 = rows).

    Returns:
        Projected matrix of shape [D, output_dim].
    """
    # 1. Apply FFT along the specified axis
    W_fft = fft(W, axis=axis)

    # 2. Keep only real part if imaginary part is negligible
    if np.all(np.isclose(W_fft.imag, 0)):
        W_fft = W_fft.real

    # 3. Thresholding for sparsity
    W_fft[np.abs(W_fft) < threshold] = 0

    # 4. Truncate to output_dim along the FFT axis (low-rank approximation)
    if axis == 1:
        W_proj_fft = W_fft[:, :output_dim]
    else:
        W_proj_fft = W_fft[:output_dim, :]

    # 5. Optionally apply inverse FFT to get back to time domain (commented if you want spectral domain)
    # W_proj = ifft(W_proj_fft, axis=axis).real
    # return W_proj

    return W_proj_fft

# Example usage
D = 3072
W = np.random.randn(D, D)
W_proj = fft_threshold_projection(W, output_dim=128, threshold=1e-2)

print("Projected shape:", W_proj.shape)


Projected shape: (3072, 128)


V11: Subspace-Tracking with FFT for Low-Rank Approximation

In [31]:
import torch
import torch.fft

def fft_low_rank_projection(X: torch.Tensor, target_dim: int, mode='row') -> torch.Tensor:
    """
    Perform low-rank projection of a square matrix using FFT and random subspace projection.

    Args:
        X (torch.Tensor): Input tensor of shape [N, N] (e.g., [3072, 3072])
        target_dim (int): Desired reduced dimension (e.g., 128)
        mode (str): 'row' or 'col' – whether to project rows or columns

    Returns:
        torch.Tensor: Projected tensor of shape [N, target_dim]
    """
    assert X.ndim == 2 and X.shape[0] == X.shape[1], "X must be a square 2D tensor"

    N = X.shape[0]

    # Step 1: FFT transform along the mode
    if mode == 'row':
        X_fft = torch.fft.fft(X, dim=1)  # FFT along rows
    elif mode == 'col':
        X_fft = torch.fft.fft(X, dim=0)  # FFT along columns
    else:
        raise ValueError("Mode must be 'row' or 'col'")

    # Step 2: Use only the real part for projection
    X_fft_real = X_fft.real

    # Step 3: Generate a random Gaussian projection matrix
    projection_matrix = torch.randn(N, target_dim, device=X.device) / (target_dim ** 0.5)

    # Step 4: Project the real FFT subspace
    X_proj = X_fft_real @ projection_matrix  # Shape: [3072, 128]

    return X_proj

X = torch.randn(3072, 3072)  # Input tensor
X_proj = fft_low_rank_projection(X, target_dim=128)
print(X_proj.shape)  # Should print: torch.Size([3072, 128])



torch.Size([3072, 128])


V12: Spectral Clustering via FFT for Low-Rank Approximation Projection

In [25]:
import numpy as np
from scipy.fft import fft, ifft
from sklearn.cluster import SpectralClustering
from scipy.linalg import svd

def fft_spectral_clustering_projection(matrix, target_dim):
    """
    Project a matrix to lower dimension using FFT and spectral clustering

    Args:
    matrix: Input matrix of shape (n, n)
    target_dim: Target dimension for projection (output will be (n, target_dim))

    Returns:
    Projected matrix of shape (n, target_dim)
    """
    n = matrix.shape[0]

    # Step 1: Apply 2D FFT to the matrix
    fft_matrix = fft(fft(matrix, axis=0), axis=1)

    # Step 2: Keep only the low-frequency components (top left quadrant)
    # This is a simple low-pass filter approach
    cutoff = target_dim // 2
    filtered_fft = np.zeros_like(fft_matrix)
    filtered_fft[:cutoff, :cutoff] = fft_matrix[:cutoff, :cutoff]

    # Step 3: Inverse FFT to get a smoothed version of the matrix
    smoothed_matrix = np.real(ifft(ifft(filtered_fft, axis=0), axis=1))

    # Step 4: Perform spectral clustering on the smoothed matrix
    # We'll use the eigenvectors as our projection
    spectral = SpectralClustering(n_clusters=target_dim,
                                affinity='precomputed',
                                random_state=42)

    # Create affinity matrix (absolute value of smoothed matrix)
    affinity_matrix = np.abs(smoothed_matrix)

    # Get the eigenvectors (this is the computationally intensive part)
    # We'll use SVD to approximate the eigenvectors
    _, _, vh = svd(affinity_matrix, full_matrices=False)
    projection_matrix = vh[:target_dim, :].T

    return projection_matrix

# Example usage:
if __name__ == "__main__":
    # Generate a random 3072x3072 matrix (in practice, use your actual matrix)
    np.random.seed(42)
    original_matrix = np.random.randn(3072, 3072)

    # Project to 3072x128
    projected_matrix = fft_spectral_clustering_projection(original_matrix, 128)

    print(f"Original matrix shape: {original_matrix.shape}")
    print(f"Projected matrix shape: {projected_matrix.shape}")

Original matrix shape: (3072, 3072)
Projected matrix shape: (3072, 128)


V14: Spectral Clustering via FFT for Low-Rank Approximation Projection

In [29]:
import torch
import torch.fft
import numpy as np
from sklearn.cluster import SpectralClustering

def fft_spectral_lowrank_projection(input_tensor, target_dim=128):
    """
    Projects a [3072x3072] tensor to [3072x128] using FFT and spectral clustering.

    Args:
        input_tensor: torch.Tensor of shape [3072, 3072]
        target_dim: desired output dimension (default: 128)

    Returns:
        projected_tensor: torch.Tensor of shape [3072, target_dim]
    """
    assert input_tensor.shape == (3072, 3072), "Input tensor must be 3072x3072"

    # Step 1: Compute 2D FFT of the input tensor
    fft_tensor = torch.fft.fft2(input_tensor)

    # Step 2: Shift zero frequency to center and get magnitude spectrum
    fft_shifted = torch.fft.fftshift(fft_tensor)
    magnitude_spectrum = torch.abs(fft_shifted)

    # Step 3: Convert to numpy for spectral clustering (sklearn doesn't work with torch tensors)
    magnitude_np = magnitude_spectrum.cpu().numpy()

    # Step 4: Perform spectral clustering on the magnitude spectrum
    # We'll cluster the rows into target_dim clusters
    spectral = SpectralClustering(
        n_clusters=target_dim,
        affinity='nearest_neighbors',
        n_neighbors=10,
        random_state=42
    )

    # Fit the model to the magnitude spectrum
    cluster_labels = spectral.fit_predict(magnitude_np)

    # Step 5: Create projection matrix by averaging within clusters
    projection_matrix = torch.zeros((3072, target_dim), device=input_tensor.device)

    for cluster_id in range(target_dim):
        # Get indices of rows belonging to this cluster
        cluster_indices = np.where(cluster_labels == cluster_id)[0]

        if len(cluster_indices) > 0:
            # Average the original rows (not the FFT) for this cluster
            cluster_rows = input_tensor[cluster_indices]
            projection_matrix[:, cluster_id] = cluster_rows.mean(dim=0)

    return projection_matrix

# Example usage:
if __name__ == "__main__":
    # Create a random 3072x3072 tensor
    original_tensor = torch.randn(3072, 3072)

    # Project to 3072x128
    projected_tensor = fft_spectral_lowrank_projection(original_tensor)

    print(f"Original shape: {original_tensor.shape}")
    print(f"Projected shape: {projected_tensor.shape}")



Original shape: torch.Size([3072, 3072])
Projected shape: torch.Size([3072, 128])


V15: Spectral Clustering via FFT for Low-Rank Approximation Projection

In [30]:
import torch
import torch.fft
from sklearn.cluster import SpectralClustering
import numpy as np

def fft_spectral_projection(input_tensor, target_dim=128, n_neighbors=5):
    """
    Projects [3072x3072] → [3072x128] using FFT frequency analysis + spectral clustering.
    More efficient implementation than original version.

    Args:
        input_tensor: torch.Tensor of shape [3072, 3072]
        target_dim: desired output dimension (default: 128)
        n_neighbors: number of neighbors for spectral clustering affinity

    Returns:
        projected_tensor: torch.Tensor of shape [3072, target_dim]
    """
    device = input_tensor.device

    # Stage 1: FFT Frequency Compression
    fft_tensor = torch.fft.fft2(input_tensor)
    fft_shifted = torch.fft.fftshift(fft_tensor)
    magnitude = torch.abs(fft_shifted)

    # Compress columns first by keeping top frequencies
    compressed_cols = magnitude[:, 1536-64:1536+64]  # Keep center 128 columns

    # Stage 2: Spectral Clustering on Reduced Data
    # Convert to numpy for sklearn (more efficient on smaller array)
    data_np = compressed_cols.cpu().numpy()

    # Use sparse affinity matrix for faster clustering
    spectral = SpectralClustering(
        n_clusters=target_dim,
        affinity='nearest_neighbors',
        n_neighbors=n_neighbors,
        random_state=42,
        assign_labels='discretize'  # Faster than eigen decomposition
    )
    cluster_labels = spectral.fit_predict(data_np)

    # Stage 3: Create Projection Matrix
    projection = torch.zeros((3072, target_dim), device=device)

    for cluster_id in range(target_dim):
        mask = torch.from_numpy(cluster_labels == cluster_id).to(device)
        if mask.any():
            # Weighted average based on magnitude importance
            weights = magnitude[:, 1536-64:1536+64].mean(dim=1)
            cluster_data = input_tensor[mask]
            weighted_avg = (cluster_data * weights[mask].view(-1, 1)).sum(dim=0) / weights[mask].sum()
            projection[:, cluster_id] = weighted_avg

    return projection

# Example usage:
if __name__ == "__main__":
    torch.manual_seed(42)

    # Create test tensor (could be covariance matrix, similarity matrix, etc.)
    original = torch.randn(3072, 3072)
    original = original @ original.T  # Make PSD for better clustering

    projected = fft_spectral_projection(original)

    print(f"Original shape: {original.shape}")
    print(f"Projected shape: {projected.shape}")
    print(f"Projection norm: {torch.norm(projected, dim=0).mean():.4f}")

Original shape: torch.Size([3072, 3072])
Projected shape: torch.Size([3072, 128])
Projection norm: 944.1972


V16: Thresholding for Sparsity and Low-Rank Approximation Projection

In [33]:
import torch
import torch.nn.functional as F

def low_rank_sparse_projection(X, proj_dim=128, sparsity_ratio=0.1, thresholding='topk'):
    """
    Projects input tensor X [3072 x 3072] to [3072 x 128] using a low-rank projection with thresholding.

    Args:
        X (Tensor): Input tensor of shape [3072, 3072]
        proj_dim (int): Target projection dimension (e.g., 128)
        sparsity_ratio (float): Proportion of values to keep (e.g., 0.1 means keep top 10%)
        thresholding (str): 'topk' or 'value'

    Returns:
        X_proj (Tensor): Projected tensor of shape [3072, 128]
    """
    input_dim = X.size(1)  # should be 3072

    # Step 1: Initialize projection matrix W ∈ ℝ^{3072×128}
    W = torch.randn(input_dim, proj_dim, device=X.device)

    # Step 2: Apply projection
    X_proj = X @ W  # shape: [3072 x 128]

    # Step 3: Apply sparsity via thresholding
    if thresholding == 'topk':
        k = int(sparsity_ratio * X_proj.numel())
        if k < 1:
            return torch.zeros_like(X_proj)
        # Get the k largest absolute values
        threshold = torch.topk(X_proj.abs().flatten(), k, sorted=False).values.min()
        X_proj = torch.where(X_proj.abs() >= threshold, X_proj, torch.zeros_like(X_proj))

    elif thresholding == 'value':
        threshold = sparsity_ratio  # use directly as a value threshold
        X_proj = torch.where(X_proj.abs() >= threshold, X_proj, torch.zeros_like(X_proj))

    else:
        raise ValueError("Invalid thresholding type. Use 'topk' or 'value'.")

    return X_proj

X = torch.randn(3072, 3072)
X_proj = low_rank_sparse_projection(X, proj_dim=128, sparsity_ratio=0.05, thresholding='topk')
print(X_proj.shape)  # should be [3072, 128]



torch.Size([3072, 128])


v17: Thresholding for Sparsity and Low-Rank Approximation Projection

In [34]:
import torch
import torch.nn.functional as F

def sparse_lowrank_projection(input_tensor, output_dim=128, sparsity_threshold=0.01):
    """
    Projects a large square matrix to a lower dimensional space with sparsity constraints

    Args:
        input_tensor: torch.Tensor of shape [d, d] (e.g., [3072, 3072])
        output_dim: target dimension for projection (e.g., 128)
        sparsity_threshold: values below this will be set to zero

    Returns:
        Projected tensor of shape [d, output_dim] (e.g., [3072, 128])
    """
    d = input_tensor.size(0)

    # Step 1: Create a random projection matrix with sparse initialization
    # Using Kaiming initialization with adjusted sparsity
    projection_matrix = torch.zeros(d, output_dim)
    torch.nn.init.kaiming_uniform_(projection_matrix, mode='fan_in', nonlinearity='linear')

    # Apply thresholding to enforce sparsity
    projection_matrix[torch.abs(projection_matrix) < sparsity_threshold] = 0

    # Step 2: Normalize columns to maintain stability
    projection_matrix = F.normalize(projection_matrix, p=2, dim=0)

    # Step 3: Project the input tensor
    projected_tensor = torch.matmul(input_tensor, projection_matrix)

    return projected_tensor

# Example usage:
if __name__ == "__main__":
    # Create a random 3072x3072 tensor
    large_tensor = torch.randn(3072, 3072)

    # Project to 3072x128
    projected = sparse_lowrank_projection(large_tensor, output_dim=128)

    print(f"Input shape: {large_tensor.shape}")
    print(f"Output shape: {projected.shape}")

Input shape: torch.Size([3072, 3072])
Output shape: torch.Size([3072, 128])


V18: FFT-based CUR projection

In [35]:
import torch
import torch.fft

def fft_based_projection(X: torch.Tensor, out_dim: int = 128, method='magnitude') -> torch.Tensor:
    """
    FFT-based CUR-style projection of a matrix [N x N] -> [N x out_dim]

    Args:
        X (torch.Tensor): Input matrix of shape [N, N]
        out_dim (int): Target number of columns to project to
        method (str): 'magnitude' to use top freq by energy, 'random' for uniform freq sampling

    Returns:
        torch.Tensor: Projected matrix of shape [N, out_dim]
    """
    assert X.shape[0] == X.shape[1], "Input must be square"
    N = X.shape[0]
    assert out_dim <= N, "Output dimension must be <= input size"

    # Step 1: FFT along the column axis (dim=0)
    X_fft = torch.fft.fft(X, dim=0)

    # Step 2: Choose frequency indices based on energy
    if method == 'magnitude':
        energy = torch.sum(torch.abs(X_fft)**2, dim=1)  # Energy along each freq row
        topk_indices = torch.topk(energy, out_dim, largest=True).indices
    elif method == 'random':
        topk_indices = torch.randperm(N)[:out_dim]
    else:
        raise ValueError("method must be 'magnitude' or 'random'")

    # Step 3: Select those rows (frequencies), then inverse FFT
    X_fft_selected = X_fft[topk_indices]
    X_projected = torch.fft.ifft(X_fft_selected, dim=0).real  # [out_dim, N]

    # Step 4: Transpose to [N x out_dim]
    return X_projected.T.contiguous()

# Example
X = torch.randn(3072, 3072)
X_proj = fft_based_projection(X, out_dim=128)
print(X_proj.shape)  # Should be [3072, 128]


torch.Size([3072, 128])


V19: FFT-based CUR projection

In [36]:
import torch
import torch.fft
import numpy as np

def fft_based_cur_projection(input_tensor, target_dim=128):
    """
    FFT-based CUR-like projection without SVD
    Projects [3072 x 3072] tensor to [3072 x target_dim]

    Args:
        input_tensor: torch.Tensor of shape [3072, 3072]
        target_dim: desired output dimension (default: 128)

    Returns:
        Projected tensor of shape [3072, target_dim]
    """
    n = input_tensor.size(0)

    # Step 1: Compute 2D FFT of the input matrix
    fft_matrix = torch.fft.fft2(input_tensor)

    # Step 2: Select important frequency components (top target_dim in magnitude)
    magnitudes = torch.abs(fft_matrix)

    # Flatten and get indices of top magnitudes
    flat_magnitudes = magnitudes.view(-1)
    _, top_indices = torch.topk(flat_magnitudes, k=target_dim)

    # Convert flat indices to 2D coordinates
    rows = top_indices // n
    cols = top_indices % n

    # Step 3: Create sampling probability distribution based on magnitude
    prob_dist = flat_magnitudes / flat_magnitudes.sum()

    # Step 4: Sample columns based on the probability distribution
    # (Here we're actually selecting frequency components rather than columns)
    sampled_indices = torch.multinomial(prob_dist, target_dim, replacement=False)

    # Step 5: Construct the projection matrix using selected frequency components
    projection_matrix = torch.zeros(n, target_dim, dtype=torch.complex64)

    for i, idx in enumerate(sampled_indices):
        row = idx // n
        col = idx % n
        projection_matrix[:, i] = fft_matrix[:, col] * (magnitudes[row, col] / magnitudes[:, col].sum())

    # Step 6: Convert back to spatial domain
    projected_real = torch.fft.ifft(projection_matrix, dim=0).real

    # Normalize the output
    projected_real = projected_real / torch.norm(projected_real, dim=0, keepdim=True)

    return projected_real

# Example usage
if __name__ == "__main__":
    # Create a random 3072x3072 tensor
    input_tensor = torch.randn(3072, 3072)

    # Project to 3072x128
    projected = fft_based_cur_projection(input_tensor, target_dim=128)

    print(f"Input shape: {input_tensor.shape}")
    print(f"Projected shape: {projected.shape}")

Input shape: torch.Size([3072, 3072])
Projected shape: torch.Size([3072, 128])
