Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use GPU on Ubuntu 16.04, CUDA 11.0 #6046

Closed
holl- opened this issue Mar 12, 2021 · 15 comments
Closed

Cannot use GPU on Ubuntu 16.04, CUDA 11.0 #6046

holl- opened this issue Mar 12, 2021 · 15 comments
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@holl-
Copy link

holl- commented Mar 12, 2021

I have a GeForce RTX 3090 with CUDA 11.0 installed on Ubuntu 16.04 and the installation works fine with TensorFlow.
The path /usr/local/cuda points to that installation.

I installed Jax into my Python 3.8.6 conda environment by running

pip3 install --upgrade jax jaxlib==0.1.62+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I can import Jax from Python but the first operation throws an error.

from jax import numpy
numpy.zeros(4)
2021-03-12 21:26:30.353284: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:191] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2021-03-12 21:26:30.353307: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:194] Used ptxas at /usr/local/cuda-11.0/bin/ptxas
2021-03-12 21:26:30.353808: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:682] failed to get PTX kernel "broadcast_2" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2021-03-12 21:26:30.353849: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1881] Execution of replica 0 failed: Internal: Could not find the corresponding function
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1181, in __repr__
    s = np.array2string(self._value, prefix=prefix, suffix=',',
  File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1122, in _value
    self._npy_value = _force(self).device_buffer.to_py()
  File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1333, in _force
    result = force_fun(x)
  File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1357, in force_fun
    return compiled.execute([x.device_buffer])[0]
RuntimeError: Internal: Could not find the corresponding function

Running nvcc --version prints

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Thu_Jun_11_22:26:38_PDT_2020
Cuda compilation tools, release 11.0, V11.0.194
Build cuda_11.0_bu.TC445_37.28540450_0

Is this a bug or am I doing something wrong?

@holl- holl- added the bug Something isn't working label Mar 12, 2021
@holl-
Copy link
Author

holl- commented Mar 16, 2021

Not sure if it's related but building from source fails with the following error:

ERROR: /home/holl/.cache/bazel/_bazel_holl/0b9919be22653e3173d0276d146dc0c1/external/org_tensorflow/tensorflow/stream_executor/gpu/BUILD:226:11: C++ compilation of rule '@org_tensorflow//tensorflow/stream_executor/gpu:asm_compiler' failed (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command
  (cd /home/holl/.cache/bazel/_bazel_holl/0b9919be22653e3173d0276d146dc0c1/execroot/__main__ && \
  exec env - \
    PATH=/home/holl/bin:/home/holl/.local/bin:/home/holl/miniconda3/envs/jax/bin:/home/holl/miniconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/usr/local/cuda/bin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 \
    TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.o' -DTF_USE_SNAPPY -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -DHAVE_SYS_UIO_H -D__CLANG_SUPPORT_DYN_ANNOTATION__ -iquote external/org_tensorflow -iquote bazel-out/k8-opt/bin/external/org_tensorflow -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/gif -iquote bazel-out/k8-opt/bin/external/gif -iquote external/libjpeg_turbo -iquote bazel-out/k8-opt/bin/external/libjpeg_turbo -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/com_googlesource_code_re2 -iquote bazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquote external/farmhash_archive -iquote bazel-out/k8-opt/bin/external/farmhash_archive -iquote external/fft2d -iquote bazel-out/k8-opt/bin/external/fft2d -iquote external/highwayhash -iquote bazel-out/k8-opt/bin/external/highwayhash -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/snappy -iquote bazel-out/k8-opt/bin/external/snappy -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -iquote external/local_config_tensorrt -iquote bazel-out/k8-opt/bin/external/local_config_tensorrt -Ibazel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -Ibazel-out/k8-opt/bin/external/local_config_tensorrt/_virtual_includes/tensorrt_headers -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/double_conversion -isystem bazel-out/k8-opt/bin/external/double_conversion -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda/cuda/include -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canonical-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections -Wno-sign-compare -Wno-stringop-truncation -mavx '-std=c++14' -DEIGEN_AVOID_STL_ARRAY -Iexternal/gemmlowp -Wno-sign-compare '-ftemplate-depth=900' -fno-exceptions '-DGOOGLE_CUDA=1' '-DTENSORFLOW_USE_NVCC=1' -msse3 -DTENSORFLOW_MONOLITHIC_BUILD -pthread -c external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc -o bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.o)
Execution platform: @local_execution_config_platform//:platform
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc: In function ‘void stream_executor::LogPtxasTooOld(const string&, int, int)’:
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:190:62: error: converting to ‘absl::lts_2020_02_25::container_internal::raw_hash_set<absl::lts_2020_02_25::container_internal::FlatHashSetPolicy<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, absl::lts_2020_02_25::hash_internal::Hash<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::equal_to<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::allocator<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> > >::init_type {aka std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int>}’ from initializer list would use explicit constructor ‘constexpr std::tuple< <template-parameter-1-1> >::tuple(_UElements&& ...) [with _UElements = {const std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&, int&, int&}; <template-parameter-2-2> = void; _Elements = {std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int}]’
   if (already_logged->insert({ptxas_path, cc_major, cc_minor}).second) {
                                                              ^
At global scope:
cc1plus: warning: unrecognized command line option ‘-Wno-stringop-truncation’
Target //build:build_wheel failed to build
INFO: Elapsed time: 1298.922s, Critical Path: 60.78s
INFO: 2938 processes: 806 internal, 2132 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
Traceback (most recent call last):
  File "build/build.py", line 521, in <module>
    main()
  File "build/build.py", line 516, in main
    shell(command)
  File "build/build.py", line 51, in shell
    output = subprocess.check_output(cmd)
  File "/home/holl/miniconda3/envs/jax/lib/python3.8/subprocess.py", line 415, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/home/holl/miniconda3/envs/jax/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-3.7.2-linux-x86_64', 'run', '--verbose_failures=true', '--config=short_logs', '--config=avx_posix', '--config=mkl_open_source_only', '--config=cuda', '--define=xla_python_enable_gpu=true', ':build_wheel', '--', '--output_path=/home/holl/jax/jax/dist']' returned non-zero exit status 1.

@danieljtait
Copy link

danieljtait commented Mar 29, 2021

Getting the same issue on a RTX 30 Series with jaxlib==0.1.64+cuda110 and nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Wed_Jul_22_19:09:09_PDT_2020
Cuda compilation tools, release 11.0, V11.0.221
Build cuda_11.0_bu.TC445_37.28845127_0

I was able to build from source without error, but get the same set of warnings and then errors when actually trying to do anything

2021-03-29 11:23:18.595649: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:191] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2021-03-29 11:23:18.595690: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:194] Used ptxas at /usr/local/cuda/bin/ptxas
2021-03-29 11:23:18.595754: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:297] Couldn't read CUDA driver version.
2021-03-29 11:23:18.711470: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:682] failed to get PTX kernel "broadcast_2" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2021-03-29 11:23:18.711527: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1960] Execution of replica 0 failed: Internal: Could not find the corresponding function

Should add was also able to install and use TensorFlow without issue

@danieljtait
Copy link

However, after upgrading to CUDA Toolkit 11.2 and using jaxlib==0.1.64+cuda112 everything seems to be working 🤷‍♂️

@holl-
Copy link
Author

holl- commented Mar 29, 2021

@danieljtait Very interesting. Unfortunately I'm stuck with CUDA 11.0 for the moment because I use TensorFlow as well. But good to know there is a potential workaround.

@njzjz
Copy link

njzjz commented Jul 22, 2021

Not sure if it's related but building from source fails with the following error:

ERROR: /home/holl/.cache/bazel/_bazel_holl/0b9919be22653e3173d0276d146dc0c1/external/org_tensorflow/tensorflow/stream_executor/gpu/BUILD:226:11: C++ compilation of rule '@org_tensorflow//tensorflow/stream_executor/gpu:asm_compiler' failed (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command
  (cd /home/holl/.cache/bazel/_bazel_holl/0b9919be22653e3173d0276d146dc0c1/execroot/__main__ && \
  exec env - \
    PATH=/home/holl/bin:/home/holl/.local/bin:/home/holl/miniconda3/envs/jax/bin:/home/holl/miniconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/usr/local/cuda/bin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 \
    TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.o' -DTF_USE_SNAPPY -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -DHAVE_SYS_UIO_H -D__CLANG_SUPPORT_DYN_ANNOTATION__ -iquote external/org_tensorflow -iquote bazel-out/k8-opt/bin/external/org_tensorflow -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/gif -iquote bazel-out/k8-opt/bin/external/gif -iquote external/libjpeg_turbo -iquote bazel-out/k8-opt/bin/external/libjpeg_turbo -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/com_googlesource_code_re2 -iquote bazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquote external/farmhash_archive -iquote bazel-out/k8-opt/bin/external/farmhash_archive -iquote external/fft2d -iquote bazel-out/k8-opt/bin/external/fft2d -iquote external/highwayhash -iquote bazel-out/k8-opt/bin/external/highwayhash -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/snappy -iquote bazel-out/k8-opt/bin/external/snappy -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -iquote external/local_config_tensorrt -iquote bazel-out/k8-opt/bin/external/local_config_tensorrt -Ibazel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -Ibazel-out/k8-opt/bin/external/local_config_tensorrt/_virtual_includes/tensorrt_headers -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/double_conversion -isystem bazel-out/k8-opt/bin/external/double_conversion -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda/cuda/include -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canonical-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections -Wno-sign-compare -Wno-stringop-truncation -mavx '-std=c++14' -DEIGEN_AVOID_STL_ARRAY -Iexternal/gemmlowp -Wno-sign-compare '-ftemplate-depth=900' -fno-exceptions '-DGOOGLE_CUDA=1' '-DTENSORFLOW_USE_NVCC=1' -msse3 -DTENSORFLOW_MONOLITHIC_BUILD -pthread -c external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc -o bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.o)
Execution platform: @local_execution_config_platform//:platform
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc: In function ‘void stream_executor::LogPtxasTooOld(const string&, int, int)’:
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:190:62: error: converting to ‘absl::lts_2020_02_25::container_internal::raw_hash_set<absl::lts_2020_02_25::container_internal::FlatHashSetPolicy<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, absl::lts_2020_02_25::hash_internal::Hash<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::equal_to<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::allocator<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> > >::init_type {aka std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int>}’ from initializer list would use explicit constructor ‘constexpr std::tuple< <template-parameter-1-1> >::tuple(_UElements&& ...) [with _UElements = {const std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&, int&, int&}; <template-parameter-2-2> = void; _Elements = {std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int}]’
   if (already_logged->insert({ptxas_path, cc_major, cc_minor}).second) {
                                                              ^
At global scope:
cc1plus: warning: unrecognized command line option ‘-Wno-stringop-truncation’
Target //build:build_wheel failed to build
INFO: Elapsed time: 1298.922s, Critical Path: 60.78s
INFO: 2938 processes: 806 internal, 2132 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
Traceback (most recent call last):
  File "build/build.py", line 521, in <module>
    main()
  File "build/build.py", line 516, in main
    shell(command)
  File "build/build.py", line 51, in shell
    output = subprocess.check_output(cmd)
  File "/home/holl/miniconda3/envs/jax/lib/python3.8/subprocess.py", line 415, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/home/holl/miniconda3/envs/jax/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-3.7.2-linux-x86_64', 'run', '--verbose_failures=true', '--config=short_logs', '--config=avx_posix', '--config=mkl_open_source_only', '--config=cuda', '--define=xla_python_enable_gpu=true', ':build_wheel', '--', '--output_path=/home/holl/jax/jax/dist']' returned non-zero exit status 1.

I had a same error when I built TensorFlow. Have you resolved this problem? Thank you!

@holl-
Copy link
Author

holl- commented Jul 22, 2021

Unfortunately not with CUDA 11.0. If there is a fix, I would be interested, too.

@njzjz
Copy link

njzjz commented Jul 23, 2021

It should be caused by a non-empty braced-init-list. Using std::make_tuple instead works to me. Here's the patch:

From d1f3e960bba01a3c50731e6165aabf3ab277f5cf Mon Sep 17 00:00:00 2001
From: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Date: Thu, 22 Jul 2021 14:17:37 -0400
Subject: [PATCH 1/2] use std::make_tuple

---
 .../profiler/internal/gpu/cupti_tracer.cc     | 56 +++++++++----------
 .../stream_executor/gpu/asm_compiler.cc       |  2 +-
 2 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
index df9c3b7efd7..93a04c0fd2b 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
+++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
@@ -159,70 +159,70 @@ DecodeDriverMemcpy(CUpti_CallbackId cbid, const void *params) {
   switch (cbid) {
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2: {
       const auto *p = reinterpret_cast<const cuMemcpyHtoD_v2_params *>(params);
-      return {p->ByteCount, CuptiTracerEventType::MemcpyH2D, false};
+      return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyH2D, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2: {
       const auto *p =
           reinterpret_cast<const cuMemcpyHtoDAsync_v2_params *>(params);
-      return {p->ByteCount, CuptiTracerEventType::MemcpyH2D, true};
+      return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyH2D, true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2: {
       const auto *p = reinterpret_cast<const cuMemcpyDtoH_v2_params *>(params);
-      return {p->ByteCount, CuptiTracerEventType::MemcpyD2H, false};
+      return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyD2H, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2: {
       const auto *p =
           reinterpret_cast<const cuMemcpyDtoHAsync_v2_params *>(params);
-      return {p->ByteCount, CuptiTracerEventType::MemcpyD2H, true};
+      return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyD2H, true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2: {
       const auto *p = reinterpret_cast<const cuMemcpyDtoD_v2_params *>(params);
-      return {p->ByteCount, CuptiTracerEventType::MemcpyD2D, false};
+      return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyD2D, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2: {
       const auto *p =
           reinterpret_cast<const cuMemcpyDtoDAsync_v2_params *>(params);
-      return {p->ByteCount, CuptiTracerEventType::MemcpyD2D, true};
+      return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyD2D, true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: {
       const auto *p = reinterpret_cast<const cuMemcpy_params *>(params);
-      return {p->ByteCount, CuptiTracerEventType::MemcpyOther, false};
+      return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyOther, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: {
       const auto *p = reinterpret_cast<const cuMemcpyAsync_params *>(params);
-      return {p->ByteCount, CuptiTracerEventType::MemcpyOther, true};
+      return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyOther, true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpy2D_v2: {
       const auto *p = reinterpret_cast<const cuMemcpy2D_v2_params *>(params);
-      return {Bytes2D(p->pCopy), MemcpyKind(p->pCopy), false};
+      return std::make_tuple(Bytes2D(p->pCopy), MemcpyKind(p->pCopy), false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpy2DAsync_v2: {
       const auto *p =
           reinterpret_cast<const cuMemcpy2DAsync_v2_params *>(params);
-      return {Bytes2D(p->pCopy), MemcpyKind(p->pCopy), true};
+      return std::make_tuple(Bytes2D(p->pCopy), MemcpyKind(p->pCopy), true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpy3D_v2: {
       const auto *p = reinterpret_cast<const cuMemcpy3D_v2_params *>(params);
-      return {Bytes3D(p->pCopy), MemcpyKind(p->pCopy), true};
+      return std::make_tuple(Bytes3D(p->pCopy), MemcpyKind(p->pCopy), true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpy3DAsync_v2: {
       const auto *p =
           reinterpret_cast<const cuMemcpy3DAsync_v2_params *>(params);
-      return {Bytes3D(p->pCopy), MemcpyKind(p->pCopy), true};
+      return std::make_tuple(Bytes3D(p->pCopy), MemcpyKind(p->pCopy), true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyPeer: {
       const auto *p2p_params =
           reinterpret_cast<const cuMemcpyPeer_params *>(params);
-      return {p2p_params->ByteCount, CuptiTracerEventType::MemcpyP2P, false};
+      return std::make_tuple(p2p_params->ByteCount, CuptiTracerEventType::MemcpyP2P, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemcpyPeerAsync: {
       const auto *p2p_params =
           reinterpret_cast<const cuMemcpyPeerAsync_params *>(params);                                                                                                   -      return {p2p_params->ByteCount, CuptiTracerEventType::MemcpyP2P, true};
+      return std::make_tuple(p2p_params->ByteCount, CuptiTracerEventType::MemcpyP2P, true);
     }
     default: {
       LOG(ERROR) << "Unsupported memcpy activity observed: " << cbid;
-      return {0, CuptiTracerEventType::Unsupported, false};
+      return std::make_tuple(0, CuptiTracerEventType::Unsupported, false);
     }
   }
 }
@@ -232,58 +232,58 @@ DecodeDriverMemset(CUpti_CallbackId cbid, const void *params) {
   switch (cbid) {
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD8_v2: {
       const auto *p = reinterpret_cast<const cuMemsetD8_v2_params *>(params);
-      return {p->N, CuptiTracerEventType::Memset, false};
+      return std::make_tuple(p->N, CuptiTracerEventType::Memset, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD16_v2: {
       const auto *p = reinterpret_cast<const cuMemsetD16_v2_params *>(params);
-      return {p->N, CuptiTracerEventType::Memset, false};
+      return std::make_tuple(p->N, CuptiTracerEventType::Memset, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD32_v2: {
       const auto *p = reinterpret_cast<const cuMemsetD32_v2_params *>(params);
-      return {p->N, CuptiTracerEventType::Memset, false};
+      return std::make_tuple(p->N, CuptiTracerEventType::Memset, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D8_v2: {                                                                                                                            const auto *p = reinterpret_cast<const cuMemsetD2D8_v2_params *>(params);
-      return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, false};
+      return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D16_v2: {
       const auto *p = reinterpret_cast<const cuMemsetD2D16_v2_params *>(params);
-      return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, false};
+      return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D32_v2: {
       const auto *p = reinterpret_cast<const cuMemsetD2D32_v2_params *>(params);
-      return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, false};
+      return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, false);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD8Async: {
       const auto *p = reinterpret_cast<const cuMemsetD8Async_params *>(params);
-      return {p->N, CuptiTracerEventType::Memset, true};
+      return std::make_tuple(p->N, CuptiTracerEventType::Memset, true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD16Async: {
       const auto *p = reinterpret_cast<const cuMemsetD16Async_params *>(params);
-      return {p->N, CuptiTracerEventType::Memset, true};
+      return std::make_tuple(p->N, CuptiTracerEventType::Memset, true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD32Async: {
       const auto *p = reinterpret_cast<const cuMemsetD32Async_params *>(params);
-      return {p->N, CuptiTracerEventType::Memset, true};
+      return std::make_tuple(p->N, CuptiTracerEventType::Memset, true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D8Async: {
       const auto *p =
           reinterpret_cast<const cuMemsetD2D8Async_params *>(params);
-      return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, true};
+      return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, true);
     }
     case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D16Async: {
       const auto *p =
           reinterpret_cast<const cuMemsetD2D16Async_params *>(params);
-      return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, true};
+      return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, true);
     }                                                                                                                                                                        case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D32Async: {
       const auto *p =
           reinterpret_cast<const cuMemsetD2D32Async_params *>(params);
-      return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, true};
+      return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, true);
     }
     default: {
       LOG(ERROR) << "Unsupported memset activity observed: " << cbid;
-      return {0, CuptiTracerEventType::Unsupported, false};
+      return std::make_tuple(0, CuptiTracerEventType::Unsupported, false);
     }
   }
 }
diff --git a/tensorflow/stream_executor/gpu/asm_compiler.cc b/tensorflow/stream_executor/gpu/asm_compiler.cc
index 0ade37f4c32..66ca927dd8f 100644
--- a/tensorflow/stream_executor/gpu/asm_compiler.cc
+++ b/tensorflow/stream_executor/gpu/asm_compiler.cc
@@ -187,7 +187,7 @@ static void LogPtxasTooOld(const std::string& ptxas_path, int cc_major,

   absl::MutexLock lock(mutex);

-  if (already_logged->insert({ptxas_path, cc_major, cc_minor}).second) {
+  if (already_logged->insert(std::make_tuple(ptxas_path, cc_major, cc_minor)).second) {
     LOG(WARNING) << "Falling back to the CUDA driver for PTX compilation; "
                     "ptxas does not support CC "
                  << cc_major << "." << cc_minor;
--
2.31.1

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 27, 2021
Non-empty braced-init-lists cause some errors when I build from source.
I don't know the exact reason. It may be related to my build tools, but
the safest way is not to use non-empty braced-init-lists.
The error is:
```
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc: In function ‘void stream_executor::LogPtxasTooOld(const string&, int, int)’:
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:190:62: error: converting to ‘absl::lts_2020_02_25::container_internal::raw_hash_set<absl::lts_2020_02_25::container_internal::FlatHashSetPolicy<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, absl::lts_2020_02_25::hash_internal::Hash<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::equal_to<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::allocator<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> > >::init_type {aka std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int>}’ from initializer list would use explicit constructor ‘constexpr std::tuple< <template-parameter-1-1> >::tuple(_UElements&& ...) [with _UElements = {const std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&, int&, int&}; <template-parameter-2-2> = void; _Elements = {std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int}]’
   if (already_logged->insert({ptxas_path, cc_major, cc_minor}).second) {
```
google/jax#6046 (comment)
reported a same error.
Using `std::make_tuple` instead makes the error disappear.
@peizhaoli05
Copy link

I encounter the same issue. I am using CUDA 11.1 with Jax 0.2.18.

2021-08-11 22:19:59.368992: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:235] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2021-08-11 22:19:59.369003: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:238] Used ptxas at ptxas
2021-08-11 22:19:59.369425: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:625] failed to get PTX kernel "shift_right_logical_3" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2021-08-11 22:19:59.369444: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2040] Execution of replica 0 failed: Internal: Could not find the corresponding function
Traceback (most recent call last):
  File "/home/lallazhao/alphafold/debug.py", line 8, in <module>
    key = random.PRNGKey(0)
  File "/home/lallazhao/Downloads/enter/envs/bio/lib/python3.8/site-packages/jax/_src/random.py", line 75, in PRNGKey
    k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
  File "/home/lallazhao/Downloads/enter/envs/bio/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 382, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/home/lallazhao/Downloads/enter/envs/bio/lib/python3.8/site-packages/jax/core.py", line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/lallazhao/Downloads/enter/envs/bio/lib/python3.8/site-packages/jax/core.py", line 604, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/lallazhao/Downloads/enter/envs/bio/lib/python3.8/site-packages/jax/interpreters/xla.py", line 262, in apply_primitive
    return compiled_fun(*args)
  File "/home/lallazhao/Downloads/enter/envs/bio/lib/python3.8/site-packages/jax/interpreters/xla.py", line 378, in _execute_compiled_primitive
    out_bufs = compiled.execute(input_bufs)
RuntimeError: Internal: Could not find the corresponding function

Process finished with exit code 1

@hawkinsp
Copy link
Member

hawkinsp commented Aug 12, 2021

@peizhaoli05 The ptxas found by JAX on your system (typically in your PATH or /usr/local/cuda) is too old for your GPU. Make sure an up-to-date ptxas appears in your path. If that doesn't solve the problem, you'll probably need to use a newer version of CUDA.

@peizhaoli05
Copy link

Hi @hawkinsp
I think I'm using a Cuda version that is supported by JAX. Do you know how to find the ptxas path found by JAX? Thanks.

@mshafiei
Copy link

mshafiei commented Feb 3, 2022

I'm using CUDA 11.5. I installed jaxlib and jax from source (1.75.0 and 0.2.28). I am getting CUDA_ERROR_MISALIGNED_ADDRESS randomly while optimizing a neural net. I tried reducing the size of the NN and doing simple SGD instead of using optax, etc. but the problem still exists. I also tried using older tensorflow but it doesn't help. I even got this error when I was using prebuild libraries from recommended online sources on jax tutorial but I was getting this same error even back then (which motivated me to install jaxlib from source). Any suggestion would be much appreciated. Following is the error I'm getting in the command line

2022-02-03 10:19:47.906580: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:1163] failed to enqueue async memcpy from device to host: CUDA_ERROR_MISALIGNED_ADDRESS: misaligned address; host dst: 0x55f8cf8a3600; GPU src: 0x7f02c8000000; size: 48=0x30
2022-02-03 10:19:47.906611: E external/org_tensorflow/tensorflow/stream_executor/stream.cc:334] Error recording event in stream: Error recording CUDA event: CUDA_ERROR_MISALIGNED_ADDRESS: misaligned address; not marking stream as bad, as the Event object may be at fault. Monitor for further errors.
2022-02-03 10:19:47.906626: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:618] unable to add host callback: CUDA_ERROR_MISALIGNED_ADDRESS: misaligned address

The python error manager is showing the following message,


Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: <module>)
INTERNAL: stream did not block host until done; was already in an error state
  File "/home/mohammad/Projects/optimizer/jax/jax/_src/dispatch.py", line 432, in _check_special
    if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
  File "/home/mohammad/Projects/optimizer/jax/jax/_src/dispatch.py", line 427, in check_special
    _check_special(name, buf.xla_shape(), buf)
  File "/home/mohammad/Projects/optimizer/jax/jax/_src/api.py", line 141, in _nan_check_posthook
    dispatch.check_special(xla.xla_call_p, buffers)
  File "/home/mohammad/Projects/optimizer/DifferentiableSolver/Flash_No_Flash/train.py", line 220, in <module> (Current frame)
    params = update2(params,batch)
  File "/home/mohammad/bin/anaconda3/envs/autoint4/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/mohammad/bin/anaconda3/envs/autoint4/lib/python3.9/runpy.py", line 97, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/mohammad/bin/anaconda3/envs/autoint4/lib/python3.9/runpy.py", line 268, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/mohammad/bin/anaconda3/envs/autoint4/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/mohammad/bin/anaconda3/envs/autoint4/lib/python3.9/runpy.py", line 197, in _run_module_as_main

@Guitaricet
Copy link

In my case the problem was caused by my PATH. CUDA did not include itself into it automatically (although nvidia-smi worked), so I had to add this to my .bashrc.

export PATH=/usr/local/cuda-11/bin:${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

@Waterkin
Copy link

Waterkin commented Apr 29, 2022

Me too.

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P0 (urgent) An issue of the highest priority. We are addressing this urgently. (Assignee required) P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) needs info More information is required to diagnose & prioritize the issue. and removed P0 (urgent) An issue of the highest priority. We are addressing this urgently. (Assignee required) labels Aug 10, 2022
@sudhakarsingh27
Copy link
Collaborator

@holl- was this resolved?

@hawkinsp
Copy link
Member

There are a bunch of unrelated issues in this bug, but I think I'm comfortable calling the original issue moot because we don't support CUDA 11.0 any more. If any of the folks who posted in this bug are still experiencing problems, please open a new issue? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests

9 participants