In [2]:
import jax.numpy as jnp
import jax

In [14]:
from typing import Callable
from functools import partial
@partial(jax.jit, static_argnums=(0,))
def newtons_method(f: Callable, x0: jnp.ndarray):
    fprime = jax.grad(f)
    def newton_step(_i, x):
        return x - f(x) / fprime(x)
    return jax.lax.fori_loop(0, 5, newton_step, x0)

In [21]:

y = jnp.array(3.0)
def f(x: jnp.ndarray):
    return x**2 - y**2

newtons_method = newtons_method
print(newtons_method(f, jnp.array(1.0)))
jax.grad(lambda y: newtons_method(lambda x: x ** 2 - y**2, 2.0))(y)


3.0


Array(1., dtype=float32, weak_type=True)

In [22]:
import jax.random as jrandom
key = jrandom.PRNGKey(0)
x = jrandom.normal(key, (5,))

In [23]:
x

Array([ 0.18784384, -1.2833426 , -0.2710917 ,  1.2490594 ,  0.24447003],      dtype=float32)

In [34]:
def function(v):
    return jnp.dot((v > 0).astype(jnp.float32), v)

In [35]:
jax.grad(function)(x)

Array([1., 0., 0., 1., 1.], dtype=float32)