Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Studies on autodiff with JAX #440

Closed
kinnala opened this issue Jul 21, 2020 · 6 comments
Closed

Studies on autodiff with JAX #440

kinnala opened this issue Jul 21, 2020 · 6 comments

Comments

@kinnala
Copy link
Owner

kinnala commented Jul 21, 2020

I'll collect here some tests I'm currently doing using JAX, inspired by #439 .

First, I tried to calculate the plain gradient of a usual linear problem:

import numpy as np
from skfem import *
from jax import jit, grad
from jax.numpy import vectorize


def energy(du0, du1):
    return .5 * (du0 ** 2  + du1 ** 2)

@jit
def jacf(du0, du1):
    return vectorize(grad(energy, (0, 1)))(du0, du1)

m = MeshTri()
basis = InteriorBasis(m, ElementTriP1())

@BilinearForm
def bilinf1(u, v, w):
    from skfem.helpers import dot, grad
    Ju = jacf(*u.grad)
    return dot(Ju, grad(v))

@BilinearForm
def bilinf2(u, v, w):
    from skfem.helpers import dot, grad
    return dot(grad(u), grad(v))

A = bilinf1.assemble(basis)
B = bilinf2.assemble(basis)

This gives

In [1]: (A - B).todense()                                                       
Out[1]: 
matrix([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
@kinnala kinnala closed this as completed Jul 21, 2020
@kinnala
Copy link
Owner Author

kinnala commented Jul 21, 2020

I'm next looking into reimplementing ex10.py using JAX.

@kinnala
Copy link
Owner Author

kinnala commented Jul 21, 2020

Got up to the point where I'm successfully recreating the RHS of ex10.py using JAX. Now trying to figure out how to get the Hessian.

@kinnala
Copy link
Owner Author

kinnala commented Jul 21, 2020

I think this is more of an issue on "how to use JAX". Trying to use vectorize on the Hessian results in

ValueError: output shape (2,) does not match core dimensions () on vectorized function with excluded=frozenset() and signature=None

I was trying to run

import numpy as np
from skfem import *
from jax import jit, grad, jacrev, jacfwd
from jax.numpy import vectorize
import jax.numpy as jnp

def F(du0, du1):
    return jnp.sqrt(1 + du0 ** 2 + du1 ** 2)

vectorize(jacfwd(jacrev(F, (0, 1)), (0, 1)))(0.1, 0.1)

@kinnala
Copy link
Owner Author

kinnala commented Jul 21, 2020

Now I was able to implement ex10.py using JAX:

# Minimal surface problem

import numpy as np
from skfem import *
from jax import jit, grad
from jax.numpy import vectorize
import jax.numpy as jnp
from jax.ops import index_add


def F(du0, du1):
    return jnp.sqrt(1 + du0 ** 2 + du1 ** 2)

def jac_eval(du0, du1):
    out = np.zeros((2,) + du0.shape)
    for i in range(2):
        out[i] = vectorize(grad(F, i))(du0, du1)
    return out

def hess_eval(du0, du1):
    out = np.zeros((2, 2) + du0.shape)
    for i in range(2):
        for j in range(2):
            out[i, j] = vectorize(grad(grad(F, i), j))(du0, du1)
    return out

m = MeshTri()
m.refine(5)
basis = InteriorBasis(m, ElementTriP1())

@LinearForm
def linf_rhs(v, w):
    from skfem.helpers import dot, grad
    Jw = np.array(jac_eval(*w['prev'].grad))
    return -dot(Jw, grad(v))

@BilinearForm
def bilinf_hess(u, v, w):
    from skfem.helpers import ddot, grad, prod
    Hw = np.array(hess_eval(*w['prev'].grad))
    return ddot(Hw, prod(grad(u), grad(v)))

x = np.zeros(basis.N)
I = m.interior_nodes()
D = m.boundary_nodes()
x[D] = np.sin(np.pi * m.p[0, D])

for itr in range(100):
    prev = basis.interpolate(x)
    K = asm(bilinf_hess, basis, prev=prev)
    f = asm(linf_rhs, basis, prev=prev)
    x_prev = x.copy()
    x += .7 * solve(*condense(K, f, I=I))
    if np.linalg.norm(x - x_prev) < 1e-7:
        break
    print(np.linalg.norm(x - x_prev))

from skfem.visuals.matplotlib import plot3, show
plot3(m, x)
show()

There is some overhead on autodiff which I think can be improved by learning JAX better but it's not huge. Result is correct.

@bhaveshshrimali
Copy link
Contributor

Hi Tom,
This looks great. Thanks a lot for providing this example.

@gdmcbain gdmcbain mentioned this issue Jul 22, 2020
Closed
@bhaveshshrimali
Copy link
Contributor

I came back to this to experiment a bit with AD (and it's potential in solving #439 ). I think the following produces a correct answer for ex10.py (just by looking at the plot) using jacfwd/jacrev and hessian from JAX. Although this is orders of magnitude slower than your MWE.

import numpy as np
from skfem import *
from jax import jit, grad as jgrad, vmap, pmap, jacrev, jacfwd, hessian
from jax.numpy import vectorize
import jax.numpy as jnp
from jax.ops import index_add

@jit
def F(du):
    return jnp.sqrt(1 + du[0] ** 2 + du[1] ** 2)

@jit
def jac_eval(dw):
    Jw = jnp.sum(jit(jacfwd(F))(dw).sum(axis=0),axis=0)
    return Jw

@jit
def hess_eval(dw):
    Hw = jit(hessian(F))(dw)
    Hw = jnp.sum(Hw.sum(axis=0), axis=0)
    Hw = jnp.sum(Hw.sum(axis=2), axis=1)
    return Hw

def hess_anal(du):
    normu = np.sqrt(1. + du[0]**2 + du[1]**2)
    try:
        eyeu = np.zeros((2,2)+du.shape[1:])
        eyeu[0,0] = 1.
        eyeu[1,1] = 1.
    except:
        eyeu = np.eye(2)
    Hw = normu**2. * eyeu - np.einsum('i...,j...->ij...',du,du)
    Hw /= normu**3.
    # print(Hw.shape)
    return Hw

m = MeshTri()
m.refine(3)
basis = InteriorBasis(m, ElementTriP1())

@LinearForm
def linf_rhs(v, w):
    from skfem.helpers import dot, grad
    Jw = np.array(jac_eval(w['prev'].grad))
    return -dot(Jw, grad(v))

@BilinearForm
def bilinf_hess(u, v, w):
    from skfem.helpers import ddot, grad, prod
    Hw = np.array(hess_eval(w['prev'].grad))
    return ddot(Hw, prod(grad(u), grad(v)))

x = np.zeros(basis.N)
I = m.interior_nodes()
D = m.boundary_nodes()
x[D] = np.sin(np.pi * m.p[0, D])

for itr in range(100):
    prev = basis.interpolate(x)
    K = asm(bilinf_hess, basis, prev=prev)
    f = asm(linf_rhs, basis, prev=prev)
    x_prev = x.copy()
    x += solve(*condense(K, f, I=I))
    if np.linalg.norm(x - x_prev) < 1.e-7:
        break
    print(np.linalg.norm(x - x_prev))

from skfem.visuals.matplotlib import plot3, show
plot3(m, x)
show()

I am still not sure of the output of jacrev and jacfwd and hessian but after experimenting a bit (it was a simple fluke tbh) found that summing along the above axes did the trick (just by looking at the shape of the output). The documentation seems to be rather sparse on using these functions.

The above is terribly slow when translated for the hyperelasticity example so it's better to probably shelve that for a later time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants