-
Notifications
You must be signed in to change notification settings - Fork 773
Closed
Description
Hello! Are there any plans to add a dot_general argument to Einsum, if not would a PR be accepted?
Linear allows injectable dot general, but not Einsum. Major reason to not do this is that jnp.einsum marks
dot general injection as experimental, but perhaps it could be marked as experimental in Einsum as well?
My specific motivation is that i'm finding that if we use Flax NNX modules with what we'd expect to lead to a standard mixed precision scheme (bf16 compute dtype and fp32 param dtype), leads to gradients being computed in bf16, allreduced in bf16, then casted to fp32. I want to make it so GEMM outputs fp32 gradient which is then reduced in fp32. More context about this in jax-ml/jax#27496
Metadata
Metadata
Assignees
Labels
No labels