You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a proposal to improve and generalize jax.lax.Precision. The current design has a number of limitations:
Its difficult to write device-independent code, as jax.lax.Precision means different numerics on different hardware (eg. jax.lax.Precision.DEFAULT with FP32 inputs does matmul in FP32 on V100, TF32 on A100/H100, BF16 on TPUs), which is a major user footgun for porting JAX functions between TPU/GPU.
Its not general enough to support alternative matmul precision modes. For example, CUTLASS supports faster FP32 matmul via 3-pass TF32. Or using bfloat16_3x or bfloat16_6x or bfloat16_9x on GPUs supporting bfloat16. Or jax.lax.Precision('fastest'): the user might want float8, bfloat16, float16 or tensorfloat32.
The simplest way of supporting this is to extend the jax.lax.Precision enum to support extra types:
Breaking change: jax.lax.Precision('bfloat16_3x') and jax.lax.Precision('tensorfloat32') no longer evaluates to jax.lax.Precision.HIGH. These are now separate enum types. The current design is highly misleading: if a user sees jax.lax.Precision('bfloat16_3x'), they would expect bfloat16_3x precision, but they will only get this on TPU (will get tensorfloat32 on A100).
If a user requests a precision mode not supported on a device (eg. jax.lax.Precision('tensorfloat32') on TPU), this should fail rather than defaulting silently to some other type with different numerics.
Is it possible to use a full proto for this? I'm thinking of cases such as integer matrix multiplication where you might want arbitrary bit counts for the inputs and output, even if they are stored in wider declared data types.
This is a proposal to improve and generalize jax.lax.Precision. The current design has a number of limitations:
jax.lax.Precision
means different numerics on different hardware (eg.jax.lax.Precision.DEFAULT
with FP32 inputs does matmul in FP32 on V100, TF32 on A100/H100, BF16 on TPUs), which is a major user footgun for porting JAX functions between TPU/GPU.bfloat16_3x
orbfloat16_6x
orbfloat16_9x
on GPUs supporting bfloat16. Orjax.lax.Precision('fastest')
: the user might wantfloat8
,bfloat16
,float16
ortensorfloat32
.The simplest way of supporting this is to extend the
jax.lax.Precision
enum to support extra types:jax.lax.Precision('bfloat16_3x')
andjax.lax.Precision('tensorfloat32')
no longer evaluates tojax.lax.Precision.HIGH
. These are now separate enum types. The current design is highly misleading: if a user seesjax.lax.Precision('bfloat16_3x')
, they would expectbfloat16_3x
precision, but they will only get this on TPU (will gettensorfloat32
on A100).jax.lax.Precision('tensorfloat32')
on TPU), this should fail rather than defaulting silently to some other type with different numerics.@hawkinsp, @cheshire
The text was updated successfully, but these errors were encountered: