You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix corner case issue in gradient of lax.pow with an exponent of zero (#12041)
Breaking changes
jax.checkpoint, also known as jax.remat, no longer supports the concrete option, following the previous version's deprecation; see JEP 11830.
Changes
Added jax.pure_callback that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with jax.jit or jax.pmap).
Deprecations:
The deprecated DeviceArray.tile() method has been removed. Use jax.numpy.tile (#11944).
DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead.