v0.3.0
0.3.0 (2025-03-05)
The main changes are:
- [JAX] New JIT Uniform 1d kernel with JAX bindings
- Computes any polynomial based on 1d uniform STPs
- Supports arbitrary derivatives
- Provides optional fused scatter/gather for the inputs and outputs
- 🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉
- [Torch] Adds torch.compile support
- [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel
- enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1
- enable the new kernel by setting the environement variable
- [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d
- this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed
- this is a temporary API that will change,
Breaking Changes
- [Torch/JAX] Removed
cue.TensorProductExecutionand addedcue.Operationwhich is more lightweight and better aligned with the backend. - [JAX] In
cuex.equivariant_tensor_product, the argumentsdtype_mathanddtype_outputare renamed tomath_dtypeandoutput_dtyperespectively. This change adds consistency with the rest of the library. - [JAX] In
cuex.equivariant_tensor_product, the argumentsalgorithm,precision,use_custom_primitiveanduse_custom_kernelshave been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argumentimpl: strhas been added instead to select the implementation. - [JAX] Removed
cuex.symmetric_tensor_product. Thecuex.tensor_productfunction now handles any non-homogeneous polynomials. - [JAX] The batching support (
jax.vmap) ofcuex.equivariant_tensor_productis now limited to specific use cases. - [JAX] The interface of
cuex.tensor_producthas changed. It now takes a list oftuple[cue.Operation, cue.SegmentedTensorProduct]instead of a singlecue.SegmentedTensorProduct. This change allowscuex.tensor_productto execute any type of non-homogeneous polynomials. - [JAX] Removed
cuex.flax_linen.Linearto reduce maintenance burden. Usecue.descriptor.lineartogether withcuex.equivariant_tensor_productinstead.
e = cue.descriptors.linear(input.irreps, output_irreps)
w = self.param(name, jax.random.normal, (e.inputs[0].dim,), input.dtype)
output = cuex.equivariant_tensor_product(e, w, input)Fixed
- [Torch/JAX] Fixed
cue.descriptor.full_tensor_productwhich was ignoring theirreps3_filterargument. - [Torch/JAX] Fixed a rare bug with
np.bincountwhen using an old version of numpy. The input is now flattened to make it work with all versions. - [Torch] Identified a bug in the CUDA kernel and disabled CUDA kernel for
cuet.TransposeSegmentsandcuet.TransposeIrrepsLayout.
Added
- [Torch/JAX] Added
__mul__tocue.EquivariantTensorProductto allow rescaling the equivariant tensor product. - [JAX] Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations.
- [JAX] Added an
indicesargument tocuex.equivariant_tensor_productandcuex.tensor_productto handle the scatter/gather fusion. - [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1) - [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed)
Full Changelog: v0.2.0...v0.3.0