-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimal viable support for f16 dot product with f16 accumulator on the block pointer path #1211
Conversation
37835e5
to
df9b126
Compare
https://github.com/triton-lang/triton/blob/v2.1.0/python/tutorials/08-experimental-block-pointer.py#L161,#L176 |
if res_dtype in [torch.float16]: | ||
# We observed high relative errors on small numbers when only using 16 bit for accumulation; | ||
# hence, use a more restricted input set here. | ||
a = torch.randint(low=-8, high=8, size=(512, 512), device='xpu', dtype=dtype) / 16 | ||
b = torch.randint(low=-8, high=8, size=(512, 512), device='xpu', dtype=dtype) / 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an alternative to raising the tolerances, we could use restricted inputs (here: [-0.5, 0.5]
in 1/16th increments). WDYT?
Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com>
55a34ee
to
01da0fc
Compare
if res_dtype in [torch.float16]: | ||
# We observed high relative errors on small numbers when only using 16 bit for accumulation; | ||
# hence, use a more restricted input set here. | ||
a = torch.randint(low=-8, high=8, size=(512, 512), device='xpu', dtype=dtype) / 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we generating random integer values for float16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While the results overall look reasonable to me, there are outliers which would require the relative tolerance to be >> 10, so I looked for alternatives to restrict the inputs. I'll bring this up later in the call.
Superseded by #1258. |
This PR adds support for using the `bf16 += bf16 x bf16` variant of the DPAS instruction in the `10-experimental-block-pointer.py` tutorial. I extended the `triton_gen.dpas` op's verifier to support non-f32 accumulator types, and added corresponding testcases for bf16 (as well as for f16 accumulation which I missed in #1211). I had to raise the absolute tolerance again (compared to f16 accumulation). `1e-2` for `bf16` and `1e-3` for `f16` matches what `torch.finfo` returns as "resolution" for the data types, though that might be just a coincidence. --------- Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com>
Adds support for accumulating the dot result with 16-bit precision.