Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduction of CUDA Graphs to LLama.cpp #6766

Merged
merged 20 commits into from
May 8, 2024

Conversation

agray3
Copy link
Contributor

@agray3 agray3 commented Apr 19, 2024

See Issue #6763

@agray3
Copy link
Contributor Author

agray3 commented Apr 19, 2024

Labelled as DRAFT since this will need some more testing across different models, CUDA versions, etc before it is merged. See #6763.

@phymbert phymbert added need feedback Testing and feedback with results are needed performance Speed related topics labels Apr 19, 2024
@phymbert
Copy link
Collaborator

phymbert commented Apr 19, 2024

@ggerganov please restart the CI github manager for benchmark

EDIT: the job just failed: https://github.com/ggerganov/llama.cpp/actions/runs/8753504588/job/24040839682?pr=6766

@JohannesGaessler
Copy link
Collaborator

./llama-bench and ./perplexity are broken with this PR which is why I'm using ./main for testing. Results for LLaMA 2 7b q4_0:

GPU Test t/s master t/s PR Speedup
1x RTX 3090 ./main pp 3445 3185 0.92
1x RTX 3090 ./main tg 116.31 128.27 1.10
1x RTX 4090 ./main pp 6668 6963 0.96
1x RTX 4090 ./main tg 138.31 146.81 1.06

The speedup does not seem to be consistent: for a batch size of 1 it's faster but for batch sizes >> 1 it's slower. Still, the potential speedup is higher than I thought so it seems I was wrong when I previously said using CUDA graphs would not be worthwhile.

Error for perplexity
johannesg@johannes-romed82t-00 ~/Projects/llama.cpp                                                                         [15:04:17]
> $ ./perplexity --file wikitext-2-raw/wiki.test.raw --model models/opt/${model_name}-${quantization}.gguf -ngl 99 -c 4096
main: build = 2699 (cec409aa)
main: built with cc (GCC) 13.2.1 20230801 for x86_64-pc-linux-gnu
main: seed  = 1713531858
llama_model_loader: loaded meta data with 22 key-value pairs and 291 tensors from models/opt/llama_2-7b-q4_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = LLaMA v2
llama_model_loader: - kv   2:                       llama.context_length u32              = 4096
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 32
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 2
llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  12:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  13:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  14:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  15:                      tokenizer.ggml.merges arr[str,61249]   = ["▁ t", "e r", "i n", "▁ a", "e n...
llama_model_loader: - kv  16:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  17:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  18:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  19:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  20:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  21:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type q4_0:  225 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 4096
llm_load_print_meta: n_embd_v_gqa     = 4096
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 11008
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = Q4_0
llm_load_print_meta: model params     = 6.74 B
llm_load_print_meta: model size       = 3.56 GiB (4.54 BPW) 
llm_load_print_meta: general.name     = LLaMA v2
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
llm_load_tensors: ggml ctx size =    0.30 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =    70.31 MiB
llm_load_tensors:      CUDA0 buffer size =  3577.56 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 4096
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =  2048.00 MiB
llama_new_context_with_model: KV self size  = 2048.00 MiB, K (f16): 1024.00 MiB, V (f16): 1024.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   296.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    16.01 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 64 / 128 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 958.351 ms
perplexity: calculating perplexity over 81 chunks, n_ctx=4096, batch_size=2048, n_seq=1
CUDA error: invalid argument
  current device: 0, in function ggml_backend_cuda_graph_compute at ggml-cuda.cu:2515
  cuGraphKernelNodeGetParams_v2(nodes[i], &paramsDriver[i])
GGML_ASSERT: ggml-cuda.cu:60: !"CUDA error"
[1]    1644936 IOT instruction (core dumped)  ./perplexity --file wikitext-2-raw/wiki.test.raw --model  -ngl 99 -c 4096

@agray3
Copy link
Contributor Author

agray3 commented Apr 19, 2024

Thanks for these tests. I haven't yet optimized/tested for batch size greater than one - it might be a good idea for me to only enable CUDA graphs for size 1 initially. I'll also look at the failures.

it seems I was wrong when I previously said using CUDA graphs would not be worthwhile.

It's not obvious - even without CUDA graphs, llama.cpp already does a good job of pre-launching all kernels in the GGML graph, so CPU-side launch overheads are not the issue. But CUDA graphs also optimises GPU-side launch overheads, to reduce the "gaps" between kernels, and that is the benefit we are seeing here.

@ardfork
Copy link
Contributor

ardfork commented Apr 19, 2024

Tried to add ROCm HIP compatibility but it error with:

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 6700 XT, compute capability 10.3, VMM: no
| model                          |       size |     params | backend    | ngl | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
| llama 7B Q6_K                  |   5.53 GiB |     7.24 B | ROCm       |  99 | pp 512     |    763.24 ± 2.89 |
CUDA error: operation not permitted when stream is capturing
  current device: 0, in function alloc at llama.cpp/ggml-cuda.cu:233
  hipMalloc((void **) &ptr, look_ahead_size)
GGML_ASSERT: llama.cpp/ggml-cuda.cu:60: !"CUDA error"

Here my patch:

diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh
index 481065b2..f1553f7c 100644
--- a/ggml-cuda/common.cuh
+++ b/ggml-cuda/common.cuh
@@ -117,6 +117,23 @@
 #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
 #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
 #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+#define CUDA_KERNEL_NODE_PARAMS_v2 hipKernelNodeParams
+#define CUresult hipError_t
+#define cuGetErrorString hipDrvGetErrorString
+#define cuGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cuGraphKernelNodeSetParams hipGraphKernelNodeSetParams
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
+#define cudaGraphExec_t hipGraphExec_t
+#define cudaGraphGetNodes hipGraphGetNodes
+#define cudaGraphInstantiate hipGraphInstantiate
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cudaGraphLaunch hipGraphLaunch
+#define cudaGraphNode_t hipGraphNode_t
+#define cudaGraph_t hipGraph_t
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaStreamBeginCapture hipStreamBeginCapture
+#define cudaStreamCaptureModeGlobal hipStreamCaptureModeGlobal
+#define cudaStreamEndCapture hipStreamEndCapture
 #else
 #include <cuda_runtime.h>
 #include <cuda.h>
@@ -208,14 +225,12 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
 
 #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
 
-#if !defined(GGML_USE_HIPBLAS)
 static const char * cu_get_error_str(CUresult err) {
     const char * err_str;
     cuGetErrorString(err, &err_str);
     return err_str;
 }
 #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
-#endif
 
 #if CUDART_VERSION >= 11100
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
@@ -389,6 +404,16 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 #endif
     return c;
 }
+
+struct cudaGraphExecUpdateResultInfo {
+    cudaGraphNode_t errorFromNode;
+    cudaGraphNode_t errorNode;
+    cudaGraphExecUpdateResult result;
+};
+
+static __host__ __forceinline__ cudaError_t cudaGraphExecUpdate(cudaGraphExec_t hGraphExec, cudaGraph_t hGraph, cudaGraphExecUpdateResultInfo* resultInfo ) {
+    return hipGraphExecUpdate(hGraphExec, hGraph, &resultInfo->errorNode, &resultInfo->result);
+}
 #endif // defined(GGML_USE_HIPBLAS)
 
 // TODO: move to ggml-common.h

@sorasoras
Copy link

.\llama-bench.exe -m W:\model\mistral-7b-v0.1.Q5_K_S.gguf -ngl 99 -sm none
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 2 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: yes
  Device 1: Tesla P40, compute capability 6.1, VMM: yes
| model                          |       size |     params | backend    | ngl | sm         | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------- | ---------------: |

It does not seems to work at P40.

and I cannot get it compile on ROCM

W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2410:5: error: unknown type name 'cudaGraph_t'
    cudaGraph_t graph = nullptr;
    ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2411:5: error: unknown type name 'cudaGraphExec_t'; did you mean 'hipGraphExec_t'?
    cudaGraphExec_t instance = nullptr;
    ^~~~~~~~~~~~~~~
    hipGraphExec_t
C:/Program Files/AMD/ROCm/5.7/include\hip/hip_runtime_api.h:1098:30: note: 'hipGraphExec_t' declared here
typedef struct hipGraphExec* hipGraphExec_t;
                             ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2458:63: error: use of undeclared identifier 'cudaStreamCaptureModeGlobal'
        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal));
                                                              ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2501:9: error: unknown type name 'cudaGraphNode_t'; did you mean 'hipGraphNode_t'?
        cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
        ^~~~~~~~~~~~~~~
        hipGraphNode_t
