In [11]:
"""
Script for demonstrating and testing the speed of JAX when used in different ways

"""

import numpy as np
from scipy.optimize import minimize, least_squares
import jax.numpy as jnp
from jax import grad, jit, device_put, jacfwd, jacrev, jvp
from functools import partial

In [17]:
def matrix_multiply(A,B):
    return jnp.matmul(A+A,B)
Anp = np.random.randn(100, 1000)
Bnp = np.random.randn(1000, 100)
A = device_put(Anp)
B = device_put(Bnp)

%timeit matrix_multiply(Anp,Bnp).block_until_ready()

lmm = lambda A: matrix_multiply(A,B)



lmm_jit = jit(lmm)

%timeit lmm_jit(A).block_until_ready()

def myfunc(func, A, B):
    return func(A, B)

myfunc_jit = jit(myfunc, static_argnums=(0,))
%timeit myfunc_jit(matrix_multiply, A, B).block_until_ready()

@partial(jit, static_argnums=(0,))
def myfunc2(func, A):
    return func(A)

%timeit myfunc2(lmm, A).block_until_ready()



188 µs ± 6.88 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46.3 µs ± 323 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
48 µs ± 84.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46.8 µs ± 85.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [3]:
#
A = np.random.randn(100, 100)
B = np.random.randn(100, 100)

%timeit matrix_multiply(A, B).block_until_ready()

196 µs ± 4.96 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [5]:
mm_jit = jit(matrix_multiply)

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)

%timeit mm_jit(A, B).block_until_ready()


70.3 µs ± 856 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]:
mm_jit = jit(matrix_multiply)

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)
A = device_put(A)
# B = device_put(B)

%timeit mm_jit(A, B).block_until_ready()

60.6 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [8]:
mm_jit = jit(matrix_multiply)

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)
A = device_put(A)
B = device_put(B)

%timeit mm_jit(A, B).block_until_ready()

36.3 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [17]:
def sum_mat_mul(A, B):
    return jnp.matmul(A, B).sum()

gradient = grad(sum_mat_mul, argnums=0)

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)

%timeit gradient(A, B).block_until_ready()

1.37 ms ± 79.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [18]:
smm_jit = jit(sum_mat_mul)

gradient = grad(smm_jit, argnums=0)

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)

%timeit gradient(A, B).block_until_ready()

714 µs ± 2.37 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [19]:
# this and the next are the fastest options

gradient = jit(grad(sum_mat_mul, argnums=0))

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)

%timeit gradient(A, B).block_until_ready()

85.7 µs ± 8.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [20]:
# gives same speed as previous

gradient = jit(grad(jit(sum_mat_mul), argnums=0))

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)

%timeit gradient(A, B).block_until_ready()

82.6 µs ± 3.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [21]:
# this and the next are the fastest options

gradient = jit(grad(sum_mat_mul, argnums=0))

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)
A = device_put(A)
B = device_put(B)

%timeit gradient(A, B).block_until_ready()

45.1 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [22]:

hessian = jacfwd(jacrev(smm_jit, argnums=0))

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)
A = device_put(A)
B = device_put(B)

%timeit hessian(A, B).block_until_ready()

100 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
# this and the next are the fastest options, this is really fast

hessian = jit(jacfwd(jacrev(smm_jit, argnums=0)))

A = np.random.randn(100, 100)
B = np.random.randn(100, 100)
A = device_put(A)
B = device_put(B)

%timeit hessian(A, B).block_until_ready()





1.27 ms ± 18.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
