# Faster SVD via Accelerated Newton-Schulz Iteration

This notebook demonstrates the accuracy of the **CANS-SVD** algorithm compared to existing SVD implementations in CUDA.

We will:
1. Generate ill-conditioned random matrix.
2. Compute SVD using different algorithms.
3. Compare reconstruction errors.

### 1. Clone the repository

In [None]:
!git clone https://github.com/fallnlove/polar-svd.git

### 2. Install Dependencies

In [None]:
!pip install -r polar-svd/requirements.txt

In [None]:
%cd polar-svd

### 3. Import JAX and CANS-SVD

In [None]:
import jax
import jax.numpy as jnp
from jax import config
config.update("jax_default_matmul_precision", "float32")  # set default matmul precision to float32
assert jax.device_count('gpu') > 0, "GPU not available. Please use environment with GPU"

from src import cans_svd

In [None]:
def generate_matrix(size: int, cond_number: float, key) -> jnp.array:
    """
    Generate a random matrix with a specified condition number.

    Args:
        size (int): The size of the matrix (size x size).
        cond_number (float): The desired condition number of the matrix.

    Returns:
        jnp.array: The generated matrix.
    """
    key1, key2 = jax.random.split(key)

    u = jax.random.normal(key=key1, shape=(size, size), dtype=jnp.float32)
    u, _ = jax.lax.linalg.qr(u)

    v = jax.random.normal(key=key2, shape=(size, size), dtype=jnp.float32)
    v, _ = jax.lax.linalg.qr(v)

    s = jnp.logspace(0, jnp.log10(cond_number), size, base=10, dtype=jnp.float32)

    return u @ jnp.diag(s) @ v

In [None]:
def reconstruction_err(A: jnp.array, U: jnp.array, S: jnp.array, VT: jnp.array) -> float:
    """
    Compute the reconstruction error of the SVD decomposition.

    Args:
        A (jnp.array): The original matrix.
        U (jnp.array): The left singular vectors.
        S (jnp.array): The singular values.
        VT (jnp.array): The right singular vectors (transposed).

    Returns:
        float: The relative Frobenius norm of the reconstruction error.
    """
    A_reconstructed = U @ jnp.diag(S) @ VT

    return jnp.linalg.norm(A - A_reconstructed, ord='fro') / jnp.linalg.norm(A, ord='fro')

### 4. Check the accuracy of algorithms on an ill-conditioned matrix

First, lets generate random matrix:

In [None]:
A = generate_matrix(size=512, cond_number=1e6, key=jax.random.PRNGKey(0))

Now we can compare with different algorithms:
- Polar-based `jax.lax.linalg.SvdAlgorithm.POLAR`
- QR-based `jax.lax.linalg.SvdAlgorithm.QR`
- Jacobi-based `jax.lax.linalg.SvdAlgorithm.JACOBI`

In [None]:
U_cans, S_cans, VT_cans = cans_svd(A)
U_polar, S_polar, VT_polar = jax.lax.linalg.svd(A, algorithm=jax.lax.linalg.SvdAlgorithm.POLAR)
U_qr, S_qr, VT_qr = jax.lax.linalg.svd(A, algorithm=jax.lax.linalg.SvdAlgorithm.QR)
U_jacobi, S_jacobi, VT_jacobi = jax.lax.linalg.svd(A, algorithm=jax.lax.linalg.SvdAlgorithm.JACOBI)

Let see relative reconstruction error:

In [None]:
print("CANS SVD reconstruction error:", reconstruction_err(A, U_cans, S_cans, VT_cans))
print("Polar SVD reconstruction error:", reconstruction_err(A, U_polar, S_polar, VT_polar))
print("QR SVD reconstruction error:", reconstruction_err(A, U_qr, S_qr, VT_qr))
print("Jacobi SVD reconstruction error:", reconstruction_err(A, U_jacobi, S_jacobi, VT_jacobi))

As shown above Polar-based SVD has large reconstruction error on ill-conditioned matrix.

⚠️ **Important note:** runtime experiments were conducted on NVIDIA B200 GPU, results may differ on older GPUs.