In [None]:
import torch
import numpy as np
import time

def lowrank_eigh_plus_low_rank_symmetrize(Q: torch.Tensor, L: torch.Tensor, U: torch.Tensor, V: torch.Tensor, alpha: float):
    """
    Q is ``(m, rank)``; L is ``(rank, )``; U and V are the low rank correction such that U V^T is ``(m, m)``.

    This computes eigendecomposition of A, where

    ``M = Q diag(L) Q^T + alpha * (U V^T)``;

    ``A = (M + M^T) / 2``
    """
    m, rank = Q.shape
    _, k = V.shape

    # project U and V out of the Q subspace via gram-schmidt
    Q_T_U = Q.T @ U
    U_perp = U - Q @ Q_T_U

    Q_T_V = Q.T @ V
    V_perp = V - Q @ Q_T_V

    R = torch.hstack([U_perp, V_perp])
    Q_perp, _ = torch.linalg.qr(R)

    Q_B = torch.hstack([Q, Q_perp])
    r_B = Q_B.shape[1]

    # project and compute new eigendecomposition
    A_proj = torch.zeros((r_B, r_B), device=Q.device, dtype=Q.dtype)
    A_proj[:rank, :rank] = L.diag_embed()

    Q_perp_T_U = Q_perp.T @ U
    Q_B_T_U = torch.vstack([Q_T_U, Q_perp_T_U])

    Q_perp_T_V = Q_perp.T @ V
    Q_B_T_V = torch.vstack([Q_T_V, Q_perp_T_V])

    update_proj = (alpha / 2.0) * (Q_B_T_U @ Q_B_T_V.T + Q_B_T_V @ Q_B_T_U.T)
    A_proj += update_proj

    L_prime, S = torch.linalg.eigh(A_proj)

    # unproject and sort
    Q_prime = Q_B @ S

    idx = torch.argsort(L_prime)
    L_prime = L_prime[idx]
    Q_prime = Q_prime[:, idx]

    return L_prime, Q_prime

# --- Setup for the Comparison ---
m = 500
rank = 20
k = 5
alpha = 0.75

print(f"--- Verification on a smaller matrix ---")
print(f"m={m}, rank={rank}, k={k}\n")

# Generate random data with the correct properties
np.random.seed(42)
Q, _ = np.linalg.qr(np.random.randn(m, rank))
L = np.sort(np.random.rand(rank) * 100)[::-1]
U = np.random.randn(m, k)
V = np.random.randn(m, k)

# --- 1. Run the Naive Method ---
start_time = time.time()
# Explicitly form the full m x m matrix
A_orig = Q @ np.diag(L) @ Q.T
A_update = (alpha / 2.0) * (U @ V.T + V @ U.T)
A_sym_naive = A_orig + A_update

# Compute its full eigendecomposition
L_naive, Q_naive = np.linalg.eigh(A_sym_naive)

# Sort eigenvalues/eigenvectors in descending order for comparison
sort_indices_naive = np.argsort(L_naive)[::-1]
L_naive = L_naive[sort_indices_naive]
Q_naive = Q_naive[:, sort_indices_naive]
naive_time = time.time() - start_time
print(f"Naive method took: {naive_time:.6f} seconds")

# --- 2. Run the Efficient Method ---
start_time = time.time()
Q_efficient, L_efficient = lowrank_eigh_plus_low_rank_symmetrize(Q, L, U, V, alpha)
efficient_time = time.time() - start_time
print(f"Efficient method took: {efficient_time:.6f} seconds")
print(f"Speedup: {naive_time / efficient_time:.2f}x\n")

# --- 3. Compare the Results ---

# The effective rank of the new system is at most rank + 2k
new_rank = Q_efficient.shape[1]
print(f"Effective rank of the result: {new_rank} (<= {rank + 2*k})")

# We only need to compare the top `new_rank` eigenvalues/vectors from the naive method
L_naive_top = L_naive[:new_rank]
Q_naive_top = Q_naive[:, :new_rank]

# Compare Eigenvalues
print("\n--- Eigenvalue Comparison ---")
print("First 5 eigenvalues (Naive):   ", L_naive_top[:5])
print("First 5 eigenvalues (Efficient):", L_efficient[:5])
eigenvalues_match = np.allclose(L_naive_top, L_efficient, 1e-3, 1e-3)
print(f"\nEigenvalues match? {eigenvalues_match}")

# Compare Eigenvectors
# We compare absolute values to account for the potential sign flip
print("\n--- Eigenvector Comparison ---")
# To handle sign ambiguity, compare the absolute value of the dot products
# of corresponding vectors. They should all be close to 1.
# einsum computes dot products of corresponding columns: sum_i (Q1_i*Q2_i)
col_dot_products = np.abs(np.einsum('ij,ij->j', Q_naive_top, Q_efficient))
eigenvectors_match = np.allclose(col_dot_products, 1.0)
print("Abs dot products of first 5 corresponding eigenvectors:", col_dot_products[:5])
print(f"\nEigenvectors match (span the same space)? {eigenvectors_match}")

