-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot use GPU on Ubuntu 16.04, CUDA 11.0 #6046
Comments
Not sure if it's related but building from source fails with the following error:
|
Getting the same issue on a RTX 30 Series with
I was able to build from source without error, but get the same set of warnings and then errors when actually trying to do anything
Should add was also able to install and use TensorFlow without issue |
However, after upgrading to CUDA Toolkit |
@danieljtait Very interesting. Unfortunately I'm stuck with CUDA |
I had a same error when I built TensorFlow. Have you resolved this problem? Thank you! |
Unfortunately not with CUDA 11.0. If there is a fix, I would be interested, too. |
It should be caused by a non-empty braced-init-list. Using From d1f3e960bba01a3c50731e6165aabf3ab277f5cf Mon Sep 17 00:00:00 2001
From: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Date: Thu, 22 Jul 2021 14:17:37 -0400
Subject: [PATCH 1/2] use std::make_tuple
---
.../profiler/internal/gpu/cupti_tracer.cc | 56 +++++++++----------
.../stream_executor/gpu/asm_compiler.cc | 2 +-
2 files changed, 29 insertions(+), 29 deletions(-)
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
index df9c3b7efd7..93a04c0fd2b 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
+++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
@@ -159,70 +159,70 @@ DecodeDriverMemcpy(CUpti_CallbackId cbid, const void *params) {
switch (cbid) {
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2: {
const auto *p = reinterpret_cast<const cuMemcpyHtoD_v2_params *>(params);
- return {p->ByteCount, CuptiTracerEventType::MemcpyH2D, false};
+ return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyH2D, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2: {
const auto *p =
reinterpret_cast<const cuMemcpyHtoDAsync_v2_params *>(params);
- return {p->ByteCount, CuptiTracerEventType::MemcpyH2D, true};
+ return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyH2D, true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2: {
const auto *p = reinterpret_cast<const cuMemcpyDtoH_v2_params *>(params);
- return {p->ByteCount, CuptiTracerEventType::MemcpyD2H, false};
+ return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyD2H, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2: {
const auto *p =
reinterpret_cast<const cuMemcpyDtoHAsync_v2_params *>(params);
- return {p->ByteCount, CuptiTracerEventType::MemcpyD2H, true};
+ return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyD2H, true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2: {
const auto *p = reinterpret_cast<const cuMemcpyDtoD_v2_params *>(params);
- return {p->ByteCount, CuptiTracerEventType::MemcpyD2D, false};
+ return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyD2D, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2: {
const auto *p =
reinterpret_cast<const cuMemcpyDtoDAsync_v2_params *>(params);
- return {p->ByteCount, CuptiTracerEventType::MemcpyD2D, true};
+ return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyD2D, true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: {
const auto *p = reinterpret_cast<const cuMemcpy_params *>(params);
- return {p->ByteCount, CuptiTracerEventType::MemcpyOther, false};
+ return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyOther, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: {
const auto *p = reinterpret_cast<const cuMemcpyAsync_params *>(params);
- return {p->ByteCount, CuptiTracerEventType::MemcpyOther, true};
+ return std::make_tuple(p->ByteCount, CuptiTracerEventType::MemcpyOther, true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpy2D_v2: {
const auto *p = reinterpret_cast<const cuMemcpy2D_v2_params *>(params);
- return {Bytes2D(p->pCopy), MemcpyKind(p->pCopy), false};
+ return std::make_tuple(Bytes2D(p->pCopy), MemcpyKind(p->pCopy), false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpy2DAsync_v2: {
const auto *p =
reinterpret_cast<const cuMemcpy2DAsync_v2_params *>(params);
- return {Bytes2D(p->pCopy), MemcpyKind(p->pCopy), true};
+ return std::make_tuple(Bytes2D(p->pCopy), MemcpyKind(p->pCopy), true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpy3D_v2: {
const auto *p = reinterpret_cast<const cuMemcpy3D_v2_params *>(params);
- return {Bytes3D(p->pCopy), MemcpyKind(p->pCopy), true};
+ return std::make_tuple(Bytes3D(p->pCopy), MemcpyKind(p->pCopy), true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpy3DAsync_v2: {
const auto *p =
reinterpret_cast<const cuMemcpy3DAsync_v2_params *>(params);
- return {Bytes3D(p->pCopy), MemcpyKind(p->pCopy), true};
+ return std::make_tuple(Bytes3D(p->pCopy), MemcpyKind(p->pCopy), true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyPeer: {
const auto *p2p_params =
reinterpret_cast<const cuMemcpyPeer_params *>(params);
- return {p2p_params->ByteCount, CuptiTracerEventType::MemcpyP2P, false};
+ return std::make_tuple(p2p_params->ByteCount, CuptiTracerEventType::MemcpyP2P, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemcpyPeerAsync: {
const auto *p2p_params =
reinterpret_cast<const cuMemcpyPeerAsync_params *>(params); - return {p2p_params->ByteCount, CuptiTracerEventType::MemcpyP2P, true};
+ return std::make_tuple(p2p_params->ByteCount, CuptiTracerEventType::MemcpyP2P, true);
}
default: {
LOG(ERROR) << "Unsupported memcpy activity observed: " << cbid;
- return {0, CuptiTracerEventType::Unsupported, false};
+ return std::make_tuple(0, CuptiTracerEventType::Unsupported, false);
}
}
}
@@ -232,58 +232,58 @@ DecodeDriverMemset(CUpti_CallbackId cbid, const void *params) {
switch (cbid) {
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD8_v2: {
const auto *p = reinterpret_cast<const cuMemsetD8_v2_params *>(params);
- return {p->N, CuptiTracerEventType::Memset, false};
+ return std::make_tuple(p->N, CuptiTracerEventType::Memset, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD16_v2: {
const auto *p = reinterpret_cast<const cuMemsetD16_v2_params *>(params);
- return {p->N, CuptiTracerEventType::Memset, false};
+ return std::make_tuple(p->N, CuptiTracerEventType::Memset, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD32_v2: {
const auto *p = reinterpret_cast<const cuMemsetD32_v2_params *>(params);
- return {p->N, CuptiTracerEventType::Memset, false};
+ return std::make_tuple(p->N, CuptiTracerEventType::Memset, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D8_v2: { const auto *p = reinterpret_cast<const cuMemsetD2D8_v2_params *>(params);
- return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, false};
+ return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D16_v2: {
const auto *p = reinterpret_cast<const cuMemsetD2D16_v2_params *>(params);
- return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, false};
+ return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D32_v2: {
const auto *p = reinterpret_cast<const cuMemsetD2D32_v2_params *>(params);
- return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, false};
+ return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, false);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD8Async: {
const auto *p = reinterpret_cast<const cuMemsetD8Async_params *>(params);
- return {p->N, CuptiTracerEventType::Memset, true};
+ return std::make_tuple(p->N, CuptiTracerEventType::Memset, true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD16Async: {
const auto *p = reinterpret_cast<const cuMemsetD16Async_params *>(params);
- return {p->N, CuptiTracerEventType::Memset, true};
+ return std::make_tuple(p->N, CuptiTracerEventType::Memset, true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD32Async: {
const auto *p = reinterpret_cast<const cuMemsetD32Async_params *>(params);
- return {p->N, CuptiTracerEventType::Memset, true};
+ return std::make_tuple(p->N, CuptiTracerEventType::Memset, true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D8Async: {
const auto *p =
reinterpret_cast<const cuMemsetD2D8Async_params *>(params);
- return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, true};
+ return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, true);
}
case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D16Async: {
const auto *p =
reinterpret_cast<const cuMemsetD2D16Async_params *>(params);
- return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, true};
+ return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, true);
} case CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D32Async: {
const auto *p =
reinterpret_cast<const cuMemsetD2D32Async_params *>(params);
- return {p->dstPitch * p->Height, CuptiTracerEventType::Memset, true};
+ return std::make_tuple(p->dstPitch * p->Height, CuptiTracerEventType::Memset, true);
}
default: {
LOG(ERROR) << "Unsupported memset activity observed: " << cbid;
- return {0, CuptiTracerEventType::Unsupported, false};
+ return std::make_tuple(0, CuptiTracerEventType::Unsupported, false);
}
}
}
diff --git a/tensorflow/stream_executor/gpu/asm_compiler.cc b/tensorflow/stream_executor/gpu/asm_compiler.cc
index 0ade37f4c32..66ca927dd8f 100644
--- a/tensorflow/stream_executor/gpu/asm_compiler.cc
+++ b/tensorflow/stream_executor/gpu/asm_compiler.cc
@@ -187,7 +187,7 @@ static void LogPtxasTooOld(const std::string& ptxas_path, int cc_major,
absl::MutexLock lock(mutex);
- if (already_logged->insert({ptxas_path, cc_major, cc_minor}).second) {
+ if (already_logged->insert(std::make_tuple(ptxas_path, cc_major, cc_minor)).second) {
LOG(WARNING) << "Falling back to the CUDA driver for PTX compilation; "
"ptxas does not support CC "
<< cc_major << "." << cc_minor;
--
2.31.1
|
Non-empty braced-init-lists cause some errors when I build from source. I don't know the exact reason. It may be related to my build tools, but the safest way is not to use non-empty braced-init-lists. The error is: ``` external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc: In function ‘void stream_executor::LogPtxasTooOld(const string&, int, int)’: external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:190:62: error: converting to ‘absl::lts_2020_02_25::container_internal::raw_hash_set<absl::lts_2020_02_25::container_internal::FlatHashSetPolicy<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, absl::lts_2020_02_25::hash_internal::Hash<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::equal_to<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::allocator<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> > >::init_type {aka std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int>}’ from initializer list would use explicit constructor ‘constexpr std::tuple< <template-parameter-1-1> >::tuple(_UElements&& ...) [with _UElements = {const std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&, int&, int&}; <template-parameter-2-2> = void; _Elements = {std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int}]’ if (already_logged->insert({ptxas_path, cc_major, cc_minor}).second) { ``` google/jax#6046 (comment) reported a same error. Using `std::make_tuple` instead makes the error disappear.
I encounter the same issue. I am using CUDA 11.1 with Jax 0.2.18.
|
@peizhaoli05 The |
Hi @hawkinsp |
I'm using CUDA 11.5. I installed jaxlib and jax from source (1.75.0 and 0.2.28). I am getting CUDA_ERROR_MISALIGNED_ADDRESS randomly while optimizing a neural net. I tried reducing the size of the NN and doing simple SGD instead of using optax, etc. but the problem still exists. I also tried using older tensorflow but it doesn't help. I even got this error when I was using prebuild libraries from recommended online sources on jax tutorial but I was getting this same error even back then (which motivated me to install jaxlib from source). Any suggestion would be much appreciated. Following is the error I'm getting in the command line
The python error manager is showing the following message,
|
In my case the problem was caused by my PATH. CUDA did not include itself into it automatically (although nvidia-smi worked), so I had to add this to my
|
Me too. |
@holl- was this resolved? |
There are a bunch of unrelated issues in this bug, but I think I'm comfortable calling the original issue moot because we don't support CUDA 11.0 any more. If any of the folks who posted in this bug are still experiencing problems, please open a new issue? Thanks! |
I have a GeForce RTX 3090 with CUDA 11.0 installed on Ubuntu 16.04 and the installation works fine with TensorFlow.
The path
/usr/local/cuda
points to that installation.I installed Jax into my Python 3.8.6 conda environment by running
I can import Jax from Python but the first operation throws an error.
Running
nvcc --version
printsIs this a bug or am I doing something wrong?
The text was updated successfully, but these errors were encountered: