Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triton_autotuner: Rounding modifier required for instruction 'cvt' #15900

Open
KeremTurgutlu opened this issue May 7, 2023 · 11 comments
Open
Assignees
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@KeremTurgutlu
Copy link

KeremTurgutlu commented May 7, 2023

Description

Getting the following error when trying to run code on a A100 80GB Google Cloud Debian Deep Learning image (c0-deeplearning-common-cu113-v20230501-debian-10). This code is tested and works on TPU (using t5x library). I don't know if this error is related to my setup but after creating the instance before running the code these are the steps I took:

  1. Created a new conda environment with py3.9

  2. Install latest jax cuda pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, which is 4.8.0 as of writing.

  3. Clone t5x library and install editable local version -e git+https://github.com/google-research/t5x.git@2b010160e7fe8a4505a6d1032a7b737a633636e5#egg=t5x.

  4. Install extra dep:pip install t5.

  5. Upgrade CUDNN library to 8.6.0 as jax complained it requires at least that version by manually downloading cudnn-linux-x86_64-8.6.0.163_cuda11-archive.tar.xz and then running the following:

$ sudo cp cudnn-*-archive/include/cudnn*.h /usr/local/cuda/include 
$ sudo cp -P cudnn-*-archive/lib/libcudnn* /usr/local/cuda/lib64 
$ sudo chmod a+r /usr/local/cuda/include/cudnn*.h /usr/local/cuda/lib64/libcudnn*
  1. Verified GPU is usable by jax:
>>> import jax
>>> jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0]).device()
StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

The following is the error I get when running a t5x pretraining script using train.py.

Traceback (most recent call last):
  File "/home/keremturgutlu/t5x/t5x/train.py", line 835, in <module>
    config_utils.run(main)
  File "/home/keremturgutlu/t5x/t5x/config_utils.py", line 214, in run
    gin_utils.run(main)
  File "/home/keremturgutlu/t5x/t5x/gin_utils.py", line 129, in run
    app.run(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/keremturgutlu/t5x/t5x/train.py", line 788, in main
    _main(argv)
  File "/home/keremturgutlu/t5x/t5x/train.py", line 830, in _main
    train_using_gin()
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/keremturgutlu/t5x/t5x/train.py", line 614, in train
    trainer.compile_train(dummy_batch)
  File "/home/keremturgutlu/t5x/t5x/trainer.py", line 538, in compile_train
    self._compiled_train_step = self._partitioner.compile(
  File "/home/keremturgutlu/t5x/t5x/partitioning.py", line 805, in compile
    return partitioned_fn.lower(*args).compile()
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/stages.py", line 600, in compile
    self._lowering.compile(**kw),
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 234; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 238; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 242; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 246; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 250; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 254; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 258; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 262; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 266; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 270; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 274; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 278; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 282; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 286; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 290; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 294; error   : Rounding modifier required for instruction 'cvt'
ptxas fatal   : Ptx assembly aborted due to errors

What jax/jaxlib version are you using?

jax==0.4.7 jaxlib==0.4.7+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:39:03) [GCC 11.3.0] on linux

NVIDIA GPU info

Sun May  7 01:53:18 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   34C    P0    82W / 400W |      0MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_May__3_19:15:13_PDT_2021
Cuda compilation tools, release 11.3, V11.3.109
Build cuda_11.3.r11.3/compiler.29920130_0
@cheshire
Copy link
Member

cheshire commented May 7, 2023

@KeremTurgutlu How do we know that it originates in the triton_autotuner?

@cheshire
Copy link
Member

cheshire commented May 7, 2023

@hawkinsp Is it hard to propagate C++ stack trace along with the error?

@mjsML mjsML added the NVIDIA GPU Issues specific to NVIDIA GPUs label May 7, 2023
@KeremTurgutlu
Copy link
Author

KeremTurgutlu commented May 7, 2023

@KeremTurgutlu How do we know that it originates in the triton_autotuner?

I am not 100% if that's the root cause but I should've probably pasted this as well:

[triton_autotuner.cc:271] failure: internal: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-gpu-3b3b9d27-29193-5fb10408aefa4, line 234; error : rounding modifier required for instruction 'cvt'

I was able to successfully run the code with from scratch nvidia driver, cuda (12.1), cudnn and jax installation

  1. Launched a A100 in Google Cloud with base ubuntu 18.04 image.

  2. Install latest nvidia driver with cuda 12.1.

  3. Install miniconda and create a conda env.

  4. Install jax and cudnn.


# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  1. Install t5x from source and install t5.
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB           On | 00000000:00:05.0 Off |                    0 |
| N/A   40C    P0               60W / 400W|  73901MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      9080      C   python3                                   73890MiB |
+---------------------------------------------------------------------------------------+

nvcc not installed.

@cheshire
Copy link
Member

cheshire commented May 8, 2023

Oh this is CUDA12, probably that would explain it --- we had some bugs filed on CUDA12 before.

@hawkinsp
Copy link
Member

hawkinsp commented May 8, 2023

I think this is actually a case of too old a CUDA installation, not the other way around. The image is named c0-deeplearning-common-cu113-v20230501-debian-10: note "cu113".

JAX is built for CUDA 11.8 (or CUDA 12), and if I recall correctly Ampere GPU support wasn't added until longer after 11.3.

Can you update to CUDA 11.8 or newer?

Note nvidia-smi reports the CUDA version of the driver, not the installed libraries. You need both to be sufficiently new.

@KeremTurgutlu
Copy link
Author

Sorry if it was not clear but what I wanted mention was issue was fixed when I installed cuda 12 from scratch instead using the Google Cloud image.

@KeremTurgutlu
Copy link
Author

KeremTurgutlu commented May 10, 2023

Recently got this error (might be related to 0.4.9 release looking into it):

(myenv) keremturgutlu@gpu:~$ python -c "import jax; print(jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0]).device())"
2023-05-10 07:12:19.701059: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2092, in ones
    return lax.full(shape, 1, _jnp_dtype(dtype))
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1190, in full
    return broadcast(fill_value, shape)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 756, in broadcast
    return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 784, in broadcast_in_dim
    return broadcast_in_dim_p.bind(
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Edit: Tried again by recreating a new instance, and I wasn't able to reproduce the error.

@euclaise
Copy link

euclaise commented May 18, 2023

Getting similar when trying to run a custom model on an A6000 with pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html over Runpod's PyTorch or Tensorflow images. Tried both cuda11 and cuda12, same issue.

https://pastebin.com/raw/MUsYZje8

Update: Seems to only happen with bfloat16. Works fine with float32

@hawkinsp
Copy link
Member

@euclaise Do you have another copy of ptxas installed? Is there one in your PATH? My strong suspicion is still "we are finding an ancient ptxas".

@euclaise
Copy link

euclaise commented May 24, 2023

@hawkinsp I don't, but it should be whatever is used here https://hub.docker.com/r/runpod/pytorch/

@euclaise
Copy link

euclaise commented May 29, 2023

@hawkinsp

ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:17:32_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0

After some testing, it appears to be caused by me accidentally mixing bflaot16 values with float32 ones. Seems a check for that is missing somewhere prior to assembly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

5 participants