In [3]:
import torch
import time
import statistics as stats

def orthogonalize(M):
    abc_list = [
        (3955/1024, -8306/1024, 5008/1024),
        (3735/1024, -6681/1024, 3463/1024),
        (3799/1024, -6499/1024, 3211/1024),
        (4019/1024, -6385/1024, 2906/1024),
        (2677/1024, -3029/1024, 1162/1024),
        (2172/1024, -1833/1024,  682/1024)
    ]
    transpose = M.shape[1] > M.shape[0]
    if transpose:
        M = M.T
    M = M / torch.linalg.norm(M)
    for a, b, c in abc_list:
        A = M.T @ M
        I = torch.eye(A.shape[0], device=M.device, dtype=M.dtype)
        M = M @ (a * I + b * A + c * A @ A)
    if transpose:
        M = M.T
    return M

def orthogonalize_svd(M):
    U, _, Vh = torch.linalg.svd(M, full_matrices=False)
    return U @ Vh

def sync(device):
    if device == "cuda":
        torch.cuda.synchronize()

def benchmark(fn, x, device, warmup=5, iters=20):
    for _ in range(warmup):
        _ = fn(x.clone())
        sync(device)
    times = []
    for _ in range(iters):
        inp = x.clone()
        sync(device)
        t0 = time.perf_counter()
        _ = fn(inp)
        sync(device)
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000)
    return stats.median(times)

def rel_error(A, B):
    return (torch.linalg.norm(A - B) / torch.linalg.norm(B)).item()

if __name__ == "__main__":
    torch.manual_seed(0)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sizes = [(64, 64), (128, 128), (256, 256), (512, 512), (1024, 1024), (2048, 2048)]
    for (m, n) in sizes:
        M = torch.randn(m, n, device=device)
        t1 = benchmark(orthogonalize, M, device)
        t2 = benchmark(orthogonalize_svd, M, device)
        Q1 = orthogonalize(M.clone())
        Q2 = orthogonalize_svd(M.clone())
        e = rel_error(Q1, Q2)
        print(f"Size {m}x{n} | Orthogonalize: {t1:.2f} ms | SVD: {t2:.2f} ms | rel err: {e:.3e}")



Size 64x64 | Orthogonalize: 0.96 ms | SVD: 1.27 ms | rel err: 1.470e-02
Size 128x128 | Orthogonalize: 0.95 ms | SVD: 2.80 ms | rel err: 6.394e-02
Size 256x256 | Orthogonalize: 0.89 ms | SVD: 7.60 ms | rel err: 6.531e-02
Size 512x512 | Orthogonalize: 1.02 ms | SVD: 21.24 ms | rel err: 6.234e-02
Size 1024x1024 | Orthogonalize: 4.18 ms | SVD: 74.05 ms | rel err: 8.226e-02
Size 2048x2048 | Orthogonalize: 28.88 ms | SVD: 583.89 ms | rel err: 9.820e-02
