In [1]:
import numpy as np
import scipy


import jax 
import jax.numpy as jnp

In [2]:
jax.devices()

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M2 Max




[METAL(id=0)]

In [165]:
a = np.random.normal(size=(500, 500))

In [166]:
b = a @ a.T

In [167]:
b32 = b.astype('float32')

In [168]:
sqrt_np = scipy.linalg.sqrtm(b32)

In [169]:
def sqrtm_newton_schulz(a):
    k = 10
    normalization = np.trace(a)
    y = a.copy() / normalization
    z = np.eye(a.shape[0])
    identity = np.eye(a.shape[0])
    for i in range(k):
        y_now = 0.5 * y @ (3. * identity - z @ y)
        z_now = 0.5 * (3. * identity - z @ y) @ z
        y = y_now
        z = z_now
    return y * np.sqrt(normalization)

In [170]:
sqrt_ns = sqrtm_newton_schulz(b32)

In [171]:
np.mean(sqrt_np.ravel())

0.03860921

In [172]:
np.mean((sqrt_ns-sqrt_np).ravel())

-0.0013150125887562334

In [175]:
%%timeit

sqrt_np = scipy.linalg.sqrtm(b32)

281 ms ± 8.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [176]:
%%timeit

sqrt_ns = sqrtm_newton_schulz(b32)

106 ms ± 2.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [157]:
jax.devices()

[METAL(id=0)]

In [177]:
b32_j = jnp.array(b32)

In [178]:
@jax.jit
def sqrtm_newton_schulz_jax(a):
    k = 10
    normalization = jnp.trace(a)
    y = a.copy() / normalization
    z = jnp.eye(a.shape[0])
    identity = jnp.eye(a.shape[0])
    for i in range(k):
        y_now = 0.5 * y @ (3. * identity - z @ y)
        z_now = 0.5 * (3. * identity - z @ y) @ z
        y = y_now
        z = z_now
    return y * jnp.sqrt(normalization)

In [179]:
sqrt_ns_j = sqrtm_newton_schulz(b32_j)

In [180]:
np.mean((sqrt_ns_j-sqrt_np).ravel())

Array(-0.00131493, dtype=float32)

In [181]:
%%timeit

sqrt_ns = sqrtm_newton_schulz_jax(b32_j)

468 µs ± 65.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [163]:
@jax.jit
def sqrtm_newton_schulz_jax_loop(a):

    def body_fun(i, pars):
        y, z = pars
        y_now = 0.5 * y @ (3. * identity - z @ y)
        z_now = 0.5 * (3. * identity - z @ y) @ z
        return (y_now, z_now)
    k = 10
    normalization = jnp.trace(a)
    y = a.copy() / normalization
    z = jnp.eye(a.shape[0])
    identity = jnp.eye(a.shape[0])
    (y, z) = jax.lax.fori_loop(0, k, body_fun, (y, z))
    return y * jnp.sqrt(normalization)

In [164]:
%%timeit
for i in range(100):
    sqrt_ns = sqrtm_newton_schulz_jax_loop(b32_j)

56.6 ms ± 278 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [134]:
sqrt_np

array([[ 2.33130240e+00,  2.81554043e-01,  1.09425280e-02,
        -7.68378615e-01, -1.02408528e+00,  3.05229928e-02,
         1.98599786e-01,  2.92746842e-01, -1.48482010e-01,
         9.51892585e-02],
       [ 2.81554103e-01,  2.14266109e+00, -7.85147488e-01,
         5.59842885e-01,  5.01165032e-01, -8.71865273e-01,
         4.98518914e-01,  2.88285971e-01,  3.87258716e-02,
        -5.86577393e-02],
       [ 1.09424656e-02, -7.85147965e-01,  3.17536473e+00,
        -4.03983951e-01, -8.71151984e-01, -1.30707189e-03,
        -1.08615709e-02,  1.96505621e-01, -9.26938951e-01,
        -1.41628847e-01],
       [-7.68378675e-01,  5.59843779e-01, -4.03984666e-01,
         2.60584831e+00,  1.31551996e-01,  2.57390589e-01,
        -1.27274916e-01, -6.37116313e-01,  1.01954436e+00,
         3.44771147e-01],
       [-1.02408552e+00,  5.01166105e-01, -8.71152043e-01,
         1.31551310e-01,  3.40686631e+00, -3.67351413e-01,
         6.89415693e-01, -3.45420480e-01,  1.00823689e+00,
        -5.

In [135]:
sqrt_ns_j

Array([[ 2.3300555e+00,  2.8366032e-01,  1.1663947e-02, -7.7059537e-01,
        -1.0251204e+00,  3.1487048e-02,  1.9835164e-01,  2.9107773e-01,
        -1.4646366e-01,  9.5632270e-02],
       [ 2.8366035e-01,  2.1357379e+00, -7.8712708e-01,  5.6494278e-01,
         5.0138843e-01, -8.7825042e-01,  5.0081027e-01,  2.9328823e-01,
         3.7338983e-02, -6.1612599e-02],
       [ 1.1664016e-02, -7.8712684e-01,  3.1747723e+00, -4.0239438e-01,
        -8.7089884e-01, -2.9422785e-03, -1.0293233e-02,  1.9796506e-01,
        -9.2764670e-01, -1.4238487e-01],
       [-7.7059507e-01,  5.6494278e-01, -4.0239418e-01,  2.6013591e+00,
         1.3032401e-01,  2.6101714e-01, -1.2847038e-01, -6.4096165e-01,
         1.0223191e+00,  3.4644756e-01],
       [-1.0251204e+00,  5.0138843e-01, -8.7089902e-01,  1.3032392e-01,
         3.4053125e+00, -3.6871079e-01,  6.9005990e-01, -3.4581569e-01,
         1.0108309e+00, -5.8314741e-01],
       [ 3.1487007e-02, -8.7825042e-01, -2.9422895e-03,  2.6101699e-01,
   

In [183]:
jax.devices('gpu')

RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: cpu,METAL