Skip to content

Adding dot_general to Einsum Module #4664

@jfc4050

Description

@jfc4050

Hello! Are there any plans to add a dot_general argument to Einsum, if not would a PR be accepted?

Linear allows injectable dot general, but not Einsum. Major reason to not do this is that jnp.einsum marks
dot general injection as experimental, but perhaps it could be marked as experimental in Einsum as well?

My specific motivation is that i'm finding that if we use Flax NNX modules with what we'd expect to lead to a standard mixed precision scheme (bf16 compute dtype and fp32 param dtype), leads to gradients being computed in bf16, allreduced in bf16, then casted to fp32. I want to make it so GEMM outputs fp32 gradient which is then reduced in fp32. More context about this in jax-ml/jax#27496

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions