From ba0533100d562b48ed2ae600d1368be114ef3414 Mon Sep 17 00:00:00 2001 From: akieslinger Date: Thu, 19 Dec 2024 10:40:19 +0100 Subject: [PATCH 01/11] Refactor: Moves cuda graph executable update step to separate function. --- ggml/src/ggml-cuda/ggml-cuda.cu | 39 ++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0b06be729864e..715606f324746 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2337,6 +2337,28 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } #endif + +#ifdef USE_CUDA_GRAPH +static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { + + cudaGraphExecUpdateResultInfo result_info; + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + if (stat == cudaErrorGraphExecUpdateFailure) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); +#endif + // The pre-existing graph exec cannot be updated due to violated constraints + // so instead clear error and re-instantiate + cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); + cuda_ctx->cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } else { + GGML_ASSERT(stat == cudaSuccess); + } +} +#endif + static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -2585,21 +2607,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, } // Update graph executable - cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); - if (stat == cudaErrorGraphExecUpdateFailure) { -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); -#endif - // The pre-existing graph exec cannot be updated due to violated constraints - // so instead clear error and re-instantiate - cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); - } else { - GGML_ASSERT(stat == cudaSuccess); - } + update_cuda_graph_executable(cuda_ctx); + // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); #else From 22c24294964dced1671a5308b1076eadc054d4c9 Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Wed, 8 Jan 2025 13:18:18 +0000 Subject: [PATCH 02/11] Refactor: Moves cuda graph update check to separate function. --- ggml/src/ggml-cuda/ggml-cuda.cu | 53 +++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 715606f324746..40a0392749cca 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2337,6 +2337,36 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } #endif +#ifdef USE_CUDA_GRAPH +static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) { + + if (cuda_ctx->cuda_graph->instance == nullptr) { + cuda_graph_update_required = true; + } + + // Check if the graph size has changed + if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cuda_graph_update_required = true; + cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } + + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if (!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + if (!has_matching_properties) { + cuda_graph_update_required = true; + } + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + + return cuda_graph_update_required; +} +#endif + #ifdef USE_CUDA_GRAPH static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { @@ -2398,28 +2428,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, } if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { - cuda_graph_update_required = true; - } - - // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { - cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); - } - - // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token - for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = true; - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); - } - if (!has_matching_properties) { - cuda_graph_update_required = true; - } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); - } + cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required); // Loop over nodes in GGML graph to obtain info needed for CUDA graph cuda_ctx->cuda_graph->updated_kernel_arg.clear(); From eb3ea69850d5e2838f5c2eda4feb11bd43bdaef1 Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Wed, 8 Jan 2025 13:47:14 +0000 Subject: [PATCH 03/11] Refactor: Moves cuda graph maintenance (update or adjusting copy parameters) to separate function for improved readability. --- ggml/src/ggml-cuda/ggml-cuda.cu | 93 ++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 40a0392749cca..de5709f060493 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2337,6 +2337,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } #endif + +#ifdef USE_CUDA_GRAPH +static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { + + if (cuda_graph_update_required) { + // Extract nodes from graph + // First call with null argument gets number of nodes in graph + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); + // Subsequent call with non-null argument gets nodes + cuda_ctx->cuda_graph->nodes.clear(); + cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); + cuda_ctx->cuda_graph->params.clear(); + cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); + if (cuda_ctx->cuda_graph->num_nodes > 0) { + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); + + // Loop over nodes, and extract kernel parameters from each node + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + cudaGraphNodeType node_type; + CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); + if (node_type == cudaGraphNodeTypeKernel) { + cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime + if (stat == cudaErrorInvalidDeviceFunction) { + // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. + // We don't need to update blas nodes, so clear error and move on. + cudaGetLastError(); + } else { + GGML_ASSERT(stat == cudaSuccess); + } + } + } + } + } else { + // One of the arguments to the copy kernel is updated for each token, hence we need to + // replace that argument with the updated value in the CUDA graph + // on update steps, the live parameters will already be captured + int k = 0; + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { + char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); + cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; + CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); + } + } + } +} +#endif + + #ifdef USE_CUDA_GRAPH static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) { @@ -2571,49 +2620,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, } // Perform update to graph (if required for this token), and change copy parameter (required for every token) - - if (cuda_graph_update_required) { - // Extract nodes from graph - // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); - // Subsequent call with non-null argument gets nodes - cuda_ctx->cuda_graph->nodes.clear(); - cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); - cuda_ctx->cuda_graph->params.clear(); - cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); - if (cuda_ctx->cuda_graph->num_nodes > 0) { - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); - - // Loop over nodes, and extract kernel parameters from each node - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); - if (node_type == cudaGraphNodeTypeKernel) { - cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime - if (stat == cudaErrorInvalidDeviceFunction) { - // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. - // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); - } else { - GGML_ASSERT(stat == cudaSuccess); - } - } - } - } - } - - // One of the arguments to the copy kernel is updated for each token, hence we need to - // replace that argument with the updated value in the CUDA graph - if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured - int k = 0; - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { - char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); - cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; - CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); - } - } - } + maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); // Update graph executable update_cuda_graph_executable(cuda_ctx); From ed10ff58a6c317eb29bb7f266e4c0cbb5ba3a1ec Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Wed, 8 Jan 2025 14:06:07 +0000 Subject: [PATCH 04/11] Fix: Adds missing reference to maintain_cuda_graph() definition. --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de5709f060493..e2beceb6f7a8c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2339,7 +2339,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra #ifdef USE_CUDA_GRAPH -static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { +static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { if (cuda_graph_update_required) { // Extract nodes from graph From 37518b7dda860ba4b35bc9eba8f476fb820dd0d4 Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Thu, 9 Jan 2025 09:15:12 +0000 Subject: [PATCH 05/11] Refactor: Improves structure and abstractions by moving CUDA graph evaluation and capture to its own function. --- ggml/src/ggml-cuda/ggml-cuda.cu | 161 +++++++++++++++++--------------- 1 file changed, 85 insertions(+), 76 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e2beceb6f7a8c..d7be9562dddf0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2438,11 +2438,95 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { } #endif + +static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph, + bool & cuda_graph_update_required) { + + while (!graph_evaluated_or_captured) { + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. + // With the use of CUDA graphs, the execution will be performed by the graph launch. + if (!use_cuda_graph || cuda_graph_update_required) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + +#ifndef NDEBUG + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer); + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || + ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft)); + } + } +#endif + + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); + if (!ok) { + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } + } + +#ifdef USE_CUDA_GRAPH + if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture + if (cuda_ctx->cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); + cuda_ctx->cuda_graph->graph = nullptr; + } + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + +#if 0 + if (disable_cuda_graphs_due_to_failed_capture) { + use_cuda_graph = false; + cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__); +#endif + } else { + graph_evaluated_or_captured = true; // CUDA graph has been captured + } +#endif + graph_evaluated_or_captured = true; // CUDA graph has been captured + } else { + graph_evaluated_or_captured = true; // ggml graph has been directly evaluated + } + } + + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } + + // Perform update to graph (if required for this token), and change copy parameter (required for every token) + maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); + + // Update graph executable + update_cuda_graph_executable(cuda_ctx); + + // Launch graph + CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); +#else + graph_evaluated_or_captured = true; +#endif // USE_CUDA_GRAPH + } +} + + static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_cuda_set_device(cuda_ctx->device); + // vector of pointers to CUDA cpy kernels, which are required to identify + // kernel parameters which need updated in the graph for each token + std::vector ggml_cuda_cpy_fn_ptrs; + #ifdef USE_CUDA_GRAPH static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); @@ -2453,9 +2537,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool use_cuda_graph = true; bool cuda_graph_update_required = false; - // vector of pointers to CUDA cpy kernels, which are required to identify - // kernel parameters which need updated in the graph for each token - std::vector ggml_cuda_cpy_fn_ptrs; if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { @@ -2559,79 +2640,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool graph_evaluated_or_captured = false; - while (!graph_evaluated_or_captured) { - // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. - // With the use of CUDA graphs, the execution will be performed by the graph launch. - if (!use_cuda_graph || cuda_graph_update_required) { - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; - } - -#ifndef NDEBUG - assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); - for (int j = 0; j < GGML_MAX_SRC; j++) { - if (node->src[j] != nullptr) { - assert(node->src[j]->buffer); - assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || - ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft)); - } - } -#endif - - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); - if (!ok) { - GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); - } - GGML_ASSERT(ok); - } - } - -#ifdef USE_CUDA_GRAPH - if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture - if (cuda_ctx->cuda_graph->graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); - cuda_ctx->cuda_graph->graph = nullptr; - } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); - -#if 0 - if (disable_cuda_graphs_due_to_failed_capture) { - use_cuda_graph = false; - cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__); -#endif - } else { - graph_evaluated_or_captured = true; // CUDA graph has been captured - } -#endif - graph_evaluated_or_captured = true; // CUDA graph has been captured - } else { - graph_evaluated_or_captured = true; // ggml graph has been directly evaluated - } - } - - if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); - } - - // Perform update to graph (if required for this token), and change copy parameter (required for every token) - maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); - - // Update graph executable - update_cuda_graph_executable(cuda_ctx); - - // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); -#else - graph_evaluated_or_captured = true; -#endif // USE_CUDA_GRAPH - } - + evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); return GGML_STATUS_SUCCESS; } From 0cdc1339198d36dfe318d8e9a77106fd7349306c Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Thu, 9 Jan 2025 12:16:37 +0000 Subject: [PATCH 06/11] Refactor: Moves node graph checks and copy ops into individual function for improved readability. --- ggml/src/ggml-cuda/ggml-cuda.cu | 119 ++++++++++++++++++-------------- 1 file changed, 66 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d7be9562dddf0..0e570e993fa44 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2284,6 +2284,70 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } + +#ifdef USE_CUDA_GRAPH +static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) { + + // Loop over nodes in GGML graph to obtain info needed for CUDA graph + cuda_ctx->cuda_graph->updated_kernel_arg.clear(); + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { + use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__); +#endif + } + + if (node->op == GGML_OP_MUL_MAT_ID) { + use_cuda_graph = false; // This node type is not supported by CUDA graph capture +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); +#endif + } + + if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) { + // disable CUDA graphs for batch size > 1 for now. + // Changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } + + if (node->op == GGML_OP_CPY) { + // store the copy op parameter which changes with each token. + cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); + // store a pointer to each copy op CUDA kernel to identify it later + void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + if (!ptr) { + use_cuda_graph = false; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); +#endif + } else { + if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { + ggml_cuda_cpy_fn_ptrs.push_back(ptr); + } + } + } + + if (!use_cuda_graph) { + break; + } + } + + return use_cuda_graph; +} +#endif + + #ifdef USE_CUDA_GRAPH static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { graph_node_properties->node_address = node->data; @@ -2560,59 +2624,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, if (use_cuda_graph) { cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required); - // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->updated_kernel_arg.clear(); - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; - } - - if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { - use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__); -#endif - } - - if (node->op == GGML_OP_MUL_MAT_ID) { - use_cuda_graph = false; // This node type is not supported by CUDA graph capture -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); -#endif - } - - if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) { - // disable CUDA graphs for batch size > 1 for now. - // Changes in batch size or context size can cause changes to the grid size of some kernels. - use_cuda_graph = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); -#endif - } - - if (node->op == GGML_OP_CPY) { - // store the copy op parameter which changes with each token. - cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); - // store a pointer to each copy op CUDA kernel to identify it later - void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); - if (!ptr) { - use_cuda_graph = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); -#endif - } else { - if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { - ggml_cuda_cpy_fn_ptrs.push_back(ptr); - } - } - } - - if (!use_cuda_graph) { - break; - } - } + use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, + ggml_cuda_cpy_fn_ptrs, use_cuda_graph); // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. if (use_cuda_graph && cuda_graph_update_required) { From dd95edfcfb630851172e0019c54c7d7f229f1cb3 Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Thu, 9 Jan 2025 12:44:14 +0000 Subject: [PATCH 07/11] Refactor: Removes code permanently excluded from compilation to increase readability. --- ggml/src/ggml-cuda/ggml-cuda.cu | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0e570e993fa44..3e76e5b77e932 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2543,19 +2543,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); cuda_ctx->cuda_graph->graph = nullptr; } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); -#if 0 - if (disable_cuda_graphs_due_to_failed_capture) { - use_cuda_graph = false; - cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__); -#endif - } else { - graph_evaluated_or_captured = true; // CUDA graph has been captured - } -#endif + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); graph_evaluated_or_captured = true; // CUDA graph has been captured } else { graph_evaluated_or_captured = true; // ggml graph has been directly evaluated From 98d4e55f5ac784ff54b024d7f894525745e74b30 Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Thu, 9 Jan 2025 13:09:38 +0000 Subject: [PATCH 08/11] Style: Adds missing newline --- ggml/src/ggml-cuda/ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3e76e5b77e932..94d26c1ec6e60 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2643,6 +2643,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool graph_evaluated_or_captured = false; evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + return GGML_STATUS_SUCCESS; } From fcd62d9de6f7c61c17d579a35add9a4ef427fd5f Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Mon, 13 Jan 2025 09:19:07 +0000 Subject: [PATCH 09/11] Style: Consolidates several neighboring '#ifdef USE_CUDA_GRAPH' into a single one --- ggml/src/ggml-cuda/ggml-cuda.cu | 8 -------- 1 file changed, 8 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 94d26c1ec6e60..44f4f6b6a9a0e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2345,10 +2345,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud return use_cuda_graph; } -#endif -#ifdef USE_CUDA_GRAPH static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { graph_node_properties->node_address = node->data; graph_node_properties->node_op = node->op; @@ -2399,10 +2397,8 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } -#endif -#ifdef USE_CUDA_GRAPH static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { if (cuda_graph_update_required) { @@ -2447,10 +2443,8 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto } } } -#endif -#ifdef USE_CUDA_GRAPH static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) { if (cuda_ctx->cuda_graph->instance == nullptr) { @@ -2478,10 +2472,8 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, return cuda_graph_update_required; } -#endif -#ifdef USE_CUDA_GRAPH static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { cudaGraphExecUpdateResultInfo result_info; From 62f2f62453b889d993bc480af19f06a7b681a892 Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Mon, 13 Jan 2025 09:35:36 +0000 Subject: [PATCH 10/11] Refactor: Makes 'cuda_graph_update_required' a local variable --- ggml/src/ggml-cuda/ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 44f4f6b6a9a0e..4a7c210bfab39 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2445,7 +2445,9 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) { +static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { + + bool cuda_graph_update_required = false; if (cuda_ctx->cuda_graph->instance == nullptr) { cuda_graph_update_required = true; @@ -2603,7 +2605,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, } if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required); + cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, use_cuda_graph); From 5226732fc4e02df883d67759e3b1d718e692c6ac Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 13 Jan 2025 16:44:02 +0100 Subject: [PATCH 11/11] remove double lines between functions --- ggml/src/ggml-cuda/ggml-cuda.cu | 7 ------- 1 file changed, 7 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4a7c210bfab39..a7f6712da3c79 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2284,7 +2284,6 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } - #ifdef USE_CUDA_GRAPH static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) { @@ -2346,7 +2345,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud return use_cuda_graph; } - static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { graph_node_properties->node_address = node->data; graph_node_properties->node_op = node->op; @@ -2398,7 +2396,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } - static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { if (cuda_graph_update_required) { @@ -2444,7 +2441,6 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto } } - static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { bool cuda_graph_update_required = false; @@ -2475,7 +2471,6 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, return cuda_graph_update_required; } - static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { cudaGraphExecUpdateResultInfo result_info; @@ -2496,7 +2491,6 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { } #endif - static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { @@ -2564,7 +2558,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } - static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;