<a href="https://colab.research.google.com/github/jonbarron/svd2/blob/master/svd2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as np

@jax.jit
def svd2(A):
  # Adapted from https://lucidar.me/en/mathematics/singular-value-decomposition-of-a-2x2-matrix/

  def f(X):
    a = X[:,0,1] + X[:,1,0]
    b = X[:,0,0] - X[:,1,1]
    z = np.sqrt((b + 1j*a)/np.sqrt(a**2 + b**2))
    z_real = np.real(z)
    z_imag = np.imag(z)
    q = (1 + 1/(z_real**2 + z_imag**2))
    cos = 0.5 * z_real * q
    sin = 0.5 * z_imag * q
    Y = np.reshape(np.stack([cos, -sin, sin, cos], -1), [-1, 2, 2])
    return Y

  AAT = np.einsum('nij,nkj->nik', A, A) 
  ATA = np.einsum('nji,njk->nik', A, A) 

  U = f(AAT)
  W = f(ATA)

  a = AAT[:,0,1] * AAT[:,1,0]
  b = AAT[:,0,0] - AAT[:,1,1]
  c = AAT[:,0,0] + AAT[:,1,1]
  d = np.sqrt(b**2 + 4*a)
  s = np.stack([np.sqrt(0.5*(c+d)), np.sqrt(0.5*(c-d))], -1)

  D00 = np.sign(
      (U[:,0,0] * A[:,0,0] + U[:,1,0] * A[:,1,0]) * W[:,0,0] +
      (U[:,0,0] * A[:,0,1] + U[:,1,0] * A[:,1,1]) * W[:,1,0])
  D11 = np.sign(
      (U[:,0,1] * A[:,0,0] + U[:,1,1] * A[:,1,0]) * W[:,0,1] +
      (U[:,0,1] * A[:,0,1] + U[:,1,1] * A[:,1,1]) * W[:,1,1])
  VT = np.reshape(np.stack([
    W[:,0,0] * D00, W[:,1,0] * D00,
    W[:,0,1] * D11, W[:,1,1] * D11], -1), [-1, 2, 2])
  
  return U, s, VT

In [2]:
# Unit Tests.

A = jax.random.normal(jax.random.PRNGKey(0), (100, 2, 2))
A *= np.exp(jax.random.normal(jax.random.PRNGKey(0), (A.shape[0], 1, 1)))

U, s, VT = svd2(A)
U_, s_, VT_ =  np.linalg.svd(A)

batch_matmul = lambda X, Y: np.einsum('nij,njk->nik', X, Y) 

def batch_diag(x):
  import numpy as onp
  D = onp.zeros([np.prod(x.shape[:-1])] + [x.shape[-1]]*2)
  for d in range(x.shape[-1]):
    D[...,d,d] = x[...,d]
  return np.array(D)

tol = 1e-5
assert(np.all(np.abs(batch_matmul(U, batch_matmul(batch_diag(s), VT)) - A) < tol))

U_err = np.minimum(np.min(np.abs(U_ - U), -2), np.min(np.abs(U_ + U), -2))
assert(np.all(np.abs(U_err) < tol))

VT_err = np.minimum(np.min(np.abs(VT_ - VT), -2), np.min(np.abs(VT_ + VT), -2))
assert(np.all(np.abs(VT_err) < tol))

assert(np.all(np.abs(s - s_) < tol))



In [5]:
# Profiling
A = jax.random.normal(jax.random.PRNGKey(0), (100000, 2, 2))
%timeit [x.block_until_ready() for x in np.linalg.svd(A)]
%timeit [x.block_until_ready() for x in svd2(A)]

10 loops, best of 3: 51.8 ms per loop
1000 loops, best of 3: 1.36 ms per loop
