Euclid-equivariant operations and harmonic polynomials for JAX.
This library is intended as a faster and full-featured substitute for the e3nn and e3x Euclidean equivariance backends, replacing slow components in Machine Learned Interatomic Potentials (MLIPs) with carefully optimized and open-source CUDA kernels.
The equivariance backend of our MLIP library is e3j as of mlip 0.2.0.
Note:
e3jis currently in pre-release, with version 0.1.0 planned for early June 2026. Additional CUDA kernels and dedicated Pallas kernels for TPU will be rolled out progressively.
The e3j package consists of a thin JAX-based Python API which can run on CPU, GPU and TPU, and currently supports Python versions from 3.11 to 3.14 included.
For efficiency on GPU, our CUDA binaries need to be pulled via the "e3j[ops]" extra:
# requirements.txt
jax[cuda13_local] ~= 0.8.0
e3j[ops] == 0.1.0b0See JAX installation instructions for more information.
Our dependencies are managed with uv. After cloning the repository, you can build from source by running run one of:
# Existing CUDA 13 install with `e3j_ops` kernels:
uv sync --group cuda13_local --extra ops
# Install CUDA 13 via pip and the `exp` group for benchmarks:
uv sync --group cuda13 --extra opsThe Python build internally relies on CMake, scikit-build and pybind11. You can also look at the Makefile for alternate recipes to build kernels,
C++ tests and the Python bindings.
The e3j_ops Python package only contains our CUDA binaries and bindings
to their associated XLA handlers, and is not meant to be used as standalone.
The JAX primitives wrapping our custom XLA handlers are defined in the e3j.ops subpackage of e3j, provided the e3j_ops binaries can be found in the environment.