diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index e172ec5c2a69e..0a0d0d4dc10fd 100644 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -8,6 +8,7 @@ - [DataType Supports](#datatype-supports) - [Docker](#docker) - [Linux](#linux) + - [Environment variable setup](#environment-variable-setup) - [TODO](#todo) @@ -281,5 +282,24 @@ cmake --build build --config release Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. +## Environment variable setup + +### GGML_CANN_ASYNC_MODE + +Enables asynchronous operator submission. Disabled by default. + +### GGML_CANN_MEM_POOL + +Specifies the memory pool management strategy: + +- vmm: Utilizes a virtual memory manager pool. If hardware support for VMM is unavailable, falls back to the legacy (leg) memory pool. + +- prio: Employs a priority queue-based memory pool management. +- leg: Uses a fixed-size buffer pool. + +### GGML_CANN_DISABLE_BUF_POOL_CLEAN + +Controls automatic cleanup of the memory pool. This option is only effective when using the prio or leg memory pool strategies. + ## TODO - Support more models and data types. diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 7ef80a4793314..ba2cef0c25fb2 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -37,6 +37,7 @@ #include #include #include +#include #include "../include/ggml-cann.h" #include "../include/ggml.h" @@ -103,6 +104,9 @@ const ggml_cann_device_info& ggml_cann_info(); void ggml_cann_set_device(int32_t device); int32_t ggml_cann_get_device(); +std::optional get_env(const std::string& name); +bool parse_bool(const std::string& value); + /** * @brief Abstract base class for memory pools used by CANN. */ @@ -354,7 +358,8 @@ struct ggml_backend_cann_context { : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) { ggml_cann_set_device(device); description = aclrtGetSocName(); - async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr); + + bool async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF"); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 605b6a73c3a13..360d3ae85f775 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -31,6 +31,8 @@ #include #include #include +#include +#include #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -92,6 +94,26 @@ int32_t ggml_cann_get_device() { return id; } +/** + * @brief Get the value of the specified environment variable (name). + * if not empty, return a std::string object + */ +std::optional get_env(const std::string& name) { + const char* val = std::getenv(name.c_str()); + if (!val) return std::nullopt; + std::string res = std::string(val); + std::transform(res.begin(), res.end(), res.begin(), ::tolower); + return res; +} + +/** + * @brief Verify whether the environment variable is a valid value. + */ +bool parse_bool(const std::string& value) { + std::unordered_set valid_values = {"on", "1", "yes", "y", "enable", "true"}; + return valid_values.find(value) != valid_values.end(); +} + /** * @brief Initialize the CANN device information. * @@ -213,7 +235,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_buf_prio(int device) : device(device) { - disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); } /** @@ -409,7 +431,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_buf(int device) : device(device) { - disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); } /** @@ -730,16 +752,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device( int device) { - bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr); - if (!disable_vmm && ggml_cann_info().devices[device].vmm) { - GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); - return std::unique_ptr(new ggml_cann_pool_vmm(device)); - } - bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr); - if (enable_buf_prio) { + std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); + + if (mem_pool_type == "prio") { GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); return std::unique_ptr(new ggml_cann_pool_buf_prio(device)); } + + if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") { + GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_vmm(device)); + } + GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device); return std::unique_ptr(new ggml_cann_pool_buf(device)); }