In [1]:
import numpy as np
import scipy


import jax 
import jax.numpy as jnp

In [2]:
jax.devices()

In [165]:
a = np.random.normal(size=(500, 500))

In [166]:
b = a @ a.T

In [167]:
b32 = b.astype('float32')

In [168]:
sqrt_np = scipy.linalg.sqrtm(b32)

In [169]:
def sqrtm_newton_schulz(a):
    k = 10
    normalization = np.trace(a)
    y = a.copy() / normalization
    z = np.eye(a.shape[0])
    identity = np.eye(a.shape[0])
    for i in range(k):
        y_now = 0.5 * y @ (3. * identity - z @ y)
        z_now = 0.5 * (3. * identity - z @ y) @ z
        y = y_now
        z = z_now
    return y * np.sqrt(normalization)

In [170]:
sqrt_ns = sqrtm_newton_schulz(b32)

In [171]:
np.mean(sqrt_np.ravel())

In [172]:
np.mean((sqrt_ns-sqrt_np).ravel())

In [175]:
%%timeit

sqrt_np = scipy.linalg.sqrtm(b32)

In [176]:
%%timeit

sqrt_ns = sqrtm_newton_schulz(b32)

In [157]:
jax.devices()

In [177]:
b32_j = jnp.array(b32)

In [178]:
@jax.jit
def sqrtm_newton_schulz_jax(a):
    k = 10
    normalization = jnp.trace(a)
    y = a.copy() / normalization
    z = jnp.eye(a.shape[0])
    identity = jnp.eye(a.shape[0])
    for i in range(k):
        y_now = 0.5 * y @ (3. * identity - z @ y)
        z_now = 0.5 * (3. * identity - z @ y) @ z
        y = y_now
        z = z_now
    return y * jnp.sqrt(normalization)

In [179]:
sqrt_ns_j = sqrtm_newton_schulz(b32_j)

In [180]:
np.mean((sqrt_ns_j-sqrt_np).ravel())

In [181]:
%%timeit

sqrt_ns = sqrtm_newton_schulz_jax(b32_j)

In [163]:
@jax.jit
def sqrtm_newton_schulz_jax_loop(a):

    def body_fun(i, pars):
        y, z = pars
        y_now = 0.5 * y @ (3. * identity - z @ y)
        z_now = 0.5 * (3. * identity - z @ y) @ z
        return (y_now, z_now)
    k = 10
    normalization = jnp.trace(a)
    y = a.copy() / normalization
    z = jnp.eye(a.shape[0])
    identity = jnp.eye(a.shape[0])
    (y, z) = jax.lax.fori_loop(0, k, body_fun, (y, z))
    return y * jnp.sqrt(normalization)

In [164]:
%%timeit
for i in range(100):
    sqrt_ns = sqrtm_newton_schulz_jax_loop(b32_j)

In [134]:
sqrt_np

In [135]:
sqrt_ns_j

In [183]:
jax.devices('gpu')