-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Studies on autodiff with JAX #440
Comments
I'm next looking into reimplementing ex10.py using JAX. |
Got up to the point where I'm successfully recreating the RHS of ex10.py using JAX. Now trying to figure out how to get the Hessian. |
I think this is more of an issue on "how to use JAX". Trying to use
I was trying to run import numpy as np
from skfem import *
from jax import jit, grad, jacrev, jacfwd
from jax.numpy import vectorize
import jax.numpy as jnp
def F(du0, du1):
return jnp.sqrt(1 + du0 ** 2 + du1 ** 2)
vectorize(jacfwd(jacrev(F, (0, 1)), (0, 1)))(0.1, 0.1) |
Now I was able to implement ex10.py using JAX: # Minimal surface problem
import numpy as np
from skfem import *
from jax import jit, grad
from jax.numpy import vectorize
import jax.numpy as jnp
from jax.ops import index_add
def F(du0, du1):
return jnp.sqrt(1 + du0 ** 2 + du1 ** 2)
def jac_eval(du0, du1):
out = np.zeros((2,) + du0.shape)
for i in range(2):
out[i] = vectorize(grad(F, i))(du0, du1)
return out
def hess_eval(du0, du1):
out = np.zeros((2, 2) + du0.shape)
for i in range(2):
for j in range(2):
out[i, j] = vectorize(grad(grad(F, i), j))(du0, du1)
return out
m = MeshTri()
m.refine(5)
basis = InteriorBasis(m, ElementTriP1())
@LinearForm
def linf_rhs(v, w):
from skfem.helpers import dot, grad
Jw = np.array(jac_eval(*w['prev'].grad))
return -dot(Jw, grad(v))
@BilinearForm
def bilinf_hess(u, v, w):
from skfem.helpers import ddot, grad, prod
Hw = np.array(hess_eval(*w['prev'].grad))
return ddot(Hw, prod(grad(u), grad(v)))
x = np.zeros(basis.N)
I = m.interior_nodes()
D = m.boundary_nodes()
x[D] = np.sin(np.pi * m.p[0, D])
for itr in range(100):
prev = basis.interpolate(x)
K = asm(bilinf_hess, basis, prev=prev)
f = asm(linf_rhs, basis, prev=prev)
x_prev = x.copy()
x += .7 * solve(*condense(K, f, I=I))
if np.linalg.norm(x - x_prev) < 1e-7:
break
print(np.linalg.norm(x - x_prev))
from skfem.visuals.matplotlib import plot3, show
plot3(m, x)
show() There is some overhead on autodiff which I think can be improved by learning JAX better but it's not huge. Result is correct. |
Hi Tom, |
I came back to this to experiment a bit with AD (and it's potential in solving #439 ). I think the following produces a correct answer for import numpy as np
from skfem import *
from jax import jit, grad as jgrad, vmap, pmap, jacrev, jacfwd, hessian
from jax.numpy import vectorize
import jax.numpy as jnp
from jax.ops import index_add
@jit
def F(du):
return jnp.sqrt(1 + du[0] ** 2 + du[1] ** 2)
@jit
def jac_eval(dw):
Jw = jnp.sum(jit(jacfwd(F))(dw).sum(axis=0),axis=0)
return Jw
@jit
def hess_eval(dw):
Hw = jit(hessian(F))(dw)
Hw = jnp.sum(Hw.sum(axis=0), axis=0)
Hw = jnp.sum(Hw.sum(axis=2), axis=1)
return Hw
def hess_anal(du):
normu = np.sqrt(1. + du[0]**2 + du[1]**2)
try:
eyeu = np.zeros((2,2)+du.shape[1:])
eyeu[0,0] = 1.
eyeu[1,1] = 1.
except:
eyeu = np.eye(2)
Hw = normu**2. * eyeu - np.einsum('i...,j...->ij...',du,du)
Hw /= normu**3.
# print(Hw.shape)
return Hw
m = MeshTri()
m.refine(3)
basis = InteriorBasis(m, ElementTriP1())
@LinearForm
def linf_rhs(v, w):
from skfem.helpers import dot, grad
Jw = np.array(jac_eval(w['prev'].grad))
return -dot(Jw, grad(v))
@BilinearForm
def bilinf_hess(u, v, w):
from skfem.helpers import ddot, grad, prod
Hw = np.array(hess_eval(w['prev'].grad))
return ddot(Hw, prod(grad(u), grad(v)))
x = np.zeros(basis.N)
I = m.interior_nodes()
D = m.boundary_nodes()
x[D] = np.sin(np.pi * m.p[0, D])
for itr in range(100):
prev = basis.interpolate(x)
K = asm(bilinf_hess, basis, prev=prev)
f = asm(linf_rhs, basis, prev=prev)
x_prev = x.copy()
x += solve(*condense(K, f, I=I))
if np.linalg.norm(x - x_prev) < 1.e-7:
break
print(np.linalg.norm(x - x_prev))
from skfem.visuals.matplotlib import plot3, show
plot3(m, x)
show() I am still not sure of the output of The above is terribly slow when translated for the hyperelasticity example so it's better to probably shelve that for a later time. |
I'll collect here some tests I'm currently doing using JAX, inspired by #439 .
First, I tried to calculate the plain gradient of a usual linear problem:
This gives
The text was updated successfully, but these errors were encountered: