You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to implement a fp32/tf32 matrix multiplication kernel using Pallas. However, the numeric results have more error than I hoped. Specifically, I have the following code:
While the numeric difference may seem somewhat small, the difference between C_pallas_tf32 and the others becomes significant in larger applications. I am specifically curious why there is a difference between C_ref_tf32 and C_pallas_tf32. Both of them should be using tf32, so I was thinking that they should be very close to equal, much like C_ref_no_tf32 and C_pallas_no_tf32.
Two main questions:
do you know why this may be the case?
is there any way to get Pallas/Jax to dump the Pallas kernel's PTX? that way maybe at least I could inspect what it's doing.
I know that it's unreasonable to expect bitwise equality with floating point numbers, but this error does seem really hard to understand.
System info (python version, jaxlib version, accelerator, etc.)
This is the same issue encountered here: triton-lang/triton#4574. I applied the same fix recommended there and was able to get the same result for TF32 between XLA and Pallas. You can try it by pulling this branch: #23262.
is there any way to get Pallas/Jax to dump the Pallas kernel's PTX? that way maybe at least I could inspect what it's doing.
You can pass in debug=True to pallas_call and it will dump the Triton IR. But in this case you wouldn't see anything suspicious since it's due to rounding issues.
The right solution here is probably to allow inline assembly in Pallas since we don't have that functionality yet.
Description
Hi,
I'm trying to implement a fp32/tf32 matrix multiplication kernel using Pallas. However, the numeric results have more error than I hoped. Specifically, I have the following code:
And this outputs:
While the numeric difference may seem somewhat small, the difference between
C_pallas_tf32
and the others becomes significant in larger applications. I am specifically curious why there is a difference betweenC_ref_tf32
andC_pallas_tf32
. Both of them should be usingtf32
, so I was thinking that they should be very close to equal, much likeC_ref_no_tf32
andC_pallas_no_tf32
.Two main questions:
I know that it's unreasonable to expect bitwise equality with floating point numbers, but this error does seem really hard to understand.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: