Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose CUDA device setting in public API #1840

Merged
merged 2 commits into from Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Expand Up @@ -209,9 +209,9 @@ endif

ifdef WHISPER_CUBLAS
ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1)
CUDA_ARCH_FLAG=native
CUDA_ARCH_FLAG ?= native
else
CUDA_ARCH_FLAG=all
CUDA_ARCH_FLAG ?= all
endif

CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
Expand Down
3 changes: 2 additions & 1 deletion whisper.cpp
Expand Up @@ -1060,7 +1060,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
#ifdef GGML_USE_CUBLAS
if (params.use_gpu && ggml_cublas_loaded()) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init(0);
backend_gpu = ggml_backend_cuda_init(params.gpu_device);
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
Expand Down Expand Up @@ -3213,6 +3213,7 @@ int whisper_ctx_init_openvino_encoder(
struct whisper_context_params whisper_context_default_params() {
struct whisper_context_params result = {
/*.use_gpu =*/ true,
/*.gpu_device =*/ 0,
};
return result;
}
Expand Down
1 change: 1 addition & 0 deletions whisper.h
Expand Up @@ -86,6 +86,7 @@ extern "C" {

struct whisper_context_params {
bool use_gpu;
int gpu_device; // CUDA device
};

typedef struct whisper_token_data {
Expand Down