Small package to enable using ceviche with a JAX optimizer easily.
This package is not yet published. As soon as it is install with:
pip install javiche
or
conda install javiche
Import the decorator
from javiche import jaxit
decorate your function (will be differentiated using ceviches jacobian -> HIPS autograd)
@jaxit()
def square(A):
"""squares number/array"""
return A**2
Now you can use jax as usual:
grad_fn = jax.grad(square)
grad_fn(2.0)
Array(4., dtype=float32, weak_type=True)
In this toy example that was already possible without the jaxit()
decorator. However jaxit()
decorated functions can contain autograd
operators (but no jax operators):
import autograd.numpy as npa
def sin(A):
"""computes sin of number/array using autograds numpy"""
return npa.sin(A)
grad_sin = jax.grad(sin)
try:
print(grad_sin(0.0))
except Exception as e:
print(e)
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(0.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
primal = 0.0
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[], weak_type=True), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@jaxit()
def cos(A):
"""computes sin of number/array using autograds numpy"""
return npa.cos(A)
grad_cos = jax.grad(cos)
try:
print(grad_cos(0.0))
except Exception as e:
print(e)
-0.0
This library is intended for use with ceviche, while running a JAX optimization stack as demonstated in the inverse design example