Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal viable support for f16 dot product with f16 accumulator on the block pointer path #1211

Closed
wants to merge 2 commits into from

Conversation

jopperm
Copy link
Contributor

@jopperm jopperm commented May 30, 2024

Adds support for accumulating the dot result with 16-bit precision.

@jopperm jopperm self-assigned this May 30, 2024
@jopperm jopperm linked an issue May 30, 2024 that may be closed by this pull request
8 tasks
@jopperm jopperm force-pushed the jopperm/blockptr_f16accu branch 2 times, most recently from 37835e5 to df9b126 Compare May 31, 2024 09:52
@Dewei-Wang-sh
Copy link
Contributor

https://github.com/triton-lang/triton/blob/v2.1.0/python/tutorials/08-experimental-block-pointer.py#L161,#L176
this is the original test. that's why I'm asking if you have time to make res-type the same with a/b type in #1155

Comment on lines +227 to +231
if res_dtype in [torch.float16]:
# We observed high relative errors on small numbers when only using 16 bit for accumulation;
# hence, use a more restricted input set here.
a = torch.randint(low=-8, high=8, size=(512, 512), device='xpu', dtype=dtype) / 16
b = torch.randint(low=-8, high=8, size=(512, 512), device='xpu', dtype=dtype) / 16
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative to raising the tolerances, we could use restricted inputs (here: [-0.5, 0.5] in 1/16th increments). WDYT?

Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com>
@jopperm jopperm marked this pull request as ready for review May 31, 2024 15:45
@jopperm jopperm requested review from whitneywhtsang, etiotto and a team May 31, 2024 15:45
if res_dtype in [torch.float16]:
# We observed high relative errors on small numbers when only using 16 bit for accumulation;
# hence, use a more restricted input set here.
a = torch.randint(low=-8, high=8, size=(512, 512), device='xpu', dtype=dtype) / 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we generating random integer values for float16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the results overall look reasonable to me, there are outliers which would require the relative tolerance to be >> 10, so I looked for alternatives to restrict the inputs. I'll bring this up later in the call.

@jopperm
Copy link
Contributor Author

jopperm commented Jun 5, 2024

Superseded by #1258.

@jopperm jopperm closed this Jun 5, 2024
@jopperm jopperm deleted the jopperm/blockptr_f16accu branch June 5, 2024 19:15
whitneywhtsang pushed a commit that referenced this pull request Jun 20, 2024
This PR adds support for using the `bf16 += bf16 x bf16` variant of the
DPAS instruction in the `10-experimental-block-pointer.py` tutorial.

I extended the `triton_gen.dpas` op's verifier to support non-f32
accumulator types, and added corresponding testcases for bf16 (as well
as for f16 accumulation which I missed in #1211).

I had to raise the absolute tolerance again (compared to f16
accumulation). `1e-2` for `bf16` and `1e-3` for `f16` matches what
`torch.finfo` returns as "resolution" for the data types, though that
might be just a coincidence.

---------

Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[GEMM] Fix functional issues with non-float16 dtype
4 participants