In [6]:
def fn(x, xs):
    return x[0] * 2 + x[1] * x[2] + xs[0] * 3 + xs[1] * xs[2]

In [7]:
import jax
import jax.numpy as jnp

A proof of concept to show we can differentiate a function which takes an array, for example (x, y, z).

In [14]:
jax.grad(fn, argnums=0)(
    jnp.array([1., 2., 3.]),
    jnp.array([1., 2., 3.])
)

Array([2., 3., 2.], dtype=float64)

Now let's pretend we have a function which takes in the values of the various dof's  

In [26]:
def fn(q_vec, t):
    return jnp.dot(q_vec, q_vec) - t

In [27]:
derivatives = jax.grad(fn, argnums=0)

Compute the residue for the conservative system, this corresponds with Equations 12(a) and 12(b) in the paper.

In [43]:
def residue(q_vec, t):
    dfdx = derivatives(q_vec, t)
    inner = dfdx[1:-1]
    return jnp.append(inner, dfdx[0] - dfdx[-1])

In [44]:
dfdx= derivatives(jnp.array([1.0, 2, 3, 3,4,5]), 1.0)
dfdx

Array([ 2.,  4.,  6.,  6.,  8., 10.], dtype=float64)

In [45]:
dfdx[1:-1]

Array([4., 6., 6., 8.], dtype=float64)

In [46]:
jnp.append(dfdx[1:-1], 100.0)

Array([  4.,   6.,   6.,   8., 100.], dtype=float64)

In [51]:
residue(jnp.array([4.0, 1, 3, 3,4,5]), 1.0)

Array([ 0.,  6.,  6.,  8., -2.], dtype=float64)

In [52]:
import jaxopt

In [57]:
t0=1.0
opt_res = jaxopt.GaussNewton(
    residual_fun=residue,
    verbose=True
).run(
    jnp.array([4.0, 1, 3, 3,4,5]),
    t0
)

print(opt_res.params)
print(fn(opt_res.params, t0))
print(residue(opt_res.params, t0))

Solver: GaussNewton, Error: 5.958187643906492
Solver: GaussNewton, Error: 6.280369834735101e-16
[4.5 0.  0.  0.  0.  4.5]
39.5
[0. 0. 0. 0. 0.]
