Skip to content

Commit

Permalink
whisper : add GPU support via cuBLAS (ggerganov#834)
Browse files Browse the repository at this point in the history
* make : add WHISPER_CUBLAS

* make : fix CUBLAS build

* whisper : disable Flash Attention + adjust memory buffers

* whisper : remove old commented code

* readme : add cuBLAS instructions

* cmake : add WHISPER_CUBLAS option

* gitignore : ignore build-cublas
  • Loading branch information
ggerganov committed Apr 30, 2023
1 parent 12618e5 commit 8d56873
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 46 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -12,6 +12,7 @@ build-em/
build-debug/
build-release/
build-static/
build-cublas/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/
Expand Down
39 changes: 36 additions & 3 deletions CMakeLists.txt
Expand Up @@ -51,7 +51,7 @@ option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})

option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
option(WHISPER_SDL2 "whisper: support for libSDL2" OFF)

if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
Expand All @@ -62,7 +62,8 @@ if (APPLE)
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
option(WHISPER_OPENBLAS "whisper: support for OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
endif()

option(WHISPER_PERF "whisper: enable perf timings" OFF)
Expand Down Expand Up @@ -127,7 +128,7 @@ if (APPLE)
endif()
endif()

if (WHISPER_SUPPORT_OPENBLAS)
if (WHISPER_OPENBLAS)
find_library(OPENBLAS_LIB
NAMES openblas libopenblas
)
Expand All @@ -141,6 +142,31 @@ if (WHISPER_SUPPORT_OPENBLAS)
endif()
endif()

if (WHISPER_CUBLAS)
cmake_minimum_required(VERSION 3.17)

find_package(CUDAToolkit)

if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")

enable_language(CUDA)

set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)

add_compile_definitions(GGML_USE_CUBLAS)

if (WHISPER_STATIC)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()

else()
message(WARNING "cuBLAS not found")
endif()
endif()

# compiler flags

if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
Expand Down Expand Up @@ -247,6 +273,7 @@ set(TARGET whisper)
add_library(${TARGET}
ggml.h
ggml.c
${GGML_CUDA_SOURCES}
whisper.h
whisper.cpp
)
Expand Down Expand Up @@ -279,6 +306,12 @@ if (BUILD_SHARED_LIBS)
)
endif()

if (GGML_CUDA_SOURCES)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
endif()

if (EMSCRIPTEN)
set_target_properties(${TARGET} PROPERTIES COMPILE_FLAGS "-msimd128")
endif()
Expand Down
28 changes: 20 additions & 8 deletions Makefile
@@ -1,3 +1,5 @@
default: main bench

ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
Expand Down Expand Up @@ -157,6 +159,18 @@ ifdef WHISPER_OPENBLAS
LDFLAGS += -lopenblas
endif

ifdef WHISPER_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
WHISPER_OBJ += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native

ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif

ifdef WHISPER_GPROF
CFLAGS += -pg
CXXFLAGS += -pg
Expand Down Expand Up @@ -200,28 +214,26 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )

default: main bench

#
# Build library
#

ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
ggml.o: ggml.c ggml.h ggml-cuda.h
$(CC) $(CFLAGS) -c $< -o $@

whisper.o: whisper.cpp whisper.h ggml.h
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@

ifndef WHISPER_COREML
WHISPER_OBJ = whisper.o
WHISPER_OBJ += whisper.o
else
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o

whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o

WHISPER_OBJ = whisper.o whisper-encoder.o whisper-encoder-impl.o
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
endif

libwhisper.a: ggml.o $(WHISPER_OBJ)
Expand Down
26 changes: 20 additions & 6 deletions README.md
Expand Up @@ -18,6 +18,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- Low memory usage (Flash Attention)
- Zero memory allocations at runtime
- Runs on the CPU
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)

Supported platforms:
Expand Down Expand Up @@ -254,7 +255,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
# using Makefile
make clean
WHISPER_COREML=1 make -j

# using CMake
cd build
cmake -DWHISPER_COREML=1 ..
Expand All @@ -271,20 +272,33 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded

system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |

...
```
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
Next runs are faster.
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## NVIDIA GPU support via cuBLAS
With NVIDIA cards, the Encoder processing can be offloaded to the GPU to a large extend through cuBLAS.
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
Now build `whisper.cpp` with cuBLAS support:
```
make clean
WHISPER_CUBLAS=1 make -j
```
Run all the examples as usual.
## Limitations
- Inference only
- No GPU support (yet)
## Another example
Expand Down Expand Up @@ -429,7 +443,7 @@ system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:00.320]
[00:00:00.000 --> 00:00:00.320]
[00:00:00.320 --> 00:00:00.370] And
[00:00:00.370 --> 00:00:00.690] so
[00:00:00.690 --> 00:00:00.850] my
Expand Down
4 changes: 2 additions & 2 deletions examples/CMakeLists.txt
Expand Up @@ -4,7 +4,7 @@ find_package(Threads REQUIRED)

# third-party

if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# SDL2
find_package(SDL2 REQUIRED)

Expand All @@ -27,7 +27,7 @@ include(DefaultTargetOptions)

set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)

if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# common-sdl

set(TARGET common-sdl)
Expand Down
2 changes: 1 addition & 1 deletion examples/command/CMakeLists.txt
@@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# command
set(TARGET command)
add_executable(${TARGET} command.cpp)
Expand Down
2 changes: 1 addition & 1 deletion examples/stream/CMakeLists.txt
@@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# stream
set(TARGET stream)
add_executable(${TARGET} stream.cpp)
Expand Down
2 changes: 1 addition & 1 deletion examples/talk-llama/CMakeLists.txt
@@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# talk-llama
set(TARGET talk-llama)
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
Expand Down
2 changes: 1 addition & 1 deletion examples/talk/CMakeLists.txt
@@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# talk
set(TARGET talk)
#add_executable(${TARGET} talk.cpp gpt-2.cpp)
Expand Down
37 changes: 14 additions & 23 deletions whisper.cpp
Expand Up @@ -102,7 +102,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
#define WHISPER_PRINT_DEBUG(...)
#endif

#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16

Expand Down Expand Up @@ -224,11 +224,11 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1ull*1024*1024;

static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
{ MODEL_TINY, 14ull*MB },
{ MODEL_BASE, 18ull*MB },
{ MODEL_SMALL, 28ull*MB },
{ MODEL_MEDIUM, 36ull*MB },
{ MODEL_LARGE, 44ull*MB },
{ MODEL_TINY, 62ull*MB },
{ MODEL_BASE, 80ull*MB },
{ MODEL_SMALL, 120ull*MB },
{ MODEL_MEDIUM, 158ull*MB },
{ MODEL_LARGE, 198ull*MB },
};

static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
Expand Down Expand Up @@ -280,11 +280,11 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
};

static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
{ MODEL_TINY, 6ull*MB },
{ MODEL_BASE, 8ull*MB },
{ MODEL_SMALL, 13ull*MB },
{ MODEL_MEDIUM, 22ull*MB },
{ MODEL_LARGE, 33ull*MB },
{ MODEL_TINY, 30ull*MB },
{ MODEL_BASE, 38ull*MB },
{ MODEL_SMALL, 56ull*MB },
{ MODEL_MEDIUM, 74ull*MB },
{ MODEL_LARGE, 94ull*MB },
};

static const std::map<e_model, size_t> MEM_REQ_DECODE = {
Expand Down Expand Up @@ -1554,26 +1554,17 @@ static bool whisper_encode_internal(

struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);

//struct ggml_tensor * V_trans =
// ggml_permute(ctx0,
// ggml_cpy(ctx0,
// Vcur,
// ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);

//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);

struct ggml_tensor * V =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
Vcur,
n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
);

struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);

Expand Down

0 comments on commit 8d56873

Please sign in to comment.