C:/Program Files/AMD/ROCm/5.7/include\hip/hip_runtime_api.h:1094:30: note: 'hipGraphNode_t' declared here
typedef struct hipGraphNode* hipGraphNode_t;
                             ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2502:9: error: unknown type name 'CUDA_KERNEL_NODE_PARAMS_v2'
        CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
        ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2503:9: error: unknown type name 'cudaKernelNodeParams'; did you mean 'hipKernelNodeParams'?
        cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
        ^~~~~~~~~~~~~~~~~~~~
        hipKernelNodeParams
C:/Program Files/AMD/ROCm/5.7/include\hip/hip_runtime_api.h:1139:3: note: 'hipKernelNodeParams' declared here
} hipKernelNodeParams;
  ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2516:38: error: use of undeclared identifier 'cudaGraphKernelNodeGetParams'; did you mean 'hipGraphKernelNodeGetParams'?
                cudaError_t statRT = cudaGraphKernelNodeGetParams(nodes[i], &paramsRuntime[i]); // Get params using runtime
                                     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
                                     hipGraphKernelNodeGetParams
C:/Program Files/AMD/ROCm/5.7/include\hip/hip_runtime_api.h:6607:12: note: 'hipGraphKernelNodeGetParams' declared here
hipError_t hipGraphKernelNodeGetParams(hipGraphNode_t node, hipKernelNodeParams* pNodeParams);
           ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2541:9: error: unknown type name 'cudaGraphExecUpdateResultInfo'; did you mean 'hipGraphExecUpdateResult'?
        cudaGraphExecUpdateResultInfo resultInfo;
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        hipGraphExecUpdateResult
C:/Program Files/AMD/ROCm/5.7/include\hip/hip_runtime_api.h:1199:3: note: 'hipGraphExecUpdateResult' declared here
} hipGraphExecUpdateResult;
  ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2545:20: error: use of undeclared identifier 'cudaGraphLaunch'; did you mean 'hipGraphLaunch'?
        CUDA_CHECK(cudaGraphLaunch(cudaGraph.instance, cuda_ctx->stream()));
                   ^~~~~~~~~~~~~~~
                   hipGraphLaunch
W:/git/test/agray3/llama.cpp/./ggml-cuda/common.cuh:186:40: note: expanded from macro 'CUDA_CHECK'
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
                                       ^
W:/git/test/agray3/llama.cpp/./ggml-cuda/common.cuh:180:22: note: expanded from macro 'CUDA_CHECK_GEN'
        auto err_ = (err);                                                          \
                     ^
C:/Program Files/AMD/ROCm/5.7/include\hip/hip_runtime_api.h:6538:12: note: 'hipGraphLaunch' declared here
hipError_t hipGraphLaunch(hipGraphExec_t graphExec, hipStream_t stream);
           ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2816:62: warning: unused parameter 'buffer' [-Wunused-parameter]
GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
                                                             ^
W:/git/test/agray3/llama.cpp/ggml-cuda.cu:2816:77: warning: unused parameter 'size' [-Wunused-parameter]
GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
                                                                            ^
10 warnings and 9 errors generated when compiling for gfx1100.
ninja: build stopped: subcommand failed.

@Engininja2
Copy link
Contributor

nodes, paramsDriver, and paramsRuntime are being used across multiple calls of the function but their data is only loaded in an earlier call. Should they be static?

@agray3
Copy link
Contributor Author

agray3 commented Apr 22, 2024

nodes, paramsDriver, and paramsRuntime are being used across multiple calls of the function but their data is only loaded in an earlier call. Should they be static?

Good spot! I had this earlier when developing and somehow it got lost when refactoring/tidying but still worked. Now fixed.

@agray3
Copy link
Contributor Author

agray3 commented Apr 22, 2024

@JohannesGaessler I think the llama-bench and perplexity issues should now be fixed with the latest commit - can you confirm from your end? Perplexity is slower with CUDA graphs ATM because is has batch size > 1 - as above I think I should only enable CUDA graphs for batch size 1 initially.

@agray3
Copy link
Contributor Author

agray3 commented Apr 22, 2024

It does not seems to work at P40 and I cannot get it compile on ROCM

Tried to add ROCm HIP compatibility but it error

The P40 issue may be due to the CUDA version in use, we need CUDA >= 12.0 for the functionality here. I have now added a macro such that this new code won't be compiled with earlier CUDA, and the original code will instead be used. Similarly it now won't be compiled with HIP/ROCm (which can be added in a follow up if there is adequate support and performance benefits on that platform).

@sorasoras
Copy link

It does not seems to work at P40 and I cannot get it compile on ROCM

Tried to add ROCm HIP compatibility but it error

The P40 issue may be due to the CUDA version in use, we need CUDA >= 12.0 for the functionality here. I have now added a macro such that this new code won't be compiled with earlier CUDA, and the original code will instead be used. Similarly it now won't be compiled with HIP/ROCm (which can be added in a follow up if there is adequate support and performance benefits on that platform).

I am using 12.4cuda for my P40, but this might have something to do with hardware,right?

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 551.78                 Driver Version: 551.78         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla P40                    WDDM  |   00000000:05:00.0 Off |                  Off |
| N/A   32C    P8             11W /  250W |       0MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla P40                    WDDM  |   00000000:6F:00.0 Off |                  Off |
| N/A   26C    P8              9W /  250W |       0MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

@sorasoras
Copy link

sorasoras commented Apr 22, 2024

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 2 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: yes
  Device 1: Tesla P40, compute capability 6.1, VMM: yes
| model                          |       size |     params | backend    | ngl | sm         | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------- | ---------------: |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | CUDA       |  99 | none       | pp 512     |    435.30 ± 0.13 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | CUDA       |  99 | none       | tg 128     |     21.68 ± 0.02 |

build: 8960fe86 (2713)

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 2 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: yes
  Device 1: Tesla P40, compute capability 6.1, VMM: yes
| model                          |       size |     params | backend    | ngl | sm         | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------- | ---------------: |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | CUDA       |  99 | none       | pp 512     |    435.63 ± 0.08 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | CUDA       |  99 | none       | tg 128     |     23.22 ± 0.05 |

build: 800f4fe (2701)

@agray3 ok,it seems work now.
there are decent improvement on token generations.

I encounter a error during runtime

CUDA error: invalid argument
  current device: 0, in function ggml_backend_cuda_graph_compute at W:\git\test\agray3\llama.cpp\ggml-cuda.cu:2545
  cudaGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.params[i])
GGML_ASSERT: W:\git\test\agray3\llama.cpp\ggml-cuda.cu:60: !"CUDA error"

@agray3
Copy link
Contributor Author

agray3 commented Apr 22, 2024

@sorasoras thanks for testing. Can you let me know the exact command for which you are seeing a failure, so I can try and reproduce? I don't have access to P40 but I have done a test on P100 that works OK.

@sorasoras
Copy link

@sorasoras thanks for testing. Can you let me know the exact command for which you are seeing a failure, so I can try and reproduce? I don't have access to P40 but I have done a test on P100 that works OK.

I am running batch inference for translations.
It works for llama-bench but when I run it with server and inference via Openai

server.exe -m W:\model\Qwen1.5_Q5KS.gguf -c 2048 -ngl 99 -cb --port 8081 -sm none -mg 0 -a Qwen_1.5_Q5KS

my guess is it has some issue with continuous batching

@ardfork
Copy link
Contributor

ardfork commented Apr 22, 2024

With your changes, it now works with ROCm HIP (with patch below), but it is slower, making it likely not worth enabling it on that platform. I'm using a RX 6700 XT.

model size params backend ngl test t/s master t/s PR speedup
llama 7B Q4_K - Medium 3.80 GiB 6.74 B ROCm 99 pp 512 850.47 ± 0.18 850.52 ± 0.08 1.00
llama 7B Q4_K - Medium 3.80 GiB 6.74 B ROCm 99 tg 128 54.05 ± 0.02 52.86 ± 0.05 0.98

The patch can be found here: https://gist.github.com/ardfork/a223a10d20961707e7b5f3ee0b76c7d5, didn't want to bloat your PR comments with a wall of text that will be useless for most people reading it.

@JohannesGaessler
Copy link
Collaborator

Some single GPU test results:

GPU Model Test t/s master t/s PR Speedup
1x RTX 4090 LLaMA 3 8b f16 pp512 9776 9777 1.00
1x RTX 4090 LLaMA 3 8b f16 tg128 54.92 55.96 1.02
1x RTX 4090 LLaMA 3 8b q4_0 pp512 7387 7385 1.00
1x RTX 4090 LLaMA 3 8b q4_0 tg128 137.50 146.02 1.06
1x P40 LLaMA 2 7b q4_0 pp512 908 909 1.00
1x P40 LLaMA 2 7b q4_0 tg128 53.84 51.67 0.96
1x RX 6800 LLaMA 2 7b q4_0 pp512 712 713 1.00
1x RX 6800 LLaMA 2 7b q4_0 tg128 60.25 60.66 1.01

There is a performance regression for P40. I was not able to run multi GPU tests because llama-bench crashes. Error with 6x RTX 4090:

CUDA error: operation not supported
  current device: 2, in function ggml_backend_cuda_graph_compute at ggml-cuda.cu:2558
  cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, __null, __null, 0)
GGML_ASSERT: ggml-cuda.cu:60: !"CUDA error"

Error with 3x P40:

CUDA error: an illegal memory access was encountered
  current device: 2, in function ggml_backend_cuda_synchronize at ggml-cuda.cu:2403
  cudaStreamSynchronize(cuda_ctx->stream())
GGML_ASSERT: ggml-cuda.cu:60: !"CUDA error"

@sorasoras
Copy link

sorasoras commented Apr 22, 2024

With your changes, it now works with ROCm HIP (with patch below), but it is slower, making it likely not worth enabling it on that platform. I'm using a RX 6700 XT.

model size params backend ngl test t/s master t/s PR speedup
llama 7B Q4_K - Medium 3.80 GiB 6.74 B ROCm 99 pp 512 850.47 ± 0.18 850.52 ± 0.08 1.00
llama 7B Q4_K - Medium 3.80 GiB 6.74 B ROCm 99 tg 128 54.05 ± 0.02 52.86 ± 0.05 0.98
The patch can be found here: https://gist.github.com/ardfork/a223a10d20961707e7b5f3ee0b76c7d5, didn't want to bloat your PR comments with a wall of text that will be useless for most people reading it.

It work decently on RDNA3 through. I forget to record the value but it's about 2% faster for token generation.

@agray3
Copy link
Contributor Author

agray3 commented Apr 22, 2024

OK thanks. I've now disabled CUDA graphs for multi-GPU and batch size > 1 which should prevent these crashes and regressions (where I can investigate these cases later). I can also disable for Pascal, I'l have a look at that tomorrow (also assessing Volta etc).

@sorasoras
Copy link

image
That was funny :)

@agray3
Copy link
Contributor Author

agray3 commented Apr 23, 2024

I've reproduced the llama-bench regression on Pascal (CC 6) and Volta (CC 7), so I've now added code to disable CUDA graphs for CC<8. I've also added an env var: export LLAMACPP_DISABLE_CUDA_GRAPHS=1 will optionally disable on any GPU.

@jdecourval
Copy link
Contributor

With your changes, it now works with ROCm HIP (with patch below), but it is slower, making it likely not worth enabling it on that platform. I'm using a RX 6700 XT.
model size params backend ngl test t/s master t/s PR speedup
llama 7B Q4_K - Medium 3.80 GiB 6.74 B ROCm 99 pp 512 850.47 ± 0.18 850.52 ± 0.08 1.00
llama 7B Q4_K - Medium 3.80 GiB 6.74 B ROCm 99 tg 128 54.05 ± 0.02 52.86 ± 0.05 0.98
The patch can be found here: https://gist.github.com/ardfork/a223a10d20961707e7b5f3ee0b76c7d5, didn't want to bloat your PR comments with a wall of text that will be useless for most people reading it.

It work decently on RDNA3 through. I forget to record the value but it's about 2% faster for token generation.

Here's on my 7900xtx

model size params backend ngl test t/s before t/s after
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 pp 512 737.72 ± 1.57 738.31 ± 1.48
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 tg 128 29.00 ± 0.05 28.94 ± 0.02
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 pp 512 3219.19 ± 8.14 3218.28 ± 8.52
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 tg 128 87.13 ± 0.06 87.37 ± 0.06
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 pp 512 1165.00 ± 4.31 1165.52 ± 4.11
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 tg 128 55.08 ± 0.04 55.39 ± 0.04
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 pp 512 3273.69 ± 10.66 3268.28 ± 10.24
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 tg 128 87.24 ± 0.07 87.51 ± 0.05

Over those 4 models, your PR yields a significant, but very small average TG speedup of 0.2%.

@JohannesGaessler
Copy link
Collaborator

I just noticed that instead of using Github's built-in draft feature you added "DRAFT:" to the title. Please let me know when you think the PR is ready for review, I'll take a look then.

@agray3
Copy link
Contributor Author

agray3 commented May 8, 2024

I was wondering if it would be possible to use cudaStreamBeginCaptureToGraph to keep track of which nodes in the ggml graph correspond to which nodes in the cuda graph. In other words, is it possible to query the nodes of the cuda graph as they are captured, or only after cudaStreamEndCapture? If so, I think this could allow us to implement a more reliable update mechanism in the future.

Unfortunately the graph is not exposed in the app until cudaStreamEndCapture(). You could possibly do something like this manually by introducing a wrapper to all kernel launches, plus some mechanism in that wrapper to keep track of things.

ggml-cuda.cu Show resolved Hide resolved
@slaren
Copy link
Collaborator

slaren commented May 8, 2024

Unfortunately the graph is not exposed in the app until cudaStreamEndCapture(). You could possibly do something like this manually by introducing a wrapper to all kernel launches, plus some mechanism in that wrapper to keep track of things.

If I understand cudaStreamBeginCaptureToGraph correctly, I imagine it would be possible to call cudaStreamEndCapture after every op, take note of the current number of nodes, and then restart the capture using cudaStreamBeginCaptureToGraph. Probably wouldn't be very efficient, though. It's not very clear from the documentation if cudaStreamBeginCaptureToGraph clears the graph, or requires an empty graph, or appends to it.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on static code analysis I only have minor comments like i++ vs. ++i and the things I otherwise mentioned. Definitely nothing that I would consider worth blocking a merge for; Let me also check the performance and correctness.

ggml-cuda.cu Outdated Show resolved Hide resolved
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
GGML_ASSERT(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are correctness checks for ggml-backend that are only meant to be enabled in debug builds (that code was already before). In normal circumstances this should never happen, so there is no need to check it on release builds.

}

#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
GGML_ASSERT(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));

ggml-cuda.cu Outdated
Comment on lines 2625 to 2635
#if 0
if (disable_cuda_graphs_due_to_failed_capture) {
use_cuda_graph = false;
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
#ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__);
#endif
} else {
graph_evaluated_or_captured = true; // CUDA graph has been captured
}
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this code block?

Copy link
Collaborator

@slaren slaren May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code used to check for graph capture failures, but I removed that since the implementation had issues. I think that there shouldn't be any capture failures if the ggml graph is checked correctly for incompatible ops early on. Also it uses relaxed capture mode now, which allows ops such as allocations to succeed even if they cannot be captured into the graph.

@agray3
Copy link
Contributor Author

agray3 commented May 8, 2024

If I understand cudaStreamBeginCaptureToGraph correctly, I imagine it would be possible to call cudaStreamEndCapture after every op, take note of the current number of nodes, and then restart the capture using cudaStreamBeginCaptureToGraph. Probably wouldn't be very efficient, though. It's not very clear from the documentation if cudaStreamBeginCaptureToGraph clears the graph, or requires an empty graph, or appends to it.

Yes, you are right, that should work - I've not tried it but it should append. But as you say I think it would have quite substantial overhead because there are a lot of kernels in each graph.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU Model Model Size [GiB] Test t/s master t/s f42312e Speedup
RTX 4090 llama 8B F16 14.96 tg128 56.68 57.85 1.02
RTX 4090 llama 8B Q4_0 4.33 tg128 149.28 159.67 1.07

According to my testing the results are bit-for-bit identical, the evaluation is simply faster. It's unfortunate that it right now only works for batch size 1 since that means it won't be compatible with >1 server slots or speculative decoding such as with #6828 (I'm not sure whether the constant switching between batch sizes causes issues but I consider that my own problem to fix).

@ggerganov
Copy link
Owner

I don't have Ampere hardware handy, but the PR builds and works on RTX 2060 as usual

@slaren
Copy link
Collaborator

slaren commented May 8, 2024

I have been testing the VRAM usage of the CUDA graphs, and for me it is ~10MB for 7B, ~12MB for 13B, and ~18MB for 30B. So I think it is low enough that it is unlikely to cause any regressions, and can be left enabled by default.

@slaren slaren merged commit bc4bba3 into ggerganov:master May 8, 2024
64 checks passed
@fat-tire
Copy link
Contributor

fat-tire commented May 9, 2024

I don't have Ampere hardware handy, but the PR builds and works on RTX 2060 as usual

I'm guessing 30x0 is good too (also Ampere)...

@JohannesGaessler
Copy link
Collaborator

Did anyone check server performance before and after this PR? I am seeing no difference in terms of user request throughput on 1x RTX 4090 even with a single server slot: #6828 (comment)

@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label May 9, 2024
@slaren
Copy link
Collaborator

slaren commented May 9, 2024

If it is being disabled for some reason, you can find out with a debug build.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented May 9, 2024

According to the log with debugging enabled this seems to be the reason:

ggml_backend_cuda_graph_compute: disabling CUDA graphs due to batch size > 1 [ffn_inp-0] [4096 256 1 1]
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to batch size > 1 [ffn_inp-0] [4096 256 1 1]
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to batch size > 1 [ffn_inp-0] [4096 256 1 1]
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to batch size > 1 [ffn_inp-0] [4096 256 1 1]
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to too many consecutive updates

I think the problem is that the graph is initially NULL so this sets cuda_graph_update_required = true and because the microbatches are > 1 this also doesn't set the graph to not NULL. As a consequence a sufficiently long prompt triggers disable_due_to_too_many_updates = true and the graphs are permanently disabled.

@slaren
Copy link
Collaborator

slaren commented May 9, 2024

The logic to determine when to disable graphs is not great, but necessary to avoid degrading performance when cuda graphs cannot be used. For example with pipeline parallelism tensor addresses change with every eval so it must be disabled. But maybe prompt evals shouldn't increase the counter.

@agray3
Copy link
Contributor Author

agray3 commented May 10, 2024

@JohannesGaessler can you let me know how to reproduce this? I can then try to relax these checks appropriately. Thanks.

@JohannesGaessler
Copy link
Collaborator

The easiest way is to just run main with a prompt that is much longer than --ubatch-size and with --ubatch-size > 1.

@agray3
Copy link
Contributor Author

agray3 commented May 10, 2024

Thanks. Yeah I confirm it works if I comment out the line
cuda_ctx->cuda_graph->number_consecutive_updates++;
I just need to work out the best way to stop this counter incrementing due to long prompts - will take a look next week.

@agray3
Copy link
Contributor Author

agray3 commented May 14, 2024

Thanks. Yeah I confirm it works if I comment out the line cuda_ctx->cuda_graph->number_consecutive_updates++; I just need to work out the best way to stop this counter incrementing due to long prompts - will take a look next week.

From playing around with the code, it seems that, in main.cpp, we know the prompt is still evaluating when n_past==0 so could use that information to disable CUDA graphs while the prompt is still evaluating. However the question is how to propagate that info down to the ggml-cuda module, since there are a several layers in between. Firstly, @JohannesGaessler @slaren do you know if there is another way to assess if the prompt is evaluating at a deeper level? If not, then I think it will require a series of API calls through the layers to set/unset a flag in ggml-cuda.

@slaren
Copy link
Collaborator

slaren commented May 14, 2024

I think you can use the already existing code that checks for batch size > 1 to do this. I changed your implementation to check for add operation instead of a softmax, but the idea is the same. Did you find it unreliable?

llama.cpp knows when it is a prompt (or batch) evaluation when the call to llama_decode has batch.n_tokens > 1. The way I would like to do this in the long run is by implementing the graph_plan functions of the ggml-backend interface using cuda graphs, and have llama.cpp use a graph plan for single token batches when possible (we would also need to add a graph_plan_update function, there isn't one currently).

agray3 added a commit to agray3/llama.cpp that referenced this pull request May 15, 2024
As discussed in PR ggerganov#6766, CUDA graphs were being disabled in the presence of long prompts.
This fixes the issue by avoiding the consective update counter from incrementing unnecessarily
for tokens in which cuda graphs are disabled due to batch size > 1.
@agray3
Copy link
Contributor Author

agray3 commented May 15, 2024

Thanks, I found the issue - the counter is being unnecessarily incremented even for tokens where graphs are disabled. See the simple fix at #7302

slaren pushed a commit that referenced this pull request May 15, 2024
As discussed in PR #6766, CUDA graphs were being disabled in the presence of long prompts.
This fixes the issue by avoiding the consective update counter from incrementing unnecessarily
for tokens in which cuda graphs are disabled due to batch size > 1.
teleprint-me pushed a commit to teleprint-me/llama.cpp that referenced this pull request May 17, 2024
As discussed in PR ggerganov#6766, CUDA graphs were being disabled in the presence of long prompts.
This fixes the issue by avoiding the consective update counter from incrementing unnecessarily
for tokens in which cuda graphs are disabled due to batch size > 1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
need feedback Testing and feedback with results are needed performance Speed related topics Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet