In [9]:
import jax
import jax.numpy as np

@jax.jit
def svd2(A):
  # Adapted from https://lucidar.me/en/mathematics/singular-value-decomposition-of-a-2x2-matrix/
  shape = A.shape
  A = np.reshape(A, [np.prod(shape[:-2])] + list(shape[-2:]))
  AAT = np.einsum('nij,nkj->nik', A, A) 
  phi = 0.5 * np.arctan2(AAT[:,0,1] + AAT[:,1,0], AAT[:,0,0] - AAT[:,1,1])
  cos_phi = np.cos(phi)
  sin_phi = np.sin(phi)
  U = np.reshape(np.stack([cos_phi, -sin_phi, sin_phi, cos_phi], -1), [-1, 2, 2])

  AAT_sum= AAT[:,0,0] + AAT[:,1,1]
  AAT_diff= np.sqrt((AAT[:,0,0]-AAT[:,1,1])**2 + 4*AAT[:,0,1]*AAT[:,1,0])
  s = np.stack([np.sqrt(0.5 * (AAT_sum+AAT_diff)), np.sqrt(0.5 * (AAT_sum-AAT_diff))], -1)

  ATA = np.einsum('nji,njk->nik', A, A) 
  theta = 0.5 * np.arctan2(ATA[:,0,1] + ATA[:,1,0], ATA[:,0,0] - ATA[:,1,1])
  cos_theta = np.cos(theta)
  sin_theta = np.sin(theta)
  W = np.reshape(np.stack([cos_theta, -sin_theta, sin_theta, cos_theta], -1), [-1, 2, 2])

  D00 = np.sign(
      (U[:,0,0] * A[:,0,0] + U[:,1,0] * A[:,1,0]) * W[:,0,0] +
      (U[:,0,0] * A[:,0,1] + U[:,1,0] * A[:,1,1]) * W[:,1,0])
  D11 = np.sign(
      (U[:,0,1] * A[:,0,0] + U[:,1,1] * A[:,1,0]) * W[:,0,1] +
      (U[:,0,1] * A[:,0,1] + U[:,1,1] * A[:,1,1]) * W[:,1,1])
  VT = np.reshape(np.stack([
    W[:,0,0] * D00, W[:,1,0] * D00,
    W[:,0,1] * D11, W[:,1,1] * D11], -1), [-1, 2, 2])

  U = np.reshape(U, shape)
  s = np.reshape(s, shape[:-1])
  VT = np.reshape(VT, shape)
  return U, s, VT

In [10]:
# Unit Tests.

A = jax.random.normal(jax.random.PRNGKey(0), (100, 2, 2))

U, s, VT = svd2(A)
U_, s_, VT_ =  np.linalg.svd(A)

batch_matmul = lambda X, Y: np.einsum('nij,njk->nik', X, Y) 

def batch_diag(x):
  import numpy as onp
  D = onp.zeros([np.prod(x.shape[:-1])] + [x.shape[-1]]*2)
  for d in range(x.shape[-1]):
    D[...,d,d] = x[...,d]
  return np.array(D)

threshold = 1e-5
assert(np.all(np.abs(batch_matmul(U, batch_matmul(batch_diag(s), VT)) - A) < threshold))

U_err = np.minimum(np.min(np.abs(U_ - U), -2), np.min(np.abs(U_ + U), -2))
assert(np.all(np.abs(U_err) < threshold))

VT_err = np.minimum(np.min(np.abs(VT_ - VT), -2), np.min(np.abs(VT_ + VT), -2))
assert(np.all(np.abs(VT_err) < threshold))

assert(np.all(np.abs(s - s_) < threshold))



In [13]:
# Profiling
A = jax.random.normal(jax.random.PRNGKey(0), (100000, 2, 2))
%timeit [x.block_until_ready() for x in np.linalg.svd(A)]
%timeit [x.block_until_ready() for x in svd2(A)]

10 loops, best of 3: 134 ms per loop
100 loops, best of 3: 2.62 ms per loop
