Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

After build from source, Error occured "device kernel image is invalid" #1955

Closed
akakakakakaa opened this issue Jul 18, 2023 · 6 comments
Closed

Comments

@akakakakakaa
Copy link

akakakakakaa commented Jul 18, 2023

I tried on V100, cuda 11.7.

after digging source code, When I change cuda version in this line, working.
https://github.com/openai/triton/blob/9e3e10c5edb4a062cf547ae73e6ebfb19aad7bdf/python/setup.py#L129

So, When I want to install triton from source, Do I need to control cuda version by editing setup.py?

@bhack
Copy link

bhack commented Jul 29, 2023

@akakakakakaa Do you had this error at runtime? Do you have a small code gist to reproduce this?

@akakakakakaa
Copy link
Author

akakakakakaa commented Jul 31, 2023

@bhack I tried to install exactly as written in the README But uses pip install . instead of pip install -e .. because In my case, pip install -e . can't recognize hidden directory.

git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build-time dependency
pip install .

After Installing, I tried to run 06-fused-attention.py and I met error in runtime device kernel image is invalid

@huangxiao2008
Copy link

Same Problem I met

@zy-fang
Copy link

zy-fang commented Jan 19, 2024

I have encountered the same problem, how did you solve it?

@xingjinglu
Copy link

xingjinglu commented Feb 6, 2024

I havve encountered the same problem and sovled it.
The reason, for the main branch of triton, the the default version of ptxas, cuobjdump,nvdisasm in triton is cuda-12.x(which is set in triton/python/setup.py). So when you build trion for cuda-11.x, you need to set the right version of cuda bins with setting the the path of these bins in the environment.

The environment of mine is:

  1. Driver Version: 470.141.03 CUDA Version: 11.4
  2. torch: conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia

Build triton from source as below:

export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas                                                                      
export TRITON_CUOBJDUMP_PATH=/usr/local/cuda/bin/cuobjdump                                                              
export TRITON_NVDISASM_PATH=/usr/local/cuda/bin/nvdisasm  

cd triton/python
pip install -e .

Test.
python python tutorials/01-vector-add.py

The result is as below:

tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
The maximum difference between torch and triton is 0.0
vector-add-performance:
size Triton Torch
0 4096.0 11.377778 11.130435
1 8192.0 21.787235 23.813955
2 16384.0 44.521738 41.795915
3 32768.0 73.142858 72.710056
4 65536.0 127.336788 127.336788
5 131072.0 199.399583 200.620406
6 262144.0 283.296835 285.767442
7 524288.0 381.023277 371.659727
8 1048576.0 412.608613 416.101597
9 2097152.0 444.311871 449.646643
10 4194304.0 463.766462 468.393097
11 8388608.0 472.615390 479.385543
12 16777216.0 477.602370 484.554523
13 33554432.0 478.037844 484.414634
14 67108864.0 479.979873 488.623552
15 134217728.0 479.870017 489.126924

There is a short summary on how build triton from source.

@sujuyu
Copy link

sujuyu commented May 14, 2024

I havve encountered the same problem and sovled it. The reason, for the main branch of triton, the the default version of ptxas, cuobjdump,nvdisasm in triton is cuda-12.x(which is set in triton/python/setup.py). So when you build trion for cuda-11.x, you need to set the right version of cuda bins with setting the the path of these bins in the environment.

The environment of mine is:

  1. Driver Version: 470.141.03 CUDA Version: 11.4
  2. torch: conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia

Build triton from source as below:

export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas                                                                      
export TRITON_CUOBJDUMP_PATH=/usr/local/cuda/bin/cuobjdump                                                              
export TRITON_NVDISASM_PATH=/usr/local/cuda/bin/nvdisasm  

cd triton/python
pip install -e .

Test. python python tutorials/01-vector-add.py

The result is as below:

tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0') tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0') The maximum difference between torch and triton is 0.0 vector-add-performance: size Triton Torch 0 4096.0 11.377778 11.130435 1 8192.0 21.787235 23.813955 2 16384.0 44.521738 41.795915 3 32768.0 73.142858 72.710056 4 65536.0 127.336788 127.336788 5 131072.0 199.399583 200.620406 6 262144.0 283.296835 285.767442 7 524288.0 381.023277 371.659727 8 1048576.0 412.608613 416.101597 9 2097152.0 444.311871 449.646643 10 4194304.0 463.766462 468.393097 11 8388608.0 472.615390 479.385543 12 16777216.0 477.602370 484.554523 13 33554432.0 478.037844 484.414634 14 67108864.0 479.979873 488.623552 15 134217728.0 479.870017 489.126924

There is a short summary on how build triton from source.

Your suggestion is effective, thank you very much

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants