In [2]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from functools import partial

In [37]:
@partial(jax.jit, static_argnums=(0,))
def NewtonStep(f, x0, delta):
    J_f = jax.jit(jax.jacobian(f))
    return x0 - delta*jnp.dot(jnp.linalg.inv(J_f(x0)), f(x0))

def NewtonMethod(f, x0, delta=0.1, max_iters=100, precision=1e-4):
    for _ in range(max_iters):
        x = NewtonStep(f, x0, delta)
        if jnp.sum(jnp.abs(x-x0))<precision:
            return x
        x0 = x
    return x

In [38]:
def f(x):
    return x*jnp.linalg.norm(x, ord=2)
x0 = jnp.ones(5)

In [39]:
NewtonMethod(f, x0)

Array([0.00592053, 0.00592053, 0.00592053, 0.00592053, 0.00592053],      dtype=float32)

In [4]:
mu = 0.5
@jax.jit
def system(t, y, args):
    v, w = y[...,0], y[..., 1]
    dv = mu * v + w - mu*v**2
    dw = v + mu * w + 2 * v**2
    return jnp.stack([dv, dw], axis=-1)