v0.3.8
New Features:
stax.Elementwise
- a layer for generic elementwise functions requiring the user to specify only scalar-valuednngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]
. The NTK computation (thanks to @SiuMath) and vectorization over the underlyingKernel
happen automatically under the hood. If you can't derive thenngp_fn
for your function, usestax.ElementwiseNumerical
. See docs for more details.
Bugfixes:
- Compatibility with JAX 0.2.21.
Full Changelog: v0.3.7...v0.3.8