In [1]:
import jax
from jax import jit
import jax.numpy as jnp

In [11]:
@jax.jit
def f(x, A, b):
    res = A @ x + jnp.linalg.norm(x) - b
    return res @ res

gradf = jax.grad(f)

In [12]:
n = 1000

key = jax.random.PRNGKey(1)

x = jax.random.normal(key, (n, ), dtype=jnp.float64)
A = jax.random.normal(key, (n, n), dtype=jnp.float64)
b = jax.random.normal(key, (n, ), dtype=jnp.float64)

In [13]:
%timeit 2 * A.T @ (A @ x - b)
%timeit gradf(x, A, b).block_until_ready()

jit_grad_f = jax.jit(gradf)
%timeit jit_grad_f(x, A, b).block_until_ready()

1.09 ms ± 43.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.66 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
205 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Гессиан

In [5]:
hess_func = jax.jit(jax.hessian(f))

jnp.linalg.norm(2 * A.T @ A - hess_func(x, A, b))

DeviceArray(0., dtype=float32)

In [6]:
%timeit hess_func(x, A, b).block_until_ready()

16.5 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
hess_jit = jax.jit(jax.jacfwd(jax.jacrev(f)))
%timeit hess_jit(x, A, b).block_until_ready()

16.9 ms ± 923 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Сравним fwd и bwd

In [8]:
fwd_grad = jax.jit(jax.jacfwd(f))
rev_grad = jax.jit(jax.jacrev(f))

In [9]:
%timeit fwd_grad(x, A, b)
%timeit rev_grad(x, A, b)

6.77 ms ± 168 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
159 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [10]:
%timeit fwd_grad(x, A, b).block_until_ready()
%timeit rev_grad(x, A, b).block_until_ready()

6.56 ms ± 29.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
200 µs ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
