Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impossible to use the tutorials #1271

Open
lucasgrjn opened this issue Mar 2, 2023 · 26 comments
Open

Impossible to use the tutorials #1271

lucasgrjn opened this issue Mar 2, 2023 · 26 comments

Comments

@lucasgrjn
Copy link

Hi !

I am currently trying to understand how to use Triton with tutorials. Unfortunately, I encounter two different issues:

  • for 03-matrix-multiplication.py and 06-fused-attention.py, I get:
python: /project/lib/Analysis/Utility.cpp:136: bool mlir::supportMMA(mlir::Value, int): Assertion `(version == 1 || version == 2) && "Unexpected MMA layout version found"' failed.
Aborted

The error seems to occurs at the line

tl.store(c_ptrs, c, mask=c_mask)

Since I have a GTX1080 on my computer, I work with Pascal architecture. The MMA is supported by Volta and Hopper. Nevertheless, is it possible to optimize the matmul for my GTX1080 ?

  • for 05-layer-norm.py, the error is
Argument rematerialization not implemented

UNREACHABLE executed at /project/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp:45!
Aborted

For this one, I dont have any clue...

Does someone have some thoughts on my issues?

Thanks in advance and regards,
Lucas.

@ptillet
Copy link
Collaborator

ptillet commented Mar 2, 2023

FP16 is not upported on pre-tensorcores GPU. Can you try FP32?

@lucasgrjn
Copy link
Author

When using tl.float32, nothing changes, I get the same error

@Jokeren
Copy link
Contributor

Jokeren commented Mar 2, 2023

If it's a pre-Volta GPU, we don't generate the MMA layout in any means.

So perhaps we shouldn't use assert in places like:

https://github.com/openai/triton/blob/65e5a3bc24c9649d7a5e96acfc11e65bd3899fd6/lib/Analysis/Utility.cpp#L138

Feel free to modify the code and contribute.

@lucasgrjn
Copy link
Author

If it's a pre-Volta GPU, we don't generate the MMA layout in any means.

Thanks ! I will take a look and see if I can find a way to avoid this issue and make a PR.

Any idea for my second issue on Argument rematerialization ?

@Jokeren
Copy link
Contributor

Jokeren commented Mar 3, 2023

Any idea for my second issue on Argument rematerialization ?

Not sure how this problem is triggered yet.

@ptillet
Copy link
Collaborator

ptillet commented Mar 3, 2023

We don't have pre-Volta GPUs to test things out, but we can provide some guidance if you're interesting in debugging the issue. I think the main thing for layer norm would be to figure out why the codegen is any different for your 1080 than for a Volta GPU. All GPUs with compute capability <= 70 should be treated the same 🤔

@lucasgrjn
Copy link
Author

Right, I see the main idea! I will give it a look but since I am a newbie in this kind of stuff, not sure I could go to deep unfortunately...

@andreicozma1
Copy link

I can confirm I am also getting this issue on RTX A6000

@s-JoL
Copy link

s-JoL commented Mar 24, 2023

I also encounter the issue "Argument rematerialization not implemented" when running 05-layer-norm.py on a100-80g.

@RuABraun
Copy link

RuABraun commented Mar 28, 2023

Randomly (not every time) getting

Argument rematerialization not implemented
UNREACHABLE executed at /project/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp:45!

when running a custom fused linear layer. (has activation, dropout and scaling)

edit: this was actually cuz of layernorm

@clxyder
Copy link

clxyder commented Apr 8, 2023

Hey @Dj1312 were you able to find a fix for this issue?

@Ph0rk0z
Copy link

Ph0rk0z commented Apr 10, 2023

How to fix this for pascal? Even if it's slower.

@clxyder
Copy link

clxyder commented Apr 11, 2023

Hey @ptillet, I'm trying to debug this issue on my pascal card. I have outlined my particular case in this issue qwopqwop200/GPTQ-for-LLaMa#142.

I've swapped the following lines, note this is off of the v2.0.0 tag:

https://github.com/openai/triton/blob/bd5c2117f62c73a9e922d5e93353a39ab3ac269b/lib/Analysis/Utility.cpp#L136-L137

with the following:

if (version != 1 || version != 2)
    return false;

This results in the following error:

error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: builtin.unrealized_conversion_cast
Failed to emit LLVM IR
Translate to LLVM IR failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.

Do you have any suggestions?

@lucasgrjn
Copy link
Author

lucasgrjn commented Apr 11, 2023

Hey @Dj1312 were you able to find a fix for this issue?

Unfortunately, no...

@Ph0rk0z
Copy link

Ph0rk0z commented Apr 11, 2023

So it needs to be casted somehow? But I swear I have run other float16 code.

ptillet added a commit that referenced this issue Apr 24, 2023
Related to #1271 . I am currently working on adding support for
Pre-volta GPUs in Triton.

---------

Co-authored-by: Himanshu Pathak <himanshu@mtatva.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
@vmarkovtsev
Copy link

"Argument rematerialization not implemented" is probably a regression because the tutorials work for me on version 2.0.0.dev20221105 with CUDA 11.8.

@ptillet
Copy link
Collaborator

ptillet commented May 11, 2023

Our docs build runs nightly without issues on an A100. It's possible there are some troubles on older GPUs unfortunately. I don't have any Pascal GPU I can use so it's hard for me to repro

@RuABraun
Copy link

RuABraun commented May 12, 2023

Just to add I think people are getting this error from running pip install as that version crashes when doing

x = torch.randn(512).cuda()
ln = FusedLayerNorm(512).cuda()
y=ln(x)
l=y.sum()
l.backward()  # crash

on an A100 (cuda 11.8, torch 2.0.0+cu118, triton 2.0.0) (FusedLayerNorm uses this and code from the tutorial)

Not clear to me how to get nightly without compiling the code (which if I'm understanding my compilation error correctly requires an advanced version of C++)

@ptillet
Copy link
Collaborator

ptillet commented May 13, 2023

Nightly will be back up soon. Thanks for your patience! In the meantime recompiling the code shouldn't be too difficult

@cszipper
Copy link

cszipper commented May 16, 2023

pip install triton==2.0.0.dev20230217 works on V100

@cebtenzzre
Copy link

I tried the tutorials on my GTX 970, and didn't get very far. I'm testing on latest main (commit dd2d5f4).

03-matrix-multiplication.py, 06-fused-attention.py, and 08-experimental-block-pointer.py (duplicate lines omitted)

error: invalid element type in packLLEElements. Expected 'f32' but got 'f16'
error: 'llvm.intr.fmuladd' op requires the same type for all operands and results
Pass execution failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.

05-layer-norm.py

Traceback (most recent call last):
  File "/home/cebtenzzre/src/clones/triton/python/tutorials/05-layer-norm.py", line 367, in <module>
    test_layer_norm(1151, 8192, torch.float16)
  File "/home/cebtenzzre/src/clones/triton/python/tutorials/05-layer-norm.py", line 310, in test_layer_norm
    y_tri.backward(dy, retain_graph=True)
  File "/usr/lib/python3.11/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/usr/lib/python3.11/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/lib/python3.11/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/cebtenzzre/src/clones/triton/python/tutorials/05-layer-norm.py", line 281, in backward
    _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
  File "<string>", line 42, in _layer_norm_bwd_dx_fused
  File "/home/cebtenzzre/src/clones/triton/python/triton/compiler/compiler.py", line 465, in compile
    next_module = compile_kernel(module)
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cebtenzzre/src/clones/triton/python/triton/compiler/compiler.py", line 361, in <lambda>
    lambda src: ptx_to_cubin(src, arch))
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cebtenzzre/src/clones/triton/python/triton/compiler/compiler.py", line 160, in ptx_to_cubin
    return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/compile-ptx-src-b7492e, line 1370; error   : Feature 'scopes on atomic operations' requires .target sm_60 or higher
ptxas /tmp/compile-ptx-src-b7492e, line 1466; error   : Feature 'scopes on atomic operations' requires .target sm_60 or higher
ptxas fatal   : Ptx assembly aborted due to errors

@RuABraun
Copy link

RuABraun commented Jul 6, 2023

Is there a nightly wheel available somewhere?

@mikegreen7892003
Copy link

I modified the code as following and it works.

# First store doesn't accumulate
if count == 0:
    tl.atomic_xchg(Count, 1)
else:
    # partial_dw += tl.load(DW, mask=mask)
    # partial_db += tl.load(DB, mask=mask)

# ignore the condition of count == 0 
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)

tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask

Maybe this condition triggers something.

@cebtenzzre
Copy link

@mikegreen7892003 That will throw an IndentationError, you either need a 'pass' in the else block or you need to comment out the else clause entirely. Also, you're missing a closing parenthesis.

@ogrisel
Copy link

ogrisel commented Dec 12, 2023

tried the tutorials on my GTX 970, and didn't get very far. I'm testing on latest main (commit dd2d5f4).

error: invalid element type in packLLEElements. Expected 'f32' but got 'f16'
error: 'llvm.intr.fmuladd' op requires the same type for all operands and results
Pass execution failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.

@cebtenzzre I believe this is because your GPU does not support operating on float16 inputs.

Try to edit the tutorial code to use float32 instead. In the matmul tutorial you will also have to edit the autotuning configs to reduce the num_stages values and probably the group sizes to not go above the maximum shared memory limit of the hardware.

Note for triton developers: instead of crashing with a low level error message for unsupported dtypes, it would be more user friendly to raise a Python-level exception earlier with a higher level error message.

At the moment I get on a GTX 1080 TI:

loc(fused["/home/ogrisel/code/triton-sandbox/matmul.py":72:23, "/home/ogrisel/code/triton-sandbox/matmul.py":72:33]): error: invalid element type in packLLEE
lements. Expected 'f32' but got 'f16'
loc(fused["/home/ogrisel/code/triton-sandbox/matmul.py":72:23, "/home/ogrisel/code/triton-sandbox/matmul.py":72:33]): error: invalid element type in packLLEE
lements. Expected 'f32' but got 'f16'
[...]  # repeated many times, then:
loc(fused["/home/ogrisel/code/triton-sandbox/matmul.py":72:23, "/home/ogrisel/code/triton-sandbox/matmul.py":72:33]): error: 'llvm.intr.fmuladd' op requires the same type for all operands and results
Pass execution failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Aborted (core dumped)

I am not sure how to inspect which dtypes are supported by a given device though. I had a look at: https://pytorch.org/docs/stable/cuda.html but the only think I see would be to manually map the compute capability tuple to a list of supported dtypes.

@Ph0rk0z
Copy link

Ph0rk0z commented Dec 14, 2023

Well pascal is unsupported. I mean why support a $200 24G card when everyone can buy $700 3090s or $3000 V100. 7b model should be enough for everyone :P

pingzhuu pushed a commit to siliconflow/triton that referenced this issue Apr 2, 2024
…#1505)

Related to triton-lang#1271 . I am currently working on adding support for
Pre-volta GPUs in Triton.

---------

Co-authored-by: Himanshu Pathak <himanshu@mtatva.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 5, 2024
…ointer (triton-lang#1272)

Addition of a possible pattern for MMA layout propagation when the
ConvertLayoutOp is inside the loop, the layout is retrieved from the
layout map instead of the ConvertLayoutOp.

Addresses Issue: triton-lang#1271

---------

Signed-off-by: Maxime France-Pillois <maxime.francepillois@codeplay.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests