Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow explicit matmul precision #18934

Open
sbodenstein opened this issue Dec 12, 2023 · 3 comments
Open

Allow explicit matmul precision #18934

sbodenstein opened this issue Dec 12, 2023 · 3 comments
Labels
enhancement New feature or request

Comments

@sbodenstein
Copy link
Contributor

sbodenstein commented Dec 12, 2023

This is a proposal to improve and generalize jax.lax.Precision. The current design has a number of limitations:

  1. Its difficult to write device-independent code, as jax.lax.Precision means different numerics on different hardware (eg. jax.lax.Precision.DEFAULT with FP32 inputs does matmul in FP32 on V100, TF32 on A100/H100, BF16 on TPUs), which is a major user footgun for porting JAX functions between TPU/GPU.
  2. Its not general enough to support alternative matmul precision modes. For example, CUTLASS supports faster FP32 matmul via 3-pass TF32. Or using bfloat16_3x or bfloat16_6x or bfloat16_9x on GPUs supporting bfloat16. Or jax.lax.Precision('fastest'): the user might want float8, bfloat16, float16 or tensorfloat32.

The simplest way of supporting this is to extend the jax.lax.Precision enum to support extra types:

  1. Breaking change: jax.lax.Precision('bfloat16_3x') and jax.lax.Precision('tensorfloat32') no longer evaluates to jax.lax.Precision.HIGH. These are now separate enum types. The current design is highly misleading: if a user sees jax.lax.Precision('bfloat16_3x'), they would expect bfloat16_3x precision, but they will only get this on TPU (will get tensorfloat32 on A100).
  2. If a user requests a precision mode not supported on a device (eg. jax.lax.Precision('tensorfloat32') on TPU), this should fail rather than defaulting silently to some other type with different numerics.

@hawkinsp, @cheshire

@sbodenstein sbodenstein added the enhancement New feature or request label Dec 12, 2023
@andportnoy
Copy link
Contributor

Related: openxla/stablehlo#755.

@jewillco
Copy link

Is it possible to use a full proto for this? I'm thinking of cases such as integer matrix multiplication where you might want arbitrary bit counts for the inputs and output, even if they are stored in wider declared data types.

@cheshire
Copy link
Member

@pschuh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants