-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimal viable support for tf32 on the block pointer path #1172
Conversation
fb22772
to
9896ed2
Compare
The failure in https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/9267818058/job/25513998366?pr=1172 is in the fallback path, it is doing some invalid bitcast, e.g., |
f8cfe71
to
37de28a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@@ -243,8 +244,9 @@ def matmul(a, b, res_dtype): | |||
|
|||
# Note: the torch.matmul and Triton implementations uses different | |||
# algorithms so we need to adjust tolerance. | |||
atol = 4e-2 if dtype == torch.float32 else 1e-4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for this, we can double confirm with other teams(kernel library, igc, etc) that used tf32 gemm previously.
to make sure this is as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that would be good to know. I'd need to raise the bounds again in #1211.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the rational to increase atol vs rtol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing specific, I just played around with the parameters. The maximum relative error is 146 (triton: 0.0146, torch: 0.0001). I think the underlying problem might be that the reference computation is not done with TF32 precision; still investigating how to enable that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Increasing rtol is better because the torch.allclose
comparison is:
∣input−other∣≤atol+rtol×∣other∣
So increasing atol affects comparisons regardless of the value.
If we can force torch to use TF32 precision that would be ideal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found out how to set the TF32 mode and hence was able to drop the changes to the tolerances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
5030564
to
e588e8c
Compare
Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com>
e588e8c
to
3b09e6d
Compare
f131540
to
7e0640b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Adds support for TF32 dot products. The main difference compared to 16-bit datatypes is that the A operand load needs to be encoded as an
i32
-based block load, and it must not use the VNNI format.