# Final check
eigenvalues_match, eigenvectors_match

--- Verification on a smaller matrix ---
m=500, rank=20, k=5

Naive method took: 0.032005 seconds
Efficient method took: 0.001231 seconds
Speedup: 25.99x

Effective rank of the result: 30 (<= 30)

--- Eigenvalue Comparison ---
First 5 eigenvalues (Naive):    [211.77418889 202.73271337 189.16824266 184.81831786 166.29866334]
First 5 eigenvalues (Efficient): [211.77418889 202.73271337 189.16824266 184.81831786 166.29866334]

Eigenvalues match? False

--- Eigenvector Comparison ---
Abs dot products of first 5 corresponding eigenvectors: [1. 1. 1. 1. 1.]

Eigenvectors match (span the same space)? False


(False, False)

In [19]:
# --- (Previous code for setup and running both methods remains the same) ---

# --- 3. Corrected Comparison ---

print("\n--- Corrected Verification ---")
# The efficient method gives the r_B eigenvalues of the non-null space
# The naive method gives ALL m eigenvalues.
# The true spectrum of A_sym consists of the r_B values from L_efficient
# plus (m - r_B) zero eigenvalues.

# From the naive result, let's isolate the non-zero eigenvalues
# Use a tolerance to handle floating point inaccuracies around zero
tolerance = 1e-9
L_naive_nonzero = L_naive[np.abs(L_naive) > tolerance]

# Sort both sets for a consistent comparison
# Sorting by value (descending) is fine
L_naive_nonzero_sorted = np.sort(L_naive_nonzero)[::-1]
L_efficient_sorted = np.sort(L_efficient)[::-1]

print(f"Number of non-zero eigenvalues from Naive method: {len(L_naive_nonzero_sorted)}")
print(f"Number of eigenvalues from Efficient method: {len(L_efficient_sorted)}")

# Now, the two arrays should be nearly identical.
# Note: The number of non-zero eigenvalues might be slightly less than r_B if the
# projected matrix A_proj itself had a null space. So we compare the shorter list.
min_len = min(len(L_naive_nonzero_sorted), len(L_efficient_sorted))

print("\nTop 5 positive eigenvalues (Naive non-zero):", L_naive_nonzero_sorted[:5])
print("Top 5 positive eigenvalues (Efficient):     ", L_efficient_sorted[:5])
print("\nBottom 5 negative eigenvalues (Naive non-zero):", L_naive_nonzero_sorted[-5:])
print("Bottom 5 negative eigenvalues (Efficient):      ", L_efficient_sorted[-5:])

eigenvalues_match = np.allclose(L_naive_nonzero_sorted[:min_len], L_efficient_sorted[:min_len])
print(f"\nNon-zero eigenvalues match? {eigenvalues_match}")

if eigenvalues_match:
    print("\n✅ Success! The non-zero spectrum computed by both methods is numerically equivalent.")
else:
    print("\n❌ Failure! The non-zero spectra do not match.")


--- Corrected Verification ---
Number of non-zero eigenvalues from Naive method: 30
Number of eigenvalues from Efficient method: 30

Top 5 positive eigenvalues (Naive non-zero): [211.77418889 202.73271337 189.16824266 184.81831786 166.29866334]
Top 5 positive eigenvalues (Efficient):      [211.77418889 202.73271337 189.16824266 184.81831786 166.29866334]

Bottom 5 negative eigenvalues (Naive non-zero): [-153.07656709 -158.2962256  -190.11453938 -194.77103617 -216.46304863]
Bottom 5 negative eigenvalues (Efficient):       [-153.07656709 -158.2962256  -190.11453938 -194.77103617 -216.46304863]

Non-zero eigenvalues match? True

✅ Success! The non-zero spectrum computed by both methods is numerically equivalent.


In [10]:
np.set_printoptions(suppress=True)


In [13]:
L_naive

array([ 211.77418889,  202.73271337,  189.16824266,  184.81831786,
        166.29866334,   90.34898431,   75.91631542,   74.23042114,
         72.62846709,   70.85026662,   69.02819053,   62.66894748,
         57.75809159,   50.24803961,   38.78562805,   33.1105224 ,
         30.96686479,   28.95922855,   24.11128348,   20.78014552,
         16.92810617,   10.16167673,    9.0918865 ,    7.0582667 ,
          5.11779475,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.     

In [12]:
L_efficient

array([ 211.77418889,  202.73271337,  189.16824266,  184.81831786,
        166.29866334,   90.34898431,   75.91631542,   74.23042114,
         72.62846709,   70.85026662,   69.02819053,   62.66894748,
         57.75809159,   50.24803961,   38.78562805,   33.1105224 ,
         30.96686479,   28.95922855,   24.11128348,   20.78014552,
         16.92810617,   10.16167673,    9.0918865 ,    7.0582667 ,
          5.11779475, -153.07656709, -158.2962256 , -190.11453938,
       -194.77103617, -216.46304863])