Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal viable support for tf32 on the block pointer path #1172

Merged
merged 2 commits into from
May 31, 2024

Conversation

jopperm
Copy link
Contributor

@jopperm jopperm commented May 22, 2024

Adds support for TF32 dot products. The main difference compared to 16-bit datatypes is that the A operand load needs to be encoded as an i32-based block load, and it must not use the VNNI format.

@jopperm jopperm self-assigned this May 22, 2024
@jopperm jopperm linked an issue May 22, 2024 that may be closed by this pull request
8 tasks
@whitneywhtsang
Copy link
Contributor

whitneywhtsang commented May 28, 2024

The failure in https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/9267818058/job/25513998366?pr=1172 is in the fallback path, it is doing some invalid bitcast, e.g., %248 = llvm.bitcast %247 : vector<4xi16> to vector<4xf32>, where the source and destination type sizes are different. It should be fixed by #1204.

@jopperm jopperm changed the title Experimental TF32 support Minimal viable support for tf32 on the block pointer path May 29, 2024
@jopperm jopperm marked this pull request as ready for review May 30, 2024 09:23
Copy link
Contributor

@Dewei-Wang-sh Dewei-Wang-sh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@@ -243,8 +244,9 @@ def matmul(a, b, res_dtype):

# Note: the torch.matmul and Triton implementations uses different
# algorithms so we need to adjust tolerance.
atol = 4e-2 if dtype == torch.float32 else 1e-4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, we can double confirm with other teams(kernel library, igc, etc) that used tf32 gemm previously.
to make sure this is as expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that would be good to know. I'd need to raise the bounds again in #1211.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the rational to increase atol vs rtol?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing specific, I just played around with the parameters. The maximum relative error is 146 (triton: 0.0146, torch: 0.0001). I think the underlying problem might be that the reference computation is not done with TF32 precision; still investigating how to enable that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increasing rtol is better because the torch.allclose comparison is:

∣input−other∣≤atol+rtol×∣other∣

So increasing atol affects comparisons regardless of the value.
If we can force torch to use TF32 precision that would be ideal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found out how to set the TF32 mode and hence was able to drop the changes to the tolerances.

Copy link
Contributor

@mfrancepillois mfrancepillois left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com>
@jopperm jopperm merged commit 7843ec0 into llvm-target May 31, 2024
2 checks passed
@jopperm jopperm deleted the jopperm/blockptr_tf32 branch May 31, 2024 15:29
Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[GEMM] Fix functional issues with non-float16 dtype
6 participants