Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not run AccelerateMatmul on pre-Volta GPUs #1505

Merged
merged 3 commits into from
Apr 24, 2023

Conversation

geekypathak21
Copy link
Contributor

Related to #1271 . I am currently working on adding support for Pre-volta GPUs in Triton.

@geekypathak21
Copy link
Contributor Author

geekypathak21 commented Apr 11, 2023

Want to ask one question in convertFMADot() we are using LLVM::FMulAddOp which requires the same type for all operands and results and while doing passes we convert our result tensor to tensor<64x64xf32, #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>> which is of type f32 even when we give result tensor to be of float16 type is there any specific reason for this ? So basically this PR will give support only for float32 types currently for dot operation. If we try to pass float16 type it will give error.

@Jokeren
Copy link
Contributor

Jokeren commented Apr 11, 2023

So basically this PR will give support only for float32 types currently for dot operation. If we try to pass float16 type it will give error.

What errors did you see?

@geekypathak21
Copy link
Contributor Author

What errors did you see?

error: 'llvm.intr.fmuladd' op requires the same type for all operands and results
Pass execution failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
[1]    139783 abort (core dumped)  python3 test_files1.py

@Jokeren got this error when passing these tensors.

 a = torch.randn((64, 16), device="cuda", dtype=torch.float16)
 b = torch.randn((16, 64), device="cuda", dtype=torch.float16)
 c = torch.empty((64, 64), device="cuda", dtype=torch.float16)

I tried to dump all the tensor types the result tensor I got of different type is.

tensor<64x64xf32, #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>>

When tried with float32 it worked successfully.

lib/Analysis/Utility.cpp Outdated Show resolved Hide resolved
@@ -147,8 +149,6 @@ class BlockedToMMA : public mlir::RewritePattern {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTile);
} else {
llvm_unreachable("Mma layout only supports versionMajor in {1, 2}");
Copy link
Contributor Author

@geekypathak21 geekypathak21 Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to assertion will never reach here. So I think no point in adding it back.

@ptillet ptillet changed the title Adding support For Pre-volta GPU Do not run AccelerateMatmul on pre-Volta GPUs Apr 11, 2023
@Ph0rk0z
Copy link

Ph0rk0z commented Apr 14, 2023

I got this kind of error:

error: 'llvm.intr.fmuladd' op requires the same type for all operands and results
Pass execution failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Aborted

@Kraegge
Copy link

Kraegge commented Apr 15, 2023

I get this error running it on a Tesla K80 (Kepler architecture, compute capability 3.7).

error: 'llvm.intr.fmuladd' op requires the same type for all operands and results
Pass execution failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.

@geekypathak21
Copy link
Contributor Author

@Ph0rk0z @Kraegge I am working on this bug also will update this PR when done.

@Ph0rk0z
Copy link

Ph0rk0z commented Apr 15, 2023

Ok. i was eager to try triton and see what kind of speeds I get on my pascal card. I saw that GPTQ closed their bugs and assumed the best :)

@geekypathak21 geekypathak21 force-pushed the add-prevoltasupport branch 2 times, most recently from 8ef6fcf to b060693 Compare April 17, 2023 19:21
Copy link
Contributor Author

@geekypathak21 geekypathak21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ptillet I have tried to add support float16 but it required some discussion before moving further so just added a warning in this commit.


if (computeCapability < 70) {
if (oldAType.getElementType().isF16()) {
llvm_unreachable("Float16 type is not supported with computeCapability "
Copy link
Contributor Author

@geekypathak21 geekypathak21 Apr 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lot of people were facing this issue so added a error statement here.

auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();

if (computeCapability < 70) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v100 has a compute capability of 70 so changed it from <= to < because it supports MMA Layout.

Copy link
Collaborator

@ptillet ptillet Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error should be somewhere else. Probably at the beginning of semantic.dot in the frontend. Here an assertion that computeCapability >= 70 should be enough.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, your check doesn't cover float8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptillet Done as you have suggested 👍

@clxyder
Copy link

clxyder commented Apr 21, 2023

Hey @geekypathak21 do you think it's possible to support 4 or 8 bit data representation for the Pascal cards and greater?

@geekypathak21
Copy link
Contributor Author

Hey @clxyder I think it's possible to support 16 but not sure about 8 and 4.

@geekypathak21 geekypathak21 force-pushed the add-prevoltasupport branch 2 times, most recently from 46f7e17 to 31ee8fc Compare April 21, 2023 08:58
@Ph0rk0z
Copy link

Ph0rk0z commented Apr 21, 2023

Hey @clxyder I think it's possible to support 16 but not sure about 8 and 4.

But it is working in 4bit on cuda kernels currently. 8bit through bits and bytes is a bit sketchy and doesn't work with every model.

@ptillet ptillet merged commit 6d22643 into triton-lang:main Apr 24, 2023
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
…#1505)

Related to triton-lang#1271 . I am currently working on adding support for
Pre-volta GPUs in Triton.

---------

Co-authored-by: Himanshu Pathak <himanshu@mtatva.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 5, 2024
These fixes allow the Triton project to build under gcc-9.

cc triton-lang#1505
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants