Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found") when loading cast_f32_bf16 #2041

Closed
evgenyigumnov opened this issue Apr 11, 2024 · 3 comments

Comments

@evgenyigumnov
Copy link
Contributor

root@C.10515727:~/ai-server$ cargo run
    Finished dev [unoptimized + debuginfo] target(s) in 0.17s
     Running `target/debug/ai-server`
retrieved the files in 16.361172ms
Error: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found") when loading cast_f32_bf16
root@C.10515727:~/ai-server$ nvidia-smi --query-gpu=name,compute_cap,driver_version --format=csv
name, compute_cap, driver_version
NVIDIA GeForce RTX 2080 Ti, 7.5, 535.161.07
NVIDIA GeForce RTX 2080 Ti, 7.5, 535.161.07
NVIDIA GeForce RTX 2080 Ti, 7.5, 535.161.07
NVIDIA GeForce RTX 2080 Ti, 7.5, 535.161.07                                                                                                                                                                                                                                   NVIDIA GeForce RTX 2080 Ti, 7.5, 535.161.07
NVIDIA GeForce RTX 2080 Ti, 7.5, 535.161.07
NVIDIA GeForce RTX 2080 Ti, 7.5, 535.161.07
NVIDIA GeForce RTX 2080 Ti, 7.5, 535.161.07
root@C.10515727:~/ai-server$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Fri_Jan__6_16:45:21_PST_2023
Cuda compilation tools, release 12.0, V12.0.140
Build cuda_12.0.r12.0/compiler.32267302_0
root@C.10515727:~/ai-server$ uname -a
Linux fb0f7633e4cb 5.4.0-172-generic #190-Ubuntu SMP Fri Feb 2 23:24:
```22 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux
@evgenyigumnov
Copy link
Contributor Author

[dependencies]

candle-nn = "0.4.1"
candle-core = "0.4.1"
candle-datasets = "0.4.1"
candle-transformers = "0.4.1"
candle-examples = "0.4.1"
hf-hub = "0.3.2"
tokenizers = "0.15.2"

@edesalve
Copy link

Hi @evgenyigumnov, upgrading cuda driver to >= 545 should work: #1761.

@dommyrock
Copy link

dommyrock commented Apr 11, 2024

Got the same issue today while running the gemma ("codegemma-7b-it") example on my legacy RTX2070 gpu 😞

Error: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found") when loading cast_u32_bf16

nvidia-smi --query-gpu=name,compute_cap,driver_version --format=csv
~ NVIDIA GeForce RTX 2070 with Max-Q Design, 7.5, 552.12

The issue is our Gpu's don't support those capabilities

As seen from Kernel code HERE:
https://github.com/huggingface/candle/blob/main/candle-kernels/src/cast.cu#L73

Only supports Architectures with Capabilities >= 800 (8.0) while my gpu and @evgenyigumnov 's 2080 fall under unsupported Architectures

And it's set to >=800 for a reason
The 16-bit __nv_bfloat16 floating-point version of atomicAdd() is only supported by devices of compute capability 8.x and higher.

Docs: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities

Search for "The 16-bit __nv_bfloat16"

Gpu capabilities:
https://developer.nvidia.com/cuda-gpus (RTX Gpu is under "CUDA-Enabled GeForce and TITAN Products" )

Related Issue : #1911

So unfortunately both me and @evgenyigumnov can't run code including dependency on that kernel matrix function on current Gpu architectures :( .
8.0 Capability is mostly 3000 gen chips after ours ... (as seen from "Gpu capabilities" page)
image


Some relevant docs if you want to know more.

Docs:

Compute capabilities:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities

Bfloat16
https://en.wikipedia.org/wiki/Bfloat16_floating-point_format

virtual-architecture-macros
https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#virtual-architecture-macros

using-cuda-arch
https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#using-cuda-arch

__nv_bfloat16 samples
https://github.com/NVIDIA/cuda-samples/tree/master/Samples/3_CUDA_Features/bf16TensorCoreGemm

Candle bfloat16 added:
ec79fc4

Hope it helps : )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants