diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml new file mode 100644 index 0000000000000..d4176fef9cee2 --- /dev/null +++ b/.github/workflows/build-linux-cross.yml @@ -0,0 +1,121 @@ +name: Build on Linux using cross-compiler +on: + workflow_dispatch: + workflow_call: + +jobs: + ubuntu-latest-riscv64-cpu-cross: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Setup Riscv + run: | + sudo dpkg --add-architecture riscv64 + sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \ + /etc/apt/sources.list /etc/apt/apt-mirrors.txt + sudo apt-get clean + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + build-essential \ + gcc-14-riscv64-linux-gnu \ + g++-14-riscv64-linux-gnu + + - name: Build + run: | + cmake -B build -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + ubuntu-latest-riscv64-vulkan-cross: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Riscv + run: | + sudo dpkg --add-architecture riscv64 + sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \ + /etc/apt/sources.list /etc/apt/apt-mirrors.txt + sudo apt-get clean + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + build-essential \ + glslc \ + gcc-14-riscv64-linux-gnu \ + g++-14-riscv64-linux-gnu \ + libvulkan-dev:riscv64 + + - name: Build + run: | + cmake -B build -DCMAKE_BUILD_TYPE=Release \ + -DGGML_VULKAN=ON \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + ubuntu-latest-arm64-vulkan-cross: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Arm64 + run: | + sudo dpkg --add-architecture arm64 + sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \ + /etc/apt/sources.list /etc/apt/apt-mirrors.txt + sudo apt-get clean + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + build-essential \ + glslc \ + crossbuild-essential-arm64 \ + libvulkan-dev:arm64 + + - name: Build + run: | + cmake -B build -DCMAKE_BUILD_TYPE=Release \ + -DGGML_VULKAN=ON \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=aarch64 \ + -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \ + -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7db85528659d3..dc2a4c0566840 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,7 +10,7 @@ on: push: branches: - master - paths: ['.github/workflows/build.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] + paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] pull_request: types: [opened, synchronize, reopened] paths: ['.github/workflows/build.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] @@ -606,6 +606,9 @@ jobs: -DGGML_SYCL_F16=ON cmake --build build --config Release -j $(nproc) + build-linux-cross: + uses: ./.github/workflows/build-linux-cross.yml + macOS-latest-cmake-ios: runs-on: macos-latest @@ -803,7 +806,7 @@ jobs: env: OPENBLAS_VERSION: 0.3.23 SDE_VERSION: 9.33.0-2024-01-07 - VULKAN_VERSION: 1.4.304.1 + VULKAN_VERSION: 1.4.309.0 strategy: matrix: diff --git a/README.md b/README.md index 1eec944f273a8..95a05e6ed75c1 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,8 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM) - [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1) - [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct) +- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview) +- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32) #### Multimodal @@ -528,6 +530,35 @@ If your issue is with model generation quality, then please at least scan the fo - [Aligning language models to follow instructions](https://openai.com/research/instruction-following) - [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) +## XCFramework +The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS, +and macOS. It can be used in Swift projects without the need to compile the +library from source. For example: +```swift +// swift-tools-version: 5.10 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "MyLlamaPackage", + targets: [ + .executableTarget( + name: "MyLlamaPackage", + dependencies: [ + "LlamaFramework" + ]), + .binaryTarget( + name: "LlamaFramework", + url: "https://github.com/ggml-org/llama.cpp/releases/download/b5046/llama-b5046-xcframework.zip", + checksum: "c19be78b5f00d8d29a25da41042cb7afa094cbf6280a225abe614b03b20029ab" + ) + ] +) +``` +The above example is using an intermediate build `b5046` of the library. This can be modified +to use a different version by changing the URL and checksum. + ## Completions Command-line completion is available for some environments. diff --git a/ci/README.md b/ci/README.md index db4d9066816e8..ec3f44350394a 100644 --- a/ci/README.md +++ b/ci/README.md @@ -60,7 +60,7 @@ docker run --privileged -it \ Inside the container, execute the following commands: ```bash -apt update -y && apt install -y cmake git python3.10-venv wget +apt update -y && apt install -y bc cmake ccache git python3.10-venv time unzip wget git config --global --add safe.directory /ws GG_BUILD_MUSA=1 bash ./ci/run.sh /ci-results /ci-cache ``` diff --git a/ci/run.sh b/ci/run.sh index efc24391d2e7e..ebd662b38bf88 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -59,6 +59,8 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then export ONEAPI_DEVICE_SELECTOR="level_zero:0" # Enable sysman for correct memory reporting export ZES_ENABLE_SYSMAN=1 + # to circumvent precision issues on CPY operations + export SYCL_PROGRAM_COMPILE_OPTIONS="-cl-fp32-correctly-rounded-divide-sqrt" CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON" fi @@ -69,7 +71,7 @@ fi if [ ! -z ${GG_BUILD_MUSA} ]; then # Use qy1 by default (MTT S80) MUSA_ARCH=${MUSA_ARCH:-21} - CMAKE_EXTRA="-DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}" fi ## helpers diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 17146fffc1168..829eb5b7238b9 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -114,8 +114,8 @@ if (LLAMA_LLGUIDANCE) ExternalProject_Add(llguidance_ext GIT_REPOSITORY https://github.com/guidance-ai/llguidance - # v0.6.12: - GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09 + # v0.7.10: + GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8 PREFIX ${CMAKE_BINARY_DIR}/llguidance SOURCE_DIR ${LLGUIDANCE_SRC} BUILD_IN_SOURCE TRUE diff --git a/common/arg.cpp b/common/arg.cpp index b6bfe6f89bead..cdf8970254446 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1,12 +1,24 @@ +#include "gguf.h" // for reading GGUF splits #include "arg.h" +#include "common.h" #include "log.h" #include "sampling.h" #include "chat.h" +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + #include #include #include +#include #include #include #include @@ -14,6 +26,14 @@ #include #include +//#define LLAMA_USE_CURL + +#if defined(LLAMA_USE_CURL) +#include +#include +#include +#endif + #include "json-schema-to-grammar.h" using json = nlohmann::ordered_json; @@ -126,46 +146,552 @@ std::string common_arg::to_string() { } // -// utils +// downloader // -static void common_params_handle_model_default( - std::string & model, - const std::string & model_url, - std::string & hf_repo, - std::string & hf_file, - const std::string & hf_token, - const std::string & model_default) { - if (!hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (hf_file.empty()) { - if (model.empty()) { - auto auto_detected = common_get_hf_file(hf_repo, hf_token); - if (auto_detected.first.empty() || auto_detected.second.empty()) { - exit(1); // built without CURL, error message already printed +struct common_hf_file_res { + std::string repo; // repo name with ":tag" removed + std::string ggufFile; + std::string mmprojFile; +}; + +#ifdef LLAMA_USE_CURL + +#ifdef __linux__ +#include +#elif defined(_WIN32) +# if !defined(PATH_MAX) +# define PATH_MAX MAX_PATH +# endif +#else +#include +#endif +#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; + +#define CURL_MAX_RETRY 3 +#define CURL_RETRY_DELAY_SECONDS 2 + +static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) { + int remaining_attempts = max_attempts; + + while (remaining_attempts > 0) { + LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); + + CURLcode res = curl_easy_perform(curl); + if (res == CURLE_OK) { + return true; + } + + int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; + LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); + + remaining_attempts--; + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } + + LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); + + return false; +} + +// download one single file from remote URL to local path +static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token) { + // Initialize libcurl + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + if (!curl) { + LOG_ERR("%s: error initializing libcurl\n", __func__); + return false; + } + + bool force_download = false; + + // Set the URL, allow to follow http redirection + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + + // Check if hf-token or bearer-token was specified + if (!bearer_token.empty()) { + std::string auth_header = "Authorization: Bearer " + bearer_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + } + +#if defined(_WIN32) + // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of + // operating system. Currently implemented under MS-Windows. + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + + // Check if the file already exists locally + auto file_exists = std::filesystem::exists(path); + + // If the file exists, check its JSON metadata companion file. + std::string metadata_path = path + ".json"; + nlohmann::json metadata; + std::string etag; + std::string last_modified; + + if (file_exists) { + // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). + std::ifstream metadata_in(metadata_path); + if (metadata_in.good()) { + try { + metadata_in >> metadata; + LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + if (metadata.contains("url") && metadata.at("url").is_string()) { + auto previous_url = metadata.at("url").get(); + if (previous_url != url) { + LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); + return false; + } } - hf_repo = auto_detected.first; - hf_file = auto_detected.second; - } else { - hf_file = model; + if (metadata.contains("etag") && metadata.at("etag").is_string()) { + etag = metadata.at("etag"); + } + if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { + last_modified = metadata.at("lastModified"); + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + return false; + } + } + } else { + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + // Send a HEAD request to retrieve the etag and last-modified headers + struct common_load_model_from_url_headers { + std::string etag; + std::string last_modified; + }; + + common_load_model_from_url_headers headers; + + { + typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { + common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; + + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); + + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; + } + } + return n_items; + }; + + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); + + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { + return false; + } + + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code != 200) { + // HEAD not supported, we don't know if the file has changed + // force trigger downloading + force_download = true; + LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + } + } + + bool should_download = !file_exists || force_download; + if (!should_download) { + if (!etag.empty() && etag != headers.etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); + should_download = true; + } else if (!last_modified.empty() && last_modified != headers.last_modified) { + LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); + should_download = true; + } + } + if (should_download) { + std::string path_temporary = path + ".downloadInProgress"; + if (file_exists) { + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return false; } } - // make sure model path is present (for caching purposes) - if (model.empty()) { - // this is to avoid different repo having same file name, or same file name in different subdirs - std::string filename = hf_repo + "_" + hf_file; - // to make sure we don't have any slashes in the filename - string_replace_all(filename, "/", "_"); - model = fs_get_cache_file(filename); + + // Set the output file + + struct FILE_deleter { + void operator()(FILE * f) const { + fclose(f); + } + }; + + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); + if (!outfile) { + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); + return false; } - } else if (!model_url.empty()) { - if (model.empty()) { - auto f = string_split(model_url, '#').front(); - f = string_split(f, '?').front(); - model = fs_get_cache_file(string_split(f, '/').back()); + + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); + auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { + return fwrite(data, size, nmemb, (FILE *)fd); + }; + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); + + // display download progress + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); + + // helper function to hide password in URL + auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { + std::size_t protocol_pos = url.find("://"); + if (protocol_pos == std::string::npos) { + return url; // Malformed URL + } + + std::size_t at_pos = url.find('@', protocol_pos + 3); + if (at_pos == std::string::npos) { + return url; // No password in URL + } + + return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos); + }; + + // start the download + LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, + llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { + return false; + } + + long http_code = 0; + curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code < 200 || http_code >= 400) { + LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); + return false; + } + + // Causes file to be closed explicitly here before we rename it. + outfile.reset(); + + // Write the updated JSON metadata file. + metadata.update({ + {"url", url}, + {"etag", headers.etag}, + {"lastModified", headers.last_modified} + }); + std::ofstream(metadata_path) << metadata.dump(4); + LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + + if (rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return false; + } + } + + return true; +} + +// download multiple files from remote URLs to local paths +// the input is a vector of pairs +static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token) { + // Prepare download in parallel + std::vector> futures_download; + for (auto const & item : urls) { + futures_download.push_back(std::async(std::launch::async, [bearer_token](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token); + }, item)); + } + + // Wait for all downloads to complete + for (auto & f : futures_download) { + if (!f.get()) { + return false; + } + } + + return true; +} + +static bool common_download_model( + const common_params_model & model, + const std::string & bearer_token) { + // Basic validation of the model.url + if (model.url.empty()) { + LOG_ERR("%s: invalid model url\n", __func__); + return false; + } + + if (!common_download_file_single(model.url, model.path, bearer_token)) { + return false; + } + + // check for additional GGUFs split to download + int n_split = 0; + { + struct gguf_init_params gguf_params = { + /*.no_alloc = */ true, + /*.ctx = */ NULL, + }; + auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params); + if (!ctx_gguf) { + LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str()); + return false; + } + + auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); + if (key_n_split >= 0) { + n_split = gguf_get_val_u16(ctx_gguf, key_n_split); + } + + gguf_free(ctx_gguf); + } + + if (n_split > 1) { + char split_prefix[PATH_MAX] = {0}; + char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0}; + + // Verify the first split file format + // and extract split URL and PATH prefixes + { + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split); + return false; + } + + if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split); + return false; + } + } + + std::vector> urls; + for (int idx = 1; idx < n_split; idx++) { + char split_path[PATH_MAX] = {0}; + llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); + + char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; + llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split); + + if (std::string(split_path) == model.path) { + continue; // skip the already downloaded file + } + + urls.push_back({split_url, split_path}); + } + + // Download in parallel + common_download_file_multiple(urls, bearer_token); + } + + return true; +} + +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + // fetch model info from Hugging Face Hub API + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::string res_str; + std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + static_cast(data)->append((char * ) ptr, size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (!bearer_token.empty()) { + std::string auth_header = "Authorization: Bearer " + bearer_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + throw std::runtime_error("error: cannot make GET request to HF API"); + } + + long res_code; + std::string ggufFile = ""; + std::string mmprojFile = ""; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + if (res_code == 200) { + // extract ggufFile.rfilename in json, using regex + { + std::regex pattern("\"ggufFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\""); + std::smatch match; + if (std::regex_search(res_str, match, pattern)) { + ggufFile = match[1].str(); + } + } + // extract mmprojFile.rfilename in json, using regex + { + std::regex pattern("\"mmprojFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\""); + std::smatch match; + if (std::regex_search(res_str, match, pattern)) { + mmprojFile = match[1].str(); + } + } + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + + // check response + if (ggufFile.empty()) { + throw std::runtime_error("error: model does not have ggufFile"); + } + + return { hf_repo, ggufFile, mmprojFile }; +} + +#else + +static bool common_download_file_single(const std::string &, const std::string &, const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from internet\n"); + return false; +} + +static bool common_download_file_multiple(const std::vector> &, const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return false; +} + +static bool common_download_model( + const common_params_model &, + const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return false; +} + +static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return {}; +} + +#endif // LLAMA_USE_CURL + +// +// utils +// + +static void common_params_handle_model( + struct common_params_model & model, + const std::string & bearer_token, + const std::string & model_path_default, + bool is_mmproj = false) { // TODO: move is_mmproj to an enum when we have more files? + // handle pre-fill default model path and url based on hf_repo and hf_file + { + if (!model.hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (model.hf_file.empty()) { + if (model.path.empty()) { + auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token); + if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { + exit(1); // built without CURL, error message already printed + } + model.hf_repo = auto_detected.repo; + model.hf_file = is_mmproj ? auto_detected.mmprojFile : auto_detected.ggufFile; + } else { + model.hf_file = model.path; + } + } + + std::string hf_endpoint = "https://huggingface.co/"; + const char * hf_endpoint_env = getenv("HF_ENDPOINT"); + if (hf_endpoint_env) { + hf_endpoint = hf_endpoint_env; + if (hf_endpoint.back() != '/') hf_endpoint += '/'; + } + model.url = hf_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; + // make sure model path is present (for caching purposes) + if (model.path.empty()) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filename = model.hf_repo + "_" + model.hf_file; + // to make sure we don't have any slashes in the filename + string_replace_all(filename, "/", "_"); + model.path = fs_get_cache_file(filename); + } + + } else if (!model.url.empty()) { + if (model.path.empty()) { + auto f = string_split(model.url, '#').front(); + f = string_split(f, '?').front(); + model.path = fs_get_cache_file(string_split(f, '/').back()); + } + + } else if (model.path.empty()) { + model.path = model_path_default; + } + } + + // then, download it if needed + if (!model.url.empty()) { + bool ok = common_download_model(model, bearer_token); + if (!ok) { + LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); + exit(1); } - } else if (model.empty()) { - model = model_default; } } @@ -300,10 +826,16 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - // TODO: refactor model params in a common struct - common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH); - common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, ""); - common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, ""); + common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH); + common_params_handle_model(params.speculative.model, params.hf_token, ""); + common_params_handle_model(params.vocoder.model, params.hf_token, ""); + + // allow --mmproj to be set from -hf + // assuming that mmproj is always in the same repo as text model + if (!params.model.hf_repo.empty() && ctx_arg.ex == LLAMA_EXAMPLE_LLAVA) { + params.mmproj.hf_repo = params.model.hf_repo; + } + common_params_handle_model(params.mmproj, params.hf_token, "", true); if (params.escape) { string_process_escapes(params.prompt); @@ -322,6 +854,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context params.kv_overrides.back().key[0] = 0; } + if (!params.tensor_buft_overrides.empty()) { + params.tensor_buft_overrides.push_back({nullptr, nullptr}); + } + if (params.reranking && params.embedding) { throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); } @@ -1561,7 +2097,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--mmproj"}, "FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md", [](common_params & params, const std::string & value) { - params.mmproj = value; + params.mmproj.path = value; + } + ).set_examples({LLAMA_EXAMPLE_LLAVA})); + add_opt(common_arg( + {"--mmproj-url"}, "URL", + "URL to a multimodal projector file for LLaVA. see examples/llava/README.md", + [](common_params & params, const std::string & value) { + params.mmproj.url = value; } ).set_examples({LLAMA_EXAMPLE_LLAVA})); add_opt(common_arg( @@ -1647,6 +2190,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex exit(0); } )); + add_opt(common_arg( + {"--override-tensor", "-ot"}, "=,...", + "override tensor buffer type", [](common_params & params, const std::string & value) { + /* static */ std::map buft_list; + if (buft_list.empty()) { + // enumerate all the devices and add their buffer types to the list + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } + } + } + + for (const auto & override : string_split(value, ',')) { + std::string::size_type pos = override.find('='); + if (pos == std::string::npos) { + throw std::invalid_argument("invalid value"); + } + std::string tensor_name = override.substr(0, pos); + std::string buffer_type = override.substr(pos + 1); + + if (buft_list.find(buffer_type) == buft_list.end()) { + printf("Available buffer types:\n"); + for (const auto & it : buft_list) { + printf(" %s\n", ggml_backend_buft_name(it.second)); + } + throw std::invalid_argument("unknown buffer type"); + } + // FIXME: this leaks memory + params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + } + } + )); add_opt(common_arg( {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", "number of layers to store in VRAM", @@ -1790,14 +2368,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH ), [](common_params & params, const std::string & value) { - params.model = value; + params.model.path = value; } ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); add_opt(common_arg( {"-mu", "--model-url"}, "MODEL_URL", "model download url (default: unused)", [](common_params & params, const std::string & value) { - params.model_url = value; + params.model.url = value; } ).set_env("LLAMA_ARG_MODEL_URL")); add_opt(common_arg( @@ -1806,35 +2384,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "example: unsloth/phi-4-GGUF:q4_k_m\n" "(default: unused)", [](common_params & params, const std::string & value) { - params.hf_repo = value; + params.model.hf_repo = value; } ).set_env("LLAMA_ARG_HF_REPO")); add_opt(common_arg( {"-hfd", "-hfrd", "--hf-repo-draft"}, "/[:quant]", "Same as --hf-repo, but for the draft model (default: unused)", [](common_params & params, const std::string & value) { - params.speculative.hf_repo = value; + params.speculative.model.hf_repo = value; } ).set_env("LLAMA_ARG_HFD_REPO")); add_opt(common_arg( {"-hff", "--hf-file"}, "FILE", "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", [](common_params & params, const std::string & value) { - params.hf_file = value; + params.model.hf_file = value; } ).set_env("LLAMA_ARG_HF_FILE")); add_opt(common_arg( {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", "Hugging Face model repository for the vocoder model (default: unused)", [](common_params & params, const std::string & value) { - params.vocoder.hf_repo = value; + params.vocoder.model.hf_repo = value; } ).set_env("LLAMA_ARG_HF_REPO_V")); add_opt(common_arg( {"-hffv", "--hf-file-v"}, "FILE", "Hugging Face model file for the vocoder model (default: unused)", [](common_params & params, const std::string & value) { - params.vocoder.hf_file = value; + params.vocoder.model.hf_file = value; } ).set_env("LLAMA_ARG_HF_FILE_V")); add_opt(common_arg( @@ -1979,7 +2557,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--host"}, "HOST", - string_format("ip address to listen (default: %s)", params.hostname.c_str()), + string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()), [](common_params & params, const std::string & value) { params.hostname = value; } @@ -2454,7 +3032,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-md", "--model-draft"}, "FNAME", "draft model for speculative decoding (default: unused)", [](common_params & params, const std::string & value) { - params.speculative.model = value; + params.speculative.model.path = value; } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT")); @@ -2462,7 +3040,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-mv", "--model-vocoder"}, "FNAME", "vocoder model for audio generation (default: unused)", [](common_params & params, const std::string & value) { - params.vocoder.model = value; + params.vocoder.model.path = value; } ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( @@ -2485,10 +3063,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--tts-oute-default"}, string_format("use default OuteTTS models (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF"; - params.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf"; - params.vocoder.hf_repo = "ggml-org/WavTokenizer"; - params.vocoder.hf_file = "WavTokenizer-Large-75-F16.gguf"; + params.model.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF"; + params.model.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf"; + params.vocoder.model.hf_repo = "ggml-org/WavTokenizer"; + params.vocoder.model.hf_file = "WavTokenizer-Large-75-F16.gguf"; } ).set_examples({LLAMA_EXAMPLE_TTS})); @@ -2496,8 +3074,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--embd-bge-small-en-default"}, string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; - params.hf_file = "bge-small-en-v1.5-q8_0.gguf"; + params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; + params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf"; params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; @@ -2510,8 +3088,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--embd-e5-small-en-default"}, string_format("use default e5-small-v2 model (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; - params.hf_file = "e5-small-v2-q8_0.gguf"; + params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; + params.model.hf_file = "e5-small-v2-q8_0.gguf"; params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; @@ -2524,8 +3102,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--embd-gte-small-default"}, string_format("use default gte-small model (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; - params.hf_file = "gte-small-q8_0.gguf"; + params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; + params.model.hf_file = "gte-small-q8_0.gguf"; params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; @@ -2538,8 +3116,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-1.5b-default"}, string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf"; params.port = 8012; params.n_gpu_layers = 99; params.flash_attn = true; @@ -2554,8 +3132,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-3b-default"}, string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-3b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf"; params.port = 8012; params.n_gpu_layers = 99; params.flash_attn = true; @@ -2570,8 +3148,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-7b-default"}, string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; params.port = 8012; params.n_gpu_layers = 99; params.flash_attn = true; @@ -2586,10 +3164,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-7b-spec"}, string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; - params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; - params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; + params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.speculative.n_gpu_layers = 99; params.port = 8012; params.n_gpu_layers = 99; @@ -2605,10 +3183,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-14b-spec"}, string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; - params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; - params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; + params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.speculative.n_gpu_layers = 99; params.port = 8012; params.n_gpu_layers = 99; diff --git a/common/common.cpp b/common/common.cpp index ab1c55d53e133..6352ef8f9d037 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -7,9 +7,6 @@ #include "common.h" #include "log.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" #include "llama.h" #include @@ -51,47 +48,11 @@ #include #include #endif -#if defined(LLAMA_USE_CURL) -#include -#include -#include -#endif #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif -#if defined(LLAMA_USE_CURL) -#ifdef __linux__ -#include -#elif defined(_WIN32) -# if !defined(PATH_MAX) -# define PATH_MAX MAX_PATH -# endif -#else -#include -#endif -#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 - -// -// CURL utils -// - -using curl_ptr = std::unique_ptr; - -// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one -struct curl_slist_ptr { - struct curl_slist * ptr = nullptr; - ~curl_slist_ptr() { - if (ptr) { - curl_slist_free_all(ptr); - } - } -}; -#endif // LLAMA_USE_CURL - -using json = nlohmann::ordered_json; - // // CPU utils // @@ -900,22 +861,14 @@ std::string fs_get_cache_file(const std::string & filename) { // // Model utils // + struct common_init_result common_init_from_params(common_params & params) { common_init_result iparams; auto mparams = common_model_params_to_llama(params); - llama_model * model = nullptr; - - if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams); - } else if (!params.model_url.empty()) { - model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); - } else { - model = llama_model_load_from_file(params.model.c_str(), mparams); - } - + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { - LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); return iparams; } @@ -950,7 +903,7 @@ struct common_init_result common_init_from_params(common_params & params) { llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { - LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); llama_model_free(model); return iparams; } @@ -1092,15 +1045,18 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.devices = params.devices.data(); } + if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } + mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; + if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -1108,6 +1064,13 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.kv_overrides = params.kv_overrides.data(); } + if (params.tensor_buft_overrides.empty()) { + mparams.tensor_buft_overrides = NULL; + } else { + GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern"); + mparams.tensor_buft_overrides = params.tensor_buft_overrides.data(); + } + return mparams; } @@ -1167,451 +1130,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p return tpp; } -#ifdef LLAMA_USE_CURL - -#define CURL_MAX_RETRY 3 -#define CURL_RETRY_DELAY_SECONDS 2 - -static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) { - int remaining_attempts = max_attempts; - - while (remaining_attempts > 0) { - LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); - - CURLcode res = curl_easy_perform(curl); - if (res == CURLE_OK) { - return true; - } - - int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; - LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); - - remaining_attempts--; - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } - - LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); - - return false; -} - -static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { - // Initialize libcurl - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - if (!curl) { - LOG_ERR("%s: error initializing libcurl\n", __func__); - return false; - } - - bool force_download = false; - - // Set the URL, allow to follow http redirection - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); - - // Check if hf-token or bearer-token was specified - if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer " + hf_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); - } - -#if defined(_WIN32) - // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of - // operating system. Currently implemented under MS-Windows. - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - - // Check if the file already exists locally - auto file_exists = std::filesystem::exists(path); - - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; - nlohmann::json metadata; - std::string etag; - std::string last_modified; - - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); - if (metadata.contains("url") && metadata.at("url").is_string()) { - auto previous_url = metadata.at("url").get(); - if (previous_url != url) { - LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); - return false; - } - } - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - return false; - } - } - } else { - LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); - } - - // Send a HEAD request to retrieve the etag and last-modified headers - struct common_load_model_from_url_headers { - std::string etag; - std::string last_modified; - }; - - common_load_model_from_url_headers headers; - - { - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); - auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { - common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - - static std::regex header_regex("([^:]+): (.*)\r\n"); - static std::regex etag_regex("ETag", std::regex_constants::icase); - static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - - std::string header(buffer, n_items); - std::smatch match; - if (std::regex_match(header, match, header_regex)) { - const std::string & key = match[1]; - const std::string & value = match[2]; - if (std::regex_match(key, match, etag_regex)) { - headers->etag = value; - } else if (std::regex_match(key, match, last_modified_regex)) { - headers->last_modified = value; - } - } - return n_items; - }; - - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); - if (!was_perform_successful) { - return false; - } - - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code != 200) { - // HEAD not supported, we don't know if the file has changed - // force trigger downloading - force_download = true; - LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); - } - } - - bool should_download = !file_exists || force_download; - if (!should_download) { - if (!etag.empty() && etag != headers.etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); - should_download = true; - } else if (!last_modified.empty() && last_modified != headers.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); - should_download = true; - } - } - if (should_download) { - std::string path_temporary = path + ".downloadInProgress"; - if (file_exists) { - LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; - } - } - - // Set the output file - - struct FILE_deleter { - void operator()(FILE * f) const { - fclose(f); - } - }; - - std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); - if (!outfile) { - LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); - return false; - } - - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); - auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { - return fwrite(data, size, nmemb, (FILE *)fd); - }; - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); - - // display download progress - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); - - // helper function to hide password in URL - auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { - std::size_t protocol_pos = url.find("://"); - if (protocol_pos == std::string::npos) { - return url; // Malformed URL - } - - std::size_t at_pos = url.find('@', protocol_pos + 3); - if (at_pos == std::string::npos) { - return url; // No password in URL - } - - return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos); - }; - - // start the download - LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, - llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); - if (!was_perform_successful) { - return false; - } - - long http_code = 0; - curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code < 200 || http_code >= 400) { - LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); - return false; - } - - // Causes file to be closed explicitly here before we rename it. - outfile.reset(); - - // Write the updated JSON metadata file. - metadata.update({ - {"url", url}, - {"etag", headers.etag}, - {"lastModified", headers.last_modified} - }); - std::ofstream(metadata_path) << metadata.dump(4); - LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); - - if (rename(path_temporary.c_str(), path.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return false; - } - } - - return true; -} - -struct llama_model * common_load_model_from_url( - const std::string & model_url, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params) { - // Basic validation of the model_url - if (model_url.empty()) { - LOG_ERR("%s: invalid model_url\n", __func__); - return NULL; - } - - if (!common_download_file(model_url, local_path, hf_token)) { - return NULL; - } - - // check for additional GGUFs split to download - int n_split = 0; - { - struct gguf_init_params gguf_params = { - /*.no_alloc = */ true, - /*.ctx = */ NULL, - }; - auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params); - if (!ctx_gguf) { - LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str()); - return NULL; - } - - auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); - if (key_n_split >= 0) { - n_split = gguf_get_val_u16(ctx_gguf, key_n_split); - } - - gguf_free(ctx_gguf); - } - - if (n_split > 1) { - char split_prefix[PATH_MAX] = {0}; - char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0}; - - // Verify the first split file format - // and extract split URL and PATH prefixes - { - if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split); - return NULL; - } - - if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split); - return NULL; - } - } - - // Prepare download in parallel - std::vector> futures_download; - for (int idx = 1; idx < n_split; idx++) { - futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool { - char split_path[PATH_MAX] = {0}; - llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split); - - char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; - llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - - return common_download_file(split_url, split_path, hf_token); - }, idx)); - } - - // Wait for all downloads to complete - for (auto & f : futures_download) { - if (!f.get()) { - return NULL; - } - } - } - - return llama_model_load_from_file(local_path.c_str(), params); -} - -struct llama_model * common_load_model_from_hf( - const std::string & repo, - const std::string & remote_path, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params) { - // construct hugging face model url: - // - // --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf - // https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf - // - // --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf - // https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf - // - - std::string model_url = "https://huggingface.co/"; - model_url += repo; - model_url += "/resolve/main/"; - model_url += remote_path; - - return common_load_model_from_url(model_url, local_path, hf_token, params); -} - -/** - * Allow getting the HF file from the HF repo with tag (like ollama), for example: - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 - * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s - * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) - * - * Return pair of (with "repo" already having tag removed) - * - * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. - */ -std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { - auto parts = string_split(hf_repo_with_tag, ':'); - std::string tag = parts.size() > 1 ? parts.back() : "latest"; - std::string hf_repo = parts[0]; - if (string_split(hf_repo, '/').size() != 2) { - throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); - } - - // fetch model info from Hugging Face Hub API - json model_info; - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - std::string res_str; - std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); - auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { - static_cast(data)->append((char * ) ptr, size * nmemb); - return size * nmemb; - }; - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); -#if defined(_WIN32) - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer " + hf_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - } - // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); - - CURLcode res = curl_easy_perform(curl.get()); - - if (res != CURLE_OK) { - throw std::runtime_error("error: cannot make GET request to HF API"); - } - - long res_code; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); - if (res_code == 200) { - model_info = json::parse(res_str); - } else if (res_code == 401) { - throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); - } else { - throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); - } - - // check response - if (!model_info.contains("ggufFile")) { - throw std::runtime_error("error: model does not have ggufFile"); - } - json & gguf_file = model_info.at("ggufFile"); - if (!gguf_file.contains("rfilename")) { - throw std::runtime_error("error: ggufFile does not have rfilename"); - } - - return std::make_pair(hf_repo, gguf_file.at("rfilename")); -} - -#else - -struct llama_model * common_load_model_from_url( - const std::string & /*model_url*/, - const std::string & /*local_path*/, - const std::string & /*hf_token*/, - const struct llama_model_params & /*params*/) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); - return nullptr; -} - -struct llama_model * common_load_model_from_hf( - const std::string & /*repo*/, - const std::string & /*remote_path*/, - const std::string & /*local_path*/, - const std::string & /*hf_token*/, - const struct llama_model_params & /*params*/) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); - return nullptr; -} - -std::pair common_get_hf_file(const std::string &, const std::string &) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); - return std::make_pair("", ""); -} - -#endif // LLAMA_USE_CURL - // // Batch utils // @@ -2035,26 +1553,3 @@ common_control_vector_data common_control_vector_load(const std::vector -json common_grammar_trigger::to_json() const { - json out { - {"type", (int) type}, - {"value", value}, - }; - if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out["token"] = (int) token; - } - return out; -} - -template <> -common_grammar_trigger common_grammar_trigger::from_json(const json & in) { - common_grammar_trigger out; - out.type = (common_grammar_trigger_type) in.at("type").get(); - out.value = in.at("value").get(); - if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out.token = (llama_token) in.at("token").get(); - } - return out; -} diff --git a/common/common.h b/common/common.h index 1c0f199774976..725b5123d24f9 100644 --- a/common/common.h +++ b/common/common.h @@ -121,10 +121,6 @@ struct common_grammar_trigger { common_grammar_trigger_type type; std::string value; llama_token token = LLAMA_TOKEN_NULL; - - // T can only be nlohmann::ordered_json - template T to_json() const; - template static common_grammar_trigger from_json(const T & in); }; // sampling parameters @@ -184,6 +180,13 @@ struct common_params_sampling { std::string print() const; }; +struct common_params_model { + std::string path = ""; // model local path // NOLINT + std::string url = ""; // model url to download // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT +}; + struct common_params_speculative { std::vector devices; // devices to use for offloading @@ -197,19 +200,11 @@ struct common_params_speculative { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT - - std::string model = ""; // draft model for speculative decoding // NOLINT - std::string model_url = ""; // model url to download // NOLINT + struct common_params_model model; }; struct common_params_vocoder { - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT - - std::string model = ""; // model path // NOLINT - std::string model_url = ""; // model url to download // NOLINT + struct common_params_model model; std::string speaker_file = ""; // speaker file path // NOLINT @@ -267,12 +262,10 @@ struct common_params { struct common_params_speculative speculative; struct common_params_vocoder vocoder; - std::string model = ""; // model path // NOLINT + struct common_params_model model; + std::string model_alias = ""; // model alias // NOLINT - std::string model_url = ""; // model url to download // NOLINT std::string hf_token = ""; // HF token // NOLINT - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT std::string prompt = ""; // NOLINT std::string system_prompt = ""; // NOLINT std::string prompt_file = ""; // store the external prompt file name // NOLINT @@ -286,6 +279,7 @@ struct common_params { std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; + std::vector tensor_buft_overrides; bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) std::vector lora_adapters; // lora adapter path with user defined scale @@ -347,7 +341,7 @@ struct common_params { common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector // NOLINT + struct common_params_model mmproj; std::vector image; // path to image file(s) // embedding @@ -546,23 +540,6 @@ struct llama_model_params common_model_params_to_llama ( common_params struct llama_context_params common_context_params_to_llama(const common_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); -struct llama_model * common_load_model_from_url( - const std::string & model_url, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params); - -struct llama_model * common_load_model_from_hf( - const std::string & repo, - const std::string & remote_path, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params); - -std::pair common_get_hf_file( - const std::string & hf_repo_with_tag, - const std::string & hf_token); - // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 2feeb93c87e30..8bff89ea4aa9e 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -11,25 +11,24 @@ struct llama_sampler_llg { std::string grammar_kind; std::string grammar_data; LlgTokenizer * tokenizer; - LlgConstraint * grammar; - LlgMaskResult llg_res; - bool has_llg_res; + LlgMatcher * grammar; }; -static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, - const char * grammar_data) { +static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, + const char * grammar_data) { LlgConstraintInit cinit; llg_constraint_init_set_defaults(&cinit, tokenizer); const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL"); if (log_level && *log_level) { cinit.log_stderr_level = atoi(log_level); } - auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data); - if (llg_get_error(c)) { - LOG_ERR("llg error: %s\n", llg_get_error(c)); - llg_free_constraint(c); + auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data); + if (llg_matcher_get_error(c)) { + LOG_ERR("llg error: %s\n", llg_matcher_get_error(c)); + llg_free_matcher(c); return nullptr; } + return c; } @@ -40,39 +39,29 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) { static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) { auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - LlgCommitResult res; - llg_commit_token(ctx->grammar, token, &res); - ctx->has_llg_res = false; + llg_matcher_consume_token(ctx->grammar, token); } } static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - if (!ctx->has_llg_res) { - if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) { - ctx->has_llg_res = true; + const uint32_t * mask = llg_matcher_get_mask(ctx->grammar); + if (mask == nullptr) { + if (llg_matcher_compute_mask(ctx->grammar) == 0) { + mask = llg_matcher_get_mask(ctx->grammar); } else { - LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar)); - llg_free_constraint(ctx->grammar); + LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar)); + llg_free_matcher(ctx->grammar); ctx->grammar = nullptr; + return; } } - if (ctx->has_llg_res) { - if (ctx->llg_res.is_stop) { - for (size_t i = 0; i < cur_p->size; ++i) { - if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) { - cur_p->data[i].logit = -INFINITY; - } - } - } else { - const uint32_t * mask = ctx->llg_res.sample_mask; - for (size_t i = 0; i < cur_p->size; ++i) { - auto token = cur_p->data[i].id; - if ((mask[token / 32] & (1 << (token % 32))) == 0) { - cur_p->data[i].logit = -INFINITY; - } - } + + for (size_t i = 0; i < cur_p->size; ++i) { + auto token = cur_p->data[i].id; + if ((mask[token / 32] & (1 << (token % 32))) == 0) { + cur_p->data[i].logit = -INFINITY; } } } @@ -80,14 +69,9 @@ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array static void llama_sampler_llg_reset(llama_sampler * smpl) { auto * ctx = (llama_sampler_llg *) smpl->ctx; - if (!ctx->grammar) { - return; + if (ctx->grammar) { + llg_matcher_reset(ctx->grammar); } - - auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str()); - llg_free_constraint(ctx->grammar); - ctx->grammar = grammar_new; - ctx->has_llg_res = false; } static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { @@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { if (ctx->grammar) { result_ctx->grammar_kind = ctx->grammar_kind; result_ctx->grammar_data = ctx->grammar_data; - result_ctx->grammar = llg_clone_constraint(ctx->grammar); + result_ctx->grammar = llg_clone_matcher(ctx->grammar); result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer); } } @@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) { const auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - llg_free_constraint(ctx->grammar); + llg_free_matcher(ctx->grammar); llg_free_tokenizer(ctx->tokenizer); } @@ -239,9 +223,11 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g /* .grammar_data = */ grammar_data, /* .tokenizer = */ tokenizer, /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data), - /* .llg_res = */ {}, - /* .has_llg_res = */ false, }; + if (ctx->grammar) { + GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 == + llg_matcher_get_mask_byte_size(ctx->grammar)); + } } else { *ctx = { /* .vocab = */ vocab, @@ -249,15 +235,12 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g /* .grammar_data = */ {}, /* .tokenizer = */ nullptr, /* .grammar = */ nullptr, - /* .llg_res = */ {}, - /* .has_llg_res = */ false, }; } return llama_sampler_init( /* .iface = */ &llama_sampler_llg_i, - /* .ctx = */ ctx - ); + /* .ctx = */ ctx); } #else diff --git a/common/minja/chat-template.hpp b/common/minja/chat-template.hpp index 6e5b00f5e95ce..ef2b29cf09ef6 100644 --- a/common/minja/chat-template.hpp +++ b/common/minja/chat-template.hpp @@ -9,9 +9,16 @@ #pragma once #include "minja.hpp" -#include "json.hpp" +#include +#include +#include +#include +#include +#include +#include #include #include +#include using json = nlohmann::ordered_json; @@ -425,7 +432,7 @@ class chat_template { auto obj = json { {"tool_calls", tool_calls}, }; - if (!content.is_null() && content != "") { + if (!content.is_null() && !content.empty()) { obj["content"] = content; } message["content"] = obj.dump(2); @@ -435,13 +442,12 @@ class chat_template { if (polyfill_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { - {"tool_response", { - {"content", message.at("content")}, - }}, + {"tool_response", json::object()}, }; if (message.contains("name")) { - obj["tool_response"]["name"] = message.at("name"); + obj["tool_response"]["tool"] = message.at("name"); } + obj["tool_response"]["content"] = message.at("content"); if (message.contains("tool_call_id")) { obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); } @@ -510,7 +516,7 @@ class chat_template { static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { json messages_with_system = messages; - if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { + if (!messages_with_system.empty() && messages_with_system[0].at("role") == "system") { std::string existing_system = messages_with_system.at(0).at("content"); messages_with_system[0] = json { {"role", "system"}, diff --git a/common/minja/minja.hpp b/common/minja/minja.hpp index f86a5811a1649..0728c47196099 100644 --- a/common/minja/minja.hpp +++ b/common/minja/minja.hpp @@ -8,15 +8,26 @@ // SPDX-License-Identifier: MIT #pragma once +#include +#include +#include +#include +#include +#include #include -#include -#include -#include +#include +#include +#include #include -#include +#include #include +#include +#include +#include #include -#include "json.hpp" +#include +#include +#include using json = nlohmann::ordered_json; @@ -731,51 +742,51 @@ class TemplateToken { struct TextTemplateToken : public TemplateToken { std::string text; - TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} + TextTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, loc, pre, post), text(t) {} }; struct ExpressionTemplateToken : public TemplateToken { std::shared_ptr expr; - ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} + ExpressionTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, loc, pre, post), expr(std::move(e)) {} }; struct IfTemplateToken : public TemplateToken { std::shared_ptr condition; - IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} + IfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, loc, pre, post), condition(std::move(c)) {} }; struct ElifTemplateToken : public TemplateToken { std::shared_ptr condition; - ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} + ElifTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, loc, pre, post), condition(std::move(c)) {} }; struct ElseTemplateToken : public TemplateToken { - ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} + ElseTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, loc, pre, post) {} }; struct EndIfTemplateToken : public TemplateToken { - EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} + EndIfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, loc, pre, post) {} }; struct MacroTemplateToken : public TemplateToken { std::shared_ptr name; Expression::Parameters params; - MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) - : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} + MacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, loc, pre, post), name(std::move(n)), params(std::move(p)) {} }; struct EndMacroTemplateToken : public TemplateToken { - EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} + EndMacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, loc, pre, post) {} }; struct FilterTemplateToken : public TemplateToken { std::shared_ptr filter; - FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) - : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} + FilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) + : TemplateToken(Type::Filter, loc, pre, post), filter(std::move(filter)) {} }; struct EndFilterTemplateToken : public TemplateToken { - EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} + EndFilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, loc, pre, post) {} }; struct ForTemplateToken : public TemplateToken { @@ -783,38 +794,38 @@ struct ForTemplateToken : public TemplateToken { std::shared_ptr iterable; std::shared_ptr condition; bool recursive; - ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + ForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, std::shared_ptr && c, bool r) - : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} + : TemplateToken(Type::For, loc, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} }; struct EndForTemplateToken : public TemplateToken { - EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} + EndForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, loc, pre, post) {} }; struct GenerationTemplateToken : public TemplateToken { - GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} + GenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, loc, pre, post) {} }; struct EndGenerationTemplateToken : public TemplateToken { - EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} + EndGenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, loc, pre, post) {} }; struct SetTemplateToken : public TemplateToken { std::string ns; std::vector var_names; std::shared_ptr value; - SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) - : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} + SetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateToken(Type::Set, loc, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} }; struct EndSetTemplateToken : public TemplateToken { - EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} + EndSetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, loc, pre, post) {} }; struct CommentTemplateToken : public TemplateToken { std::string text; - CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} + CommentTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, loc, pre, post), text(t) {} }; enum class LoopControlType { Break, Continue }; @@ -830,7 +841,7 @@ class LoopControlException : public std::runtime_error { struct LoopControlTemplateToken : public TemplateToken { LoopControlType control_type; - LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {} + LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {} }; class TemplateNode { @@ -868,8 +879,8 @@ class TemplateNode { class SequenceNode : public TemplateNode { std::vector> children; public: - SequenceNode(const Location & location, std::vector> && c) - : TemplateNode(location), children(std::move(c)) {} + SequenceNode(const Location & loc, std::vector> && c) + : TemplateNode(loc), children(std::move(c)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { for (const auto& child : children) child->render(out, context); } @@ -878,7 +889,7 @@ class SequenceNode : public TemplateNode { class TextNode : public TemplateNode { std::string text; public: - TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} + TextNode(const Location & loc, const std::string& t) : TemplateNode(loc), text(t) {} void do_render(std::ostringstream & out, const std::shared_ptr &) const override { out << text; } @@ -887,7 +898,7 @@ class TextNode : public TemplateNode { class ExpressionNode : public TemplateNode { std::shared_ptr expr; public: - ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + ExpressionNode(const Location & loc, std::shared_ptr && e) : TemplateNode(loc), expr(std::move(e)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); auto result = expr->evaluate(context); @@ -904,8 +915,8 @@ class ExpressionNode : public TemplateNode { class IfNode : public TemplateNode { std::vector, std::shared_ptr>> cascade; public: - IfNode(const Location & location, std::vector, std::shared_ptr>> && c) - : TemplateNode(location), cascade(std::move(c)) {} + IfNode(const Location & loc, std::vector, std::shared_ptr>> && c) + : TemplateNode(loc), cascade(std::move(c)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { for (const auto& branch : cascade) { auto enter_branch = true; @@ -924,7 +935,7 @@ class IfNode : public TemplateNode { class LoopControlNode : public TemplateNode { LoopControlType control_type_; public: - LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {} + LoopControlNode(const Location & loc, LoopControlType control_type) : TemplateNode(loc), control_type_(control_type) {} void do_render(std::ostringstream &, const std::shared_ptr &) const override { throw LoopControlException(control_type_); } @@ -938,9 +949,9 @@ class ForNode : public TemplateNode { bool recursive; std::shared_ptr else_body; public: - ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, + ForNode(const Location & loc, std::vector && var_names, std::shared_ptr && iterable, std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) - : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + : TemplateNode(loc), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { // https://jinja.palletsprojects.com/en/3.0.x/templates/#for @@ -1025,8 +1036,8 @@ class MacroNode : public TemplateNode { std::shared_ptr body; std::unordered_map named_param_positions; public: - MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) - : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + MacroNode(const Location & loc, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) + : TemplateNode(loc), name(std::move(n)), params(std::move(p)), body(std::move(b)) { for (size_t i = 0; i < params.size(); ++i) { const auto & name = params[i].first; if (!name.empty()) { @@ -1072,8 +1083,8 @@ class FilterNode : public TemplateNode { std::shared_ptr body; public: - FilterNode(const Location & location, std::shared_ptr && f, std::shared_ptr && b) - : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} + FilterNode(const Location & loc, std::shared_ptr && f, std::shared_ptr && b) + : TemplateNode(loc), filter(std::move(f)), body(std::move(b)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { if (!filter) throw std::runtime_error("FilterNode.filter is null"); @@ -1095,8 +1106,8 @@ class SetNode : public TemplateNode { std::vector var_names; std::shared_ptr value; public: - SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) - : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} + SetNode(const Location & loc, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(loc), ns(ns), var_names(vns), value(std::move(v)) {} void do_render(std::ostringstream &, const std::shared_ptr & context) const override { if (!value) throw std::runtime_error("SetNode.value is null"); if (!ns.empty()) { @@ -1118,8 +1129,8 @@ class SetTemplateNode : public TemplateNode { std::string name; std::shared_ptr template_value; public: - SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) - : TemplateNode(location), name(name), template_value(std::move(tv)) {} + SetTemplateNode(const Location & loc, const std::string & name, std::shared_ptr && tv) + : TemplateNode(loc), name(name), template_value(std::move(tv)) {} void do_render(std::ostringstream &, const std::shared_ptr & context) const override { if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); Value value { template_value->render(context) }; @@ -1132,8 +1143,8 @@ class IfExpr : public Expression { std::shared_ptr then_expr; std::shared_ptr else_expr; public: - IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) - : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + IfExpr(const Location & loc, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) + : Expression(loc), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!condition) throw std::runtime_error("IfExpr.condition is null"); if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); @@ -1150,16 +1161,16 @@ class IfExpr : public Expression { class LiteralExpr : public Expression { Value value; public: - LiteralExpr(const Location & location, const Value& v) - : Expression(location), value(v) {} + LiteralExpr(const Location & loc, const Value& v) + : Expression(loc), value(v) {} Value do_evaluate(const std::shared_ptr &) const override { return value; } }; class ArrayExpr : public Expression { std::vector> elements; public: - ArrayExpr(const Location & location, std::vector> && e) - : Expression(location), elements(std::move(e)) {} + ArrayExpr(const Location & loc, std::vector> && e) + : Expression(loc), elements(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { auto result = Value::array(); for (const auto& e : elements) { @@ -1173,8 +1184,8 @@ class ArrayExpr : public Expression { class DictExpr : public Expression { std::vector, std::shared_ptr>> elements; public: - DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) - : Expression(location), elements(std::move(e)) {} + DictExpr(const Location & loc, std::vector, std::shared_ptr>> && e) + : Expression(loc), elements(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { auto result = Value::object(); for (const auto& [key, value] : elements) { @@ -1189,8 +1200,8 @@ class DictExpr : public Expression { class SliceExpr : public Expression { public: std::shared_ptr start, end; - SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) - : Expression(location), start(std::move(s)), end(std::move(e)) {} + SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e) + : Expression(loc), start(std::move(s)), end(std::move(e)) {} Value do_evaluate(const std::shared_ptr &) const override { throw std::runtime_error("SliceExpr not implemented"); } @@ -1200,8 +1211,8 @@ class SubscriptExpr : public Expression { std::shared_ptr base; std::shared_ptr index; public: - SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) - : Expression(location), base(std::move(b)), index(std::move(i)) {} + SubscriptExpr(const Location & loc, std::shared_ptr && b, std::shared_ptr && i) + : Expression(loc), base(std::move(b)), index(std::move(i)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!base) throw std::runtime_error("SubscriptExpr.base is null"); if (!index) throw std::runtime_error("SubscriptExpr.index is null"); @@ -1243,8 +1254,8 @@ class UnaryOpExpr : public Expression { enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; std::shared_ptr expr; Op op; - UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) - : Expression(location), expr(std::move(e)), op(o) {} + UnaryOpExpr(const Location & loc, std::shared_ptr && e, Op o) + : Expression(loc), expr(std::move(e)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); auto e = expr->evaluate(context); @@ -1269,8 +1280,8 @@ class BinaryOpExpr : public Expression { std::shared_ptr right; Op op; public: - BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) - : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + BinaryOpExpr(const Location & loc, std::shared_ptr && l, std::shared_ptr && r, Op o) + : Expression(loc), left(std::move(l)), right(std::move(r)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); @@ -1427,8 +1438,8 @@ class MethodCallExpr : public Expression { std::shared_ptr method; ArgumentsExpression args; public: - MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) - : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + MethodCallExpr(const Location & loc, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) + : Expression(loc), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!object) throw std::runtime_error("MethodCallExpr.object is null"); if (!method) throw std::runtime_error("MethodCallExpr.method is null"); @@ -1526,8 +1537,8 @@ class CallExpr : public Expression { public: std::shared_ptr object; ArgumentsExpression args; - CallExpr(const Location & location, std::shared_ptr && obj, ArgumentsExpression && a) - : Expression(location), object(std::move(obj)), args(std::move(a)) {} + CallExpr(const Location & loc, std::shared_ptr && obj, ArgumentsExpression && a) + : Expression(loc), object(std::move(obj)), args(std::move(a)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!object) throw std::runtime_error("CallExpr.object is null"); auto obj = object->evaluate(context); @@ -1542,8 +1553,8 @@ class CallExpr : public Expression { class FilterExpr : public Expression { std::vector> parts; public: - FilterExpr(const Location & location, std::vector> && p) - : Expression(location), parts(std::move(p)) {} + FilterExpr(const Location & loc, std::vector> && p) + : Expression(loc), parts(std::move(p)) {} Value do_evaluate(const std::shared_ptr & context) const override { Value result; bool first = true; @@ -2460,7 +2471,7 @@ class Parser { static std::regex leading_space_regex(R"(^\s+)"); text = std::regex_replace(text, leading_space_regex, ""); } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { - if (text.length() > 0 && text[0] == '\n') { + if (!text.empty() && text[0] == '\n') { text.erase(0, 1); } } @@ -2538,7 +2549,7 @@ class Parser { TemplateTokenIterator begin = tokens.begin(); auto it = begin; TemplateTokenIterator end = tokens.end(); - return parser.parseTemplate(begin, it, end, /* full= */ true); + return parser.parseTemplate(begin, it, end, /* fully= */ true); } }; @@ -2577,7 +2588,7 @@ inline std::shared_ptr Context::builtins() { throw std::runtime_error(args.at("message").get()); })); globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { - return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); + return Value(args.at("value").dump(args.get("indent", -1), /* to_json= */ true)); })); globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { auto items = Value::array(); @@ -2599,21 +2610,25 @@ inline std::shared_ptr Context::builtins() { globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { auto items = args.at("items"); if (!items.is_array()) throw std::runtime_error("object is not a list"); - if (items.size() == 0) return Value(); + if (items.empty()) return Value(); return items.at(items.size() - 1); })); globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { auto & text = args.at("text"); return text.is_null() ? text : Value(strip(text.get())); })); - globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr &, Value & args) { - auto text = args.at("text"); - if (text.is_null()) return text; - std::string res; - auto str = text.get(); - std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower); - return Value(res); - })); + auto char_transform_function = [](const std::string & name, const std::function & fn) { + return simple_function(name, { "text" }, [=](const std::shared_ptr &, Value & args) { + auto text = args.at("text"); + if (text.is_null()) return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), std::back_inserter(res), fn); + return Value(res); + }); + }; + globals.set("lower", char_transform_function("lower", ::tolower)); + globals.set("upper", char_transform_function("upper", ::toupper)); globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { args.expectArgs("default", {2, 3}, {0, 1}); auto & value = args.args[0]; @@ -2743,12 +2758,17 @@ inline std::shared_ptr Context::builtins() { return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; - if (items.is_null()) + if (items.is_null()) { return Value::array(); - if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + } + if (!items.is_array()) { + throw std::runtime_error("object is not iterable: " + items.dump()); + } auto filter_fn = context->get(args.args[1]); - if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + if (filter_fn.is_null()) { + throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + } auto filter_args = Value::array(); for (size_t i = 2, n = args.args.size(); i < n; i++) { @@ -2870,20 +2890,25 @@ inline std::shared_ptr Context::builtins() { auto v = arg.get(); startEndStep[i] = v; param_set[i] = true; - } } - for (auto & [name, value] : args.kwargs) { - size_t i; - if (name == "start") i = 0; - else if (name == "end") i = 1; - else if (name == "step") i = 2; - else throw std::runtime_error("Unknown argument " + name + " for function range"); - - if (param_set[i]) { - throw std::runtime_error("Duplicate argument " + name + " for function range"); - } - startEndStep[i] = value.get(); - param_set[i] = true; + } + for (auto & [name, value] : args.kwargs) { + size_t i; + if (name == "start") { + i = 0; + } else if (name == "end") { + i = 1; + } else if (name == "step") { + i = 2; + } else { + throw std::runtime_error("Unknown argument " + name + " for function range"); + } + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + name + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; } if (!param_set[1]) { throw std::runtime_error("Missing required argument 'end' for function range"); diff --git a/common/sampling.cpp b/common/sampling.cpp index 4a57b54566d4c..1b1d7e5607f30 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -219,6 +219,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_tokens.data(), trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); + if (!grmr) { + return nullptr; + } } auto * result = new common_sampler { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 76ab4233ef2c1..cfe94deaf76ef 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -708,6 +708,12 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "7dec86086fcc38b66b7bc1575a160ae21cf705be7718b9d5598190d7c12db76f": # ref: https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k res = "superbpe" + if chkhsh == "1994ffd01900cfb37395608534236ecd63f2bd5995d6cb1004dda1af50240f15": + # ref: https://huggingface.co/trillionlabs/Trillion-7B-preview + res = "trillion" + if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224": + # ref: https://huggingface.co/inclusionAI/Ling-lite + res = "bailingmoe" if res is None: logger.warning("\n") @@ -2269,7 +2275,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) -@Model.register("Qwen2VLForConditionalGeneration") +@Model.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") class Qwen2VLModel(Model): model_arch = gguf.MODEL_ARCH.QWEN2VL @@ -3551,8 +3557,8 @@ def set_gguf_parameters(self): head_size = hidden_size // num_attention_heads rms_norm_eps = self.hparams["rms_norm_eps"] intermediate_size = self.hparams["intermediate_size"] - time_mix_extra_dim = 64 if hidden_size >= 4096 else 32 - time_decay_extra_dim = 128 if hidden_size >= 4096 else 64 + time_mix_extra_dim = self.hparams.get("lora_rank_tokenshift", 64 if hidden_size >= 4096 else 32) + time_decay_extra_dim = self.hparams.get("lora_rank_decay", 128 if hidden_size >= 4096 else 64) # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) @@ -3803,8 +3809,6 @@ def set_gguf_parameters(self): _tok_embd = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD) @@ -3814,6 +3818,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter logger.debug("A_log --> A ==> " + new_name) data_torch = -torch.exp(data_torch) + # [4 1 8192 1] -> [4 8192 1 1] + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + # assuming token_embd.weight is seen before output.weight if self._tok_embd is not None and new_name == output_name: if torch.equal(self._tok_embd, data_torch): @@ -4417,6 +4425,29 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("PLMForCausalLM") +class PLMModel(Model): + model_arch = gguf.MODEL_ARCH.PLM + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + @Model.register("T5WithLMHeadModel") @Model.register("T5ForConditionalGeneration") @Model.register("MT5ForConditionalGeneration") @@ -5105,6 +5136,105 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@Model.register("BailingMoeForCausalLM") +class BailingMoeModel(Model): + model_arch = gguf.MODEL_ARCH.BAILINGMOE + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"] + + self.gguf_writer.add_rope_dimension_count(rope_dim) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_weights_scale(1.0) + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + _experts: list[dict[str, Tensor]] | None = None + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + n_embd = self.hparams["hidden_size"] + head_dim = self.hparams.get("head_dim") or n_embd // n_head + + output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) + + if name.endswith("attention.dense.weight"): + return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch)] + elif name.endswith("query_key_value.weight"): + q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2) + + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeModel.permute(q, n_head, n_head)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeModel.permute(k, n_head, n_kv_head)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v) + ] + elif name.find("mlp.experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + tensors: list[tuple[str, Tensor]] = [] + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + + return tensors + + new_name = self.map_tensor_name(name) + + if new_name == output_name and self.hparams.get("norm_head"): + data_torch = data_torch.float() + data_torch /= torch.norm(data_torch, p=2, dim=0, keepdim=True) + 1e-7 + + return [(new_name, data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @Model.register("ChameleonForConditionalGeneration") @Model.register("ChameleonForCausalLM") # obsolete class ChameleonModel(Model): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index ca90cf592932b..1b86f4c90acf6 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -111,6 +111,8 @@ class TOKENIZER_TYPE(IntEnum): {"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}, {"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", }, {"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", }, + {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, + {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, ] diff --git a/docs/backend/OPENCL.md b/docs/backend/OPENCL.md index 2a946dc8df0ff..07146f7102f3d 100644 --- a/docs/backend/OPENCL.md +++ b/docs/backend/OPENCL.md @@ -145,8 +145,13 @@ A Snapdragon X Elite device with Windows 11 Arm64 is used. Make sure the followi * Clang 19 * Ninja * Visual Studio 2022 +* Powershell 7 -Powershell is used for the following instructions. +Visual Studio provides necessary headers and libraries although it is not directly used for building. +Alternatively, Visual Studio Build Tools can be installed instead of the full Visual Studio. + +Powershell 7 is used for the following commands. +If an older version of Powershell is used, these commands may not work as they are. ### I. Setup Environment @@ -196,10 +201,9 @@ ninja ## Known Issues -- Qwen2.5 0.5B model produces gibberish output with Adreno kernels. +- Currently OpenCL backend does not work on Adreno 6xx GPUs. ## TODO -- Fix Qwen2.5 0.5B - Optimization for Q6_K - Support and optimization for Q4_K diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index f97e696aaaac8..cb29075b106be 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -20,7 +20,7 @@ **oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. -- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*. +- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*. - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs. - **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. @@ -227,16 +227,6 @@ Upon a successful installation, SYCL is enabled for the available intel devices, **oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup. - -**oneMKL for cuBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* do not contain the cuBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *cuBLAS* backend enabled is thus required to run it on Nvidia GPUs. - -```sh -git clone https://github.com/oneapi-src/oneMKL -cd oneMKL -cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENABLE_MKLGPU_BACKEND=OFF -DENABLE_MKLCPU_BACKEND=OFF -DENABLE_CUBLAS_BACKEND=ON -DTARGET_DOMAINS=blas -cmake --build buildWithCublas --config Release -``` - **oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target: ```sh @@ -250,16 +240,6 @@ cmake --build build-nvidia --config Release **oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit. -**oneMKL for rocBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* doesn't contain the rocBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *rocBLAS* backend enabled is thus required to run it on AMD GPUs. - -```sh -git clone https://github.com/oneapi-src/oneMKL -cd oneMKL -# Find your HIPTARGET with rocminfo, under the key 'Name:' -cmake -B buildWithrocBLAS -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENABLE_MKLGPU_BACKEND=OFF -DENABLE_MKLCPU_BACKEND=OFF -DENABLE_ROCBLAS_BACKEND=ON -DHIPTARGETS=${HIPTARGET} -DTARGET_DOMAINS=blas -cmake --build buildWithrocBLAS --config Release -``` - 3. **Verify installation and environment** In order to check the available SYCL devices on the machine, please use the `sycl-ls` command. @@ -322,15 +302,16 @@ cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx - cmake --build build --config Release -j -v ``` +It is possible to come across some precision issues when running tests that stem from using faster +instructions, which can be circumvented by setting the environment variable `SYCL_PROGRAM_COMPILE_OPTIONS` +as `-cl-fp32-correctly-rounded-divide-sqrt` + #### Nvidia GPU -```sh -# Export relevant ENV variables -export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithCublas/lib:$LD_LIBRARY_PATH -export LIBRARY_PATH=/path/to/oneMKL/buildWithCublas/lib:$LIBRARY_PATH -export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithCublas/include:$CPLUS_INCLUDE_DIR -export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR +The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices. +By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`. +```sh # Build LLAMA with Nvidia BLAS acceleration through SYCL # Setting GGML_SYCL_DEVICE_ARCH is optional but can improve performance GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture @@ -345,14 +326,15 @@ cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH= cmake --build build --config Release -j -v ``` +It is possible to come across some precision issues when running tests that stem from using faster +instructions, which can be circumvented by passing the `-fno-fast-math` flag to the compiler. + #### AMD GPU -```sh -# Export relevant ENV variables -export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LD_LIBRARY_PATH -export LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LIBRARY_PATH -export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithrocBLAS/include:$CPLUS_INCLUDE_DIR +The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices. +By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`. +```sh # Build LLAMA with rocBLAS acceleration through SYCL ## AMD @@ -493,6 +475,12 @@ b. Enable oneAPI running environment: "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 ``` +- if you are using Powershell, enable the runtime environment with the following: + +``` +cmd.exe "/K" '"C:\Program Files (x86)\Intel\oneAPI\setvars.bat" && powershell' +``` + c. Verify installation In the oneAPI command line, run the following to print the available SYCL devices: @@ -523,13 +511,13 @@ You could download the release package for Windows directly, which including bin Choose one of following methods to build from source code. -1. Script +#### 1. Script ```sh .\examples\sycl\win-build-sycl.bat ``` -2. CMake +#### 2. CMake On the oneAPI command line window, step into the llama.cpp main directory and run the following: @@ -558,13 +546,84 @@ cmake --preset x64-windows-sycl-debug cmake --build build-x64-windows-sycl-debug -j --target llama-cli ``` -3. Visual Studio +#### 3. Visual Studio + +You have two options to use Visual Studio to build llama.cpp: +- As CMake Project using CMake presets. +- Creating a Visual Studio solution to handle the project. + +**Note**: + +All following commands are executed in PowerShell. + +##### - Open as a CMake Project + +You can use Visual Studio to open the `llama.cpp` folder directly as a CMake project. Before compiling, select one of the SYCL CMake presets: -You can use Visual Studio to open llama.cpp folder as a CMake project. Choose the sycl CMake presets (`x64-windows-sycl-release` or `x64-windows-sycl-debug`) before you compile the project. +- `x64-windows-sycl-release` + +- `x64-windows-sycl-debug` *Notes:* +- For a minimal experimental setup, you can build only the inference executable using: + + ```Powershell + cmake --build build --config Release -j --target llama-cli + ``` + +##### - Generating a Visual Studio Solution + +You can use Visual Studio solution to build and work on llama.cpp on Windows. You need to convert the CMake Project into a `.sln` file. + +If you want to use the Intel C++ Compiler for the entire `llama.cpp` project, run the following command: + +```Powershell +cmake -B build -G "Visual Studio 17 2022" -T "Intel C++ Compiler 2025" -A x64 -DGGML_SYCL=ON -DCMAKE_BUILD_TYPE=Release +``` + +If you prefer to use the Intel C++ Compiler only for `ggml-sycl`, ensure that `ggml` and its backend libraries are built as shared libraries ( i.e. `-DBUILD_SHARED_LIBRARIES=ON`, this is default behaviour): + +```Powershell +cmake -B build -G "Visual Studio 17 2022" -A x64 -DGGML_SYCL=ON -DCMAKE_BUILD_TYPE=Release \ + -DSYCL_INCLUDE_DIR="C:\Program Files (x86)\Intel\oneAPI\compiler\latest\include" \ + -DSYCL_LIBRARY_DIR="C:\Program Files (x86)\Intel\oneAPI\compiler\latest\lib" +``` + +If successful the build files have been written to: *path/to/llama.cpp/build* +Open the project file **build/llama.cpp.sln** with Visual Studio. + +Once the Visual Studio solution is created, follow these steps: + +1. Open the solution in Visual Studio. + +2. Right-click on `ggml-sycl` and select **Properties**. + +3. In the left column, expand **C/C++** and select **DPC++**. + +4. In the right panel, find **Enable SYCL Offload** and set it to `Yes`. + +5. Apply the changes and save. + + +*Navigation Path:* + +``` +Properties -> C/C++ -> DPC++ -> Enable SYCL Offload (Yes) +``` + +Now, you can build `llama.cpp` with the SYCL backend as a Visual Studio project. +To do it from menu: `Build -> Build Solution`. +Once it is completed, final results will be in **build/Release/bin** + +*Additional Note* + +- You can avoid specifying `SYCL_INCLUDE_DIR` and `SYCL_LIBRARY_DIR` in the CMake command by setting the environment variables: + + - `SYCL_INCLUDE_DIR_HINT` + + - `SYCL_LIBRARY_DIR_HINT` -- In case of a minimal experimental setup, the user can build the inference executable only through `cmake --build build --config Release -j --target llama-cli`. +- Above instruction has been tested with Visual Studio 17 Community edition and oneAPI 2025.0. We expect them to work also with future version if the instructions are adapted accordingly. ### III. Run the inference diff --git a/docs/build.md b/docs/build.md index fbf12c7664d50..3f1b043992545 100644 --- a/docs/build.md +++ b/docs/build.md @@ -191,7 +191,7 @@ The following compilation options are also available to tweak performance: | Option | Legal values | Default | Description | |-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | +| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | | GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | @@ -218,6 +218,7 @@ By default, all supported compute capabilities are enabled. To customize this be ```bash cmake -B build -DGGML_MUSA=ON -DMUSA_ARCHITECTURES="21" +cmake --build build --config Release ``` This configuration enables only compute capability `2.1` (MTT S80) during compilation, which can help reduce compilation time. @@ -455,6 +456,96 @@ KleidiAI's microkernels implement optimized tensor operations using Arm CPU feat Depending on your build target, other higher priority backends may be enabled by default. To ensure the CPU backend is used, you must disable the higher priority backends either at compile time, e.g. -DGGML_METAL=OFF, or during run-time using the command line option `--device none`. +## OpenCL + +This provides GPU acceleration through OpenCL on recent Adreno GPU. +More information about OpenCL backend can be found in [OPENCL.md](./backend/OPENCL.md) for more information. + +### Android + +Assume NDK is available in `$ANDROID_NDK`. First, install OpenCL headers and ICD loader library if not available, + +```sh +mkdir -p ~/dev/llm +cd ~/dev/llm + +git clone https://github.com/KhronosGroup/OpenCL-Headers && \ +cd OpenCL-Headers && \ +cp -r CL $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + +cd ~/dev/llm + +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && \ +cd OpenCL-ICD-Loader && \ +mkdir build_ndk && cd build_ndk && \ +cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DOPENCL_ICD_LOADER_HEADERS_DIR=$ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=24 \ + -DANDROID_STL=c++_shared && \ +ninja && \ +cp libOpenCL.so $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android +``` + +Then build llama.cpp with OpenCL enabled, + +```sh +cd ~/dev/llm + +git clone https://github.com/ggml-org/llama.cpp && \ +cd llama.cpp && \ +mkdir build-android && cd build-android + +cmake .. -G Ninja \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DBUILD_SHARED_LIBS=OFF \ + -DGGML_OPENCL=ON + +ninja +``` + +### Windows Arm64 + +First, install OpenCL headers and ICD loader library if not available, + +```powershell +mkdir -p ~/dev/llm + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers +mkdir build && cd build +cmake .. -G Ninja ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader +mkdir build && cd build +cmake .. -G Ninja ` + -DCMAKE_BUILD_TYPE=Release ` + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" ` + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install +``` + +Then build llama.cpp with OpenCL enabled, + +```powershell +cmake .. -G Ninja ` + -DCMAKE_TOOLCHAIN_FILE="$HOME/dev/llm/llama.cpp/cmake/arm64-windows-llvm.cmake" ` + -DCMAKE_BUILD_TYPE=Release ` + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" ` + -DBUILD_SHARED_LIBS=OFF ` + -DGGML_OPENCL=ON +ninja +``` + ## Android To read documentation for how to build on Android, [click here](./android.md) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 430e8be512653..0f4019293d581 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -38,7 +38,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 21b95ef5e4e83..1a5de5928a526 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -41,7 +41,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: error: unable to load model\n" , __func__); diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index e7d0fbfffedb0..24dc85cf27336 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -421,7 +421,7 @@ int main(int argc, char ** argv) { g_verbose = (params.verbosity > 1); try { - lora_merge_ctx ctx(params.model, params.lora_adapters, params.out_file, params.cpuparams.n_threads); + lora_merge_ctx ctx(params.model.path, params.lora_adapters, params.out_file, params.cpuparams.n_threads); ctx.run_merge(); } catch (const std::exception & err) { fprintf(stderr, "%s\n", err.what()); diff --git a/examples/gguf-split/gguf-split.cpp b/examples/gguf-split/gguf-split.cpp index ef3ceb686f697..30e771564e808 100644 --- a/examples/gguf-split/gguf-split.cpp +++ b/examples/gguf-split/gguf-split.cpp @@ -408,8 +408,6 @@ static void gguf_merge(const split_params & split_params) { exit(EXIT_FAILURE); } - std::ofstream fout(split_params.output.c_str(), std::ios::binary); - fout.exceptions(std::ofstream::failbit); // fail fast on write errors auto * ctx_out = gguf_init_empty(); @@ -453,7 +451,6 @@ static void gguf_merge(const split_params & split_params) { gguf_free(ctx_gguf); ggml_free(ctx_meta); gguf_free(ctx_out); - fout.close(); exit(EXIT_FAILURE); } @@ -466,7 +463,6 @@ static void gguf_merge(const split_params & split_params) { gguf_free(ctx_gguf); ggml_free(ctx_meta); gguf_free(ctx_out); - fout.close(); exit(EXIT_FAILURE); } @@ -479,7 +475,6 @@ static void gguf_merge(const split_params & split_params) { gguf_free(ctx_gguf); ggml_free(ctx_meta); gguf_free(ctx_out); - fout.close(); exit(EXIT_FAILURE); } @@ -500,9 +495,11 @@ static void gguf_merge(const split_params & split_params) { fprintf(stderr, "\033[3Ddone\n"); } - - // placeholder for the meta data - { + std::ofstream fout; + if (!split_params.dry_run) { + fout.open(split_params.output.c_str(), std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors + // placeholder for the meta data auto meta_size = gguf_get_meta_size(ctx_out); ::zeros(fout, meta_size); } @@ -518,7 +515,9 @@ static void gguf_merge(const split_params & split_params) { ggml_free(ctx_metas[i]); } gguf_free(ctx_out); - fout.close(); + if (!split_params.dry_run) { + fout.close(); + } exit(EXIT_FAILURE); } fprintf(stderr, "%s: writing tensors %s ...", __func__, split_path); @@ -540,10 +539,11 @@ static void gguf_merge(const split_params & split_params) { auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor); f_input.seekg(offset); f_input.read((char *)read_data.data(), n_bytes); - - // write tensor data + padding - fout.write((const char *)read_data.data(), n_bytes); - zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes); + if (!split_params.dry_run) { + // write tensor data + padding + fout.write((const char *)read_data.data(), n_bytes); + zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes); + } } gguf_free(ctx_gguf); @@ -552,16 +552,15 @@ static void gguf_merge(const split_params & split_params) { fprintf(stderr, "\033[3Ddone\n"); } - { + if (!split_params.dry_run) { // go back to beginning of file and write the updated metadata fout.seekp(0); std::vector data(gguf_get_meta_size(ctx_out)); gguf_get_meta_data(ctx_out, data.data()); fout.write((const char *)data.data(), data.size()); - fout.close(); - gguf_free(ctx_out); } + gguf_free(ctx_out); fprintf(stderr, "%s: %s merged from %d split with %d tensors.\n", __func__, split_params.output.c_str(), n_split, total_tensors); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index f7db7861c1ad5..539bc4d6027fb 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -168,7 +168,7 @@ int main(int argc, char * argv[]) { llama_backend_init(); - llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); // create generation context llama_context * ctx = llama_init_from_model(model, cparams); diff --git a/examples/llava/README-gemma3.md b/examples/llava/README-gemma3.md index 20bf73fb5c043..3c25ee2583027 100644 --- a/examples/llava/README-gemma3.md +++ b/examples/llava/README-gemma3.md @@ -4,6 +4,26 @@ > > This is very experimental, only used for demo purpose. +## Quick started + +You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)'s Hugging Face account + +```bash +# build +cmake -B build +cmake --build build --target llama-gemma3-cli + +# alternatively, install from brew (MacOS) +brew install llama.cpp + +# run it +llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF +llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF +llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF + +# note: 1B model does not support vision +``` + ## How to get mmproj.gguf? ```bash diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h new file mode 100644 index 0000000000000..685d6e7e09ad1 --- /dev/null +++ b/examples/llava/clip-impl.h @@ -0,0 +1,273 @@ +#include "ggml.h" +#include "gguf.h" + +#include +#include +#include +#include +#include +#include + +// Internal header for clip.cpp + +#define KEY_FTYPE "general.file_type" +#define KEY_NAME "general.name" +#define KEY_DESCRIPTION "general.description" +#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" +#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" +#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" +#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" +#define KEY_USE_GELU "clip.use_gelu" +#define KEY_USE_SILU "clip.use_silu" +#define KEY_N_EMBD "clip.%s.embedding_length" +#define KEY_N_FF "clip.%s.feed_forward_length" +#define KEY_N_BLOCK "clip.%s.block_count" +#define KEY_N_HEAD "clip.%s.attention.head_count" +#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" +#define KEY_PROJ_DIM "clip.%s.projection_dim" +#define KEY_TOKENS "tokenizer.ggml.tokens" +#define KEY_N_POSITIONS "clip.text.context_length" +#define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_PATCH_SIZE "clip.vision.patch_size" +#define KEY_IMAGE_MEAN "clip.vision.image_mean" +#define KEY_IMAGE_STD "clip.vision.image_std" +#define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_FEATURE_LAYER "clip.vision.feature_layer" + +#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" +#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" +#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" + + +// +// tensor name constants +// + +#define TN_TOKEN_EMBD "%s.token_embd.weight" +#define TN_POS_EMBD "%s.position_embd.weight" +#define TN_CLASS_EMBD "v.class_embd" +#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat +#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" +#define TN_PATCH_BIAS "v.patch_embd.bias" +#define TN_ATTN_K "%s.blk.%d.attn_k.%s" +#define TN_ATTN_Q "%s.blk.%d.attn_q.%s" +#define TN_ATTN_V "%s.blk.%d.attn_v.%s" +#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" +#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" +#define TN_FFN_UP "%s.blk.%d.ffn_up.%s" +#define TN_LN_1 "%s.blk.%d.ln1.%s" +#define TN_LN_2 "%s.blk.%d.ln2.%s" +#define TN_LN_PRE "%s.pre_ln.%s" +#define TN_LN_POST "%s.post_ln.%s" +#define TN_TEXT_PROJ "text_projection.weight" +#define TN_VIS_PROJ "visual_projection.weight" +#define TN_LLAVA_PROJ "mm.%d.%s" +#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s" +#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" +#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" +#define TN_IMAGE_NEWLINE "model.image_newline" +#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 +#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 + +// mimicpmv +#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" +#define TN_MINICPMV_QUERY "resampler.query" +#define TN_MINICPMV_PROJ "resampler.proj.weight" +#define TN_MINICPMV_KV_PROJ "resampler.kv.weight" +#define TN_MINICPMV_ATTN "resampler.attn.%s.%s" +#define TN_MINICPMV_LN "resampler.ln_%s.%s" + +#define TN_GLM_ADAPER_CONV "adapter.conv.%s" +#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" +#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s" +#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s" +#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" +#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" +#define TN_GLM_BOI_W "adapter.boi" +#define TN_GLM_EOI_W "adapter.eoi" + +enum projector_type { + PROJECTOR_TYPE_MLP, + PROJECTOR_TYPE_MLP_NORM, + PROJECTOR_TYPE_LDP, + PROJECTOR_TYPE_LDPV2, + PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_GLM_EDGE, + PROJECTOR_TYPE_MERGER, + PROJECTOR_TYPE_GEMMA3, + PROJECTOR_TYPE_UNKNOWN, +}; + +static std::map PROJECTOR_TYPE_NAMES = { + { PROJECTOR_TYPE_MLP, "mlp" }, + { PROJECTOR_TYPE_LDP, "ldp" }, + { PROJECTOR_TYPE_LDPV2, "ldpv2"}, + { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, + { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, + { PROJECTOR_TYPE_GEMMA3, "gemma3"}, +}; + +static projector_type clip_projector_type_from_string(const std::string & str) { + for (const auto & pair : PROJECTOR_TYPE_NAMES) { + if (pair.second == str) { + return pair.first; + } + } + return PROJECTOR_TYPE_UNKNOWN; +} + +// +// logging +// + +static void clip_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct clip_logger_state { + ggml_log_level verbosity_thold; + ggml_log_callback log_callback; + void * log_callback_user_data; +}; + +extern struct clip_logger_state g_logger_state; + +static void clip_log_internal_v(enum ggml_log_level level, const char * format, va_list args) { + if (format == NULL) { + return; + } + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data); + } else { + char * buffer2 = (char *) calloc(len + 1, sizeof(char)); + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data); + free(buffer2); + } + va_end(args_copy); +} + +static void clip_log_internal(enum ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + clip_log_internal_v(level, format, args); + va_end(args); +} + +#define LOG_TMPL(level, ...) \ + do { \ + if ((level) >= g_logger_state.verbosity_thold) { \ + clip_log_internal((level), __VA_ARGS__); \ + } \ + } while (0) +#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, __VA_ARGS__) +#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, __VA_ARGS__) +#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__) + +// +// common utils +// + +static std::string string_format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), buf.size()); +} + +static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +// +// gguf utils +// + +static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { + switch (type) { + case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); + case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); + case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); + case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); + case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); + case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); + case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); + case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); + case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); + case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); + case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + default: return string_format("unknown type %d", type); + } +} + +static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + switch (type) { + case GGUF_TYPE_STRING: + return gguf_get_val_str(ctx_gguf, i); + case GGUF_TYPE_ARRAY: + { + const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); + int arr_n = gguf_get_arr_n(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); + std::stringstream ss; + ss << "["; + for (int j = 0; j < arr_n; j++) { + if (arr_type == GGUF_TYPE_STRING) { + std::string val = gguf_get_arr_str(ctx_gguf, i, j); + // escape quotes + string_replace_all(val, "\\", "\\\\"); + string_replace_all(val, "\"", "\\\""); + ss << '"' << val << '"'; + } else if (arr_type == GGUF_TYPE_ARRAY) { + ss << "???"; + } else { + ss << gguf_data_to_str(arr_type, data, j); + } + if (j < arr_n - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + default: + return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); + } +} diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index a1f050e39a094..1145c816c83d4 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -3,6 +3,7 @@ // I'll gradually clean and extend it // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" +#include "clip-impl.h" #include "ggml.h" #include "ggml-cpp.h" #include "ggml-cpu.h" @@ -27,17 +28,7 @@ #include #include -#if defined(LLAVA_LOG_OFF) -# define LOG_INF(...) -# define LOG_WRN(...) -# define LOG_ERR(...) -# define LOG_DBG(...) -#else // defined(LLAVA_LOG_OFF) -# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -#endif // defined(LLAVA_LOG_OFF) +struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; //#define CLIP_DEBUG_FUNCTIONS @@ -58,253 +49,6 @@ struct clip_image_f32 { std::vector buf; }; -static std::string format(const char * fmt, ...) { - va_list ap; - va_list ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), buf.size()); -} - -// -// key constants -// - -#define KEY_FTYPE "general.file_type" -#define KEY_NAME "general.name" -#define KEY_DESCRIPTION "general.description" -#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" -#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" -#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" -#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" -#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" -#define KEY_MINICPMV_VERSION "clip.minicpmv_version" -#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" -#define KEY_USE_GELU "clip.use_gelu" -#define KEY_USE_SILU "clip.use_silu" -#define KEY_N_EMBD "clip.%s.embedding_length" -#define KEY_N_FF "clip.%s.feed_forward_length" -#define KEY_N_BLOCK "clip.%s.block_count" -#define KEY_N_HEAD "clip.%s.attention.head_count" -#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" -#define KEY_PROJ_DIM "clip.%s.projection_dim" -#define KEY_TOKENS "tokenizer.ggml.tokens" -#define KEY_N_POSITIONS "clip.text.context_length" -#define KEY_IMAGE_SIZE "clip.vision.image_size" -#define KEY_PATCH_SIZE "clip.vision.patch_size" -#define KEY_IMAGE_MEAN "clip.vision.image_mean" -#define KEY_IMAGE_STD "clip.vision.image_std" -#define KEY_PROJ_TYPE "clip.projector_type" -#define KEY_FEATURE_LAYER "clip.vision.feature_layer" - -#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" -#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" -#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" - - -// -// tensor name constants -// - -#define TN_TOKEN_EMBD "%s.token_embd.weight" -#define TN_POS_EMBD "%s.position_embd.weight" -#define TN_CLASS_EMBD "v.class_embd" -#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat -#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" -#define TN_PATCH_BIAS "v.patch_embd.bias" -#define TN_ATTN_K "%s.blk.%d.attn_k.%s" -#define TN_ATTN_Q "%s.blk.%d.attn_q.%s" -#define TN_ATTN_V "%s.blk.%d.attn_v.%s" -#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" -#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" -#define TN_FFN_UP "%s.blk.%d.ffn_up.%s" -#define TN_LN_1 "%s.blk.%d.ln1.%s" -#define TN_LN_2 "%s.blk.%d.ln2.%s" -#define TN_LN_PRE "%s.pre_ln.%s" -#define TN_LN_POST "%s.post_ln.%s" -#define TN_TEXT_PROJ "text_projection.weight" -#define TN_VIS_PROJ "visual_projection.weight" -#define TN_LLAVA_PROJ "mm.%d.%s" -#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s" -#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" -#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" -#define TN_IMAGE_NEWLINE "model.image_newline" -#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 -#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 - -#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" -#define TN_MINICPMV_QUERY "resampler.query" -#define TN_MINICPMV_PROJ "resampler.proj.weight" -#define TN_MINICPMV_KV_PROJ "resampler.kv.weight" -#define TN_MINICPMV_ATTN "resampler.attn.%s.%s" -#define TN_MINICPMV_LN "resampler.ln_%s.%s" - -#define TN_GLM_ADAPER_CONV "adapter.conv.%s" -#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" -#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s" -#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s" -#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" -#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" -#define TN_GLM_BOI_W "adapter.boi" -#define TN_GLM_EOI_W "adapter.eoi" - - -enum projector_type { - PROJECTOR_TYPE_MLP, - PROJECTOR_TYPE_MLP_NORM, - PROJECTOR_TYPE_LDP, - PROJECTOR_TYPE_LDPV2, - PROJECTOR_TYPE_RESAMPLER, - PROJECTOR_TYPE_GLM_EDGE, - PROJECTOR_TYPE_MERGER, - PROJECTOR_TYPE_GEMMA3, - PROJECTOR_TYPE_UNKNOWN, -}; - -static std::map PROJECTOR_TYPE_NAMES = { - { PROJECTOR_TYPE_MLP, "mlp" }, - { PROJECTOR_TYPE_LDP, "ldp" }, - { PROJECTOR_TYPE_LDPV2, "ldpv2"}, - { PROJECTOR_TYPE_RESAMPLER, "resampler"}, - { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, - { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, - { PROJECTOR_TYPE_GEMMA3, "gemma3"}, -}; - - -// -// utilities to get data from a gguf file -// - -static int get_key_idx(const gguf_context * ctx, const char * key) { - int i = gguf_find_key(ctx, key); - if (i == -1) { - LOG_ERR("key %s not found in file\n", key); - throw std::runtime_error(format("Missing required key: %s", key)); - } - - return i; -} - -static uint32_t get_u32(const gguf_context * ctx, const std::string & key) { - const int i = get_key_idx(ctx, key.c_str()); - - return gguf_get_val_u32(ctx, i); -} - -static float get_f32(const gguf_context * ctx, const std::string & key) { - const int i = get_key_idx(ctx, key.c_str()); - - return gguf_get_val_f32(ctx, i); -} - -static struct ggml_tensor * get_tensor(struct ggml_context * ctx, const std::string & name) { - struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str()); - if (!cur) { - throw std::runtime_error(format("%s: unable to find tensor %s\n", __func__, name.c_str())); - } - - return cur; -} - -static std::string get_ftype(int ftype) { - return ggml_type_name(static_cast(ftype)); -} - -static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { - switch (type) { - case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); - case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); - case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); - case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); - case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); - case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); - case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); - case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); - case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); - case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); - case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; - default: return format("unknown type %d", type); - } -} - -static void replace_all(std::string & s, const std::string & search, const std::string & replace) { - if (search.empty()) { - return; - } - std::string builder; - builder.reserve(s.length()); - size_t pos = 0; - size_t last_pos = 0; - while ((pos = s.find(search, last_pos)) != std::string::npos) { - builder.append(s, last_pos, pos - last_pos); - builder.append(replace); - last_pos = pos + search.length(); - } - builder.append(s, last_pos, std::string::npos); - s = std::move(builder); -} - -static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { - const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); - - switch (type) { - case GGUF_TYPE_STRING: - return gguf_get_val_str(ctx_gguf, i); - case GGUF_TYPE_ARRAY: - { - const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); - int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); - std::stringstream ss; - ss << "["; - for (int j = 0; j < arr_n; j++) { - if (arr_type == GGUF_TYPE_STRING) { - std::string val = gguf_get_arr_str(ctx_gguf, i, j); - // escape quotes - replace_all(val, "\\", "\\\\"); - replace_all(val, "\"", "\\\""); - ss << '"' << val << '"'; - } else if (arr_type == GGUF_TYPE_ARRAY) { - ss << "???"; - } else { - ss << gguf_data_to_str(arr_type, data, j); - } - if (j < arr_n - 1) { - ss << ", "; - } - } - ss << "]"; - return ss.str(); - } - default: - return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); - } -} - -static void print_tensor_info(const ggml_tensor * tensor, const char * prefix = "") { - size_t tensor_size = ggml_nbytes(tensor); - LOG_INF("%s: n_dims = %d, name = %s, tensor_size=%zu, shape:[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "], type = %s\n", - prefix, ggml_n_dims(tensor), tensor->name, tensor_size, - tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], ggml_type_name(tensor->type)); -} - -static projector_type clip_projector_type_from_string(const std::string & name) { - for (const auto & kv : PROJECTOR_TYPE_NAMES) { // NOLINT - if (kv.second == name) { - return kv.first; - } - } - throw std::runtime_error(format("Unknown projector type: %s", name.c_str())); -} - #ifdef CLIP_DEBUG_FUNCTIONS static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) { std::ofstream file(filename, std::ios::binary); @@ -418,6 +162,11 @@ static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u // clip layers // +enum patch_merge_type { + PATCH_MERGE_FLAT, + PATCH_MERGE_SPATIAL_UNPAD, +}; + struct clip_hparams { int32_t image_size; int32_t patch_size; @@ -427,9 +176,9 @@ struct clip_hparams { int32_t n_head; int32_t n_layer; - float eps; + patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT; - char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) + float eps; std::vector image_grid_pinpoints; int32_t image_crop_resolution; @@ -438,44 +187,44 @@ struct clip_hparams { struct clip_layer { // attention - struct ggml_tensor * k_w; - struct ggml_tensor * k_b; - struct ggml_tensor * q_w; - struct ggml_tensor * q_b; - struct ggml_tensor * v_w; - struct ggml_tensor * v_b; + struct ggml_tensor * k_w = nullptr; + struct ggml_tensor * k_b = nullptr; + struct ggml_tensor * q_w = nullptr; + struct ggml_tensor * q_b = nullptr; + struct ggml_tensor * v_w = nullptr; + struct ggml_tensor * v_b = nullptr; - struct ggml_tensor * o_w; - struct ggml_tensor * o_b; + struct ggml_tensor * o_w = nullptr; + struct ggml_tensor * o_b = nullptr; // layernorm 1 - struct ggml_tensor * ln_1_w; - struct ggml_tensor * ln_1_b; + struct ggml_tensor * ln_1_w = nullptr; + struct ggml_tensor * ln_1_b = nullptr; // ff - struct ggml_tensor * ff_i_w; - struct ggml_tensor * ff_i_b; + struct ggml_tensor * ff_i_w = nullptr; + struct ggml_tensor * ff_i_b = nullptr; - struct ggml_tensor * ff_o_w; - struct ggml_tensor * ff_o_b; + struct ggml_tensor * ff_o_w = nullptr; + struct ggml_tensor * ff_o_b = nullptr; // layernorm 2 - struct ggml_tensor * ln_2_w; - struct ggml_tensor * ln_2_b; + struct ggml_tensor * ln_2_w = nullptr; + struct ggml_tensor * ln_2_b = nullptr; }; struct clip_vision_model { struct clip_hparams hparams; // embeddings - struct ggml_tensor * class_embedding; - struct ggml_tensor * patch_embeddings_0; - struct ggml_tensor * patch_embeddings_1; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) - struct ggml_tensor * patch_bias; - struct ggml_tensor * position_embeddings; + struct ggml_tensor * class_embedding = nullptr; + struct ggml_tensor * patch_embeddings_0 = nullptr; + struct ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) + struct ggml_tensor * patch_bias = nullptr; + struct ggml_tensor * position_embeddings = nullptr; - struct ggml_tensor * pre_ln_w; - struct ggml_tensor * pre_ln_b; + struct ggml_tensor * pre_ln_w = nullptr; + struct ggml_tensor * pre_ln_b = nullptr; std::vector layers; @@ -485,84 +234,84 @@ struct clip_vision_model { struct ggml_tensor * projection; // LLaVA projection - struct ggml_tensor * mm_0_w = NULL; - struct ggml_tensor * mm_0_b = NULL; - struct ggml_tensor * mm_2_w = NULL; - struct ggml_tensor * mm_2_b = NULL; + struct ggml_tensor * mm_0_w = nullptr; + struct ggml_tensor * mm_0_b = nullptr; + struct ggml_tensor * mm_2_w = nullptr; + struct ggml_tensor * mm_2_b = nullptr; - struct ggml_tensor * image_newline = NULL; + struct ggml_tensor * image_newline = nullptr; // Yi type models with mlp+normalization projection - struct ggml_tensor * mm_1_w = NULL; // Yi type models have 0, 1, 3, 4 - struct ggml_tensor * mm_1_b = NULL; - struct ggml_tensor * mm_3_w = NULL; - struct ggml_tensor * mm_3_b = NULL; - struct ggml_tensor * mm_4_w = NULL; - struct ggml_tensor * mm_4_b = NULL; + struct ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 + struct ggml_tensor * mm_1_b = nullptr; + struct ggml_tensor * mm_3_w = nullptr; + struct ggml_tensor * mm_3_b = nullptr; + struct ggml_tensor * mm_4_w = nullptr; + struct ggml_tensor * mm_4_b = nullptr; //GLMV-Edge projection - struct ggml_tensor * mm_model_adapter_conv_w; - struct ggml_tensor * mm_model_adapter_conv_b; - struct ggml_tensor * boi_w; - struct ggml_tensor * eoi_w; + struct ggml_tensor * mm_model_adapter_conv_w = nullptr; + struct ggml_tensor * mm_model_adapter_conv_b = nullptr; + struct ggml_tensor * boi_w = nullptr; + struct ggml_tensor * eoi_w = nullptr; // MobileVLM projection - struct ggml_tensor * mm_model_mlp_1_w; - struct ggml_tensor * mm_model_mlp_1_b; - struct ggml_tensor * mm_model_mlp_3_w; - struct ggml_tensor * mm_model_mlp_3_b; - struct ggml_tensor * mm_model_block_1_block_0_0_w; - struct ggml_tensor * mm_model_block_1_block_0_1_w; - struct ggml_tensor * mm_model_block_1_block_0_1_b; - struct ggml_tensor * mm_model_block_1_block_1_fc1_w; - struct ggml_tensor * mm_model_block_1_block_1_fc1_b; - struct ggml_tensor * mm_model_block_1_block_1_fc2_w; - struct ggml_tensor * mm_model_block_1_block_1_fc2_b; - struct ggml_tensor * mm_model_block_1_block_2_0_w; - struct ggml_tensor * mm_model_block_1_block_2_1_w; - struct ggml_tensor * mm_model_block_1_block_2_1_b; - struct ggml_tensor * mm_model_block_2_block_0_0_w; - struct ggml_tensor * mm_model_block_2_block_0_1_w; - struct ggml_tensor * mm_model_block_2_block_0_1_b; - struct ggml_tensor * mm_model_block_2_block_1_fc1_w; - struct ggml_tensor * mm_model_block_2_block_1_fc1_b; - struct ggml_tensor * mm_model_block_2_block_1_fc2_w; - struct ggml_tensor * mm_model_block_2_block_1_fc2_b; - struct ggml_tensor * mm_model_block_2_block_2_0_w; - struct ggml_tensor * mm_model_block_2_block_2_1_w; - struct ggml_tensor * mm_model_block_2_block_2_1_b; + struct ggml_tensor * mm_model_mlp_1_w = nullptr; + struct ggml_tensor * mm_model_mlp_1_b = nullptr; + struct ggml_tensor * mm_model_mlp_3_w = nullptr; + struct ggml_tensor * mm_model_mlp_3_b = nullptr; + struct ggml_tensor * mm_model_block_1_block_0_0_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_0_1_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_0_1_b = nullptr; + struct ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr; + struct ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr; + struct ggml_tensor * mm_model_block_1_block_2_0_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_2_1_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_2_1_b = nullptr; + struct ggml_tensor * mm_model_block_2_block_0_0_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_0_1_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_0_1_b = nullptr; + struct ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr; + struct ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr; + struct ggml_tensor * mm_model_block_2_block_2_0_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_2_1_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_2_1_b = nullptr; // MobileVLM_V2 projection - struct ggml_tensor * mm_model_mlp_0_w; - struct ggml_tensor * mm_model_mlp_0_b; - struct ggml_tensor * mm_model_mlp_2_w; - struct ggml_tensor * mm_model_mlp_2_b; - struct ggml_tensor * mm_model_peg_0_w; - struct ggml_tensor * mm_model_peg_0_b; + struct ggml_tensor * mm_model_mlp_0_w = nullptr; + struct ggml_tensor * mm_model_mlp_0_b = nullptr; + struct ggml_tensor * mm_model_mlp_2_w = nullptr; + struct ggml_tensor * mm_model_mlp_2_b = nullptr; + struct ggml_tensor * mm_model_peg_0_w = nullptr; + struct ggml_tensor * mm_model_peg_0_b = nullptr; // MINICPMV projection - struct ggml_tensor * mm_model_pos_embed_k; - struct ggml_tensor * mm_model_query; - struct ggml_tensor * mm_model_proj; - struct ggml_tensor * mm_model_kv_proj; - struct ggml_tensor * mm_model_attn_q_w; - struct ggml_tensor * mm_model_attn_q_b; - struct ggml_tensor * mm_model_attn_k_w; - struct ggml_tensor * mm_model_attn_k_b; - struct ggml_tensor * mm_model_attn_v_w; - struct ggml_tensor * mm_model_attn_v_b; - struct ggml_tensor * mm_model_attn_o_w; - struct ggml_tensor * mm_model_attn_o_b; - struct ggml_tensor * mm_model_ln_q_w; - struct ggml_tensor * mm_model_ln_q_b; - struct ggml_tensor * mm_model_ln_kv_w; - struct ggml_tensor * mm_model_ln_kv_b; - struct ggml_tensor * mm_model_ln_post_w; - struct ggml_tensor * mm_model_ln_post_b; + struct ggml_tensor * mm_model_pos_embed_k = nullptr; + struct ggml_tensor * mm_model_query = nullptr; + struct ggml_tensor * mm_model_proj = nullptr; + struct ggml_tensor * mm_model_kv_proj = nullptr; + struct ggml_tensor * mm_model_attn_q_w = nullptr; + struct ggml_tensor * mm_model_attn_q_b = nullptr; + struct ggml_tensor * mm_model_attn_k_w = nullptr; + struct ggml_tensor * mm_model_attn_k_b = nullptr; + struct ggml_tensor * mm_model_attn_v_w = nullptr; + struct ggml_tensor * mm_model_attn_v_b = nullptr; + struct ggml_tensor * mm_model_attn_o_w = nullptr; + struct ggml_tensor * mm_model_attn_o_b = nullptr; + struct ggml_tensor * mm_model_ln_q_w = nullptr; + struct ggml_tensor * mm_model_ln_q_b = nullptr; + struct ggml_tensor * mm_model_ln_kv_w = nullptr; + struct ggml_tensor * mm_model_ln_kv_b = nullptr; + struct ggml_tensor * mm_model_ln_post_w = nullptr; + struct ggml_tensor * mm_model_ln_post_b = nullptr; // gemma3 - struct ggml_tensor * mm_input_proj_w; - struct ggml_tensor * mm_soft_emb_norm_w; + struct ggml_tensor * mm_input_proj_w = nullptr; + struct ggml_tensor * mm_soft_emb_norm_w = nullptr; }; struct clip_ctx { @@ -584,11 +333,6 @@ struct clip_ctx { bool use_silu = false; int32_t ftype = 1; - bool has_class_embedding = true; - bool has_pre_norm = true; - bool has_post_norm = false; - bool has_patch_bias = false; - struct gguf_context * ctx_gguf = nullptr; struct ggml_context * ctx_data = nullptr; @@ -711,8 +455,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head)); - KQ = ggml_soft_max_inplace(ctx0, KQ); + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); @@ -751,7 +494,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im } // post-layernorm - if (ctx->has_post_norm) { + if (model.post_ln_w) { embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "post_ln"); @@ -827,7 +570,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int patches_w = image_size_width / patch_size; const int patches_h = image_size_height / patch_size; - const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions; const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; @@ -879,7 +622,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); } - if (ctx->has_patch_bias) { + if (model.patch_bias) { // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); } @@ -888,7 +631,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im if (ctx->has_llava_projector) { // concat class_embeddings and patch_embeddings - if (ctx->has_class_embedding) { + if (model.class_embedding) { embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); ggml_set_name(embeddings, "embeddings"); ggml_set_input(embeddings); @@ -925,7 +668,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } // pre-layernorm - if (ctx->has_pre_norm) { + if (model.pre_ln_w) { embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "pre_ln"); @@ -967,7 +710,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im ctx0, Q, positions, nullptr, d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); } - Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); @@ -991,7 +733,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_inplace(ctx0, KQ); + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -1035,7 +777,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } // post-layernorm - if (ctx->has_post_norm) { + if (model.post_ln_w) { embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "post_ln"); @@ -1075,8 +817,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_gelu(ctx0, embeddings); - embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + if (model.mm_2_w) { + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + } } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -1279,7 +1023,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); - Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); // permute @@ -1293,7 +1036,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_inplace(ctx0, KQ); + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -1369,547 +1112,395 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } } -// read and create ggml_context containing the tensors and their data -struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { - return clip_init(fname, clip_context_params{ - /* use_gpu */ true, - /* verbosity */ verbosity, - }); -} +struct clip_model_loader { + ggml_context_ptr ctx_meta; + gguf_context_ptr ctx_gguf; -struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) { - int verbosity = ctx_params.verbosity; - struct ggml_context * meta = NULL; + clip_ctx & ctx_clip; + std::string fname; - struct gguf_init_params params = { - /*.no_alloc = */ true, - /*.ctx = */ &meta, - }; + size_t model_size; // in bytes - struct gguf_context * ctx = gguf_init_from_file(fname, params); - if (!ctx) { - throw std::runtime_error(format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname)); - } - - if (verbosity >= 1) { - const int n_tensors = gguf_get_n_tensors(ctx); - const int n_kv = gguf_get_n_kv(ctx); - const int ftype = get_u32(ctx, KEY_FTYPE); - const std::string ftype_str = get_ftype(ftype); - const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION); - const std::string description = gguf_get_val_str(ctx, idx_desc); - const int idx_name = gguf_find_key(ctx, KEY_NAME); - if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug - const std::string name = gguf_get_val_str(ctx, idx_name); - LOG_INF("%s: model name: %s\n", __func__, name.c_str()); - } - LOG_INF("%s: description: %s\n", __func__, description.c_str()); - LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx)); - LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); - LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors); - LOG_INF("%s: n_kv: %d\n", __func__, n_kv); - LOG_INF("%s: ftype: %s\n", __func__, ftype_str.c_str()); - LOG_INF("\n"); - } - const int n_tensors = gguf_get_n_tensors(ctx); - - // kv - const int n_kv = gguf_get_n_kv(ctx); - LOG_INF("%s: loaded meta data with %d key-value pairs and %d tensors from %s\n", - __func__, n_kv, n_tensors, fname); - { - std::map n_type; + // TODO @ngxson : we should not pass clip_ctx here, it should be clip_vision_model + clip_model_loader(const char * fname, clip_ctx & ctx_clip) : ctx_clip(ctx_clip), fname(fname) { + struct ggml_context * meta = nullptr; - for (int i = 0; i < n_tensors; i++) { - enum ggml_type type = gguf_get_tensor_type(ctx, i); + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &meta, + }; - n_type[type]++; + ctx_gguf = gguf_context_ptr(gguf_init_from_file(fname, params)); + if (!ctx_gguf.get()) { + throw std::runtime_error(string_format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname)); } - LOG_INF("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); - for (int i = 0; i < n_kv; i++) { - const char * name = gguf_get_key(ctx, i); - const enum gguf_type type = gguf_get_kv_type(ctx, i); - const std::string type_name = - type == GGUF_TYPE_ARRAY - ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(ctx, i)), gguf_get_arr_n(ctx, i)) - : gguf_type_name(type); + ctx_meta.reset(meta); - std::string value = gguf_kv_to_str(ctx, i); - const size_t MAX_VALUE_LEN = 40; - if (value.size() > MAX_VALUE_LEN) { - value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); - } - replace_all(value, "\n", "\\n"); + const int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); - LOG_INF("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + // print gguf info + { + int ftype = -1; + get_u32(KEY_FTYPE, ftype, false); + const std::string ftype_str = ggml_type_name(static_cast(ftype)); + std::string name; + get_string(KEY_NAME, name, false); + std::string description; + get_string(KEY_DESCRIPTION, description, false); + LOG_INF("%s: model name: %s\n", __func__, name.c_str()); + LOG_INF("%s: description: %s\n", __func__, description.c_str()); + LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx_gguf.get())); + LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx_gguf.get())); + LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors); + LOG_INF("%s: n_kv: %d\n", __func__, (int)gguf_get_n_kv(ctx_gguf.get())); + LOG_INF("%s: ftype: %s\n", __func__, ftype_str.c_str()); + LOG_INF("\n"); } - // print type counts - for (auto & kv : n_type) { - if (kv.second == 0) { - continue; + // tensors + { + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf.get(), i); + const size_t offset = gguf_get_tensor_offset(ctx_gguf.get(), i); + enum ggml_type type = gguf_get_tensor_type(ctx_gguf.get(), i); + struct ggml_tensor * cur = ggml_get_tensor(meta, name); + size_t tensor_size = ggml_nbytes(cur); + model_size += tensor_size; + LOG_DBG("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n", + __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type)); } - - LOG_INF("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); } } - // data - size_t model_size = 0; - { - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx, i); - const size_t offset = gguf_get_tensor_offset(ctx, i); - enum ggml_type type = gguf_get_tensor_type(ctx, i); - struct ggml_tensor * cur = ggml_get_tensor(meta, name); - size_t tensor_size = ggml_nbytes(cur); - model_size += tensor_size; - if (verbosity >= 3) { - LOG_INF("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n", - __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type)); + void load_hparams() { + // projector type + { + std::string proj_type; + get_string(KEY_PROJ_TYPE, proj_type, false); + if (!proj_type.empty()) { + ctx_clip.proj_type = clip_projector_type_from_string(proj_type); + } + if (ctx_clip.proj_type == PROJECTOR_TYPE_UNKNOWN) { + throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str())); } } - } - clip_ctx * new_clip = new clip_ctx(ctx_params); + // other hparams + { + get_bool(KEY_HAS_TEXT_ENC, ctx_clip.has_text_encoder, false); + get_bool(KEY_HAS_VIS_ENC, ctx_clip.has_vision_encoder, false); + GGML_ASSERT(ctx_clip.has_vision_encoder); + GGML_ASSERT(!ctx_clip.has_text_encoder); + + // legacy keys, use KEY_PROJ_TYPE instead + get_bool(KEY_HAS_LLAVA_PROJ, ctx_clip.has_llava_projector, false); + get_bool(KEY_HAS_MINICPMV_PROJ, ctx_clip.has_minicpmv_projector, false); + get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); + get_bool(KEY_HAS_GLM_PROJ, ctx_clip.has_glm_projector, false); + get_bool(KEY_HAS_QWEN2VL_MERGER, ctx_clip.has_qwen2vl_merger, false); + // !!! do NOT extend the list above, use KEY_PROJ_TYPE instead + + get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); + get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); + + auto & hparams = ctx_clip.vision_model.hparams; + get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size); + get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head); + get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate); + get_u32(string_format(KEY_N_BLOCK, "vision"), hparams.n_layer); + get_u32(string_format(KEY_PROJ_DIM, "vision"), hparams.projection_dim); + get_f32(string_format(KEY_LAYER_NORM_EPS, "vision"), hparams.eps); + get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_PATCH_SIZE, hparams.patch_size); + get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); + get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); - // update projector type - { - int idx = gguf_find_key(ctx, KEY_PROJ_TYPE); - if (idx != -1) { - const std::string proj_type = gguf_get_val_str(ctx, idx); - new_clip->proj_type = clip_projector_type_from_string(proj_type); - } else { - new_clip->proj_type = PROJECTOR_TYPE_MLP; - } + { + std::string mm_patch_merge_type; + get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false); + if (mm_patch_merge_type == "spatial_unpad") { + hparams.mm_patch_merge_type = PATCH_MERGE_SPATIAL_UNPAD; + } + } + + { + int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN); + int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD); + GGML_ASSERT(idx_mean >= 0 && "image_mean not found"); + GGML_ASSERT(idx_std >= 0 && "image_std not found"); + const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean); + const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std); + for (int i = 0; i < 3; ++i) { + ctx_clip.image_mean[i] = mean_data[i]; + ctx_clip.image_std[i] = std_data[i]; + } + } - if (new_clip->proj_type == PROJECTOR_TYPE_MLP) { - if (gguf_find_tensor(ctx, format(TN_LLAVA_PROJ, 3, "weight").c_str()) != -1) { - new_clip->proj_type = PROJECTOR_TYPE_MLP_NORM; + // Load the vision feature layer indices if they are explicitly provided; + // if multiple vision feature layers are present, the values will be concatenated + // to form the final visual features. + // NOTE: gguf conversions should standardize the values of the vision feature layer to + // be non-negative, since we use -1 to mark values as unset here. + std::vector vision_feature_layer; + get_arr_int(KEY_FEATURE_LAYER, vision_feature_layer, false); + // convert std::vector to std::unordered_set + for (auto & layer : vision_feature_layer) { + hparams.vision_feature_layer.insert(layer); } + // Calculate the deepest feature layer based on hparams and projector type + ctx_clip.max_feature_layer = get_deepest_feature_layer(&ctx_clip); + + LOG_INF("%s: text_encoder: %d\n", __func__, ctx_clip.has_text_encoder); + LOG_INF("%s: vision_encoder: %d\n", __func__, ctx_clip.has_vision_encoder); + LOG_INF("%s: llava_projector: %d\n", __func__, ctx_clip.has_llava_projector); + LOG_INF("%s: minicpmv_projector: %d\n", __func__, ctx_clip.has_minicpmv_projector); + LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version); + LOG_INF("%s: glm_projector: %d\n", __func__, ctx_clip.has_glm_projector); + LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); + LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); } } - // model size and capabilities - { - int idx = get_key_idx(ctx, KEY_HAS_TEXT_ENC); - new_clip->has_text_encoder = gguf_get_val_bool(ctx, idx); - - idx = get_key_idx(ctx, KEY_HAS_VIS_ENC); - new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx); + void load_tensors() { + std::map tensor_offset; + std::vector tensors_to_load; - idx = gguf_find_key(ctx, KEY_HAS_LLAVA_PROJ); - if (idx != -1) { - new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx); + // get offsets + for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf.get(), i); + tensor_offset[name] = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), i); } - idx = gguf_find_key(ctx, KEY_HAS_MINICPMV_PROJ); - if (idx != -1) { - new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx); + // create data context + struct ggml_init_params params = { + /*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ctx_clip.ctx_data = ggml_init(params); + if (!ctx_clip.ctx_data) { + throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__)); } - idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION); - if (idx != -1) { - new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); - } + // helper function + auto get_tensor = [&](const std::string & name, bool required = true) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str()); + if (!cur && required) { + throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str())); + } + if (cur) { + tensors_to_load.push_back(cur); + // add tensors to context + struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data, cur); + ggml_set_name(data_tensor, cur->name); + cur = data_tensor; + } + return cur; + }; - idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ); - if (idx != -1) { - new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx); - } + auto & vision_model = ctx_clip.vision_model; - idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER); - if (idx != -1) { - new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx); - } - // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search + vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false); - GGML_ASSERT(new_clip->has_vision_encoder); - GGML_ASSERT(!new_clip->has_text_encoder); + vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false); + vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false); - try { - idx = get_key_idx(ctx, KEY_USE_GELU); - new_clip->use_gelu = gguf_get_val_bool(ctx, idx); - } catch (std::runtime_error & /*e*/) { - new_clip->use_gelu = false; - } + vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false); + vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false); - try { - idx = get_key_idx(ctx, KEY_USE_SILU); - new_clip->use_silu = gguf_get_val_bool(ctx, idx); - } catch (std::runtime_error & /*e*/) { - new_clip->use_silu = false; + vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false); + vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); + vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); + if (vision_model.patch_embeddings_1 == nullptr) { + ctx_clip.has_qwen2vl_merger = false; } - if (verbosity >= 1) { - LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); - LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); - LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); - LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); - LOG_INF("%s: minicpmv_version: %d\n", __func__, new_clip->minicpmv_version); - LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector); - LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); - LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); - } - } + vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false); - LOG_INF("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors); + // layers + vision_model.layers.resize(vision_model.hparams.n_layer); + for (int il = 0; il < vision_model.hparams.n_layer; ++il) { + auto & layer = vision_model.layers[il]; + layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight")); + layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight")); + layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight")); + layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight")); + layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false); + layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); + layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); + layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); + layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); + layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); + layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); + layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false); + layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); + layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); + layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); + layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); + } + + switch (ctx_clip.proj_type) { + case PROJECTOR_TYPE_MLP: + case PROJECTOR_TYPE_MLP_NORM: + { + // LLaVA projection + vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false); + vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + // Yi-type llava + vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false); + vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + // missing in Yi-type llava + vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false); + vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + // Yi-type llava + vision_model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false); + vision_model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false); + vision_model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false); + vision_model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false); + if (vision_model.mm_3_w) { + // TODO: this is a hack to support Yi-type llava + ctx_clip.proj_type = PROJECTOR_TYPE_MLP_NORM; + } + vision_model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); + } break; + case PROJECTOR_TYPE_LDP: + { + // MobileVLM projection + vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + vision_model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); + vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + vision_model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); + vision_model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); + vision_model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); + vision_model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); + vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); + vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); + vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); + vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); + vision_model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); + vision_model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); + vision_model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); + vision_model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); + vision_model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); + vision_model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); + vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); + vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); + vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); + vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); + vision_model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); + vision_model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); + vision_model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); + } break; + case PROJECTOR_TYPE_LDPV2: + { + // MobilVLM_V2 projection + vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + vision_model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + vision_model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias")); + vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight")); + vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias")); + } break; + case PROJECTOR_TYPE_RESAMPLER: + { + // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); + vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); + vision_model.mm_model_query = get_tensor(TN_MINICPMV_QUERY); + vision_model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ); + vision_model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ); + vision_model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight")); + vision_model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight")); + vision_model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight")); + vision_model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias")); + vision_model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias")); + vision_model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias")); + vision_model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight")); + vision_model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias")); + vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight")); + vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias")); + vision_model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight")); + vision_model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias")); + vision_model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight")); + vision_model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias")); + } break; + case PROJECTOR_TYPE_GLM_EDGE: + { + vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight")); + vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias")); + vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR,"weight")); + vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"weight")); + vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"bias")); + vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); + vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight")); + vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); + vision_model.boi_w = get_tensor(TN_GLM_BOI_W); + vision_model.eoi_w = get_tensor(TN_GLM_EOI_W); + } break; + case PROJECTOR_TYPE_MERGER: + { + vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; + case PROJECTOR_TYPE_GEMMA3: + { + vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); + vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); + } break; + default: + GGML_ASSERT(false && "unknown projector type"); + } - // load tensors - { - std::vector read_buf; - struct ggml_init_params params = { - /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; + // load data + { + std::vector read_buf; - new_clip->ctx_data = ggml_init(params); - if (!new_clip->ctx_data) { - LOG_ERR("%s: ggml_init() failed\n", __func__); - clip_free(new_clip); - gguf_free(ctx); - return nullptr; - } - - auto fin = std::ifstream(fname, std::ios::binary); - if (!fin) { - LOG_ERR("cannot open model file for loading tensors\n"); - clip_free(new_clip); - gguf_free(ctx); - return nullptr; - } - - // add tensors to context - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx, i); - struct ggml_tensor * t = ggml_get_tensor(meta, name); - struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx_data, t); - ggml_set_name(cur, name); - } - - // alloc memory and offload data - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(new_clip->backend); - new_clip->buf = ggml_backend_alloc_ctx_tensors_from_buft(new_clip->ctx_data, buft); - ggml_backend_buffer_set_usage(new_clip->buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx, i); - struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name); - const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i); - fin.seekg(offset, std::ios::beg); + auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { - LOG_ERR("%s: failed to seek for tensor %s\n", __func__, name); - clip_free(new_clip); - gguf_free(ctx); - return nullptr; + throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); } - int num_bytes = ggml_nbytes(cur); - if (ggml_backend_buft_is_host(buft)) { - // for the CPU and Metal backend, we can read directly into the tensor - fin.read(reinterpret_cast(cur->data), num_bytes); - } else { - // read into a temporary buffer first, then copy to device memory - read_buf.resize(num_bytes); - fin.read(reinterpret_cast(read_buf.data()), num_bytes); - ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); - } - } - fin.close(); - } - - // vision model - if (new_clip->has_vision_encoder) { - // load vision model - auto & vision_model = new_clip->vision_model; - auto & hparams = vision_model.hparams; - hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision")); - hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision")); - hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision")); - hparams.n_layer = get_u32(ctx, format(KEY_N_BLOCK, "vision")); - hparams.image_size = get_u32(ctx, KEY_IMAGE_SIZE); - hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE); - hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision")); - hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision")); - - try { - int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); - int n = gguf_get_arr_n(ctx, idx); - const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - for (int i = 0; i < n; ++i) { - hparams.image_grid_pinpoints.push_back(pinpoints[i]); - } - } catch (std::runtime_error & /*e*/) { } - // Load the vision feature layer indices if they are explicitly provided; - // if multiple vision feature layers are present, the values will be concatenated - // to form the final visual features. - // NOTE: gguf conversions should standardize the values of the vision feature layer to - // be non-negative, since we use -1 to mark values as unset here. - try { - int idx = get_key_idx(ctx, KEY_FEATURE_LAYER); - int n = gguf_get_arr_n(ctx, idx); - - const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); - - for (int i = 0; i < n; ++i) { - hparams.vision_feature_layer.insert(vision_feature_layer[i]); - } - } catch (std::runtime_error & /*e*/) { } - - try { - int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); - strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx)); - } catch (std::runtime_error & /*e*/) { - strcpy(hparams.mm_patch_merge_type, "flat"); - } - - try { - hparams.image_crop_resolution = get_u32(ctx, KEY_IMAGE_CROP_RESOLUTION); // llava-1.6 - } catch(const std::exception& /*e*/) { - hparams.image_crop_resolution = hparams.image_size; - } - - int idx_mean = get_key_idx(ctx, KEY_IMAGE_MEAN); - int idx_std = get_key_idx(ctx, KEY_IMAGE_STD); - - const float * mean_data = (const float *)gguf_get_arr_data(ctx, idx_mean); - const float * std_data = (const float *)gguf_get_arr_data(ctx, idx_std); - - for (int i = 0; i < 3; ++i) { - new_clip->image_mean[i] = mean_data[i]; - new_clip->image_std[i] = std_data[i]; - } - - // Calculate the deepest feature layer based on hparams and projector type - new_clip->max_feature_layer = get_deepest_feature_layer(new_clip); - - if (verbosity >= 2) { - LOG_INF("\n%s: vision model hparams\n", __func__); - LOG_INF("image_size %d\n", hparams.image_size); - LOG_INF("patch_size %d\n", hparams.patch_size); - LOG_INF("v_hidden_size %d\n", hparams.hidden_size); - LOG_INF("v_n_intermediate %d\n", hparams.n_intermediate); - LOG_INF("v_projection_dim %d\n", hparams.projection_dim); - LOG_INF("v_n_head %d\n", hparams.n_head); - LOG_INF("v_n_layer %d\n", hparams.n_layer); - LOG_INF("v_eps %f\n", hparams.eps); - LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); - LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); - LOG_INF("v_image_grid_pinpoints: "); - for (const auto & pp : hparams.image_grid_pinpoints) { - LOG_INF("%d ", pp); - } - LOG_INF("\n"); - LOG_INF("v_vision_feature_layer: "); - for (const auto & feature_layer: hparams.vision_feature_layer) { - LOG_INF("%d ", feature_layer); + // alloc memory and offload data + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); + ctx_clip.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data, buft); + ggml_backend_buffer_set_usage(ctx_clip.buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + for (auto & t : tensors_to_load) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data, t->name); + const size_t offset = tensor_offset[t->name]; + fin.seekg(offset, std::ios::beg); + if (!fin) { + throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name)); + } + size_t num_bytes = ggml_nbytes(cur); + if (ggml_backend_buft_is_host(buft)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } } - LOG_INF("\n"); - LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); - - } - - try { - vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD); - new_clip->has_class_embedding = true; - } catch (const std::exception& /*e*/) { - new_clip->has_class_embedding = false; - } - - try { - vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight")); - vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); - new_clip->has_pre_norm = true; - } catch (std::exception & /*e*/) { - new_clip->has_pre_norm = false; - } - - try { - vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight")); - vision_model.post_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias")); - new_clip->has_post_norm = true; - } catch (std::exception & /*e*/) { - new_clip->has_post_norm = false; - } - - try { - vision_model.patch_bias = get_tensor(new_clip->ctx_data, TN_PATCH_BIAS); - new_clip->has_patch_bias = true; - } catch (std::exception & /*e*/) { - new_clip->has_patch_bias = false; - } - - try { - vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); - } catch(const std::exception& /*e*/) { - vision_model.patch_embeddings_0 = nullptr; - } - - try { - vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v")); - } catch(const std::exception& /*e*/) { - vision_model.position_embeddings = nullptr; - } - - try { - vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1); - } catch(const std::exception& /*e*/) { - new_clip->has_qwen2vl_merger = false; - } - - // LLaVA projection - if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { - vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); - vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); - try { - // Yi-type llava - vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "weight")); - vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "bias")); - } catch (std::runtime_error & /*e*/) { } - try { - // missing in Yi-type llava - vision_model.mm_2_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_2_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); - } catch (std::runtime_error & /*e*/) { } - try { - // Yi-type llava - vision_model.mm_3_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "weight")); - vision_model.mm_3_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "bias")); - } catch (std::runtime_error & /*e*/) { } - try { - // Yi-type llava - vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight")); - vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias")); - } catch (std::runtime_error & /*e*/) { } - try { - vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE); - // LOG_INF("%s: image_newline tensor (llava-1.6) found\n", __func__); - } catch (std::runtime_error & /*e*/) { } - } else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) { - // MobileVLM projection - vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_model_mlp_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "bias")); - vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "weight")); - vision_model.mm_model_mlp_3_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "bias")); - vision_model.mm_model_block_1_block_0_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); - vision_model.mm_model_block_1_block_0_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); - vision_model.mm_model_block_1_block_0_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); - vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); - vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); - vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); - vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); - vision_model.mm_model_block_1_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); - vision_model.mm_model_block_1_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); - vision_model.mm_model_block_1_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); - vision_model.mm_model_block_2_block_0_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); - vision_model.mm_model_block_2_block_0_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); - vision_model.mm_model_block_2_block_0_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); - vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); - vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); - vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); - vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); - vision_model.mm_model_block_2_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); - vision_model.mm_model_block_2_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); - vision_model.mm_model_block_2_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_LDPV2) - { - // MobilVLM_V2 projection - vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "weight")); - vision_model.mm_model_mlp_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "bias")); - vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "weight")); - vision_model.mm_model_mlp_2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "bias")); - vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight")); - vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias")); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_RESAMPLER) { - // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); - vision_model.mm_model_pos_embed_k = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD_K); - vision_model.mm_model_query = get_tensor(new_clip->ctx_data, TN_MINICPMV_QUERY); - vision_model.mm_model_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_PROJ); - vision_model.mm_model_kv_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_KV_PROJ); - vision_model.mm_model_attn_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "weight")); - vision_model.mm_model_attn_k_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "weight")); - vision_model.mm_model_attn_v_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "weight")); - vision_model.mm_model_attn_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "bias")); - vision_model.mm_model_attn_k_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "bias")); - vision_model.mm_model_attn_v_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "bias")); - vision_model.mm_model_attn_o_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "weight")); - vision_model.mm_model_attn_o_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "bias")); - vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "weight")); - vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "bias")); - vision_model.mm_model_ln_kv_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "weight")); - vision_model.mm_model_ln_kv_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "bias")); - vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); - vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_GLM_EDGE) { - vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight")); - vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias")); - vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight")); - vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight")); - vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias")); - vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); - vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight")); - vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); - vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W); - vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) { - vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); - vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); - vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) { - vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ); - vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N); - } - else { - std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; - throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); - } + fin.close(); - vision_model.layers.resize(hparams.n_layer); + LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str()); + } + } - for (int il = 0; il < hparams.n_layer; ++il) { - auto & layer = vision_model.layers[il]; - layer.k_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_K, "v", il, "weight")); - layer.q_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q, "v", il, "weight")); - layer.v_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_V, "v", il, "weight")); - layer.o_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "weight")); - layer.ln_1_w = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "weight")); - layer.ln_2_w = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "weight")); - layer.ff_i_w = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN, "v", il, "weight")); - layer.ff_o_w = get_tensor(new_clip->ctx_data, format(TN_FFN_UP, "v", il, "weight")); - layer.k_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_K, "v", il, "bias")); - layer.q_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q, "v", il, "bias")); - layer.v_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_V, "v", il, "bias")); - layer.o_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "bias")); - layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "bias")); - layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "bias")); - layer.ff_i_b = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN, "v", il, "bias")); - layer.ff_o_b = get_tensor(new_clip->ctx_data, format(TN_FFN_UP, "v", il, "bias")); - } - } - - ggml_free(meta); - - new_clip->ctx_gguf = ctx; - - // measure mem requirement and allocate - { - new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); + void alloc_compute_meta() { + ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); clip_image_f32_batch batch; batch.size = 1; batch.data = nullptr; - ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); - ggml_backend_sched_reserve(new_clip->sched.get(), gf); - for (size_t i = 0; i < new_clip->backend_ptrs.size(); ++i) { - ggml_backend_t backend = new_clip->backend_ptrs[i]; - ggml_backend_buffer_type_t buft = new_clip->backend_buft[i]; - size_t size = ggml_backend_sched_get_buffer_size(new_clip->sched.get(), backend); + ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, &batch, nullptr, false); + ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); + for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) { + ggml_backend_t backend = ctx_clip.backend_ptrs[i]; + ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i]; + size_t size = ggml_backend_sched_get_buffer_size(ctx_clip.sched.get(), backend); if (size > 1) { LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__, ggml_backend_buft_name(buft), @@ -1918,7 +1509,90 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p } } - return new_clip; + void get_bool(const std::string & key, bool & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = gguf_get_val_bool(ctx_gguf.get(), i); + } + + void get_i32(const std::string & key, int & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = gguf_get_val_i32(ctx_gguf.get(), i); + } + + void get_u32(const std::string & key, int & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = gguf_get_val_u32(ctx_gguf.get(), i); + } + + void get_f32(const std::string & key, float & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = gguf_get_val_f32(ctx_gguf.get(), i); + } + + void get_string(const std::string & key, std::string & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = std::string(gguf_get_val_str(ctx_gguf.get(), i)); + } + + void get_arr_int(const std::string & key, std::vector & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + int n = gguf_get_arr_n(ctx_gguf.get(), i); + output.resize(n); + const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i); + for (int i = 0; i < n; ++i) { + output[i] = values[i]; + } + } +}; + +// read and create ggml_context containing the tensors and their data +struct clip_ctx * clip_model_load(const char * fname, const int verbosity) { + return clip_init(fname, clip_context_params{ + /* use_gpu */ true, + /* verbosity */ static_cast(verbosity), + }); +} + +struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) { + g_logger_state.verbosity_thold = ctx_params.verbosity; + clip_ctx * ctx_clip = new clip_ctx(ctx_params); + + try { + clip_model_loader loader(fname, *ctx_clip); + loader.load_hparams(); + loader.load_tensors(); + loader.alloc_compute_meta(); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what()); + delete ctx_clip; + return nullptr; + } + + return ctx_clip; } void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { @@ -2290,7 +1964,7 @@ static std::vector> uhd_slice_image(const clip_imag const int multiple = fmin(ceil(ratio), max_slice_nums); std::vector> images; - LOG_INF("%s: multiple %d\n", __func__, multiple); + LOG_DBG("%s: multiple %d\n", __func__, multiple); images.push_back(std::vector()); if (multiple <= 1) { @@ -2305,17 +1979,17 @@ static std::vector> uhd_slice_image(const clip_imag clip_image_u8 * source_image = clip_image_u8_init(); bicubic_resize(*img, *source_image, best_size.first, best_size.second); // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) - LOG_INF("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); + LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); images[images.size()-1].push_back(source_image); std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); - LOG_INF("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); + LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); clip_image_u8 * refine_image = clip_image_u8_init(); bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); - LOG_INF("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); + LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); // split_to_patches int width = refine_image->nx; @@ -2423,12 +2097,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli bool pad_to_square = true; if (!ctx->has_vision_encoder) { - LOG_ERR("This gguf file seems to have no vision encoder\n"); + LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); return false; } auto & params = ctx->vision_model.hparams; // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing - if (strcmp(params.mm_patch_merge_type, "spatial_unpad") == 0) { + if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) { pad_to_square = false; } // free the previous res_imgs if any set @@ -2624,7 +2298,7 @@ int32_t clip_hidden_size(const struct clip_ctx * ctx) { } const char * clip_patch_merge_type(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.mm_patch_merge_type; + return ctx->vision_model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat"; } const int32_t * clip_image_grid(const struct clip_ctx * ctx) { @@ -2760,7 +2434,7 @@ static std::vector> get_2d_sincos_pos_embed(int embed_dim, co bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { if (!ctx->has_vision_encoder) { - LOG_ERR("This gguf file seems to have no vision encoder\n"); + LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); return false; } @@ -2772,7 +2446,7 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { if (!ctx->has_vision_encoder) { - LOG_ERR("This gguf file seems to have no vision encoder\n"); + LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); return false; } @@ -2808,7 +2482,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); - const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); if(ctx->load_image_size==nullptr){ ctx->load_image_size= clip_image_size_init(); } @@ -2893,16 +2567,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima free(pos_embed_data); } } - else{ - { - if (ctx->has_class_embedding) { - struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); + else { + if (model.class_embedding) { + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); - void* zero_mem = malloc(ggml_nbytes(embeddings)); - memset(zero_mem, 0, ggml_nbytes(embeddings)); - ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); - free(zero_mem); - } + void* zero_mem = malloc(ggml_nbytes(embeddings)); + memset(zero_mem, 0, ggml_nbytes(embeddings)); + ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); + free(zero_mem); } if (ctx->has_qwen2vl_merger) { @@ -2950,7 +2622,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // The patches vector is used to get rows to index into the embeds with; // we should skip dim 0 only if we have CLS to avoid going out of bounds // when retrieving the rows. - int patch_offset = ctx->has_class_embedding ? 1 : 0; + int patch_offset = model.class_embedding ? 1 : 0; int* patches_data = (int*)malloc(ggml_nbytes(patches)); for (int i = 0; i < num_patches; i++) { patches_data[i] = i + patch_offset; @@ -2989,7 +2661,10 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i assert(itype < GGML_TYPE_COUNT); ggml_type type = static_cast(itype); - auto * ctx_clip = clip_model_load(fname_inp, 2); + auto * ctx_clip = clip_init(fname_inp, clip_context_params{ + /* use_gpu */ false, + /* verbosity */ GGML_LOG_LEVEL_ERROR, + }); const auto & ctx_src = ctx_clip->ctx_gguf; const auto & ctx_data = ctx_clip->ctx_data; @@ -3066,7 +2741,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i f32_data = (float *)conv_buf.data(); break; default: - LOG_ERR("Please use an input file in f32 or f16\n"); + LOG_ERR("%s: Please use an input file in f32 or f16\n", __func__); gguf_free(ctx_out); return false; } @@ -3152,7 +2827,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; - throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); + throw std::runtime_error(string_format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); } int clip_is_minicpmv(const struct clip_ctx * ctx) { diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 47059ca1b9f78..783bdca3e09f2 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -1,6 +1,7 @@ #ifndef CLIP_H #define CLIP_H +#include "ggml.h" #include #include @@ -41,7 +42,7 @@ struct clip_image_f32_batch { struct clip_context_params { bool use_gpu; - int verbosity; + ggml_log_level verbosity; }; // deprecated, use clip_init diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index c36bb2eda0c70..4f89c0e15b4e9 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -10,7 +10,7 @@ #include #include -#include +#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include @@ -78,8 +78,12 @@ struct gemma3_context { } void init_clip_model(common_params & params) { - const char * clip_path = params.mmproj.c_str(); - ctx_clip = clip_model_load(clip_path, params.verbosity > 1); + const char * clip_path = params.mmproj.path.c_str(); + ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO); + if (!ctx_clip) { + LOG_ERR("Failed to load CLIP model from %s\n", clip_path); + exit(1); + } } ~gemma3_context() { @@ -232,13 +236,13 @@ int main(int argc, char ** argv) { common_init(); - if (params.mmproj.empty()) { + if (params.mmproj.path.empty()) { show_additional_info(argc, argv); return 1; } gemma3_context ctx(params); - printf("%s: %s\n", __func__, params.model.c_str()); + printf("%s: %s\n", __func__, params.model.path.c_str()); bool is_single_turn = !params.prompt.empty() && !params.image.empty(); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 40aa0876f24a7..0fe0e333a523d 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -225,7 +225,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -234,14 +234,14 @@ static struct llama_model * llava_init(common_params * params) { } static struct llava_context * llava_init_context(common_params * params, llama_model * model) { - const char * clip_path = params->mmproj.c_str(); + const char * clip_path = params->mmproj.path.c_str(); auto prompt = params->prompt; if (prompt.empty()) { prompt = "describe the image in detail."; } - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO); llama_context_params ctx_params = common_context_params_to_llama(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings @@ -283,7 +283,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { + if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { print_usage(argc, argv); return 1; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 12f536cf5cfff..5ad970c220528 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -31,7 +31,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -80,7 +80,7 @@ static void llava_free(struct llava_context * ctx_llava) { } static struct clip_ctx * clip_init_context(common_params * params) { - const char * clip_path = params->mmproj.c_str(); + const char * clip_path = params->mmproj.path.c_str(); auto prompt = params->prompt; if (prompt.empty()) { @@ -88,7 +88,7 @@ static struct clip_ctx * clip_init_context(common_params * params) { } struct clip_context_params clip_params = { /* use_gpu */ params->n_gpu_layers != 0, - /* verbosity */ params->verbosity, + /* verbosity */ GGML_LOG_LEVEL_INFO, // TODO: make this configurable }; auto * ctx_clip = clip_init(clip_path, clip_params); return ctx_clip; @@ -290,7 +290,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.mmproj.empty() || (params.image.empty())) { + if (params.mmproj.path.empty() || (params.image.empty())) { show_additional_info(argc, argv); return 1; } diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 132a7da543c2a..eca7b7f10b9e3 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -314,7 +314,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -323,14 +323,14 @@ static struct llama_model * llava_init(common_params * params) { } static struct llava_context * llava_init_context(common_params * params, llama_model * model) { - const char * clip_path = params->mmproj.c_str(); + const char * clip_path = params->mmproj.path.c_str(); auto prompt = params->prompt; if (prompt.empty()) { prompt = "describe the image in detail."; } - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO); llama_context_params ctx_params = common_context_params_to_llama(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings @@ -524,7 +524,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { + if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { print_usage(argc, argv); return 1; } diff --git a/examples/llava/test-1.jpeg b/examples/llava/test-1.jpeg new file mode 100644 index 0000000000000..7fdcaaf04b24b Binary files /dev/null and b/examples/llava/test-1.jpeg differ diff --git a/examples/llava/tests.sh b/examples/llava/tests.sh new file mode 100755 index 0000000000000..cc9bda8769ca6 --- /dev/null +++ b/examples/llava/tests.sh @@ -0,0 +1,81 @@ +#!/bin/bash + +# make sure we are in the right directory +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +#export LLAMA_CACHE="$SCRIPT_DIR/tmp" + +set -eux + +mkdir -p $SCRIPT_DIR/output + +PROJ_ROOT="$SCRIPT_DIR/../.." +cd $PROJ_ROOT + +############### + +arr_bin=() +arr_hf=() + +add_test() { + local bin=$1 + local hf=$2 + arr_bin+=("$bin") + arr_hf+=("$hf") +} + +add_test "llama-gemma3-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" +add_test "llama-llava-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K" +add_test "llama-llava-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" +add_test "llama-llava-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M" +add_test "llama-llava-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" +add_test "llama-llava-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K" +add_test "llama-llava-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M" +add_test "llama-minicpmv-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted +add_test "llama-minicpmv-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K" +add_test "llama-minicpmv-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0" +add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" + +############### + +cmake --build build -j --target "${arr_bin[@]}" + +arr_res=() + +for i in "${!arr_bin[@]}"; do + bin="${arr_bin[$i]}" + hf="${arr_hf[$i]}" + + echo "Running test with binary: $bin and HF model: $hf" + echo "" + echo "" + + output=$("$PROJ_ROOT/build/bin/$bin" -hf "$hf" --image $SCRIPT_DIR/test-1.jpeg -p "what is the publisher name of the newspaper?" --temp 0 2>&1 | tee /dev/tty) + + echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log + + if echo "$output" | grep -iq "new york"; then + result="\033[32mOK\033[0m: $bin $hf" + else + result="\033[31mFAIL\033[0m: $bin $hf" + fi + echo -e "$result" + arr_res+=("$result") + + echo "" + echo "" + echo "" + echo "#################################################" + echo "#################################################" + echo "" + echo "" +done + +set +x + +for i in "${!arr_res[@]}"; do + echo -e "${arr_res[$i]}" +done +echo "" +echo "Output logs are saved in $SCRIPT_DIR/output" diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 588632f0432b2..80698518e3102 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -106,6 +106,8 @@ int main(int argc, char ** argv) { common_params params; + params.n_predict = 128; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { return 1; } @@ -405,7 +407,7 @@ int main(int argc, char ** argv) { params.prompt_file = "used built-in defaults"; } LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str()); - LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.c_str()); + LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.path.c_str()); LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6); LOG_INF("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index ea3a6c1fca3ee..347ea4a698f2e 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -64,7 +64,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); diff --git a/examples/rpc/CMakeLists.txt b/examples/rpc/CMakeLists.txt index ae48fb98d0913..c2c748148645e 100644 --- a/examples/rpc/CMakeLists.txt +++ b/examples/rpc/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(rpc-server rpc-server.cpp) -target_link_libraries(rpc-server PRIVATE ggml llama) +set(TARGET rpc-server) +add_executable(${TARGET} rpc-server.cpp) +target_link_libraries(${TARGET} PRIVATE ggml) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/rpc/README.md b/examples/rpc/README.md index 312bb634dc920..561f19fda6b06 100644 --- a/examples/rpc/README.md +++ b/examples/rpc/README.md @@ -72,3 +72,14 @@ $ bin/llama-cli -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name This way you can offload model layers to both local and remote devices. +### Local cache + +The RPC server can use a local cache to store large tensors and avoid transferring them over the network. +This can speed up model loading significantly, especially when using large models. +To enable the cache, use the `-c` option: + +```bash +$ bin/rpc-server -c +``` + +By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable. diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index 8b1b23edad174..3d590feb086c6 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -1,3 +1,7 @@ +#if defined(_MSC_VER) +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING +#endif + #include "ggml-cpu.h" #ifdef GGML_USE_CUDA @@ -18,26 +22,142 @@ #include "ggml-rpc.h" #ifdef _WIN32 +# define DIRECTORY_SEPARATOR '\\' +# include # include +# include +# include #else +# define DIRECTORY_SEPARATOR '/' # include +# include #endif +#include #include #include +#include +#include + +namespace fs = std::filesystem; + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +// returns true if successful, false otherwise +static bool fs_create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring_convert> converter; + std::wstring wpath = converter.from_bytes(path); + + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; + } + + size_t pos_slash = 0; + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + const wchar_t * test = subpath.c_str(); + + const bool success = CreateDirectoryW(test, NULL); + if (!success) { + const DWORD error = GetLastError(); + + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; + } + } else { + return false; + } + } + + pos_slash += 1; + } + + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } + } + + pos_slash += 1; + } + + return true; +#endif // _WIN32 +} + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +static std::string fs_get_cache_directory() { + std::string cache_directory = ""; + auto ensure_trailing_slash = [](std::string p) { + // Make sure to add trailing slash + if (p.back() != DIRECTORY_SEPARATOR) { + p += DIRECTORY_SEPARATOR; + } + return p; + }; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + } else { +#ifdef __linux__ + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } else { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("LOCALAPPDATA"); +#endif // __linux__ + cache_directory = ensure_trailing_slash(cache_directory); + cache_directory += "llama.cpp"; + } + return ensure_trailing_slash(cache_directory); +} struct rpc_server_params { std::string host = "127.0.0.1"; int port = 50052; size_t backend_mem = 0; + bool use_cache = false; }; static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); - fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); - fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); + fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, "\n"); } @@ -58,6 +178,8 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (params.port <= 0 || params.port > 65535) { return false; } + } else if (arg == "-c" || arg == "--cache") { + params.use_cache = true; } else if (arg == "-m" || arg == "--mem") { if (++i >= argc) { return false; @@ -164,8 +286,20 @@ int main(int argc, char * argv[]) { } else { get_backend_memory(&free_mem, &total_mem); } - printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024)); - ggml_backend_rpc_start_server(backend, endpoint.c_str(), free_mem, total_mem); + const char * cache_dir = nullptr; + std::string cache_dir_str = fs_get_cache_directory() + "rpc/"; + if (params.use_cache) { + if (!fs_create_directory_with_parents(cache_dir_str)) { + fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str()); + return 1; + } + cache_dir = cache_dir_str.c_str(); + } + printf("Starting RPC server\n"); + printf(" endpoint : %s\n", endpoint.c_str()); + printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); + printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + ggml_backend_rpc_start_server(backend, endpoint.c_str(), cache_dir, free_mem, total_mem); ggml_backend_free(backend); return 0; } diff --git a/examples/server/httplib.h b/examples/server/httplib.h index 593beb50150b1..0f981dc89519f 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -8,7 +8,7 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.19.0" +#define CPPHTTPLIB_VERSION "0.20.0" /* * Configuration @@ -188,15 +188,16 @@ using ssize_t = long; #include #include +// afunix.h uses types declared in winsock2.h, so has to be included after it. +#include + #ifndef WSA_FLAG_NO_HANDLE_INHERIT #define WSA_FLAG_NO_HANDLE_INHERIT 0x80 #endif +using nfds_t = unsigned long; using socket_t = SOCKET; using socklen_t = int; -#ifdef CPPHTTPLIB_USE_POLL -#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) -#endif #else // not _WIN32 @@ -216,16 +217,11 @@ using socklen_t = int; #ifdef __linux__ #include #endif +#include #include -#ifdef CPPHTTPLIB_USE_POLL #include -#endif -#include #include #include -#ifndef __VMS -#include -#endif #include #include #include @@ -247,7 +243,6 @@ using socket_t = int; #include #include #include -#include #include #include #include @@ -320,6 +315,10 @@ using socket_t = int; #include #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +#include +#endif + /* * Declaration */ @@ -435,6 +434,15 @@ struct scope_exit { } // namespace detail +enum SSLVerifierResponse { + // no decision has been made, use the built-in certificate verifier + NoDecisionMade, + // connection certificate is verified and accepted + CertificateAccepted, + // connection certificate was processed but is rejected + CertificateRejected +}; + enum StatusCode { // Information responses Continue_100 = 100, @@ -670,7 +678,7 @@ struct Request { bool is_chunked_content_provider_ = false; size_t authorization_count_ = 0; std::chrono::time_point start_time_ = - std::chrono::steady_clock::time_point::min(); + (std::chrono::steady_clock::time_point::min)(); }; struct Response { @@ -736,7 +744,8 @@ class Stream { virtual ~Stream() = default; virtual bool is_readable() const = 0; - virtual bool is_writable() const = 0; + virtual bool wait_readable() const = 0; + virtual bool wait_writable() const = 0; virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -879,7 +888,7 @@ class MatcherBase { * Captures parameters in request path and stores them in Request::path_params * * Capture name is a substring of a pattern from : to /. - * The rest of the pattern is matched agains the request path directly + * The rest of the pattern is matched against the request path directly * Parameters are captured starting from the next character after * the end of the last matched static pattern fragment until the next /. * @@ -1109,7 +1118,7 @@ class Server { virtual bool process_and_close_socket(socket_t sock); std::atomic is_running_{false}; - std::atomic is_decommisioned{false}; + std::atomic is_decommissioned{false}; struct MountPointEntry { std::string mount_point; @@ -1483,7 +1492,8 @@ class ClientImpl { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier(std::function verifier); + void set_server_certificate_verifier( + std::function verifier); #endif void set_logger(Logger logger); @@ -1600,7 +1610,7 @@ class ClientImpl { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT bool server_certificate_verification_ = true; bool server_hostname_verification_ = true; - std::function server_certificate_verifier_; + std::function server_certificate_verifier_; #endif Logger logger_; @@ -1913,7 +1923,8 @@ class Client { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier(std::function verifier); + void set_server_certificate_verifier( + std::function verifier); #endif void set_logger(Logger logger); @@ -2046,6 +2057,10 @@ inline void duration_to_sec_and_usec(const T &duration, U callback) { callback(static_cast(sec), static_cast(usec)); } +template inline constexpr size_t str_len(const char (&)[N]) { + return N - 1; +} + inline bool is_numeric(const std::string &str) { return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); } @@ -2205,9 +2220,9 @@ inline const char *status_message(int status) { inline std::string get_bearer_token_auth(const Request &req) { if (req.has_header("Authorization")) { - static std::string BearerHeaderPrefix = "Bearer "; + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); return req.get_header_value("Authorization") - .substr(BearerHeaderPrefix.length()); + .substr(bearer_header_prefix_len); } return ""; } @@ -2382,8 +2397,6 @@ std::string encode_query_param(const std::string &value); std::string decode_url(const std::string &s, bool convert_plus_to_space); -void read_file(const std::string &path, std::string &out); - std::string trim_copy(const std::string &s); void divide( @@ -2439,7 +2452,7 @@ ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); -enum class EncodingType { None = 0, Gzip, Brotli }; +enum class EncodingType { None = 0, Gzip, Brotli, Zstd }; EncodingType encoding_type(const Request &req, const Response &res); @@ -2449,7 +2462,8 @@ class BufferStream final : public Stream { ~BufferStream() override = default; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -2551,6 +2565,34 @@ class brotli_decompressor final : public decompressor { }; #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +class zstd_compressor : public compressor { +public: + zstd_compressor(); + ~zstd_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + ZSTD_CCtx *ctx_ = nullptr; +}; + +class zstd_decompressor : public decompressor { +public: + zstd_decompressor(); + ~zstd_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + ZSTD_DCtx *ctx_ = nullptr; +}; +#endif + // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { @@ -2569,7 +2611,7 @@ class stream_line_reader { char *fixed_buffer_; const size_t fixed_buffer_size_; size_t fixed_buffer_used_size_ = 0; - std::string glowable_buffer_; + std::string growable_buffer_; }; class mmap { @@ -2910,18 +2952,9 @@ inline std::string decode_url(const std::string &s, return result; } -inline void read_file(const std::string &path, std::string &out) { - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); -} - inline std::string file_extension(const std::string &path) { std::smatch m; - static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + thread_local auto re = std::regex("\\.([a-zA-Z0-9]+)$"); if (std::regex_search(path, m, re)) { return m[1].str(); } return std::string(); } @@ -3005,18 +3038,18 @@ inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, fixed_buffer_size_(fixed_buffer_size) {} inline const char *stream_line_reader::ptr() const { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { return fixed_buffer_; } else { - return glowable_buffer_.data(); + return growable_buffer_.data(); } } inline size_t stream_line_reader::size() const { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { return fixed_buffer_used_size_; } else { - return glowable_buffer_.size(); + return growable_buffer_.size(); } } @@ -3027,7 +3060,7 @@ inline bool stream_line_reader::end_with_crlf() const { inline bool stream_line_reader::getline() { fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); + growable_buffer_.clear(); #ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR char prev_byte = 0; @@ -3065,11 +3098,11 @@ inline void stream_line_reader::append(char c) { fixed_buffer_[fixed_buffer_used_size_++] = c; fixed_buffer_[fixed_buffer_used_size_] = '\0'; } else { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + growable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); } - glowable_buffer_ += c; + growable_buffer_ += c; } } @@ -3246,35 +3279,23 @@ inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, }); } +inline int poll_wrapper(struct pollfd *fds, nfds_t nfds, int timeout) { +#ifdef _WIN32 + return ::WSAPoll(fds, nfds, timeout); +#else + return ::poll(fds, nfds, timeout); +#endif +} + template inline ssize_t select_impl(socket_t sock, time_t sec, time_t usec) { -#ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd; pfd.fd = sock; pfd.events = (Read ? POLLIN : POLLOUT); auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd, 1, timeout); }); -#else -#ifndef _WIN32 - if (sock >= FD_SETSIZE) { return -1; } -#endif - - fd_set fds, *rfds, *wfds; - FD_ZERO(&fds); - FD_SET(sock, &fds); - rfds = (Read ? &fds : nullptr); - wfds = (Read ? nullptr : &fds); - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - return handle_EINTR([&]() { - return select(static_cast(sock + 1), rfds, wfds, nullptr, &tv); - }); -#endif + return handle_EINTR([&]() { return poll_wrapper(&pfd, 1, timeout); }); } inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { @@ -3287,14 +3308,14 @@ inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { -#ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd_read; pfd_read.fd = sock; pfd_read.events = POLLIN | POLLOUT; auto timeout = static_cast(sec * 1000 + usec / 1000); - auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + auto poll_res = + handle_EINTR([&]() { return poll_wrapper(&pfd_read, 1, timeout); }); if (poll_res == 0) { return Error::ConnectionTimeout; } @@ -3308,38 +3329,6 @@ inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, } return Error::Connection; -#else -#ifndef _WIN32 - if (sock >= FD_SETSIZE) { return Error::Connection; } -#endif - - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); - - auto fdsw = fdsr; - auto fdse = fdsr; - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - auto ret = handle_EINTR([&]() { - return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); - }); - - if (ret == 0) { return Error::ConnectionTimeout; } - - if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { - auto error = 0; - socklen_t len = sizeof(error); - auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len); - auto successful = res >= 0 && !error; - return successful ? Error::Success : Error::Connection; - } - return Error::Connection; -#endif } inline bool is_socket_alive(socket_t sock) { @@ -3359,11 +3348,12 @@ class SocketStream final : public Stream { time_t write_timeout_sec, time_t write_timeout_usec, time_t max_timeout_msec = 0, std::chrono::time_point start_time = - std::chrono::steady_clock::time_point::min()); + (std::chrono::steady_clock::time_point::min)()); ~SocketStream() override; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3378,7 +3368,7 @@ class SocketStream final : public Stream { time_t write_timeout_sec_; time_t write_timeout_usec_; time_t max_timeout_msec_; - const std::chrono::time_point start_time; + const std::chrono::time_point start_time_; std::vector read_buff_; size_t read_buff_off_ = 0; @@ -3395,11 +3385,12 @@ class SSLSocketStream final : public Stream { time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, time_t max_timeout_msec = 0, std::chrono::time_point start_time = - std::chrono::steady_clock::time_point::min()); + (std::chrono::steady_clock::time_point::min)()); ~SSLSocketStream() override; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3415,7 +3406,7 @@ class SSLSocketStream final : public Stream { time_t write_timeout_sec_; time_t write_timeout_usec_; time_t max_timeout_msec_; - const std::chrono::time_point start_time; + const std::chrono::time_point start_time_; }; #endif @@ -3550,7 +3541,6 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, hints.ai_flags = socket_flags; } -#ifndef _WIN32 if (hints.ai_family == AF_UNIX) { const auto addrlen = host.length(); if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } @@ -3574,11 +3564,19 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, sizeof(addr) - sizeof(addr.sun_path) + addrlen); #ifndef SOCK_CLOEXEC +#ifndef _WIN32 fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif #endif if (socket_options) { socket_options(sock); } +#ifdef _WIN32 + // Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so + // remove the option. + detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); +#endif + bool dummy; if (!bind_or_connect(sock, hints, dummy)) { close_socket(sock); @@ -3587,7 +3585,6 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, } return sock; } -#endif auto service = std::to_string(port); @@ -3993,6 +3990,12 @@ inline EncodingType encoding_type(const Request &req, const Response &res) { if (ret) { return EncodingType::Gzip; } #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + // TODO: 'Accept-Encoding' has zstd, not zstd;q=0 + ret = s.find("zstd") != std::string::npos; + if (ret) { return EncodingType::Zstd; } +#endif + return EncodingType::None; } @@ -4201,6 +4204,61 @@ inline bool brotli_decompressor::decompress(const char *data, } #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +inline zstd_compressor::zstd_compressor() { + ctx_ = ZSTD_createCCtx(); + ZSTD_CCtx_setParameter(ctx_, ZSTD_c_compressionLevel, ZSTD_fast); +} + +inline zstd_compressor::~zstd_compressor() { ZSTD_freeCCtx(ctx_); } + +inline bool zstd_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + ZSTD_EndDirective mode = last ? ZSTD_e_end : ZSTD_e_continue; + ZSTD_inBuffer input = {data, data_length, 0}; + + bool finished; + do { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_compressStream2(ctx_, &output, &input, mode); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + + finished = last ? (remaining == 0) : (input.pos == input.size); + + } while (!finished); + + return true; +} + +inline zstd_decompressor::zstd_decompressor() { ctx_ = ZSTD_createDCtx(); } + +inline zstd_decompressor::~zstd_decompressor() { ZSTD_freeDCtx(ctx_); } + +inline bool zstd_decompressor::is_valid() const { return ctx_ != nullptr; } + +inline bool zstd_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + std::array buff{}; + ZSTD_inBuffer input = {data, data_length, 0}; + + while (input.pos < input.size) { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_decompressStream(ctx_, &output, &input); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + } + + return true; +} +#endif + inline bool has_header(const Headers &headers, const std::string &key) { return headers.find(key) != headers.end(); } @@ -4227,6 +4285,9 @@ inline bool parse_header(const char *beg, const char *end, T fn) { p++; } + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + if (p == end) { return false; } auto key_end = p; @@ -4242,10 +4303,6 @@ inline bool parse_header(const char *beg, const char *end, T fn) { if (!key_len) { return false; } auto key = std::string(beg, key_end); - // auto val = (case_ignore::equal(key, "Location") || - // case_ignore::equal(key, "Referer")) - // ? std::string(p, end) - // : decode_url(std::string(p, end), false); auto val = std::string(p, end); if (!detail::fields::is_field_value(val)) { return false; } @@ -4341,7 +4398,8 @@ inline bool read_content_without_length(Stream &strm, uint64_t r = 0; for (;;) { auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n <= 0) { return false; } + if (n == 0) { return true; } + if (n < 0) { return false; } if (!out(buf, static_cast(n), r, 0)) { return false; } r += static_cast(n); @@ -4384,7 +4442,7 @@ inline bool read_content_chunked(Stream &strm, T &x, assert(chunk_len == 0); - // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentiones "The chunked + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked // transfer coding is complete when a chunk with a chunk-size of zero is // received, possibly followed by a trailer section, and finally terminated by // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 @@ -4394,8 +4452,8 @@ inline bool read_content_chunked(Stream &strm, T &x, // to be ok whether the final CRLF exists or not in the chunked data. // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 // - // According to the reference code in RFC 9112, cpp-htpplib now allows - // chuncked transfer coding data without the final CRLF. + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. if (!line_reader.getline()) { return true; } while (strcmp(line_reader.ptr(), "\r\n") != 0) { @@ -4442,6 +4500,13 @@ bool prepare_content_receiver(T &x, int &status, #else status = StatusCode::UnsupportedMediaType_415; return false; +#endif + } else if (encoding == "zstd") { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; #endif } @@ -4565,7 +4630,7 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, data_sink.write = [&](const char *d, size_t l) -> bool { if (ok) { - if (strm.is_writable() && write_data(strm, d, l)) { + if (write_data(strm, d, l)) { offset += l; } else { ok = false; @@ -4574,10 +4639,10 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; while (offset < end_offset && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { error = Error::Write; return false; } else if (!content_provider(offset, end_offset - offset, data_sink)) { @@ -4615,17 +4680,17 @@ write_content_without_length(Stream &strm, data_sink.write = [&](const char *d, size_t l) -> bool { if (ok) { offset += l; - if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; } + if (!write_data(strm, d, l)) { ok = false; } } return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; data_sink.done = [&](void) { data_available = false; }; while (data_available && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { return false; } else if (!content_provider(offset, 0, data_sink)) { return false; @@ -4660,10 +4725,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, // Emit chunked response header and footer for each chunk auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (!strm.is_writable() || - !write_data(strm, chunk.data(), chunk.size())) { - ok = false; - } + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; } } } else { ok = false; @@ -4672,7 +4734,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -4692,17 +4754,14 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, if (!payload.empty()) { // Emit chunked response header and footer for each chunk auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (!strm.is_writable() || - !write_data(strm, chunk.data(), chunk.size())) { + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; return; } } - static const std::string done_marker("0\r\n"); - if (!write_data(strm, done_marker.data(), done_marker.size())) { - ok = false; - } + constexpr const char done_marker[] = "0\r\n"; + if (!write_data(strm, done_marker, str_len(done_marker))) { ok = false; } // Trailer if (trailer) { @@ -4714,8 +4773,8 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, } } - static const std::string crlf("\r\n"); - if (!write_data(strm, crlf.data(), crlf.size())) { ok = false; } + constexpr const char crlf[] = "\r\n"; + if (!write_data(strm, crlf, str_len(crlf))) { ok = false; } }; data_sink.done = [&](void) { done_with_trailer(nullptr); }; @@ -4725,7 +4784,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; while (data_available && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { error = Error::Write; return false; } else if (!content_provider(offset, 0, data_sink)) { @@ -4957,13 +5016,13 @@ class MultipartFormDataParser { return false; } - static const std::string header_content_type = "Content-Type:"; + constexpr const char header_content_type[] = "Content-Type:"; if (start_with_case_ignore(header, header_content_type)) { file_.content_type = - trim_copy(header.substr(header_content_type.size())); + trim_copy(header.substr(str_len(header_content_type))); } else { - static const std::regex re_content_disposition( + thread_local const std::regex re_content_disposition( R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", std::regex_constants::icase); @@ -4985,8 +5044,8 @@ class MultipartFormDataParser { it = params.find("filename*"); if (it != params.end()) { - // Only allow UTF-8 enconnding... - static const std::regex re_rfc5987_encoding( + // Only allow UTF-8 encoding... + thread_local const std::regex re_rfc5987_encoding( R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); std::smatch m2; @@ -5058,10 +5117,10 @@ class MultipartFormDataParser { file_.content_type.clear(); } - bool start_with_case_ignore(const std::string &a, - const std::string &b) const { - if (a.size() < b.size()) { return false; } - for (size_t i = 0; i < b.size(); i++) { + bool start_with_case_ignore(const std::string &a, const char *b) const { + const auto b_len = strlen(b); + if (a.size() < b_len) { return false; } + for (size_t i = 0; i < b_len; i++) { if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { return false; } @@ -5148,19 +5207,18 @@ class MultipartFormDataParser { }; inline std::string random_string(size_t length) { - static const char data[] = + constexpr const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - // std::random_device might actually be deterministic on some - // platforms, but due to lack of support in the c++ standard library, - // doing better requires either some ugly hacks or breaking portability. - static std::random_device seed_gen; - - // Request 128 bits of entropy for initialization - static std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), - seed_gen()}; - - static std::mt19937 engine(seed_sequence); + thread_local auto engine([]() { + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + std::random_device seed_gen; + // Request 128 bits of entropy for initialization + std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + return std::mt19937(seed_sequence); + }()); std::string result; for (size_t i = 0; i < length; i++) { @@ -5232,7 +5290,7 @@ serialize_multipart_formdata(const MultipartFormDataItems &items, inline bool range_error(Request &req, Response &res) { if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { - ssize_t contant_len = static_cast( + ssize_t content_len = static_cast( res.content_length_ ? res.content_length_ : res.body.size()); ssize_t prev_first_pos = -1; @@ -5252,12 +5310,12 @@ inline bool range_error(Request &req, Response &res) { if (first_pos == -1 && last_pos == -1) { first_pos = 0; - last_pos = contant_len; + last_pos = content_len; } if (first_pos == -1) { - first_pos = contant_len - last_pos; - last_pos = contant_len - 1; + first_pos = content_len - last_pos; + last_pos = content_len - 1; } // NOTE: RFC-9110 '14.1.2. Byte Ranges': @@ -5269,13 +5327,13 @@ inline bool range_error(Request &req, Response &res) { // with a value that is one less than the current length of the selected // representation). // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 - if (last_pos == -1 || last_pos >= contant_len) { - last_pos = contant_len - 1; + if (last_pos == -1 || last_pos >= content_len) { + last_pos = content_len - 1; } // Range must be within content length if (!(0 <= first_pos && first_pos <= last_pos && - last_pos <= contant_len - 1)) { + last_pos <= content_len - 1)) { return true; } @@ -5674,7 +5732,8 @@ inline bool parse_www_authenticate(const Response &res, bool is_proxy) { auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; if (res.has_header(auth_key)) { - static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + thread_local auto re = + std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); auto s = res.get_header_value(auth_key); auto pos = s.find(' '); if (pos != std::string::npos) { @@ -5758,7 +5817,7 @@ inline void hosted_at(const std::string &hostname, inline std::string append_query_params(const std::string &path, const Params ¶ms) { std::string path_with_query = path; - const static std::regex re("[^?]+\\?.*"); + thread_local const std::regex re("[^?]+\\?.*"); auto delm = std::regex_match(path, re) ? '&' : '?'; path_with_query += delm + detail::params_to_query_str(params); return path_with_query; @@ -5987,14 +6046,14 @@ inline ssize_t Stream::write(const std::string &s) { namespace detail { -inline void calc_actual_timeout(time_t max_timeout_msec, - time_t duration_msec, time_t timeout_sec, - time_t timeout_usec, time_t &actual_timeout_sec, +inline void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, + time_t timeout_sec, time_t timeout_usec, + time_t &actual_timeout_sec, time_t &actual_timeout_usec) { auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000); auto actual_timeout_msec = - std::min(max_timeout_msec - duration_msec, timeout_msec); + (std::min)(max_timeout_msec - duration_msec, timeout_msec); actual_timeout_sec = actual_timeout_msec / 1000; actual_timeout_usec = (actual_timeout_msec % 1000) * 1000; @@ -6010,12 +6069,16 @@ inline SocketStream::SocketStream( read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time(start_time), + max_timeout_msec_(max_timeout_msec), start_time_(start_time), read_buff_(read_buff_size_, 0) {} inline SocketStream::~SocketStream() = default; inline bool SocketStream::is_readable() const { + return read_buff_off_ < read_buff_content_size_; +} + +inline bool SocketStream::wait_readable() const { if (max_timeout_msec_ <= 0) { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } @@ -6028,7 +6091,7 @@ inline bool SocketStream::is_readable() const { return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } -inline bool SocketStream::is_writable() const { +inline bool SocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); } @@ -6055,7 +6118,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) { } } - if (!is_readable()) { return -1; } + if (!wait_readable()) { return -1; } read_buff_off_ = 0; read_buff_content_size_ = 0; @@ -6080,7 +6143,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) { } inline ssize_t SocketStream::write(const char *ptr, size_t size) { - if (!is_writable()) { return -1; } + if (!wait_writable()) { return -1; } #if defined(_WIN32) && !defined(_WIN64) size = @@ -6104,14 +6167,16 @@ inline socket_t SocketStream::socket() const { return sock_; } inline time_t SocketStream::duration() const { return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time) + std::chrono::steady_clock::now() - start_time_) .count(); } // Buffer stream implementation inline bool BufferStream::is_readable() const { return true; } -inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::wait_readable() const { return true; } + +inline bool BufferStream::wait_writable() const { return true; } inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1910 @@ -6141,7 +6206,7 @@ inline time_t BufferStream::duration() const { return 0; } inline const std::string &BufferStream::get_buffer() const { return buffer; } inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { - static constexpr char marker[] = "/:"; + constexpr const char marker[] = "/:"; // One past the last ending position of a path param substring std::size_t last_param_end = 0; @@ -6162,7 +6227,7 @@ inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { static_fragments_.push_back( pattern.substr(last_param_end, marker_pos - last_param_end + 1)); - const auto param_name_start = marker_pos + 2; + const auto param_name_start = marker_pos + str_len(marker); auto sep_pos = pattern.find(separator, param_name_start); if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } @@ -6469,12 +6534,12 @@ inline Server &Server::set_payload_max_length(size_t length) { inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { auto ret = bind_internal(host, port, socket_flags); - if (ret == -1) { is_decommisioned = true; } + if (ret == -1) { is_decommissioned = true; } return ret >= 0; } inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { auto ret = bind_internal(host, 0, socket_flags); - if (ret == -1) { is_decommisioned = true; } + if (ret == -1) { is_decommissioned = true; } return ret; } @@ -6488,7 +6553,7 @@ inline bool Server::listen(const std::string &host, int port, inline bool Server::is_running() const { return is_running_; } inline void Server::wait_until_ready() const { - while (!is_running_ && !is_decommisioned) { + while (!is_running_ && !is_decommissioned) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } } @@ -6500,10 +6565,10 @@ inline void Server::stop() { detail::shutdown_socket(sock); detail::close_socket(sock); } - is_decommisioned = false; + is_decommissioned = false; } -inline void Server::decommission() { is_decommisioned = true; } +inline void Server::decommission() { is_decommissioned = true; } inline bool Server::parse_request_line(const char *s, Request &req) const { auto len = strlen(s); @@ -6526,7 +6591,7 @@ inline bool Server::parse_request_line(const char *s, Request &req) const { if (count != 3) { return false; } } - static const std::set methods{ + thread_local const std::set methods{ "GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; @@ -6680,6 +6745,10 @@ Server::write_content_with_provider(Stream &strm, const Request &req, } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); #endif } else { compressor = detail::make_unique(); @@ -6862,7 +6931,7 @@ Server::create_server_socket(const std::string &host, int port, inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { - if (is_decommisioned) { return -1; } + if (is_decommissioned) { return -1; } if (!is_valid()) { return -1; } @@ -6889,7 +6958,7 @@ inline int Server::bind_internal(const std::string &host, int port, } inline bool Server::listen_internal() { - if (is_decommisioned) { return false; } + if (is_decommissioned) { return false; } auto ret = true; is_running_ = true; @@ -6913,7 +6982,7 @@ inline bool Server::listen_internal() { #endif #if defined _WIN32 - // sockets conneced via WASAccept inherit flags NO_HANDLE_INHERIT, + // sockets connected via WASAccept inherit flags NO_HANDLE_INHERIT, // OVERLAPPED socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); #elif defined SOCK_CLOEXEC @@ -6955,7 +7024,7 @@ inline bool Server::listen_internal() { task_queue->shutdown(); } - is_decommisioned = !ret; + is_decommissioned = !ret; return ret; } @@ -7095,6 +7164,8 @@ inline void Server::apply_ranges(const Request &req, Response &res, res.set_header("Content-Encoding", "gzip"); } else if (type == detail::EncodingType::Brotli) { res.set_header("Content-Encoding", "br"); + } else if (type == detail::EncodingType::Zstd) { + res.set_header("Content-Encoding", "zstd"); } } } @@ -7134,6 +7205,11 @@ inline void Server::apply_ranges(const Request &req, Response &res, #ifdef CPPHTTPLIB_BROTLI_SUPPORT compressor = detail::make_unique(); content_encoding = "br"; +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); + content_encoding = "zstd"; #endif } @@ -7189,20 +7265,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, res.version = "HTTP/1.1"; res.headers = default_headers_; -#ifdef _WIN32 - // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). -#else -#ifndef CPPHTTPLIB_USE_POLL - // Socket file descriptor exceeded FD_SETSIZE... - if (strm.socket() >= FD_SETSIZE) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = StatusCode::InternalServerError_500; - return write_response(strm, close_connection, req, res); - } -#endif -#endif - // Request line and headers if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { @@ -7394,6 +7456,16 @@ inline ClientImpl::ClientImpl(const std::string &host, int port, client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} inline ClientImpl::~ClientImpl() { + // Wait until all the requests in flight are handled. + size_t retry_count = 10; + while (retry_count-- > 0) { + { + std::lock_guard guard(socket_mutex_); + if (socket_requests_in_flight_ == 0) { break; } + } + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + std::lock_guard guard(socket_mutex_); shutdown_socket(socket_); close_socket(socket_); @@ -7519,9 +7591,9 @@ inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, if (!line_reader.getline()) { return false; } #ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR - const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); #else - const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); #endif std::cmatch m; @@ -7577,7 +7649,7 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { #endif if (!is_alive) { - // Attempt to avoid sigpipe by shutting down nongracefully if it seems + // Attempt to avoid sigpipe by shutting down non-gracefully if it seems // like the other side has already closed the connection Also, there // cannot be any requests in flight from other threads since we locked // request_mutex_, so safe to close everything immediately @@ -7753,7 +7825,7 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { auto location = res.get_header_value("location"); if (location.empty()) { return false; } - const static std::regex re( + thread_local const std::regex re( R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); std::smatch m; @@ -7862,6 +7934,10 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, #ifdef CPPHTTPLIB_ZLIB_SUPPORT if (!accept_encoding.empty()) { accept_encoding += ", "; } accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; #endif req.set_header("Accept-Encoding", accept_encoding); } @@ -8213,8 +8289,7 @@ inline bool ClientImpl::process_socket( std::function callback) { return detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, max_timeout_msec_, start_time, - std::move(callback)); + write_timeout_usec_, max_timeout_msec_, start_time, std::move(callback)); } inline bool ClientImpl::is_ssl() const { return false; } @@ -9009,7 +9084,7 @@ inline void ClientImpl::enable_server_hostname_verification(bool enabled) { } inline void ClientImpl::set_server_certificate_verifier( - std::function verifier) { + std::function verifier) { server_certificate_verifier_ = verifier; } #endif @@ -9062,18 +9137,13 @@ inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, // Note that it is not always possible to avoid SIGPIPE, this is merely a // best-efforts. if (shutdown_gracefully) { -#ifdef _WIN32 (void)(sock); - SSL_shutdown(ssl); -#else - detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO, 1, 0); - - auto ret = SSL_shutdown(ssl); - while (ret == 0) { - std::this_thread::sleep_for(std::chrono::milliseconds{100}); - ret = SSL_shutdown(ssl); + // SSL_shutdown() returns 0 on first call (indicating close_notify alert + // sent) and 1 on subsequent call (indicating close_notify alert received) + if (SSL_shutdown(ssl) == 0) { + // Expected to return 1, but even if it doesn't, we free ssl + SSL_shutdown(ssl); } -#endif } std::lock_guard guard(ctx_mutex); @@ -9124,19 +9194,11 @@ inline bool process_client_socket_ssl( time_t max_timeout_msec, std::chrono::time_point start_time, T callback) { SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec, - max_timeout_msec, start_time); + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); return callback(strm); } -class SSLInit { -public: - SSLInit() { - OPENSSL_init_ssl( - OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); - } -}; - // SSL socket stream implementation inline SSLSocketStream::SSLSocketStream( socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, @@ -9147,13 +9209,17 @@ inline SSLSocketStream::SSLSocketStream( read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time(start_time) { + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); } inline SSLSocketStream::~SSLSocketStream() = default; inline bool SSLSocketStream::is_readable() const { + return SSL_pending(ssl_) > 0; +} + +inline bool SSLSocketStream::wait_readable() const { if (max_timeout_msec_ <= 0) { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } @@ -9166,7 +9232,7 @@ inline bool SSLSocketStream::is_readable() const { return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } -inline bool SSLSocketStream::is_writable() const { +inline bool SSLSocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); } @@ -9174,7 +9240,7 @@ inline bool SSLSocketStream::is_writable() const { inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { + } else if (wait_readable()) { auto ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret < 0) { auto err = SSL_get_error(ssl_, ret); @@ -9188,7 +9254,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { #endif if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { + } else if (wait_readable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret >= 0) { return ret; } @@ -9205,7 +9271,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { + if (wait_writable()) { auto handle_size = static_cast( std::min(size, (std::numeric_limits::max)())); @@ -9220,7 +9286,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { #else while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { #endif - if (is_writable()) { + if (wait_writable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_write(ssl_, ptr, static_cast(handle_size)); if (ret >= 0) { return ret; } @@ -9249,12 +9315,10 @@ inline socket_t SSLSocketStream::socket() const { return sock_; } inline time_t SSLSocketStream::duration() const { return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time) + std::chrono::steady_clock::now() - start_time_) .count(); } -static SSLInit sslinit_; - } // namespace detail // SSL HTTP server implementation @@ -9623,12 +9687,18 @@ inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { } if (server_certificate_verification_) { + auto verification_status = SSLVerifierResponse::NoDecisionMade; + if (server_certificate_verifier_) { - if (!server_certificate_verifier_(ssl2)) { - error = Error::SSLServerVerification; - return false; - } - } else { + verification_status = server_certificate_verifier_(ssl2); + } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + error = Error::SSLServerVerification; + return false; + } + + if (verification_status == SSLVerifierResponse::NoDecisionMade) { verify_result_ = SSL_get_verify_result(ssl2); if (verify_result_ != X509_V_OK) { @@ -9740,8 +9810,8 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { auto type = GEN_DNS; - struct in6_addr addr6{}; - struct in_addr addr{}; + struct in6_addr addr6 = {}; + struct in_addr addr = {}; size_t addr_len = 0; #ifndef __MINGW32__ @@ -10389,7 +10459,7 @@ inline void Client::enable_server_hostname_verification(bool enabled) { } inline void Client::set_server_certificate_verifier( - std::function verifier) { + std::function verifier) { cli_->set_server_certificate_verifier(verifier); } #endif @@ -10433,8 +10503,4 @@ inline SSL_CTX *Client::ssl_context() const { } // namespace httplib -#if defined(_WIN32) && defined(CPPHTTPLIB_USE_POLL) -#undef poll -#endif - #endif // CPPHTTPLIB_HTTPLIB_H diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index d0e6da8e4a1e0..941815c22efb0 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 18caa9127662d..760c3646433ad 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -133,7 +133,8 @@ struct slot_params { auto grammar_triggers = json::array(); for (const auto & trigger : sampling.grammar_triggers) { - grammar_triggers.push_back(trigger.to_json()); + server_grammar_trigger ct(std::move(trigger)); + grammar_triggers.push_back(ct.to_json()); } return json { @@ -372,9 +373,9 @@ struct server_task { const auto grammar_triggers = data.find("grammar_triggers"); if (grammar_triggers != data.end()) { for (const auto & t : *grammar_triggers) { - auto ct = common_grammar_trigger::from_json(t); - if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto & word = ct.value; + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value.value; auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); if (ids.size() == 1) { auto token = ids[0]; @@ -392,7 +393,7 @@ struct server_task { params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); } } else { - params.sampling.grammar_triggers.push_back(ct); + params.sampling.grammar_triggers.push_back(std::move(ct.value)); } } } @@ -489,8 +490,12 @@ struct result_timings { double predicted_per_token_ms; double predicted_per_second; + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + json to_json() const { - return { + json base = { {"prompt_n", prompt_n}, {"prompt_ms", prompt_ms}, {"prompt_per_token_ms", prompt_per_token_ms}, @@ -501,6 +506,13 @@ struct result_timings { {"predicted_per_token_ms", predicted_per_token_ms}, {"predicted_per_second", predicted_per_second}, }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; } }; @@ -1299,6 +1311,10 @@ struct server_slot { std::function callback_on_release; + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + void reset() { SLT_DBG(*this, "%s", "\n"); @@ -1315,6 +1331,10 @@ struct server_slot { generated_tokens.clear(); generated_token_probs.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; } bool is_non_causal() const { @@ -1381,6 +1401,12 @@ struct server_slot { timings.predicted_per_token_ms = t_token_generation / n_decoded; timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + return timings; } @@ -1428,6 +1454,15 @@ struct server_slot { t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, t_token_generation, n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float) n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total + ); + } } json to_json() const { @@ -1842,7 +1877,7 @@ struct server_context { } bool load_model(const common_params & params) { - SRV_INF("loading model '%s'\n", params.model.c_str()); + SRV_INF("loading model '%s'\n", params.model.path.c_str()); params_base = params; @@ -1852,7 +1887,7 @@ struct server_context { ctx = llama_init.context.get(); if (model == nullptr) { - SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } @@ -1863,16 +1898,13 @@ struct server_context { add_bos_token = llama_vocab_get_add_bos(vocab); has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; - if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); + if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); auto params_dft = params_base; params_dft.devices = params_base.speculative.devices; - params_dft.hf_file = params_base.speculative.hf_file; - params_dft.hf_repo = params_base.speculative.hf_repo; params_dft.model = params_base.speculative.model; - params_dft.model_url = params_base.speculative.model_url; params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; @@ -1886,12 +1918,12 @@ struct server_context { model_dft = llama_init_dft.model.get(); if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); return false; } if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { - SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str()); + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); return false; } @@ -3290,6 +3322,9 @@ struct server_context { llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + // keep track of total number of tokens generated in the draft + slot.n_draft_total += draft.size(); + // ignore small drafts if (slot.params.speculative.n_min > (int) draft.size()) { SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); @@ -3315,6 +3350,9 @@ struct server_context { slot.n_past += ids.size(); slot.n_decoded += ids.size(); + // update how many tokens out of draft was accepted + slot.n_draft_accepted += ids.size() - 1; + slot.cache_tokens.push_back(id); slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); @@ -3825,7 +3863,7 @@ int main(int argc, char ** argv) { json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, - { "model_path", ctx_server.params_base.model }, + { "model_path", ctx_server.params_base.model.path }, { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, @@ -4091,7 +4129,7 @@ int main(int argc, char ** argv) { {"object", "list"}, {"data", { { - {"id", params.model_alias.empty() ? params.model : params.model_alias}, + {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, @@ -4459,15 +4497,24 @@ int main(int argc, char ** argv) { llama_backend_free(); }; - // bind HTTP listen port bool was_bound = false; - if (params.port == 0) { - int bound_port = svr->bind_to_any_port(params.hostname); - if ((was_bound = (bound_port >= 0))) { - params.port = bound_port; - } + if (string_ends_with(std::string(params.hostname), ".sock")) { + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); + svr->set_address_family(AF_UNIX); + // bind_to_port requires a second arg, any value other than 0 should + // simply get ignored + was_bound = svr->bind_to_port(params.hostname, 8080); } else { - was_bound = svr->bind_to_port(params.hostname, params.port); + LOG_INF("%s: binding port with default address family\n", __func__); + // bind HTTP listen port + if (params.port == 0) { + int bound_port = svr->bind_to_any_port(params.hostname); + if ((was_bound = (bound_port >= 0))) { + params.port = bound_port; + } + } else { + was_bound = svr->bind_to_port(params.hostname, params.port); + } } if (!was_bound) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 58cdd6af92974..55cf3230d90ce 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -58,6 +58,32 @@ static T json_value(const json & body, const std::string & key, const T & defaul const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; + + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} + server_grammar_trigger(const json & in) { + value.type = (common_grammar_trigger_type) in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token) in.at("token").get(); + } + } + + json to_json() const { + json out { + {"type", (int) value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int) value.token; + } + return out; + } +}; + // // tokenizer and input processing utils // @@ -627,7 +653,8 @@ static json oaicompat_completion_params_parse( llama_params["grammar_lazy"] = chat_params.grammar_lazy; auto grammar_triggers = json::array(); for (const auto & trigger : chat_params.grammar_triggers) { - grammar_triggers.push_back(trigger.to_json()); + server_grammar_trigger ct(trigger); + grammar_triggers.push_back(ct.to_json()); } llama_params["grammar_triggers"] = grammar_triggers; llama_params["preserved_tokens"] = chat_params.preserved_tokens; diff --git a/examples/server/webui/package-lock.json b/examples/server/webui/package-lock.json index 056592dd4775d..b2e3cf94aca41 100644 --- a/examples/server/webui/package-lock.json +++ b/examples/server/webui/package-lock.json @@ -10,9 +10,11 @@ "dependencies": { "@heroicons/react": "^2.2.0", "@sec-ant/readable-stream": "^0.6.0", + "@tailwindcss/postcss": "^4.1.1", + "@tailwindcss/vite": "^4.1.1", "@vscode/markdown-it-katex": "^1.1.1", "autoprefixer": "^10.4.20", - "daisyui": "^4.12.14", + "daisyui": "^5.0.12", "dexie": "^4.0.11", "highlight.js": "^11.10.0", "katex": "^0.16.15", @@ -26,7 +28,7 @@ "remark-breaks": "^4.0.0", "remark-gfm": "^4.0.0", "remark-math": "^6.0.0", - "tailwindcss": "^3.4.15", + "tailwindcss": "^4.1.1", "textlinestream": "^1.1.1", "vite-plugin-singlefile": "^2.0.3" }, @@ -979,27 +981,11 @@ "url": "https://github.com/sponsors/nzakas" } }, - "node_modules/@isaacs/cliui": { - "version": "8.0.2", - "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", - "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==", - "license": "ISC", - "dependencies": { - "string-width": "^5.1.2", - "string-width-cjs": "npm:string-width@^4.2.0", - "strip-ansi": "^7.0.1", - "strip-ansi-cjs": "npm:strip-ansi@^6.0.1", - "wrap-ansi": "^8.1.0", - "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0" - }, - "engines": { - "node": ">=12" - } - }, "node_modules/@jridgewell/gen-mapping": { "version": "0.3.8", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz", "integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==", + "dev": true, "license": "MIT", "dependencies": { "@jridgewell/set-array": "^1.2.1", @@ -1014,6 +1000,7 @@ "version": "3.1.2", "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "dev": true, "license": "MIT", "engines": { "node": ">=6.0.0" @@ -1023,6 +1010,7 @@ "version": "1.2.1", "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", + "dev": true, "license": "MIT", "engines": { "node": ">=6.0.0" @@ -1032,12 +1020,14 @@ "version": "1.5.0", "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "dev": true, "license": "MIT" }, "node_modules/@jridgewell/trace-mapping": { "version": "0.3.25", "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "dev": true, "license": "MIT", "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", @@ -1048,6 +1038,7 @@ "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dev": true, "license": "MIT", "dependencies": { "@nodelib/fs.stat": "2.0.5", @@ -1061,6 +1052,7 @@ "version": "2.0.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "dev": true, "license": "MIT", "engines": { "node": ">= 8" @@ -1070,6 +1062,7 @@ "version": "1.2.8", "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dev": true, "license": "MIT", "dependencies": { "@nodelib/fs.scandir": "2.1.5", @@ -1079,16 +1072,6 @@ "node": ">= 8" } }, - "node_modules/@pkgjs/parseargs": { - "version": "0.11.0", - "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", - "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==", - "license": "MIT", - "optional": true, - "engines": { - "node": ">=14" - } - }, "node_modules/@rollup/rollup-android-arm-eabi": { "version": "4.34.2", "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.34.2.tgz", @@ -1342,6 +1325,243 @@ "integrity": "sha512-uiBh8DrB5FN35gP6/o8JEhEQ7/ci1jUsOZO/VMUjyvTpjtV54VstOXVj1TvTj/wsT23pfX6butxxh3qufsW3+g==", "license": "MIT" }, + "node_modules/@tailwindcss/node": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/node/-/node-4.1.1.tgz", + "integrity": "sha512-xvlh4pvfG/bkv0fEtJDABAm1tjtSmSyi2QmS4zyj1EKNI1UiOYiUq1IphSwDsNJ5vJ9cWEGs4rJXpUdCN2kujQ==", + "license": "MIT", + "dependencies": { + "enhanced-resolve": "^5.18.1", + "jiti": "^2.4.2", + "lightningcss": "1.29.2", + "tailwindcss": "4.1.1" + } + }, + "node_modules/@tailwindcss/oxide": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide/-/oxide-4.1.1.tgz", + "integrity": "sha512-7+YBgnPQ4+jv6B6WVOerJ6WOzDzNJXrRKDts674v6TKAqFqYRr9+EBtSziO7nNcwQ8JtoZNMeqA+WJDjtCM/7w==", + "license": "MIT", + "engines": { + "node": ">= 10" + }, + "optionalDependencies": { + "@tailwindcss/oxide-android-arm64": "4.1.1", + "@tailwindcss/oxide-darwin-arm64": "4.1.1", + "@tailwindcss/oxide-darwin-x64": "4.1.1", + "@tailwindcss/oxide-freebsd-x64": "4.1.1", + "@tailwindcss/oxide-linux-arm-gnueabihf": "4.1.1", + "@tailwindcss/oxide-linux-arm64-gnu": "4.1.1", + "@tailwindcss/oxide-linux-arm64-musl": "4.1.1", + "@tailwindcss/oxide-linux-x64-gnu": "4.1.1", + "@tailwindcss/oxide-linux-x64-musl": "4.1.1", + "@tailwindcss/oxide-win32-arm64-msvc": "4.1.1", + "@tailwindcss/oxide-win32-x64-msvc": "4.1.1" + } + }, + "node_modules/@tailwindcss/oxide-android-arm64": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-android-arm64/-/oxide-android-arm64-4.1.1.tgz", + "integrity": "sha512-gTyRzfdParpoCU1yyUC/iN6XK6T0Ra4bDlF8Aeul5NP9cLzKEZDogdNVNGv5WZmCDkVol7qlex7TMmcfytMmmw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-darwin-arm64": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-arm64/-/oxide-darwin-arm64-4.1.1.tgz", + "integrity": "sha512-dI0QbdMWBvLB3MtaTKetzUKG9CUUQow8JSP4Nm+OxVokeZ+N+f1OmZW/hW1LzMxpx9RQCBgSRL+IIvKRat5Wdg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-darwin-x64": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-x64/-/oxide-darwin-x64-4.1.1.tgz", + "integrity": "sha512-2Y+NPQOTRBCItshPgY/CWg4bKi7E9evMg4bgdb6h9iZObCZLOe3doPcuSxGS3DB0dKyMFKE8pTdWtFUbxZBMSA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-freebsd-x64": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-freebsd-x64/-/oxide-freebsd-x64-4.1.1.tgz", + "integrity": "sha512-N97NGMsB/7CHShbc5ube4dcsW/bYENkBrg8yWi8ieN9boYVRdw3cZviVryV/Nfu9bKbBV9kUvduFF2qBI7rEqg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm-gnueabihf": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm-gnueabihf/-/oxide-linux-arm-gnueabihf-4.1.1.tgz", + "integrity": "sha512-33Lk6KbHnUZbXqza6RWNFo9wqPQ4+H5BAn1CkUUfC1RZ1vYbyDN6+iJPj53wmnWJ3mhRI8jWt3Jt1fO02IVdUQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm64-gnu": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-gnu/-/oxide-linux-arm64-gnu-4.1.1.tgz", + "integrity": "sha512-LyW35RzSUy+80WYScv03HKasAUmMFDaSbNpWfk1gG5gEE9kuRGnDzSrqMoLAmY/kzMCYP/1kqmUiAx8EFLkI2A==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm64-musl": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-musl/-/oxide-linux-arm64-musl-4.1.1.tgz", + "integrity": "sha512-1KPnDMlHdqjPTUSFjx55pafvs8RZXRgxfeRgUrukwDKkuj7gFk28vW3Mx65YdiugAc9NWs3VgueZWaM1Po6uGw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-x64-gnu": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-gnu/-/oxide-linux-x64-gnu-4.1.1.tgz", + "integrity": "sha512-4WdzA+MRlsinEEE6yxNMLJxpw0kE9XVipbAKdTL8BeUpyC2TdA3TL46lBulXzKp3BIxh3nqyR/UCqzl5o+3waQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-x64-musl": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-musl/-/oxide-linux-x64-musl-4.1.1.tgz", + "integrity": "sha512-q7Ugbw3ARcjCW2VMUYrcMbJ6aMQuWPArBBE2EqC/swPZTdGADvMQSlvR0VKusUM4HoSsO7ZbvcZ53YwR57+AKw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-win32-arm64-msvc": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.1.tgz", + "integrity": "sha512-0KpqsovgHcIzm7eAGzzEZsEs0/nPYXnRBv+aPq/GehpNQuE/NAQu+YgZXIIof+VflDFuyXOEnaFr7T5MZ1INhA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-win32-x64-msvc": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-x64-msvc/-/oxide-win32-x64-msvc-4.1.1.tgz", + "integrity": "sha512-B1mjeXNS26kBOHv5sXARf6Wd0PWHV9x1TDlW0ummrBUOUAxAy5wcy4Nii1wzNvCdvC448hgiL06ylhwAbNthmg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/postcss": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/postcss/-/postcss-4.1.1.tgz", + "integrity": "sha512-GX9AEM+msH0i2Yh1b6CuDRaZRo3kmbvIrLbSfvJ53C3uaAgsQ//fTQAh9HMQ6t1a9zvoUptlYqG//plWsBQTCw==", + "license": "MIT", + "dependencies": { + "@alloc/quick-lru": "^5.2.0", + "@tailwindcss/node": "4.1.1", + "@tailwindcss/oxide": "4.1.1", + "postcss": "^8.4.41", + "tailwindcss": "4.1.1" + } + }, + "node_modules/@tailwindcss/vite": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@tailwindcss/vite/-/vite-4.1.1.tgz", + "integrity": "sha512-tFTkRZwXq4XKr3S2dUZBxy80wbWYHdDSsu4QOB1yE1HJFKjfxKVpXtup4dyTVdQcLInoHC9lZXFPHnjoBP774g==", + "license": "MIT", + "dependencies": { + "@tailwindcss/node": "4.1.1", + "@tailwindcss/oxide": "4.1.1", + "tailwindcss": "4.1.1" + }, + "peerDependencies": { + "vite": "^5.2.0 || ^6" + } + }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -1815,22 +2035,11 @@ "url": "https://github.com/sponsors/epoberezkin" } }, - "node_modules/ansi-regex": { - "version": "6.1.0", - "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.1.0.tgz", - "integrity": "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA==", - "license": "MIT", - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/chalk/ansi-regex?sponsor=1" - } - }, "node_modules/ansi-styles": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, "license": "MIT", "dependencies": { "color-convert": "^2.0.1" @@ -1842,31 +2051,6 @@ "url": "https://github.com/chalk/ansi-styles?sponsor=1" } }, - "node_modules/any-promise": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/any-promise/-/any-promise-1.3.0.tgz", - "integrity": "sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==", - "license": "MIT" - }, - "node_modules/anymatch": { - "version": "3.1.3", - "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", - "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", - "license": "ISC", - "dependencies": { - "normalize-path": "^3.0.0", - "picomatch": "^2.0.4" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/arg": { - "version": "5.0.2", - "resolved": "https://registry.npmjs.org/arg/-/arg-5.0.2.tgz", - "integrity": "sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg==", - "license": "MIT" - }, "node_modules/argparse": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", @@ -1925,20 +2109,9 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true, "license": "MIT" }, - "node_modules/binary-extensions": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", - "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", - "license": "MIT", - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/brace-expansion": { "version": "1.1.11", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", @@ -2011,15 +2184,6 @@ "node": ">=6" } }, - "node_modules/camelcase-css": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/camelcase-css/-/camelcase-css-2.0.1.tgz", - "integrity": "sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, "node_modules/caniuse-lite": { "version": "1.0.30001697", "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001697.tgz", @@ -2107,46 +2271,11 @@ "url": "https://github.com/sponsors/wooorm" } }, - "node_modules/chokidar": { - "version": "3.6.0", - "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", - "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", - "license": "MIT", - "dependencies": { - "anymatch": "~3.1.2", - "braces": "~3.0.2", - "glob-parent": "~5.1.2", - "is-binary-path": "~2.1.0", - "is-glob": "~4.0.1", - "normalize-path": "~3.0.0", - "readdirp": "~3.6.0" - }, - "engines": { - "node": ">= 8.10.0" - }, - "funding": { - "url": "https://paulmillr.com/funding/" - }, - "optionalDependencies": { - "fsevents": "~2.3.2" - } - }, - "node_modules/chokidar/node_modules/glob-parent": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", - "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", - "license": "ISC", - "dependencies": { - "is-glob": "^4.0.1" - }, - "engines": { - "node": ">= 6" - } - }, "node_modules/color-convert": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, "license": "MIT", "dependencies": { "color-name": "~1.1.4" @@ -2159,6 +2288,7 @@ "version": "1.1.4", "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true, "license": "MIT" }, "node_modules/colorjs.io": { @@ -2214,6 +2344,7 @@ "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "dev": true, "license": "MIT", "dependencies": { "path-key": "^3.1.0", @@ -2224,60 +2355,19 @@ "node": ">= 8" } }, - "node_modules/css-selector-tokenizer": { - "version": "0.8.0", - "resolved": "https://registry.npmjs.org/css-selector-tokenizer/-/css-selector-tokenizer-0.8.0.tgz", - "integrity": "sha512-Jd6Ig3/pe62/qe5SBPTN8h8LeUg/pT4lLgtavPf7updwwHpvFzxvOQBHYj2LZDMjUnBzgvIUSjRcf6oT5HzHFg==", - "license": "MIT", - "dependencies": { - "cssesc": "^3.0.0", - "fastparse": "^1.1.2" - } - }, - "node_modules/cssesc": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", - "integrity": "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==", - "license": "MIT", - "bin": { - "cssesc": "bin/cssesc" - }, - "engines": { - "node": ">=4" - } - }, "node_modules/csstype": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==", "license": "MIT" }, - "node_modules/culori": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/culori/-/culori-3.3.0.tgz", - "integrity": "sha512-pHJg+jbuFsCjz9iclQBqyL3B2HLCBF71BwVNujUYEvCeQMvV97R59MNK3R2+jgJ3a1fcZgI9B3vYgz8lzr/BFQ==", - "license": "MIT", - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - } - }, "node_modules/daisyui": { - "version": "4.12.23", - "resolved": "https://registry.npmjs.org/daisyui/-/daisyui-4.12.23.tgz", - "integrity": "sha512-EM38duvxutJ5PD65lO/AFMpcw+9qEy6XAZrTpzp7WyaPeO/l+F/Qiq0ECHHmFNcFXh5aVoALY4MGrrxtCiaQCQ==", + "version": "5.0.12", + "resolved": "https://registry.npmjs.org/daisyui/-/daisyui-5.0.12.tgz", + "integrity": "sha512-01DU0eYBcHgPtuf5fxcrkGkIN6/Uyaqmkle5Yo3ZyW9YVAu036ALZbjv2KH5euvUbeQ4r9q3gAarGcf7Tywhng==", "license": "MIT", - "dependencies": { - "css-selector-tokenizer": "^0.8", - "culori": "^3", - "picocolors": "^1", - "postcss-js": "^4" - }, - "engines": { - "node": ">=16.9.0" - }, "funding": { - "type": "opencollective", - "url": "https://opencollective.com/daisyui" + "url": "https://github.com/saadeghi/daisyui?sponsor=1" } }, "node_modules/debug": { @@ -2326,6 +2416,15 @@ "node": ">=6" } }, + "node_modules/detect-libc": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.3.tgz", + "integrity": "sha512-bwy0MGW55bG41VqxxypOsdSdGqLwXPI/focwgTYCFMbdUiBAxLg9CFzG08sz2aqzknwiX7Hkl0bQENjg8iLByw==", + "license": "Apache-2.0", + "engines": { + "node": ">=8" + } + }, "node_modules/devlop": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz", @@ -2345,35 +2444,24 @@ "integrity": "sha512-SOKO002EqlvBYYKQSew3iymBoN2EQ4BDw/3yprjh7kAfFzjBYkaMNa/pZvcA7HSWlcKSQb9XhPe3wKyQ0x4A8A==", "license": "Apache-2.0" }, - "node_modules/didyoumean": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", - "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==", - "license": "Apache-2.0" - }, - "node_modules/dlv": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz", - "integrity": "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==", - "license": "MIT" - }, - "node_modules/eastasianwidth": { - "version": "0.2.0", - "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", - "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==", - "license": "MIT" - }, "node_modules/electron-to-chromium": { "version": "1.5.91", "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.91.tgz", "integrity": "sha512-sNSHHyq048PFmZY4S90ax61q+gLCs0X0YmcOII9wG9S2XwbVr+h4VW2wWhnbp/Eys3cCwTxVF292W3qPaxIapQ==", "license": "ISC" }, - "node_modules/emoji-regex": { - "version": "9.2.2", - "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", - "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==", - "license": "MIT" + "node_modules/enhanced-resolve": { + "version": "5.18.1", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.1.tgz", + "integrity": "sha512-ZSW3ma5GkcQBIpwZTSRAI8N71Uuwgs93IezB7mf7R60tC8ZbJideoDNKjHn2O9KIlx6rkGTTEk1xUCK2E1Y2Yg==", + "license": "MIT", + "dependencies": { + "graceful-fs": "^4.2.4", + "tapable": "^2.2.0" + }, + "engines": { + "node": ">=10.13.0" + } }, "node_modules/entities": { "version": "4.5.0", @@ -2653,6 +2741,7 @@ "version": "3.3.3", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", + "dev": true, "license": "MIT", "dependencies": { "@nodelib/fs.stat": "^2.0.2", @@ -2669,6 +2758,7 @@ "version": "5.1.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, "license": "ISC", "dependencies": { "is-glob": "^4.0.1" @@ -2691,16 +2781,11 @@ "dev": true, "license": "MIT" }, - "node_modules/fastparse": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/fastparse/-/fastparse-1.1.2.tgz", - "integrity": "sha512-483XLLxTVIwWK3QTrMGRqUfUpoOs/0hbQrl2oz4J0pAcm3A3bu84wxTFqGqkJzewCLdME38xJLJAxBABfQT8sQ==", - "license": "MIT" - }, "node_modules/fastq": { "version": "1.19.0", "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.0.tgz", "integrity": "sha512-7SFSRCNjBQIZH/xZR3iy5iQYR8aGBE0h3VG6/cwlbrpdciNYBMotQav8c1XI3HjHH+NikUpP53nPdlZSdWmFzA==", + "dev": true, "license": "ISC", "dependencies": { "reusify": "^1.0.4" @@ -2769,22 +2854,6 @@ "dev": true, "license": "ISC" }, - "node_modules/foreground-child": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.0.tgz", - "integrity": "sha512-Ld2g8rrAyMYFXBhEqMz8ZAHBi4J4uS1i/CxGMDnjyFWddMXLVcDp051DZfu+t7+ab7Wv6SMqpWmyFIj5UbfFvg==", - "license": "ISC", - "dependencies": { - "cross-spawn": "^7.0.0", - "signal-exit": "^4.0.1" - }, - "engines": { - "node": ">=14" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/fraction.js": { "version": "4.3.7", "resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.3.7.tgz", @@ -2812,15 +2881,6 @@ "node": "^8.16.0 || ^10.6.0 || >=11.0.0" } }, - "node_modules/function-bind": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", - "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/gensync": { "version": "1.0.0-beta.2", "resolved": "https://registry.npmjs.org/gensync/-/gensync-1.0.0-beta.2.tgz", @@ -2831,30 +2891,11 @@ "node": ">=6.9.0" } }, - "node_modules/glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, "license": "ISC", "dependencies": { "is-glob": "^4.0.3" @@ -2863,30 +2904,6 @@ "node": ">=10.13.0" } }, - "node_modules/glob/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", - "license": "MIT", - "dependencies": { - "balanced-match": "^1.0.0" - } - }, - "node_modules/glob/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.1" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/globals": { "version": "15.14.0", "resolved": "https://registry.npmjs.org/globals/-/globals-15.14.0.tgz", @@ -2900,6 +2917,12 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "license": "ISC" + }, "node_modules/graphemer": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", @@ -2917,18 +2940,6 @@ "node": ">=8" } }, - "node_modules/hasown": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", - "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", - "license": "MIT", - "dependencies": { - "function-bind": "^1.1.2" - }, - "engines": { - "node": ">= 0.4" - } - }, "node_modules/hast-util-from-dom": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/hast-util-from-dom/-/hast-util-from-dom-5.0.1.tgz", @@ -3190,33 +3201,6 @@ "url": "https://github.com/sponsors/wooorm" } }, - "node_modules/is-binary-path": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", - "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", - "license": "MIT", - "dependencies": { - "binary-extensions": "^2.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/is-core-module": { - "version": "2.16.1", - "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.16.1.tgz", - "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", - "license": "MIT", - "dependencies": { - "hasown": "^2.0.2" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/is-decimal": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/is-decimal/-/is-decimal-2.0.1.tgz", @@ -3231,24 +3215,17 @@ "version": "2.1.1", "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, "license": "MIT", "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-fullwidth-code-point": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", - "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", - "license": "MIT", - "engines": { - "node": ">=8" + "node": ">=0.10.0" } }, "node_modules/is-glob": { "version": "4.0.3", "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, "license": "MIT", "dependencies": { "is-extglob": "^2.1.1" @@ -3292,30 +3269,16 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true, "license": "ISC" }, - "node_modules/jackspeak": { - "version": "3.4.3", - "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", - "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", - "license": "BlueOak-1.0.0", - "dependencies": { - "@isaacs/cliui": "^8.0.2" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - }, - "optionalDependencies": { - "@pkgjs/parseargs": "^0.11.0" - } - }, "node_modules/jiti": { - "version": "1.21.7", - "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz", - "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.4.2.tgz", + "integrity": "sha512-rg9zJN+G4n2nfJl5MW3BMygZX56zKPNVEYYqq7adpmMh4Jn2QNEwhvQlFy6jPVdcod7txZtKHWnyZiA3a0zP7A==", "license": "MIT", "bin": { - "jiti": "bin/jiti.js" + "jiti": "lib/jiti-cli.mjs" } }, "node_modules/js-tokens": { @@ -3424,23 +3387,233 @@ "node": ">= 0.8.0" } }, - "node_modules/lilconfig": { - "version": "3.1.3", - "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz", - "integrity": "sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==", - "license": "MIT", + "node_modules/lightningcss": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss/-/lightningcss-1.29.2.tgz", + "integrity": "sha512-6b6gd/RUXKaw5keVdSEtqFVdzWnU5jMxTUjA2bVcMNPLwSQ08Sv/UodBVtETLCn7k4S1Ibxwh7k68IwLZPgKaA==", + "license": "MPL-2.0", + "dependencies": { + "detect-libc": "^2.0.3" + }, "engines": { - "node": ">=14" + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + }, + "optionalDependencies": { + "lightningcss-darwin-arm64": "1.29.2", + "lightningcss-darwin-x64": "1.29.2", + "lightningcss-freebsd-x64": "1.29.2", + "lightningcss-linux-arm-gnueabihf": "1.29.2", + "lightningcss-linux-arm64-gnu": "1.29.2", + "lightningcss-linux-arm64-musl": "1.29.2", + "lightningcss-linux-x64-gnu": "1.29.2", + "lightningcss-linux-x64-musl": "1.29.2", + "lightningcss-win32-arm64-msvc": "1.29.2", + "lightningcss-win32-x64-msvc": "1.29.2" + } + }, + "node_modules/lightningcss-darwin-arm64": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-arm64/-/lightningcss-darwin-arm64-1.29.2.tgz", + "integrity": "sha512-cK/eMabSViKn/PG8U/a7aCorpeKLMlK0bQeNHmdb7qUnBkNPnL+oV5DjJUo0kqWsJUapZsM4jCfYItbqBDvlcA==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 12.0.0" }, "funding": { - "url": "https://github.com/sponsors/antonk52" + "type": "opencollective", + "url": "https://opencollective.com/parcel" } }, - "node_modules/lines-and-columns": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz", - "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==", - "license": "MIT" + "node_modules/lightningcss-darwin-x64": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-x64/-/lightningcss-darwin-x64-1.29.2.tgz", + "integrity": "sha512-j5qYxamyQw4kDXX5hnnCKMf3mLlHvG44f24Qyi2965/Ycz829MYqjrVg2H8BidybHBp9kom4D7DR5VqCKDXS0w==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-freebsd-x64": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-freebsd-x64/-/lightningcss-freebsd-x64-1.29.2.tgz", + "integrity": "sha512-wDk7M2tM78Ii8ek9YjnY8MjV5f5JN2qNVO+/0BAGZRvXKtQrBC4/cn4ssQIpKIPP44YXw6gFdpUF+Ps+RGsCwg==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm-gnueabihf": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm-gnueabihf/-/lightningcss-linux-arm-gnueabihf-1.29.2.tgz", + "integrity": "sha512-IRUrOrAF2Z+KExdExe3Rz7NSTuuJ2HvCGlMKoquK5pjvo2JY4Rybr+NrKnq0U0hZnx5AnGsuFHjGnNT14w26sg==", + "cpu": [ + "arm" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm64-gnu": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-gnu/-/lightningcss-linux-arm64-gnu-1.29.2.tgz", + "integrity": "sha512-KKCpOlmhdjvUTX/mBuaKemp0oeDIBBLFiU5Fnqxh1/DZ4JPZi4evEH7TKoSBFOSOV3J7iEmmBaw/8dpiUvRKlQ==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm64-musl": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-musl/-/lightningcss-linux-arm64-musl-1.29.2.tgz", + "integrity": "sha512-Q64eM1bPlOOUgxFmoPUefqzY1yV3ctFPE6d/Vt7WzLW4rKTv7MyYNky+FWxRpLkNASTnKQUaiMJ87zNODIrrKQ==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-x64-gnu": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-gnu/-/lightningcss-linux-x64-gnu-1.29.2.tgz", + "integrity": "sha512-0v6idDCPG6epLXtBH/RPkHvYx74CVziHo6TMYga8O2EiQApnUPZsbR9nFNrg2cgBzk1AYqEd95TlrsL7nYABQg==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-x64-musl": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-musl/-/lightningcss-linux-x64-musl-1.29.2.tgz", + "integrity": "sha512-rMpz2yawkgGT8RULc5S4WiZopVMOFWjiItBT7aSfDX4NQav6M44rhn5hjtkKzB+wMTRlLLqxkeYEtQ3dd9696w==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-win32-arm64-msvc": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-win32-arm64-msvc/-/lightningcss-win32-arm64-msvc-1.29.2.tgz", + "integrity": "sha512-nL7zRW6evGQqYVu/bKGK+zShyz8OVzsCotFgc7judbt6wnB2KbiKKJwBE4SGoDBQ1O94RjW4asrCjQL4i8Fhbw==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-win32-x64-msvc": { + "version": "1.29.2", + "resolved": "https://registry.npmjs.org/lightningcss-win32-x64-msvc/-/lightningcss-win32-x64-msvc-1.29.2.tgz", + "integrity": "sha512-EdIUW3B2vLuHmv7urfzMI/h2fmlnOQBk1xlsDxkN1tCWKjNFjfLhGxYk8C8mzpSfr+A6jFFIi8fU6LbQGsRWjA==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } }, "node_modules/locate-path": { "version": "6.0.0", @@ -3841,6 +4014,7 @@ "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "dev": true, "license": "MIT", "engines": { "node": ">= 8" @@ -4454,32 +4628,12 @@ "node": "*" } }, - "node_modules/minipass": { - "version": "7.1.2", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", - "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", - "license": "ISC", - "engines": { - "node": ">=16 || 14 >=14.17" - } - }, "node_modules/ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "license": "MIT" }, - "node_modules/mz": { - "version": "2.7.0", - "resolved": "https://registry.npmjs.org/mz/-/mz-2.7.0.tgz", - "integrity": "sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==", - "license": "MIT", - "dependencies": { - "any-promise": "^1.0.0", - "object-assign": "^4.0.1", - "thenify-all": "^1.0.0" - } - }, "node_modules/nanoid": { "version": "3.3.8", "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.8.tgz", @@ -4511,15 +4665,6 @@ "integrity": "sha512-xxOWJsBKtzAq7DY0J+DTzuz58K8e7sJbdgwkbMWQe8UYB6ekmsQ45q0M/tJDsGaZmbC+l7n57UV8Hl5tHxO9uw==", "license": "MIT" }, - "node_modules/normalize-path": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", - "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/normalize-range": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/normalize-range/-/normalize-range-0.1.2.tgz", @@ -4529,24 +4674,6 @@ "node": ">=0.10.0" } }, - "node_modules/object-assign": { - "version": "4.1.1", - "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", - "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/object-hash": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-3.0.0.tgz", - "integrity": "sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw==", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, "node_modules/optionator": { "version": "0.9.4", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", @@ -4597,12 +4724,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/package-json-from-dist": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", - "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==", - "license": "BlueOak-1.0.0" - }, "node_modules/parent-module": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", @@ -4667,182 +4788,42 @@ "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true, "license": "MIT", "engines": { "node": ">=8" } }, - "node_modules/path-parse": { - "version": "1.0.7", - "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", - "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", - "license": "MIT" - }, - "node_modules/path-scurry": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", - "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", - "license": "BlueOak-1.0.0", - "dependencies": { - "lru-cache": "^10.2.0", - "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" - }, - "engines": { - "node": ">=16 || 14 >=14.18" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/path-scurry/node_modules/lru-cache": { - "version": "10.4.3", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", - "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", - "license": "ISC" - }, "node_modules/picocolors": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", - "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", - "license": "ISC" - }, - "node_modules/picomatch": { - "version": "2.3.1", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", - "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", - "license": "MIT", - "engines": { - "node": ">=8.6" - }, - "funding": { - "url": "https://github.com/sponsors/jonschlinkert" - } - }, - "node_modules/pify": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", - "integrity": "sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/pirates": { - "version": "4.0.6", - "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.6.tgz", - "integrity": "sha512-saLsH7WeYYPiD25LDuLRRY/i+6HaPYr6G1OUlN39otzkSTxKnubR9RTxS3/Kk50s1g2JTgFwWQDQyplC5/SHZg==", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, - "node_modules/postcss": { - "version": "8.5.1", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.1.tgz", - "integrity": "sha512-6oz2beyjc5VMn/KV1pPw8fliQkhBXrVn1Z3TVyqZxU8kZpzEKhBdmCFqI6ZbmGtamQvQGuU1sgPTk8ZrXDD7jQ==", - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/postcss/" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/postcss" - }, - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "dependencies": { - "nanoid": "^3.3.8", - "picocolors": "^1.1.1", - "source-map-js": "^1.2.1" - }, - "engines": { - "node": "^10 || ^12 || >=14" - } - }, - "node_modules/postcss-import": { - "version": "15.1.0", - "resolved": "https://registry.npmjs.org/postcss-import/-/postcss-import-15.1.0.tgz", - "integrity": "sha512-hpr+J05B2FVYUAXHeK1YyI267J/dDDhMU6B6civm8hSY1jYJnBXxzKDKDswzJmtLHryrjhnDjqqp/49t8FALew==", - "license": "MIT", - "dependencies": { - "postcss-value-parser": "^4.0.0", - "read-cache": "^1.0.0", - "resolve": "^1.1.7" - }, - "engines": { - "node": ">=14.0.0" - }, - "peerDependencies": { - "postcss": "^8.0.0" - } + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", + "license": "ISC" }, - "node_modules/postcss-js": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/postcss-js/-/postcss-js-4.0.1.tgz", - "integrity": "sha512-dDLF8pEO191hJMtlHFPRa8xsizHaM82MLfNkUHdUtVEV3tgTp5oj+8qbEqYM57SLfc74KSbw//4SeJma2LRVIw==", + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", "license": "MIT", - "dependencies": { - "camelcase-css": "^2.0.1" - }, "engines": { - "node": "^12 || ^14 || >= 16" + "node": ">=8.6" }, "funding": { - "type": "opencollective", - "url": "https://opencollective.com/postcss/" - }, - "peerDependencies": { - "postcss": "^8.4.21" + "url": "https://github.com/sponsors/jonschlinkert" } }, - "node_modules/postcss-load-config": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/postcss-load-config/-/postcss-load-config-4.0.2.tgz", - "integrity": "sha512-bSVhyJGL00wMVoPUzAVAnbEoWyqRxkjv64tUl427SKnPrENtq6hJwUojroMz2VB+Q1edmi4IfrAPpami5VVgMQ==", + "node_modules/postcss": { + "version": "8.5.1", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.1.tgz", + "integrity": "sha512-6oz2beyjc5VMn/KV1pPw8fliQkhBXrVn1Z3TVyqZxU8kZpzEKhBdmCFqI6ZbmGtamQvQGuU1sgPTk8ZrXDD7jQ==", "funding": [ { "type": "opencollective", "url": "https://opencollective.com/postcss/" }, { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "dependencies": { - "lilconfig": "^3.0.0", - "yaml": "^2.3.4" - }, - "engines": { - "node": ">= 14" - }, - "peerDependencies": { - "postcss": ">=8.0.9", - "ts-node": ">=9.0.0" - }, - "peerDependenciesMeta": { - "postcss": { - "optional": true - }, - "ts-node": { - "optional": true - } - } - }, - "node_modules/postcss-nested": { - "version": "6.2.0", - "resolved": "https://registry.npmjs.org/postcss-nested/-/postcss-nested-6.2.0.tgz", - "integrity": "sha512-HQbt28KulC5AJzG+cZtj9kvKB93CFCdLvog1WFLf1D+xmMvPGlBstkpTEZfK5+AN9hfJocyBFCNiqyS48bpgzQ==", - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/postcss/" + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" }, { "type": "github", @@ -4851,26 +4832,12 @@ ], "license": "MIT", "dependencies": { - "postcss-selector-parser": "^6.1.1" - }, - "engines": { - "node": ">=12.0" - }, - "peerDependencies": { - "postcss": "^8.2.14" - } - }, - "node_modules/postcss-selector-parser": { - "version": "6.1.2", - "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.1.2.tgz", - "integrity": "sha512-Q8qQfPiZ+THO/3ZrOrO0cJJKfpYCagtMUkXbnEfmgUjwXg6z/WBeOyS9APBBPCTSiDV+s4SwQGu8yFsiMRIudg==", - "license": "MIT", - "dependencies": { - "cssesc": "^3.0.0", - "util-deprecate": "^1.0.2" + "nanoid": "^3.3.8", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" }, "engines": { - "node": ">=4" + "node": "^10 || ^12 || >=14" } }, "node_modules/postcss-value-parser": { @@ -4929,6 +4896,7 @@ "version": "1.2.3", "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "dev": true, "funding": [ { "type": "github", @@ -5030,27 +4998,6 @@ } } }, - "node_modules/read-cache": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", - "integrity": "sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA==", - "license": "MIT", - "dependencies": { - "pify": "^2.3.0" - } - }, - "node_modules/readdirp": { - "version": "3.6.0", - "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", - "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", - "license": "MIT", - "dependencies": { - "picomatch": "^2.2.1" - }, - "engines": { - "node": ">=8.10.0" - } - }, "node_modules/rehype-highlight": { "version": "7.0.2", "resolved": "https://registry.npmjs.org/rehype-highlight/-/rehype-highlight-7.0.2.tgz", @@ -5184,26 +5131,6 @@ "url": "https://opencollective.com/unified" } }, - "node_modules/resolve": { - "version": "1.22.10", - "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.10.tgz", - "integrity": "sha512-NPRy+/ncIMeDlTAsuqwKIiferiawhefFJtkNSW0qZJEqMEb+qBt/77B/jGeeek+F0uOeN05CDa6HXbbIgtVX4w==", - "license": "MIT", - "dependencies": { - "is-core-module": "^2.16.0", - "path-parse": "^1.0.7", - "supports-preserve-symlinks-flag": "^1.0.0" - }, - "bin": { - "resolve": "bin/resolve" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/resolve-from": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", @@ -5218,6 +5145,7 @@ "version": "1.0.4", "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "dev": true, "license": "MIT", "engines": { "iojs": ">=1.0.0", @@ -5266,6 +5194,7 @@ "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "dev": true, "funding": [ { "type": "github", @@ -5705,6 +5634,7 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, "license": "MIT", "dependencies": { "shebang-regex": "^3.0.0" @@ -5717,23 +5647,12 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true, "license": "MIT", "engines": { "node": ">=8" } }, - "node_modules/signal-exit": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", - "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", - "license": "ISC", - "engines": { - "node": ">=14" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/source-map-js": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", @@ -5753,65 +5672,6 @@ "url": "https://github.com/sponsors/wooorm" } }, - "node_modules/string-width": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz", - "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==", - "license": "MIT", - "dependencies": { - "eastasianwidth": "^0.2.0", - "emoji-regex": "^9.2.2", - "strip-ansi": "^7.0.1" - }, - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/string-width-cjs": { - "name": "string-width", - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", - "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", - "license": "MIT", - "dependencies": { - "emoji-regex": "^8.0.0", - "is-fullwidth-code-point": "^3.0.0", - "strip-ansi": "^6.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/string-width-cjs/node_modules/ansi-regex": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", - "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/string-width-cjs/node_modules/emoji-regex": { - "version": "8.0.0", - "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", - "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", - "license": "MIT" - }, - "node_modules/string-width-cjs/node_modules/strip-ansi": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", - "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", - "license": "MIT", - "dependencies": { - "ansi-regex": "^5.0.1" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/stringify-entities": { "version": "4.0.4", "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz", @@ -5826,43 +5686,6 @@ "url": "https://github.com/sponsors/wooorm" } }, - "node_modules/strip-ansi": { - "version": "7.1.0", - "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz", - "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==", - "license": "MIT", - "dependencies": { - "ansi-regex": "^6.0.1" - }, - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/chalk/strip-ansi?sponsor=1" - } - }, - "node_modules/strip-ansi-cjs": { - "name": "strip-ansi", - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", - "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", - "license": "MIT", - "dependencies": { - "ansi-regex": "^5.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/strip-ansi-cjs/node_modules/ansi-regex": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", - "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", - "license": "MIT", - "engines": { - "node": ">=8" - } - }, "node_modules/strip-json-comments": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", @@ -5885,37 +5708,6 @@ "inline-style-parser": "0.2.4" } }, - "node_modules/sucrase": { - "version": "3.35.0", - "resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.35.0.tgz", - "integrity": "sha512-8EbVDiu9iN/nESwxeSxDKe0dunta1GOlHufmSSXxMD2z2/tMZpDMpvXQGsc+ajGo8y2uYUmixaSRUc/QPoQ0GA==", - "license": "MIT", - "dependencies": { - "@jridgewell/gen-mapping": "^0.3.2", - "commander": "^4.0.0", - "glob": "^10.3.10", - "lines-and-columns": "^1.1.6", - "mz": "^2.7.0", - "pirates": "^4.0.1", - "ts-interface-checker": "^0.1.9" - }, - "bin": { - "sucrase": "bin/sucrase", - "sucrase-node": "bin/sucrase-node" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - } - }, - "node_modules/sucrase/node_modules/commander": { - "version": "4.1.1", - "resolved": "https://registry.npmjs.org/commander/-/commander-4.1.1.tgz", - "integrity": "sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, "node_modules/supports-color": { "version": "7.2.0", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", @@ -5929,18 +5721,6 @@ "node": ">=8" } }, - "node_modules/supports-preserve-symlinks-flag": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", - "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", - "license": "MIT", - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/sync-child-process": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/sync-child-process/-/sync-child-process-1.0.2.tgz", @@ -5965,40 +5745,18 @@ } }, "node_modules/tailwindcss": { - "version": "3.4.17", - "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.17.tgz", - "integrity": "sha512-w33E2aCvSDP0tW9RZuNXadXlkHXqFzSkQew/aIa2i/Sj8fThxwovwlXHSPXTbAHwEIhBFXAedUhP2tueAKP8Og==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.1.tgz", + "integrity": "sha512-QNbdmeS979Efzim2g/bEvfuh+fTcIdp1y7gA+sb6OYSW74rt7Cr7M78AKdf6HqWT3d5AiTb7SwTT3sLQxr4/qw==", + "license": "MIT" + }, + "node_modules/tapable": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", + "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", "license": "MIT", - "dependencies": { - "@alloc/quick-lru": "^5.2.0", - "arg": "^5.0.2", - "chokidar": "^3.6.0", - "didyoumean": "^1.2.2", - "dlv": "^1.1.3", - "fast-glob": "^3.3.2", - "glob-parent": "^6.0.2", - "is-glob": "^4.0.3", - "jiti": "^1.21.6", - "lilconfig": "^3.1.3", - "micromatch": "^4.0.8", - "normalize-path": "^3.0.0", - "object-hash": "^3.0.0", - "picocolors": "^1.1.1", - "postcss": "^8.4.47", - "postcss-import": "^15.1.0", - "postcss-js": "^4.0.1", - "postcss-load-config": "^4.0.2", - "postcss-nested": "^6.2.0", - "postcss-selector-parser": "^6.1.2", - "resolve": "^1.22.8", - "sucrase": "^3.35.0" - }, - "bin": { - "tailwind": "lib/cli.js", - "tailwindcss": "lib/cli.js" - }, "engines": { - "node": ">=14.0.0" + "node": ">=6" } }, "node_modules/textlinestream": { @@ -6007,27 +5765,6 @@ "integrity": "sha512-iBHbi7BQxrFmwZUQJsT0SjNzlLLsXhvW/kg7EyOMVMBIrlnj/qYofwo1LVLZi+3GbUEo96Iu2eqToI2+lZoAEQ==", "license": "MIT" }, - "node_modules/thenify": { - "version": "3.3.1", - "resolved": "https://registry.npmjs.org/thenify/-/thenify-3.3.1.tgz", - "integrity": "sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==", - "license": "MIT", - "dependencies": { - "any-promise": "^1.0.0" - } - }, - "node_modules/thenify-all": { - "version": "1.6.0", - "resolved": "https://registry.npmjs.org/thenify-all/-/thenify-all-1.6.0.tgz", - "integrity": "sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==", - "license": "MIT", - "dependencies": { - "thenify": ">= 3.1.0 < 4" - }, - "engines": { - "node": ">=0.8" - } - }, "node_modules/to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", @@ -6073,12 +5810,6 @@ "typescript": ">=4.8.4" } }, - "node_modules/ts-interface-checker": { - "version": "0.1.13", - "resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz", - "integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==", - "license": "Apache-2.0" - }, "node_modules/tslib": { "version": "2.8.1", "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", @@ -6304,12 +6035,6 @@ "punycode": "^2.1.0" } }, - "node_modules/util-deprecate": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", - "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", - "license": "MIT" - }, "node_modules/varint": { "version": "6.0.0", "resolved": "https://registry.npmjs.org/varint/-/varint-6.0.0.tgz", @@ -6460,6 +6185,7 @@ "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, "license": "ISC", "dependencies": { "isexe": "^2.0.0" @@ -6481,94 +6207,6 @@ "node": ">=0.10.0" } }, - "node_modules/wrap-ansi": { - "version": "8.1.0", - "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz", - "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==", - "license": "MIT", - "dependencies": { - "ansi-styles": "^6.1.0", - "string-width": "^5.0.1", - "strip-ansi": "^7.0.1" - }, - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/chalk/wrap-ansi?sponsor=1" - } - }, - "node_modules/wrap-ansi-cjs": { - "name": "wrap-ansi", - "version": "7.0.0", - "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", - "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", - "license": "MIT", - "dependencies": { - "ansi-styles": "^4.0.0", - "string-width": "^4.1.0", - "strip-ansi": "^6.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/wrap-ansi?sponsor=1" - } - }, - "node_modules/wrap-ansi-cjs/node_modules/ansi-regex": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", - "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/wrap-ansi-cjs/node_modules/emoji-regex": { - "version": "8.0.0", - "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", - "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", - "license": "MIT" - }, - "node_modules/wrap-ansi-cjs/node_modules/string-width": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", - "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", - "license": "MIT", - "dependencies": { - "emoji-regex": "^8.0.0", - "is-fullwidth-code-point": "^3.0.0", - "strip-ansi": "^6.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/wrap-ansi-cjs/node_modules/strip-ansi": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", - "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", - "license": "MIT", - "dependencies": { - "ansi-regex": "^5.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/wrap-ansi/node_modules/ansi-styles": { - "version": "6.2.1", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz", - "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==", - "license": "MIT", - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, "node_modules/yallist": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", @@ -6581,6 +6219,8 @@ "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.7.0.tgz", "integrity": "sha512-+hSoy/QHluxmC9kCIJyL/uyFmLmc+e5CFR5Wa+bpIhIj85LVb9ZH2nVnqrHoSvKogwODv0ClqZkmiSSaIH5LTA==", "license": "ISC", + "optional": true, + "peer": true, "bin": { "yaml": "bin.mjs" }, diff --git a/examples/server/webui/package.json b/examples/server/webui/package.json index 8c833d0241e1c..6ac06b1a49bd3 100644 --- a/examples/server/webui/package.json +++ b/examples/server/webui/package.json @@ -13,9 +13,11 @@ "dependencies": { "@heroicons/react": "^2.2.0", "@sec-ant/readable-stream": "^0.6.0", + "@tailwindcss/postcss": "^4.1.1", + "@tailwindcss/vite": "^4.1.1", "@vscode/markdown-it-katex": "^1.1.1", "autoprefixer": "^10.4.20", - "daisyui": "^4.12.14", + "daisyui": "^5.0.12", "dexie": "^4.0.11", "highlight.js": "^11.10.0", "katex": "^0.16.15", @@ -29,7 +31,7 @@ "remark-breaks": "^4.0.0", "remark-gfm": "^4.0.0", "remark-math": "^6.0.0", - "tailwindcss": "^3.4.15", + "tailwindcss": "^4.1.1", "textlinestream": "^1.1.1", "vite-plugin-singlefile": "^2.0.3" }, diff --git a/examples/server/webui/postcss.config.js b/examples/server/webui/postcss.config.js index 2e7af2b7f1a6f..fb05b5692bba7 100644 --- a/examples/server/webui/postcss.config.js +++ b/examples/server/webui/postcss.config.js @@ -1,6 +1,5 @@ export default { plugins: { - tailwindcss: {}, - autoprefixer: {}, + "@tailwindcss/postcss": {}, }, } diff --git a/examples/server/webui/src/App.tsx b/examples/server/webui/src/App.tsx index 2ce734682cff0..cc4659e1529fe 100644 --- a/examples/server/webui/src/App.tsx +++ b/examples/server/webui/src/App.tsx @@ -28,7 +28,7 @@ function AppLayout() { <>
diff --git a/examples/server/webui/src/Config.ts b/examples/server/webui/src/Config.ts index 779ed9bf7840c..dd1cc0e100a2e 100644 --- a/examples/server/webui/src/Config.ts +++ b/examples/server/webui/src/Config.ts @@ -1,4 +1,4 @@ -import daisyuiThemes from 'daisyui/src/theming/themes'; +import daisyuiThemes from 'daisyui/theme/object'; import { isNumeric } from './utils/misc'; export const isDev = import.meta.env.MODE === 'development'; diff --git a/examples/server/webui/src/components/Header.tsx b/examples/server/webui/src/components/Header.tsx index cbee394ba2e0c..4c6b291e61bcb 100644 --- a/examples/server/webui/src/components/Header.tsx +++ b/examples/server/webui/src/components/Header.tsx @@ -2,7 +2,7 @@ import { useEffect, useState } from 'react'; import StorageUtils from '../utils/storage'; import { useAppContext } from '../utils/app.context'; import { classNames } from '../utils/misc'; -import daisyuiThemes from 'daisyui/src/theming/themes'; +import daisyuiThemes from 'daisyui/theme/object'; import { THEMES } from '../Config'; import { useNavigate } from 'react-router'; @@ -20,7 +20,6 @@ export default function Header() { document.body.setAttribute('data-theme', selectedTheme); document.body.setAttribute( 'data-color-scheme', - // @ts-expect-error daisyuiThemes complains about index type, but it should work daisyuiThemes[selectedTheme]?.['color-scheme'] ?? 'auto' ); }, [selectedTheme]); diff --git a/examples/server/webui/src/index.scss b/examples/server/webui/src/index.scss index 314b6823222da..a18f094542eb2 100644 --- a/examples/server/webui/src/index.scss +++ b/examples/server/webui/src/index.scss @@ -1,8 +1,13 @@ @use 'sass:meta'; +@use 'tailwindcss'; -@tailwind base; -@tailwind components; -@tailwind utilities; +@plugin 'daisyui' { + themes: all; +} + +html { + scrollbar-gutter: auto; +} .markdown { h1, diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a5d2bc9d09de7..0783ed4a4c43e 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -24,7 +24,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.speculative.model.empty()) { + if (params.speculative.model.path.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 627d01bbcb5ad..561c308830351 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -46,7 +46,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.speculative.model.empty()) { + if (params.speculative.model.path.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 4cc42e1674ccc..0f047986965f8 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -577,12 +577,7 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_model_get_vocab(model_ttc); - // TODO: refactor in a common struct - params.model = params.vocoder.model; - params.model_url = params.vocoder.model_url; - params.hf_repo = params.vocoder.hf_repo; - params.hf_file = params.vocoder.hf_file; - + params.model = params.vocoder.model; params.embedding = true; common_init_result llama_init_cts = common_init_from_params(params); @@ -699,11 +694,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 const std::string voice_data = audio_data; auto tmp = common_tokenize(vocab, voice_data, false, true); - printf("\n\n"); + + std::ostringstream tokens_oss; for (size_t i = 0; i < tmp.size(); ++i) { - printf("%d, ", tmp[i]); + tokens_oss << tmp[i] << ", "; } - printf("\n\n"); + LOG_INF("\n\n%s: llama tokens: %s\n\n", __func__, tokens_oss.str().c_str()); + prompt_add(prompt_inp, tmp); #else prompt_add(prompt_inp, llama_tokens { diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 1e180aa8cb0e4..e7c29bece7cc7 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -100,6 +100,10 @@ else() set(INS_ENB ON) endif() +message(DEBUG "GGML_NATIVE : ${GGML_NATIVE}") +message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}") +message(DEBUG "INS_ENB : ${INS_ENB}") + option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF) @@ -123,10 +127,12 @@ endif() option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LSX "ggml: enable lsx" ON) option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) option(GGML_VXE "ggml: enable vxe" ON) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) -set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") +set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") +set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC") if (WIN32) diff --git a/ggml/cmake/GitVars.cmake b/ggml/cmake/GitVars.cmake new file mode 100644 index 0000000000000..1a4c24ebf6ade --- /dev/null +++ b/ggml/cmake/GitVars.cmake @@ -0,0 +1,22 @@ +find_package(Git) + +# the commit's SHA1 +execute_process(COMMAND + "${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8 + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_SHA1 + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the date of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%ad --date=local + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_DATE + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the subject of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%s + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_COMMIT_SUBJECT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 823eb797b7007..8c2dc31c6da5b 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -5,7 +5,7 @@ set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@") set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@") -set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") +#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") find_package(Threads REQUIRED) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index ade6c3b0ef13b..4e0d210f8ec97 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -17,7 +17,9 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); -GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, + const char * cache_dir, + size_t free_mem, size_t total_mem); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index cb3edb10d4702..452c967b0a637 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1791,11 +1791,11 @@ extern "C" { #define GGML_KQ_MASK_PAD 64 - // q: [n_embd, n_batch, n_head, 1] - // k: [n_embd, n_kv, n_head_kv, 1] - // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd, n_head, n_batch, 1] !! permuted !! + // q: [n_embd_k, n_batch, n_head, 1] + // k: [n_embd_k, n_kv, n_head_kv, 1] + // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd_v, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index ab15d45dcab38..776dcf4cfc85b 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -65,7 +65,7 @@ if (GGML_LTO) endif() endif() -if (GGML_CCACHE) +if (GGML_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER) find_program(GGML_CCACHE_FOUND ccache) find_program(GGML_SCCACHE_FOUND sccache) diff --git a/ggml/src/ggml-cann/.clang-format b/ggml/src/ggml-cann/.clang-format deleted file mode 100644 index 2ad03d7391936..0000000000000 --- a/ggml/src/ggml-cann/.clang-format +++ /dev/null @@ -1,168 +0,0 @@ ---- -Language: Cpp -# BasedOnStyle: Google -AccessModifierOffset: -1 -AlignAfterOpenBracket: Align -AlignConsecutiveMacros: false -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignEscapedNewlines: Left -AlignOperands: true -AlignTrailingComments: true -AllowAllArgumentsOnNextLine: true -AllowAllConstructorInitializersOnNextLine: true -AllowAllParametersOfDeclarationOnNextLine: true -AllowShortBlocksOnASingleLine: Never -AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: All -AllowShortLambdasOnASingleLine: All -AllowShortIfStatementsOnASingleLine: WithoutElse -AllowShortLoopsOnASingleLine: true -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: Yes -BinPackArguments: true -BinPackParameters: true -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false - SplitEmptyFunction: true - SplitEmptyRecord: true - SplitEmptyNamespace: true -BreakBeforeBinaryOperators: None -BreakBeforeBraces: Attach -BreakBeforeInheritanceComma: false -BreakInheritanceList: BeforeColon -BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakConstructorInitializers: BeforeColon -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: true -ColumnLimit: 80 -CommentPragmas: '^ IWYU pragma:' -CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: true -ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DeriveLineEnding: true -DerivePointerAlignment: true -DisableFormat: false -ExperimentalAutoDetectBinPacking: false -FixNamespaceComments: true -ForEachMacros: - - foreach - - Q_FOREACH - - BOOST_FOREACH -IncludeBlocks: Regroup -IncludeCategories: - - Regex: '^' - Priority: 2 - SortPriority: 0 - - Regex: '^<.*\.h>' - Priority: 1 - SortPriority: 0 - - Regex: '^<.*' - Priority: 2 - SortPriority: 0 - - Regex: '.*' - Priority: 3 - SortPriority: 0 -IncludeIsMainRegex: '([-_](test|unittest))?$' -IncludeIsMainSourceRegex: '' -IndentCaseLabels: true -IndentGotoLabels: true -IndentPPDirectives: None -IndentWidth: 4 -IndentWrappedFunctionNames: false -JavaScriptQuotes: Leave -JavaScriptWrapImports: true -KeepEmptyLinesAtTheStartOfBlocks: false -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBinPackProtocolList: Never -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PenaltyBreakAssignment: 2 -PenaltyBreakBeforeFirstCallParameter: 1 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left -RawStringFormats: - - Language: Cpp - Delimiters: - - cc - - CC - - cpp - - Cpp - - CPP - - 'c++' - - 'C++' - CanonicalDelimiter: '' - BasedOnStyle: google - - Language: TextProto - Delimiters: - - pb - - PB - - proto - - PROTO - EnclosingFunctions: - - EqualsProto - - EquivToProto - - PARSE_PARTIAL_TEXT_PROTO - - PARSE_TEST_PROTO - - PARSE_TEXT_PROTO - - ParseTextOrDie - - ParseTextProtoOrDie - CanonicalDelimiter: '' - BasedOnStyle: google -ReflowComments: true -SortIncludes: true -SortUsingDeclarations: true -SpaceAfterCStyleCast: false -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyBlock: false -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 2 -SpacesInAngles: false -SpacesInConditionalStatement: false -SpacesInContainerLiterals: true -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -SpaceBeforeSquareBrackets: false -Standard: Auto -StatementMacros: - - Q_UNUSED - - QT_REQUIRE_VERSION -TabWidth: 8 -UseCRLF: false -UseTab: Never -... - diff --git a/ggml/src/ggml-cann/CMakeLists.txt b/ggml/src/ggml-cann/CMakeLists.txt index 05cf06bfab4fc..0d8e483b291c7 100644 --- a/ggml/src/ggml-cann/CMakeLists.txt +++ b/ggml/src/ggml-cann/CMakeLists.txt @@ -51,13 +51,11 @@ if (CANN_INSTALL_DIR) ${CANN_INSTALL_DIR}/acllib/include ) - add_subdirectory(kernels) list(APPEND CANN_LIBRARIES ascendcl nnopbase opapi acl_op_compiler - ascendc_kernels ) file(GLOB GGML_SOURCES_CANN "*.cpp") diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp index d120ce6acf8a7..9b6553c500129 100644 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -54,9 +54,7 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, // added. int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2]; - int64_t acl_storage_len = 0; if (ne == nullptr) { - acl_storage_len = ggml_nbytes(tensor); for (int i = 0; i < GGML_MAX_DIMS; i++) { acl_ne[i] = tensor->ne[i]; // The step size of acl is in elements. @@ -65,14 +63,18 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, } else { // With bcast for (int i = 0; i < dims; i++) { - acl_storage_len += (ne[i] - 1) * nb[i]; acl_ne[i] = ne[i]; acl_stride[i] = nb[i] / ggml_element_size(tensor); } } - // Reverse ne and stride. int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims); + int64_t acl_storage_len = 1; + for (int i = 0; i < final_dims; i++) { + acl_storage_len += (acl_ne[i] - 1) * acl_stride[i]; + } + + // Reverse ne and stride. std::reverse(acl_ne, acl_ne + final_dims); std::reverse(acl_stride, acl_stride + final_dims); diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h index 4734a9cb8c301..93f09937efb31 100644 --- a/ggml/src/ggml-cann/acl_tensor.h +++ b/ggml/src/ggml-cann/acl_tensor.h @@ -101,14 +101,14 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype, tmp_stride[i] = nb[i] / type_size; } - std::reverse(tmp_ne, tmp_ne + dims); - std::reverse(tmp_stride, tmp_stride + dims); - - int64_t acl_storage_len = 0; + int64_t acl_storage_len = 1; for (int i = 0; i < dims; i++) { - acl_storage_len += (ne[i] - 1) * nb[i]; + acl_storage_len += (tmp_ne[i] - 1) * tmp_stride[i]; } + std::reverse(tmp_ne, tmp_ne + dims); + std::reverse(tmp_stride, tmp_stride + dims); + aclTensor* acl_tensor = aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr); diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 6bb5d08349197..fee66aea916c0 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -28,8 +28,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -44,12 +44,19 @@ #include #include #include -#include #include #include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -58,12 +65,27 @@ #include #include "ggml-impl.h" -#include "kernels/ascendc_kernels.h" #define GGML_COMMON_DECL_C #include "../ggml-common.h" +void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, + aclTensor ** acl_src1, aclTensor ** acl_dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0)); + // Need bcast + if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) { + BCAST_SHAPE(src0, src1) + *acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0)); + *acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1)); + *acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0)); + } else { + *acl_src0 = ggml_cann_create_tensor(src0); + *acl_src1 = ggml_cann_create_tensor(src1); + *acl_dst = ggml_cann_create_tensor(dst); + } +} + /** * @brief Repeats elements of a tensor along each dimension according to the * specified repeat array. @@ -79,24 +101,43 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, // repeat tensor along each dim with repeat_array aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; + GGML_CANN_CALL_ACLNN_OP(Repeat, acl_src, repeats, acl_dst); + ACL_CHECK(aclDestroyIntArray(repeats)); +} - ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst, - &workspaceSize, &executor)); +/** + * @brief Casts the data type of a source tensor to a destination tensor. + * + * This function casts the data type of the source tensor `acl_src` to the + * specified data type `cast_data_type` and stores the result in the destination + * tensor `acl_dst`. + * + * @param ctx The context for the CANN backend operations. + * @param acl_src The source tensor whose data type will be casted. + * @param acl_dst The destination tensor where the casted result will be stored. + * @param cast_data_type The target data type to which the source tensor will be + * casted. + */ +static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst, aclDataType cast_data_type) { + GGML_CANN_CALL_ACLNN_OP(Cast, acl_src, cast_data_type, acl_dst); +} - if (workspaceSize > 0) { - // Memory from allocator will "free" immediately, and this memory - // will be alloced to other pointers, but it won't access before - // this async task end because all tasks in same stream will execute - // in queue. - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnRepeat(workspaceAddr, workspaceSize, executor, ctx.stream())); - ACL_CHECK(aclDestroyIntArray(repeats)); +/** + * @brief Casts the elements of a tensor to a specified data type using the CANN backend. + * + * @details This function performs a type conversion on the elements of the input tensor `acl_src` + * and stores the results in the destination tensor `acl_dst`. The conversion type is + * determined based on the `dst` tensor's data type. + * + * @param ctx The context for the CANN backend operations. + * @param acl_src The source tensor whose elements will be cast. + * @param acl_dst The destination tensor that will store the casted elements. + * @param dst The ggml tensor specifying the target data type. + */ +static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst, ggml_tensor* dst) { + aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type)); } void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -114,69 +155,75 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); } -/** - * @brief Adds two tensors element-wise and stores the result in a destination - * tensor. - * - * This function performs the operation: - * \f[ - * dst = acl\_src0 + alpha \times acl\_src1 - * \f] - * where alpha is a scalar value and defaults to 1.0f. - * - * @param ctx The context for the CANN backend operations. - * @param acl_src0 The first source tensor. - * @param acl_src1 The second source tensor. - * @param acl_dst The destination tensor where the result will be stored. - */ -static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, +void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, aclTensor* acl_src1, aclTensor* acl_dst) { - aclScalar* alpha = nullptr; float alphaValue = 1.0f; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream())); + aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + if (acl_dst != nullptr) + GGML_CANN_CALL_ACLNN_OP(Add, acl_src0, acl_src1, alpha, acl_dst); + else + GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_src0, acl_src1, alpha); + ACL_CHECK(aclDestroyScalar(alpha)); +} +void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, + aclTensor* acl_src1, aclTensor* acl_dst) { + float alphaValue = 1.0f; + aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + if (acl_dst != nullptr) + GGML_CANN_CALL_ACLNN_OP(Sub, acl_src0, acl_src1, alpha, acl_dst); + else + GGML_CANN_CALL_ACLNN_OP(InplaceSub, acl_src0, acl_src1, alpha); ACL_CHECK(aclDestroyScalar(alpha)); } -void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); +void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst) { + if (acl_dst != nullptr) + GGML_CANN_CALL_ACLNN_OP(Mul, acl_src, acl_other, acl_dst); + else + GGML_CANN_CALL_ACLNN_OP(InplaceMul, acl_src, acl_other); +} - aclTensor* acl_src0; - aclTensor* acl_src1; - aclTensor* acl_dst; +void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst) { + if (acl_dst != nullptr) + GGML_CANN_CALL_ACLNN_OP(Div, acl_src, acl_other, acl_dst); + else + GGML_CANN_CALL_ACLNN_OP(InplaceDiv, acl_src, acl_other); +} - // Need bcast - if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) { - BCAST_SHAPE(src0, src1) - acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0)); - acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1)); - acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0)); +/** + * @brief Multiplies elements of a tensor by a scalar value, optionally + * in-place. + * + * This function multiplies each element of the source tensor `acl_src` by the + * scalar `scale` and stores the result in the destination tensor `acl_dst`. If + * `inplace` is true, `acl_dst` will not be used and the operation is performed + * in-place on `acl_src`. + * The operation is defined as: + * \f[ + * \text {acl_dst }_i=\text {acl_src }_i \times \text {scale} + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src The source tensor whose elements will be multiplied. + * @param scale The scalar value by which each element of `acl_src` will be + * multiplied. + * @param acl_dst The destination tensor where the result will be stored if + * `inplace` is false. + * @param inplace Flag indicating whether to perform the operation in-place on + * `acl_src`. + */ +static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, + float scale, aclTensor* acl_dst, bool inplace) { + aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); + if (inplace) { + GGML_CANN_CALL_ACLNN_OP(InplaceMuls, acl_src, acl_scale); } else { - acl_src0 = ggml_cann_create_tensor(src0); - acl_src1 = ggml_cann_create_tensor(src1); - acl_dst = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(Muls, acl_src, acl_scale, acl_dst); } - - aclnn_add(ctx, acl_src0, acl_src1, acl_dst); - - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ACL_CHECK(aclDestroyScalar(acl_scale)); } void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -193,19 +240,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclScalar* acl_negative_slope = aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnLeakyReluGetWorkspaceSize( - acl_src, acl_negative_slope, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(LeakyRelu, acl_src, acl_negative_slope, acl_dst); ACL_CHECK(aclDestroyScalar(acl_negative_slope)); ACL_CHECK(aclDestroyTensor(acl_src)); @@ -225,18 +260,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { static void aclnn_concat(ggml_backend_cann_context& ctx, aclTensorList* tensorList, aclTensor* acl_dst, int64_t concat_dim) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnCatGetWorkspaceSize(tensorList, concat_dim, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnCat(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(Cat, tensorList, concat_dim, acl_dst); } void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -282,24 +306,11 @@ static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst, int64_t steps = (int64_t)std::ceil((stop - start) / step); GGML_ASSERT(n_elements == steps); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - aclScalar* acl_start = aclCreateScalar(&start, aclDataType::ACL_FLOAT); aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT); aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT); - ACL_CHECK(aclnnArangeGetWorkspaceSize(acl_start, acl_end, acl_step, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnArange(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(Arange, acl_start, acl_end, acl_step, acl_dst); ACL_CHECK(aclDestroyScalar(acl_start)); ACL_CHECK(aclDestroyScalar(acl_end)); ACL_CHECK(aclDestroyScalar(acl_step)); @@ -322,15 +333,8 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); } -void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - dst->src[1] = dst->src[0]; - ggml_cann_mul_div(ctx, dst); -} - void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); float min; float max; @@ -343,19 +347,7 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT); aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnClampGetWorkspaceSize(acl_src, acl_min, acl_max, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnClamp(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(Clamp, acl_src, acl_min, acl_max, acl_dst); ACL_CHECK(aclDestroyScalar(acl_min)); ACL_CHECK(aclDestroyScalar(acl_max)); ACL_CHECK(aclDestroyTensor(acl_src)); @@ -373,19 +365,7 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_src = ggml_cann_create_tensor(src); aclTensor* acl_dst = ggml_cann_create_tensor(dst); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, scale, acl_dst, &workspaceSize, - &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(Muls, acl_src, scale, acl_dst); ACL_CHECK(aclDestroyScalar(scale)); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); @@ -403,33 +383,9 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* tmp_tensor = ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type), dst->ne, dst->nb, GGML_MAX_DIMS); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnArgsortGetWorkspaceSize( - acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), tmp_tensor, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnArgsort(workspaceAddr, workspaceSize, executor, ctx.stream())); - - workspaceSize = 0; - ACL_CHECK(aclnnCastGetWorkspaceSize(tmp_tensor, - ggml_cann_type_mapping(dst->type), - acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), + tmp_tensor); + GGML_CANN_CALL_ACLNN_OP(Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(tmp_tensor)); ACL_CHECK(aclDestroyTensor(acl_dst)); @@ -444,24 +400,10 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - std::vector normData = {dst->ne[0]}; aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size()); - ACL_CHECK(aclnnLayerNormGetWorkspaceSize(acl_src, norm, nullptr, nullptr, - eps, acl_dst, nullptr, nullptr, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnLayerNorm(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(LayerNorm, acl_src, norm, nullptr, nullptr, + eps, acl_dst, nullptr, nullptr); ACL_CHECK(aclDestroyIntArray(norm)); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); @@ -478,10 +420,6 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - int64_t N = src->ne[3]; int64_t C = src->ne[2]; int64_t HxW = src->ne[1] * src->ne[0]; @@ -498,18 +436,8 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_rstd_out = ggml_cann_create_tensor( (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); - ACL_CHECK(aclnnGroupNormGetWorkspaceSize( - acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst, - acl_mean_out, acl_rstd_out, &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnGroupNorm(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, + acl_dst, acl_mean_out, acl_rstd_out); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); ACL_CHECK(aclDestroyTensor(acl_mean_out)); @@ -536,68 +464,57 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float alphaValue = 1.0f; alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - if (!inplace) { size_t cpy_size = ggml_nbytes(dst); ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); aclTensor* acl_src0 = ggml_cann_create_tensor( src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); - ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream())); + + GGML_CANN_CALL_ACLNN_OP(Add, acl_src0, acl_src1, alpha, acl_dst); ACL_CHECK(aclDestroyTensor(acl_src0)); } else { - ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src1, alpha, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, - ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_dst, acl_src1, alpha); } ACL_CHECK(aclDestroyTensor(acl_src1)); ACL_CHECK(aclDestroyTensor(acl_dst)); } -void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; - - aclTensor* acl_src = ggml_cann_create_tensor(src); +/** + * @brief Performs sum reduction on a given tensor along specified dimensions. + * + * This function reduces the input tensor by summing along the specified dimensions. + * + * @param ctx The context for the CANN backend operations. + * @param dst The destination tensor where the reduced result will be stored. + * @param dim An array of dimension indices. + * @param dim_size The number of dimensions. + */ +static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst, + int64_t* dim, size_t dim_size) { GGML_ASSERT(dst->ne[0] == 1); + ggml_tensor* src = dst->src[0]; + aclTensor* acl_src = ggml_cann_create_tensor(src); aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclIntArray* reduce_dims = aclCreateIntArray(dim, dim_size); - int64_t reduce_dims_host[] = {3}; - aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnReduceSumGetWorkspaceSize( - acl_src, reduce_dims, true, ggml_cann_type_mapping(src->type), acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnReduceSum(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(ReduceSum, acl_src, reduce_dims, true, + ggml_cann_type_mapping(dst->type), acl_dst); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); + ACL_CHECK(aclDestroyIntArray(reduce_dims)); +} + +void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + int64_t reduce_dims[] = {3}; + aclnn_reduce_sum(ctx, dst, reduce_dims, 1); +} + +void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + int64_t reduce_dims[] = {0, 1, 2, 3}; + aclnn_reduce_sum(ctx, dst, reduce_dims, 4); } void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, @@ -611,20 +528,7 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, std::vector output_size{dst->ne[1], dst->ne[0]}; auto output_size_array = aclCreateIntArray(output_size.data(), 2); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnUpsampleNearest2dGetWorkspaceSize( - acl_src, output_size_array, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnUpsampleNearest2d(workspaceAddr, workspaceSize, executor, - ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(UpsampleNearest2d, acl_src, output_size_array, acl_dst); ACL_CHECK(aclDestroyIntArray(output_size_array)); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); @@ -650,21 +554,7 @@ static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2); aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnConstantPadNdGetWorkspaceSize( - acl_src, acl_pad, acl_value, acl_dst, &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnConstantPadNd(workspaceAddr, workspaceSize, executor, - ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst); ACL_CHECK(aclDestroyIntArray(acl_pad)); ACL_CHECK(aclDestroyScalar(acl_value)); } @@ -730,23 +620,9 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx, bool count_include_pad = true; int64_t divisor_override = 0; int8_t cube_math_type = 0; - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnAvgPool2dGetWorkspaceSize( - acl_src, kernel_size, strides, paddings_avg, ceil_mode, - count_include_pad, divisor_override, cube_math_type, acl_dst, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnAvgPool2d(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(AvgPool2d, acl_src, kernel_size, strides, paddings_avg, + ceil_mode, count_include_pad, divisor_override, + cube_math_type, acl_dst); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); ACL_CHECK(aclDestroyIntArray(kernel_size)); @@ -819,22 +695,8 @@ static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx, bool ceil_mode = false; int64_t auto_pads = 0; - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMaxPoolGetWorkspaceSize( - tmp_tensor, kernel_size, strides, auto_pads, paddings_max, dilations, - ceil_mode, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnMaxPool(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(MaxPool, tmp_tensor, kernel_size, strides, auto_pads, + paddings_max, dilations, ceil_mode, acl_dst); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); ACL_CHECK(aclDestroyTensor(tmp_tensor)); @@ -872,206 +734,81 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { */ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceCopyGetWorkspaceSize(acl_dst, acl_src, &workspaceSize, - &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceCopy(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(InplaceCopy, acl_dst, acl_src); } void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; + ggml_tensor* src0 = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); + aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_dst = ggml_cann_create_tensor(dst); - - ggml_cann_pool_alloc src_extra_allocator(ctx.pool(), sizeof(ggml_tensor)); - ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor)); - src->extra = src_extra_allocator.get(); - dst->extra = dst_extra_allocator.get(); - ACL_CHECK(aclrtMemcpyAsync(src->extra, sizeof(ggml_tensor), src, - sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE, - ctx.stream())); - ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst, - sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE, - ctx.stream())); - - if ((dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32) && - ggml_are_same_shape(src, dst)) { - cann_copy(ctx, acl_src, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - return; - } - // TODO: simplify - if (src->type == GGML_TYPE_F16) { - if (dst->type == GGML_TYPE_Q8_0) { - aclrtlaunch_ascendc_quantize_f16_q8_0( - 24, ctx.stream(), src->data, dst->data, - ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb, - ((ggml_tensor*)dst->extra)->ne); - return; - } - if (dst->type == GGML_TYPE_Q4_0) { - aclrtlaunch_ascendc_quantize_f16_to_q4_0( - 24, ctx.stream(), src->data, dst->data, - ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb, - ((ggml_tensor*)dst->extra)->ne); - return; - } - if (dst->type == GGML_TYPE_F16) { - if (ggml_are_same_shape(src, dst)) { - cann_copy(ctx, acl_src, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - return; - } - if (ggml_is_contiguous(dst)) { - const size_t src_type_size = ggml_type_size(src->type); - if (src->nb[0] == src_type_size) { - // src0 is contigous on first dimension, copy by rows - int64_t rows_num = ggml_nrows(src); - - aclrtlaunch_ascendc_dup_by_rows_fp16( - rows_num, ctx.stream(), src->data, dst->data, - ((ggml_tensor*)src->extra)->ne, - ((ggml_tensor*)src->extra)->nb, - ((ggml_tensor*)dst->extra)->ne, - ((ggml_tensor*)dst->extra)->nb); - return; - } - GGML_ABORT("fatal error"); - } - GGML_ABORT("fatal error"); - } - if (dst->type == GGML_TYPE_F32) { - if (ggml_are_same_shape(src, dst)) { - cann_copy(ctx, acl_src, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - return; - } - if (ggml_is_contiguous(dst)) { - const size_t src_type_size = ggml_type_size(src->type); - if (src->nb[0] == src_type_size) { - // src0 is contigous on first dimension, copy by rows - int64_t rows_num = ggml_nrows(src); - aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32( - rows_num, ctx.stream(), src->data, dst->data, - ((ggml_tensor*)src->extra)->ne, - ((ggml_tensor*)src->extra)->nb, - ((ggml_tensor*)dst->extra)->ne, - ((ggml_tensor*)dst->extra)->nb); - return; - } - GGML_ABORT("fatal error"); - } - GGML_ABORT("fatal error"); - } - // TODO - GGML_ABORT("fatal error"); - } else if (src->type == GGML_TYPE_F32) { - // TODO: if (src0->type == dst->type && ne00 == ne0 && nb00 == type_size - // && nb0 == type_size) - if (dst->type == GGML_TYPE_Q8_0) { - aclrtlaunch_ascendc_quantize_f32_q8_0( - 24, ctx.stream(), src->data, dst->data, - ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb, - ((ggml_tensor*)dst->extra)->ne); - return; - } - if (dst->type == GGML_TYPE_Q4_0) { - aclrtlaunch_ascendc_quantize_f32_to_q4_0( - 24, ctx.stream(), src->data, dst->data, - ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb, - ((ggml_tensor*)dst->extra)->ne); - return; + if (ggml_are_same_shape(src0, dst)) { + if (dst->type == src0->type) { + cann_copy(ctx, acl_src, acl_dst); + } else { + aclnn_cast(ctx, acl_src, acl_dst, dst); } - if (dst->type == GGML_TYPE_F32) { - if (ggml_are_same_shape(src, dst)) { - cann_copy(ctx, acl_src, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + } else { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { + if (dst->type == src0->type) { + size_t cpy_size = ggml_nbytes(dst); + ACL_CHECK(aclrtMemcpyAsync( + dst->data, cpy_size, src0->data, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); return; - } - if (ggml_is_contiguous(dst)) { - const size_t src_type_size = ggml_type_size(src->type); - if (src->nb[0] == src_type_size) { - // src0 is contigous on first dimension, copy by rows - int64_t rows_num = ggml_nrows(src); - aclrtlaunch_ascendc_dup_by_rows_fp32( - rows_num, ctx.stream(), src->data, dst->data, - ((ggml_tensor*)src->extra)->ne, - ((ggml_tensor*)src->extra)->nb, - ((ggml_tensor*)dst->extra)->ne, - ((ggml_tensor*)dst->extra)->nb); - return; - } - GGML_ABORT("fatal error"); } else { - // TODO: dst not contiguous - GGML_ABORT("fatal error"); - } - } - if (dst->type == GGML_TYPE_F16) { - if (ggml_are_same_shape(src, dst)) { - cann_copy(ctx, acl_src, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_pool_alloc src_buffer_allocator( + ctx.pool(), + ggml_nelements(dst) * ggml_type_size(dst->type)); + void* src_trans_buffer = src_buffer_allocator.get(); + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = ggml_type_size(dst->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } + aclTensor* src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), src0->ne, src_trans_nb, + GGML_MAX_DIMS); + + aclnn_cast(ctx, acl_src, src_trans_tensor, dst); + size_t cpy_size = ggml_nbytes(dst); + ACL_CHECK(aclrtMemcpyAsync( + dst->data, cpy_size, src_trans_buffer, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ACL_CHECK(aclDestroyTensor(src_trans_tensor)); return; } - if (ggml_is_contiguous(dst)) { - const size_t src_type_size = ggml_type_size(src->type); - if (src->nb[0] == src_type_size) { - // src0 is contigous on first dimension, copy by rows - int64_t rows_num = ggml_nrows(src); - aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16( - rows_num, ctx.stream(), src->data, dst->data, - ((ggml_tensor*)src->extra)->ne, - ((ggml_tensor*)src->extra)->nb, - ((ggml_tensor*)dst->extra)->ne, - ((ggml_tensor*)dst->extra)->nb); - return; - } - GGML_ABORT("fatal error"); + } else if (ggml_is_contiguous(dst)) { + ggml_cann_pool_alloc src_buffer_allocator( + ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type)); + void* src_trans_buffer = src_buffer_allocator.get(); + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = ggml_type_size(dst->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; } - } - // TODO - GGML_ABORT("fatal error"); - } else { - if (ggml_are_same_shape(src, dst)) { - cann_copy(ctx, acl_src, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + aclTensor* src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), src0->ne, src_trans_nb, + GGML_MAX_DIMS); + + aclnn_cast(ctx, acl_src, src_trans_tensor, dst); + + size_t cpy_size = ggml_nbytes(dst); + ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src_trans_buffer, + cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, + ctx.stream())); + ACL_CHECK(aclDestroyTensor(src_trans_tensor)); return; + } else { + GGML_ABORT("Unsupport dst is not tontiguous."); } - GGML_ABORT("fatal error"); } -} -#ifdef __cplusplus -extern "C" { -#endif -aclnnStatus aclnnRmsNormGetWorkspaceSize(const aclTensor* x, - const aclTensor* gamma, double epsilon, - const aclTensor* yOut, - const aclTensor* rstdOout, - uint64_t* workspaceSize, - aclOpExecutor** executor); -aclnnStatus aclnnRmsNorm(void* workspace, uint64_t workspaceSize, - aclOpExecutor* executor, aclrtStream stream); -#ifdef __cplusplus + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); } -#endif /** * @brief Creates an ACL tensor initialized with zeros using a provided buffer. @@ -1131,21 +868,7 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer, float alpha_host = 1.0f; aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT); aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceAddsGetWorkspaceSize(acl_tensor, other, alpha, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnInplaceAdds(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(InplaceAdds, acl_tensor, other, alpha); return acl_tensor; } @@ -1157,13 +880,6 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - - GGML_ASSERT(eps > 0.0f); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src); ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes); @@ -1178,18 +894,7 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes, src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), ggml_element_size(src)); - - ACL_CHECK(aclnnRmsNormGetWorkspaceSize( - acl_src, acl_gamma, eps, acl_dst, acl_rstd, &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnRmsNorm(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); ACL_CHECK(aclDestroyTensor(acl_gamma)); @@ -1215,77 +920,19 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), ggml_element_size(src), value); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceTriuGetWorkspaceSize(mask_tensor, n_past + 1, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceTriu(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclnnTrilGetWorkspaceSize(acl_src, n_past + 1, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnTril(workspaceAddr, workspaceSize, executor, ctx.stream())); - aclScalar* alpha = nullptr; float alphaValue = 1.0f; alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, mask_tensor, alpha, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(InplaceTriu, mask_tensor, n_past + 1); + GGML_CANN_CALL_ACLNN_OP(Tril, acl_src, n_past + 1, acl_dst); + GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_dst, mask_tensor, alpha); ACL_CHECK(aclDestroyScalar(alpha)); ACL_CHECK(aclDestroyTensor(mask_tensor)); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); } -/** - * @brief Casts the data type of a source tensor to a destination tensor. - * - * This function casts the data type of the source tensor `acl_src` to the - * specified data type `cast_data_type` and stores the result in the destination - * tensor `acl_dst`. - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor whose data type will be casted. - * @param acl_dst The destination tensor where the casted result will be stored. - * @param cast_data_type The target data type to which the source tensor will be - * casted. - */ -static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, aclDataType cast_data_type) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnCastGetWorkspaceSize(acl_src, cast_data_type, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - /** * @brief Permutes the dimensions of a tensor according to a specified order. * @@ -1302,41 +949,10 @@ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, * @param dims The number of dimensions in the tensor. */ static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) { - aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnPermuteGetWorkspaceSize(acl_src, acl_dims, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnPermute(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyIntArray(acl_dims)); -} - -#ifdef __cplusplus -extern "C" { -#endif -aclnnStatus aclnnIm2colGetWorkspaceSize(const aclTensor* self, - const aclIntArray* kernelSize, - const aclIntArray* dilation, - const aclIntArray* padding, - const aclIntArray* stride, - aclTensor* out, uint64_t* workspaceSize, - aclOpExecutor** executor); -aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize, - aclOpExecutor* executor, aclrtStream stream); -#ifdef __cplusplus + aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) { + aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims); + GGML_CANN_CALL_ACLNN_OP(Permute, acl_src, acl_dims, acl_dst); } -#endif static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx, ggml_tensor* dst, @@ -1501,23 +1117,8 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { auto* dilations = aclCreateIntArray(dilation_size.data(), 2); auto* paddings = aclCreateIntArray(padding_dims.data(), 2); auto* strides = aclCreateIntArray(stride_dims.data(), 2); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnIm2colGetWorkspaceSize(acl_src1, kernel_size, dilations, - paddings, strides, tmp_im2col_tensor, - &workspaceSize, &executor)); - - ggml_cann_pool_alloc workspace_allocator(ctx.pool()); - if (workspaceSize > 0) { - workspace_allocator.alloc(workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnIm2col(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(Im2col, acl_src1, kernel_size, dilations, + paddings, strides, tmp_im2col_tensor); // Cast if dst is f16. aclTensor* tmp_cast_tensor = nullptr; @@ -1536,8 +1137,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { tmp_cast_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND); - aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, - ggml_cann_type_mapping(dst->type)); + aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, dst); } // post-processing @@ -1575,285 +1175,17 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param acl_src The tensor on which the exponential function will be applied. */ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK( - aclnnInplaceExpGetWorkspaceSize(acl_src, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Multiplies elements of a tensor by a scalar value, optionally - * in-place. - * - * This function multiplies each element of the source tensor `acl_src` by the - * scalar `scale` and stores the result in the destination tensor `acl_dst`. If - * `inplace` is true, `acl_dst` will not be used and the operation is performed - * in-place on `acl_src`. - * The operation is defined as: - * \f[ - * \text {acl_dst }_i=\text {acl_src }_i \times \text {scale} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor whose elements will be multiplied. - * @param scale The scalar value by which each element of `acl_src` will be - * multiplied. - * @param acl_dst The destination tensor where the result will be stored if - * `inplace` is false. - * @param inplace Flag indicating whether to perform the operation in-place on - * `acl_src`. - */ -static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, - float scale, aclTensor* acl_dst, bool inplace) { - aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - if (inplace) { - ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor, - ctx.stream())); - } else { - ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream())); - } - - ACL_CHECK(aclDestroyScalar(acl_scale)); -} - -/** - * @brief Performs an in-place element-wise multiplication of two tensors. - * - * This function performs an element-wise multiplication of the tensors - * `acl_src` and `acl_other` and stores the result in `acl_src`. - * The operation is defined as: - * \f[ - * \text {acl_src }_i=\text {acl_src }_i \times \text {acl_other }_i - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor where the multiplication result will be - * stored. - * @param acl_other The tensor whose elements will be multiplied with `acl_src`. - */ -static void aclnn_inplace_mul(ggml_backend_cann_context& ctx, - aclTensor* acl_src, aclTensor* acl_other) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceMulGetWorkspaceSize(acl_src, acl_other, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceMul(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Performs element-wise multiplication of two tensors and stores the - * result in a destination tensor. - * - * This function performs element-wise multiplication of the tensors `acl_src` - * and `acl_other` and stores the result in the destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst }_i=\text {acl_src }_i \times \text {acl_other }_i - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The first tensor for element-wise multiplication. - * @param acl_other The second tensor for element-wise multiplication. - * @param acl_dst The destination tensor where the result will be stored. - */ -static void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_other, aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMulGetWorkspaceSize(acl_src, acl_other, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnMul(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(InplaceExp, acl_src); } -/** - * @brief Applies element-wise cosine function to the elements of a tensor. - * - * This function computes the cosine of each element in the source tensor - * `acl_src` and stores the result in the destination tensor `acl_dst`. The - * operation is defined as: \f[ \text {acl_dst }_i=\cos \left(\text {acl_src - * }_i\right) \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor on which the cosine function will be - * applied. - * @param acl_dst The destination tensor where the cosine results will be - * stored. - */ -static void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, +void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK( - aclnnCosGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnCos(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(Cos, acl_src, acl_dst); } -/** - * @brief Applies element-wise sine function to the elements of a tensor. - * - * This function computes the sine of each element in the source tensor - `acl_src` - * and stores the result in the destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst }_i=\sin \left(\text {acl_src }_i\right) - * \f] - - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor on which the sine function will be applied. - * @param acl_dst The destination tensor where the sine results will be stored. - */ -static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, +void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK( - aclnnSinGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnSin(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Performs element-wise division of tensor1 by tensor2 , multiplies the - result by the scalar value and adds it to self . - * - * Performs element-wise division of tensor1 by tensor2, - * multiplies the result by the scalar value and adds it to self . - * The operation is defined as: - * \f[ - * \text{out}_i = \text{selft}_i + \text{value} \times - \frac{\text{tensor1}_i}{\text{tensor2}_i} - * \f] - - * @param ctx The context for the CANN backend operations. - * @param acl_self The source tensor on which the addcdiv function will be - applied. - * @param tensor1 Numerator tensor. - * @param tensor2 Denominator tensor. - * @param value The value to be used for coefficient. - */ -static void aclnn_inplace_addcdiv(ggml_backend_cann_context& ctx, - aclTensor* acl_self, aclTensor* tensor1, - aclTensor* tensor2, float value) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - - ACL_CHECK(aclnnInplaceAddcdivGetWorkspaceSize( - acl_self, tensor1, tensor2, acl_value, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceAddcdiv(workspaceAddr, workspaceSize, executor, - ctx.stream())); -} - -/** - * @brief Matrix division, optionally in-place. - * - * This function division each element of the source tensor `acl_src` by the - * tensor `acl_other` and stores the result in the destination tensor `acl_dst`. - * If `inplace` is true, `acl_dst` will not be used and the operation is - * performed in-place on `acl_src`. The operation is defined as: \f[ - * \text{dst}_i = \frac{\text{acl_src}_i}{\text{acl_other}_i} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src Numerator tensor.. - * @param acl_other Denominator tensor. - * @param acl_dst The destination tensor where the result will be stored if - * `inplace` is false. - * @param inplace Flag indicating whether to perform the operation in-place on - * `acl_src`. - */ -static void aclnn_div_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_other, aclTensor* acl_dst, - bool inplace) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - if (inplace) { - ACL_CHECK(aclnnInplaceDivGetWorkspaceSize(acl_src, acl_other, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceDiv(workspaceAddr, workspaceSize, executor, - ctx.stream())); - } else { - ACL_CHECK(aclnnDivGetWorkspaceSize(acl_src, acl_other, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnDiv(workspaceAddr, workspaceSize, executor, ctx.stream())); - } + GGML_CANN_CALL_ACLNN_OP(Sin, acl_src, acl_dst); } void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, @@ -1983,20 +1315,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, aclTensor* acl_dst) { auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceFillScalarGetWorkspaceSize( - acl_dst, acl_scalar, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceFillScalar(workspaceAddr, workspaceSize, executor, - ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(InplaceFillScalar, acl_dst, acl_scalar); ACL_CHECK(aclDestroyScalar(acl_scalar)); } @@ -2018,19 +1337,7 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, */ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst, aclTensor* acl_exp) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplacePowTensorTensorGetWorkspaceSize( - acl_dst, acl_exp, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplacePowTensorTensor(workspaceAddr, workspaceSize, - executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(InplacePowTensorTensor, acl_dst, acl_exp); } /** @@ -2197,41 +1504,6 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_dup(ctx, dst); } -/** - * @brief Performs element-wise addition of two tensors in place. - * - * This function adds the source tensor `acl_src` to the destination tensor - * `acl_dst` element-wise and stores the result in the destination tensor - * `acl_dst`. - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor to be added. - * @param acl_dst The destination tensor which will hold the result of the - * addition. - */ -static void aclnn_inplace_add(ggml_backend_cann_context& ctx, - aclTensor* acl_src, aclTensor* acl_dst) { - aclScalar* alpha = nullptr; - float alphaValue = 1.0f; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyScalar(alpha)); -} - /** * @brief Applies the softmax function to a tensor along a specified dimension. * @@ -2248,20 +1520,7 @@ static void aclnn_inplace_add(ggml_backend_cann_context& ctx, */ static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src, int64_t dim, aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - aclrtStream stream = ctx.stream(); - ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream)); + GGML_CANN_CALL_ACLNN_OP(Softmax, acl_src, dim, acl_dst); } void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -2378,85 +1637,152 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(tmp_mask_tensor)); } -void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; +/** + * @brief Performs embedding operation on a 4D tensor using the CANN backend. + * + * This function extracts slices from the source tensor (`src_buffer`), + * index tensor (`index`), and destination tensor (`dst`), and performs an + * embedding operation on them. The embedding operation is applied by iterating + * over the last two dimensions of the source tensor, creating the necessary + * tensors for the source, index, and output, and executing the embedding operation. + * + * @param ctx The context for CANN backend operations. + * @param src_buffer The source buffer holding the data for the source tensor. + * @param src_ne The dimensions of the source tensor. + * @param src_nb The strides (byte offsets) of the source tensor. + * @param index The index tensor used in the embedding operation. + * @param dst The destination tensor where the result will be stored. + */ +static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer, + int64_t* src_ne, size_t* src_nb, ggml_tensor* index, + ggml_tensor* dst) { + for (int64_t i = 0; i < src_ne[3]; i++) { + for (int64_t j = 0; j < src_ne[2]; j++) { + // src + int64_t acl_src_ne[2] = {src_ne[0], src_ne[1]}; + size_t acl_src_nb[2] = {src_nb[0], src_nb[1]}; + aclTensor* acl_src_tensor = ggml_cann_create_tensor( + (char*)src_buffer + i * src_nb[3] + j * src_nb[2], + ggml_cann_type_mapping(dst->type), ggml_element_size(dst), + acl_src_ne, acl_src_nb, 2); + + // index + int64_t acl_index_ne[1] = {index->ne[0]}; + size_t acl_index_nb[1] = {index->nb[0]}; + aclTensor* acl_index = ggml_cann_create_tensor( + (char*)index->data + i * index->nb[2] + j * index->nb[1], + ggml_cann_type_mapping(index->type), ggml_element_size(index), + acl_index_ne, acl_index_nb, 1); + + // out + int64_t acl_out_ne[2] = {dst->ne[0], dst->ne[1]}; + size_t acl_out_nb[2] = {dst->nb[0], dst->nb[1]}; + aclTensor* acl_out = ggml_cann_create_tensor( + (char*)dst->data + i * dst->nb[3] + j * dst->nb[2], + ggml_cann_type_mapping(dst->type), ggml_element_size(dst), + acl_out_ne, acl_out_nb, 2); + GGML_CANN_CALL_ACLNN_OP(Embedding, acl_src_tensor, acl_index, acl_out); + ACL_CHECK(aclDestroyTensor(acl_src_tensor)); + ACL_CHECK(aclDestroyTensor(acl_index)); + ACL_CHECK(aclDestroyTensor(acl_out)); + } + } +} - ggml_cann_pool_alloc src0_extra_allocator(ctx.pool(), sizeof(ggml_tensor)); - ggml_cann_pool_alloc src1_extra_allocator(ctx.pool(), sizeof(ggml_tensor)); - ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor)); - src0->extra = src0_extra_allocator.get(); - src1->extra = src1_extra_allocator.get(); - dst->extra = dst_extra_allocator.get(); - ACL_CHECK(aclrtMemcpyAsync(src0->extra, sizeof(ggml_tensor), src0, - sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE, - ctx.stream())); - ACL_CHECK(aclrtMemcpyAsync(src1->extra, sizeof(ggml_tensor), src1, - sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE, - ctx.stream())); - ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst, - sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE, - ctx.stream())); +void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src0 = dst->src[0]; // src + ggml_tensor* src1 = dst->src[1]; // index switch (src0->type) { case GGML_TYPE_F32: { -#ifdef ASCEND_310P - // Special operation for get_row_f32 kernel of 310P: clear the - // content of dest data buffer when row is not aligned to 32 bytes - if ((src0->ne[0] % 8) != 0) { - size_t dst_len = src1->ne[0] * src1->ne[1] * src1->ne[2] * - src0->ne[0] * ggml_type_size(GGML_TYPE_F32); - ACL_CHECK(aclrtMemset((char*)dst->data, dst_len, 0, dst_len)); - } -#endif - aclrtlaunch_ascendc_get_row_f32( - 24, ctx.stream(), src0->data, src1->data, dst->data, - ((ggml_tensor*)src0->extra)->ne, - ((ggml_tensor*)src0->extra)->nb, - ((ggml_tensor*)src1->extra)->ne, - ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne, - ((ggml_tensor*)dst->extra)->nb); + aclnn_embedding_4d(ctx, src0->data, src0->ne, src0->nb, src1, + dst); break; } case GGML_TYPE_F16: { -#ifdef ASCEND_310P - // Special operation for get_row_f16 kernel of 310P: clear the - // content of dest data buffer when row is not aligned to 32 bytes - if ((src0->ne[0] % 16) != 0) { - size_t dst_len = - src1->ne[0] * src1->ne[1] * src1->ne[2] * src0->ne[0] * - ggml_type_size( - GGML_TYPE_F32); // out is also f32, even input is f16 - ACL_CHECK(aclrtMemset((char*)dst->data, dst_len, 0, dst_len)); + aclTensor* acl_src0 = ggml_cann_create_tensor(src0); + ggml_cann_pool_alloc src_buffer_allocator( + ctx.pool(), ggml_nelements(src0) * sizeof(float_t)); + void* src_trans_buffer = src_buffer_allocator.get(); + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; } -#endif - aclrtlaunch_ascendc_get_row_f16( - 24, ctx.stream(), src0->data, src1->data, dst->data, - ((ggml_tensor*)src0->extra)->ne, - ((ggml_tensor*)src0->extra)->nb, - ((ggml_tensor*)src1->extra)->ne, - ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne, - ((ggml_tensor*)dst->extra)->nb); + aclTensor* src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type), + src0->ne, src_trans_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0, src_trans_tensor, dst); + aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne, + src_trans_nb, src1, dst); + ACL_CHECK(aclDestroyTensor(acl_src0)); + ACL_CHECK(aclDestroyTensor(src_trans_tensor)); break; } - case GGML_TYPE_Q4_0: - aclrtlaunch_ascendc_get_row_q4_0( - 24, ctx.stream(), src0->data, src1->data, dst->data, - ((ggml_tensor*)src0->extra)->ne, - ((ggml_tensor*)src1->extra)->ne, - ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne, - ((ggml_tensor*)dst->extra)->nb); - break; - case GGML_TYPE_Q8_0: - aclrtlaunch_ascendc_get_row_q8_0( - 24, ctx.stream(), src0->data, src1->data, dst->data, - ((ggml_tensor*)src0->extra)->ne, - ((ggml_tensor*)src1->extra)->ne, - ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne, - ((ggml_tensor*)dst->extra)->nb); + case GGML_TYPE_Q8_0: { + // add 1 dim for bcast mul. + size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], + dequant_nb[GGML_MAX_DIMS + 1]; + int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], + *dequant_ne; + int64_t scale_offset = 0; + + // [3,4,5,64] -> [3,4,5,2,32] + weight_ne[0] = QK8_0; + weight_ne[1] = src0->ne[0] / QK8_0; + weight_nb[0] = sizeof(int8_t); + weight_nb[1] = weight_nb[0] * weight_ne[0]; + for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { + weight_ne[i] = src0->ne[i - 1]; + weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; + } + + // [3,4,5,64] -> [3,4,5,2,1] + scale_ne[0] = 1; + scale_ne[1] = src0->ne[0] / QK8_0; + scale_nb[0] = sizeof(uint16_t); + scale_nb[1] = scale_nb[0] * scale_ne[0]; + for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { + scale_ne[i] = src0->ne[i - 1]; + scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; + } + + // [3,4,5,64] -> [3,4,5,2,32] + dequant_ne = weight_ne; + dequant_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { + dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; + } + + scale_offset = ggml_nelements(src0) * sizeof(int8_t); + ggml_cann_pool_alloc dequant_buffer_allocator( + ctx.pool(), ggml_nelements(src0) * sizeof(float_t)); + + aclTensor* acl_weight_tensor = ggml_cann_create_tensor( + src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, + GGML_MAX_DIMS + 1); + aclTensor* acl_scale_tensor = ggml_cann_create_tensor( + src0->data, ACL_FLOAT16, sizeof(float16_t), scale_ne, scale_nb, + GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); + aclTensor* dequant_tensor = ggml_cann_create_tensor( + dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t), + dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); + + aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); + dequant_nb[0] = sizeof(float_t); + dequant_ne = src0->ne; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; + } + + aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(), + dequant_ne, dequant_nb, src1, dst); + + ACL_CHECK(aclDestroyTensor(dequant_tensor)); break; + } default: - GGML_ABORT("fatal error"); + GGML_ABORT("Unsupported tensor type for GGML_OP_GET_ROWS"); break; } } @@ -2480,133 +1806,8 @@ static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t dim, int64_t repeats, int64_t output_size) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnRepeatInterleaveIntWithDimGetWorkspaceSize( - acl_src, repeats, dim, output_size, acl_dst, &workspaceSize, - &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnRepeatInterleaveIntWithDim(workspaceAddr, workspaceSize, - executor, ctx.stream())); -} - -/** - * @brief Performs matrix multiplication of two tensors. - * - * This function computes the matrix multiplication of the input tensor - * `acl_input` and the weight tensor `acl_weight`, and stores the result in the - * destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst}=\text {acl_input@acl_weight} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_input The input tensor for the matrix multiplication. - * @param acl_weight The weight tensor for the matrix multiplication. - * @param acl_dst The destination tensor where the result of the matrix - * multiplication will be stored. - */ -static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input, - aclTensor* acl_weight, aclTensor* acl_dst) { - int8_t cube_math_type = 1; // ALLOW_FP32_DOWN_PRECISION, when input is - // fp32, atlas a2 will transpose it to HFLOAT32. - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMatmulGetWorkspaceSize(acl_input, acl_weight, acl_dst, - cube_math_type, &workspaceSize, - &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Performs matrix multiplication of two 2D tensors. - * - * This function computes the matrix multiplication of the input tensor - * `acl_input` and the weight tensor `acl_weight`, and stores the result in the - * destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst}=\text {acl_input@acl_weight} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_input The input tensor for the matrix multiplication. - * @param acl_weight The weight tensor for the matrix multiplication. - * @param acl_dst The destination tensor where the result of the matrix - * multiplication will be stored. - */ -static void aclnn_mat_mul_2d(ggml_backend_cann_context& ctx, - aclTensor* acl_input, aclTensor* acl_weight, - aclTensor* acl_dst) { - int8_t cube_math_type = 2; - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMmGetWorkspaceSize(acl_input, acl_weight, acl_dst, - cube_math_type, &workspaceSize, - &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnMm(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Performs matrix multiplication of two 3D tensors. - * - * This function computes the matrix multiplication of the input tensor - * `acl_input` and the weight tensor `acl_weight`, and stores the result in the - * destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst}=\text {acl_input@acl_weight} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_input The input tensor for the matrix multiplication. - * @param acl_weight The weight tensor for the matrix multiplication. - * @param acl_dst The destination tensor where the result of the matrix - * multiplication will be stored. - */ -static void aclnn_mat_mul_3d(ggml_backend_cann_context& ctx, - aclTensor* acl_input, aclTensor* acl_weight, - aclTensor* acl_dst) { - int8_t cube_math_type = 2; - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnBatchMatMulGetWorkspaceSize(acl_input, acl_weight, acl_dst, - cube_math_type, &workspaceSize, - &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(RepeatInterleaveIntWithDim, acl_src, repeats, dim, + output_size, acl_dst); } /** @@ -2654,13 +1855,15 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, switch (n_dims) { case 2: - aclnn_mat_mul_2d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + GGML_CANN_CALL_ACLNN_OP(Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2); break; case 3: - aclnn_mat_mul_3d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + GGML_CANN_CALL_ACLNN_OP(BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2); break; default: - aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + // ALLOW_FP32_DOWN_PRECISION, when input is + // fp32, atlas a2 will transpose it to HFLOAT32. + GGML_CANN_CALL_ACLNN_OP(Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1); break; } @@ -2753,9 +1956,6 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, int64_t max_elem_size = 65535; int64_t split_size = (src0->ne[1] / max_elem_size) + 1; ggml_cann_pool_alloc workspace_allocator(ctx.pool()); - aclOpExecutor* executor = nullptr; - uint64_t workspaceSize = 0; - void* workspaceAddr = nullptr; for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) { for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) { int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]); @@ -2794,17 +1994,10 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, if (src0->ne[0] > QK8_0) { antiquantGroupSize = QK8_0; } - - ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( - acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr, - nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor, - &workspaceSize, &executor)); - if (workspaceAddr == nullptr) { - workspaceAddr = workspace_allocator.alloc(workspaceSize); - } - ACL_CHECK(aclnnWeightQuantBatchMatmulV2( - workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(WeightQuantBatchMatmulV2, acl_input_tensor, + acl_weight_tensor, acl_scale_tensor, nullptr, + nullptr, nullptr, nullptr, antiquantGroupSize, + acl_output_tensor); ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); ACL_CHECK(aclDestroyTensor(acl_scale_tensor)); ACL_CHECK(aclDestroyTensor(acl_output_tensor)); @@ -2834,14 +2027,10 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16, output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, output_ne_offset); - - ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( - acl_input_tensor, acl_weight_tensor, acl_scale_tensor, - nullptr, nullptr, nullptr, nullptr, antiquantGroupSize, - acl_output_tensor, &workspaceSize, &executor)); - ACL_CHECK(aclnnWeightQuantBatchMatmulV2( - workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(WeightQuantBatchMatmulV2, acl_input_tensor, + acl_weight_tensor, acl_scale_tensor, nullptr, + nullptr, nullptr, nullptr, antiquantGroupSize, + acl_output_tensor); ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); ACL_CHECK(aclDestroyTensor(acl_scale_tensor)); ACL_CHECK(aclDestroyTensor(acl_output_tensor)); @@ -2864,8 +2053,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne, output_cast_nb, GGML_MAX_DIMS); aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, - ggml_cann_type_mapping(dst->type)); + aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, dst); ACL_CHECK(aclDestroyTensor(acl_output_tensor)); ACL_CHECK(aclDestroyTensor(acl_dst_tensor)); @@ -2884,7 +2072,7 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_mul_mat_quant(ctx, dst, type); break; default: - GGML_ABORT("fatal error"); + GGML_ABORT("Unsupported type for mul_mat"); break; } } @@ -2909,20 +2097,7 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* shifts, int64_t* dims) { aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1); aclIntArray* acl_dims = aclCreateIntArray(dims, 1); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnRollGetWorkspaceSize(acl_src, acl_shifts, acl_dims, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnRoll(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(Roll, acl_src, acl_shifts, acl_dims, acl_dst); ACL_CHECK(aclDestroyIntArray(acl_shifts)); ACL_CHECK(aclDestroyIntArray(acl_dims)); } @@ -2946,21 +2121,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, float value) { aclIntArray* acl_index = aclCreateIntArray(index, index_num); aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceIndexFillTensorGetWorkspaceSize( - acl_src, dim, acl_index, acl_value, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceIndexFillTensor(workspaceAddr, workspaceSize, - executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value); ACL_CHECK(aclDestroyIntArray(acl_index)); ACL_CHECK(aclDestroyScalar(acl_value)); } @@ -3019,8 +2180,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( src2->data, ggml_cann_type_mapping(src2->type), ggml_type_size(src2->type), arange_ne, arange_nb, GGML_MAX_DIMS); - aclnn_div_tensor(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, - nullptr, true); + aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor); ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor)); } @@ -3137,7 +2297,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: use ascendc // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input - ggml_tensor* src2 = dst->src[2]; // freq_factors + // ggml_tensor* src2 = dst->src[2]; // freq_factors, not used now. // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -3319,8 +2479,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // output void* output_fp32_buffer; if (src0->type == GGML_TYPE_F32) { - aclnn_inplace_mul(ctx, acl_src, acl_cos_reshape_tensor); - aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor, + aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor); + aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor); aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst); // TODO: ne0 != n_dims in mode2 @@ -3393,39 +2553,35 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_type_size(src0->type), sin_final_ne, sin_final_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor, - ggml_cann_type_mapping(src0->type)); - aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor, - ggml_cann_type_mapping(src0->type)); + aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor, dst); + aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor, dst); ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); acl_sin_reshape_tensor = acl_sin_final_tensor; acl_cos_reshape_tensor = acl_cos_final_tensor; } - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - - void* workspaceAddr = nullptr; - int acl_mode = mode; if (mode == 0) { acl_mode = 1; } - ACL_CHECK(aclnnRotaryPositionEmbeddingGetWorkspaceSize( - acl_src, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, - acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnRotaryPositionEmbedding(workspaceAddr, workspaceSize, - executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor, + acl_sin_reshape_tensor, acl_mode, acl_dst); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); ACL_CHECK(aclDestroyTensor(acl_dst)); } + + + void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ArgMax, acl_src, 3, false, acl_dst); + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 680129c76de68..116ddf0fb307b 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -31,20 +31,25 @@ * IN THE SOFTWARE. */ -#include +#include +#include +#include #include #include #include #include -#include #include +#include +#include #include #include #include -#include #include #include #include +#include +#include +#include #include "acl_tensor.h" #include "common.h" @@ -63,23 +68,6 @@ */ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst); -/** - * @brief Adds two ggml tensors using the CANN backend. - * - * @details This function performs an element-wise addition of two tensors. In - * case the tensors do not have the same shape, one or both tensors - * will be broadcasted to match the shape of the other before the - * addition is performed.The formula for the operation is given by: - * \f[ - * \text{dst} = \text{acl_src0} + \alpha \cdot \text{acl_src1} - * \f] - * - * @param ctx The CANN context used for operations. - * @param dst The ggml tensor representing the destination, result of the - * addition is stored at dst->data, and dst->op is `GGML_OP_ADD` - */ -void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst); - /** * @brief Applies the Leaky ReLU activation function to a tensor using the CANN * backend. @@ -131,19 +119,6 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst); -/** - * @brief Computes the square of the elements of a ggml tensor using the CANN - * backend. - * @details The function sets the second source tensor of the destination - * tensor `dst` to be equal to the first source tensor. This is - * effectively squaring the elements since the multiplication becomes - * `element * element`. - * @param ctx The CANN context used for operations. - * @param dst The destination tensor where the squared values will be stored, - * which dst->op is `GGML_OP_SQR`. - */ -void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst); - /** * @brief Applies a clamp operation to the elements of a ggml tensor using the * CANN backend. @@ -275,6 +250,20 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Computes the sum of elements in a ggml tensor. + * + * @details This function performs a reduction sum operation along the last + * dimension of the input tensor `src`. The result of the sum is stored + * in the destination tensor `dst`. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the reduced values will be stored。 + * + */ + +void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Upsamples a ggml tensor using nearest neighbor interpolation using * the CANN backend. @@ -484,109 +473,263 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst); -template -void ggml_cann_mul_div(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +/** + * @brief Computes the index of the maximum value along the specified dimension + * of a ggml tensor using the CANN backend. + * + * @details This function performs an argmax operation on the input tensor. + * It finds the index of the maximum value along the specified axis + * and stores these indices in the destination tensor `dst`. The + * operation is executed using the CANN backend for optimized performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the indices of the maximum values will be stored. + * dst->op is `GGML_OP_ARGMAX`. + */ +void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Adds two tensors element-wise and stores the result in a destination + * tensor. + * + * This function performs the operation: + * \f[ + * dst = acl\_src0 + alpha \times acl\_src1 + * \f] + * where alpha is a scalar value and defaults to 1.0f. + * + * @param ctx The context for the CANN backend operations. + * @param acl_src0 The first source tensor. + * @param acl_src1 The second source tensor. + * @param acl_dst The destination tensor where the result will be stored. + */ +void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, + aclTensor* acl_src1, aclTensor* acl_dst = nullptr); + +/** + * @brief Sub two tensors element-wise and stores the result in a destination + * tensor. + * + * This function performs the operation: + * \f[ + * dst = acl\_src0 - alpha \times acl\_src1 + * \f] + * where alpha is a scalar value and defaults to 1.0f. + * + * @param ctx The context for the CANN backend operations. + * @param acl_src0 The first source tensor. + * @param acl_src1 The second source tensor. + * @param acl_dst The destination tensor where the result will be stored. + */ +void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, + aclTensor* acl_src1, aclTensor* acl_dst = nullptr); + +/** + * @brief Performs element-wise multiplication of two tensors and stores the + * result in a destination tensor. + * + * This function performs element-wise multiplication of the tensors `acl_src` + * and `acl_other` and stores the result in the destination tensor `acl_dst`. + * The operation is defined as: + * \f[ + * \text {acl_dst }_i=\text {acl_src }_i \times \text {acl_other }_i + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src The first tensor for element-wise multiplication. + * @param acl_other The second tensor for element-wise multiplication. + * @param acl_dst The destination tensor where the result will be stored. + */ +void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst = nullptr); + +/** + * @brief Matrix division, optionally in-place. + * + * This function division each element of the source tensor `acl_src` by the + * tensor `acl_other` and stores the result in the destination tensor `acl_dst`. + * If `inplace` is true, `acl_dst` will not be used and the operation is + * performed in-place on `acl_src`. The operation is defined as: \f[ + * \text{dst}_i = \frac{\text{acl_src}_i}{\text{acl_other}_i} + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src Numerator tensor.. + * @param acl_other Denominator tensor. + * @param acl_dst The destination tensor where the result will be stored if + * `inplace` is false. + * @param inplace Flag indicating whether to perform the operation in-place on + * `acl_src`. + */ +void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst = nullptr); + +/** + * @brief Applies element-wise cosine function to the elements of a tensor. + * + * This function computes the cosine of each element in the source tensor + * `acl_src` and stores the result in the destination tensor `acl_dst`. The + * operation is defined as: \f[ \text {acl_dst }_i=\cos \left(\text {acl_src + * }_i\right) \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src The source tensor on which the cosine function will be + * applied. + * @param acl_dst The destination tensor where the cosine results will be + * stored. + */ +void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst); + +/** + * @brief Applies element-wise sine function to the elements of a tensor. + * + * This function computes the sine of each element in the source tensor + `acl_src` + * and stores the result in the destination tensor `acl_dst`. + * The operation is defined as: + * \f[ + * \text {acl_dst }_i=\sin \left(\text {acl_src }_i\right) + * \f] + + * @param ctx The context for the CANN backend operations. + * @param acl_src The source tensor on which the sine function will be applied. + * @param acl_dst The destination tensor where the sine results will be stored. + */ +void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst); + +/** + * @brief Launches an asynchronous task using the memory allocator. + * + * This macro submit an asynchronous task on the specified stream. + * The task uses memory allocated by the allocator. It is guaranteed + * that the memory will not be accessed by other tasks until this task + * completes, due to the sequential execution order within the same stream. + * + * @param OP_NAME aclnn operator name. + * @param args Additional arguments required by the task. + * + * @note + * Memory from the allocator will be "freed" immediately and can be + * reallocated to other pointers. However, it won't be accessed by any + * other task before this asynchronous task ends, because all tasks in the + * same stream are executed in queue order. + */ +#define GGML_CANN_CALL_ACLNN_OP(OP_NAME, ...) \ + do { \ + uint64_t workspaceSize = 0; \ + aclOpExecutor * executor; \ + void * workspaceAddr = nullptr; \ + \ + ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \ + \ + if (workspaceSize > 0) { \ + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); \ + workspaceAddr = workspace_allocator.get(); \ + } \ + ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, ctx.stream())); \ + } while (0) + + +/** + * @brief Prepares broadcast-compatible ACL tensors for two input tensors and one output tensor. + * + * This function checks whether broadcasting is needed between `src0` and `src1`. + * If broadcasting is required, it calculates the proper shapes and creates + * ACL tensors with broadcast parameters. Otherwise, it directly creates ACL tensors + * based on the original tensor shapes. + * + * @param src0 The first input tensor (reference shape). + * @param src1 The second input tensor (possibly broadcasted). + * @param dst The destination/output tensor. + * @param acl_src0 Output pointer to the created ACL tensor corresponding to src0. + * @param acl_src1 Output pointer to the created ACL tensor corresponding to src1. + * @param acl_dst Output pointer to the created ACL tensor corresponding to dst. + */ +void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, + aclTensor ** acl_src1, aclTensor ** acl_dst); + +/** + * @brief Applies a element-wise operation to two input tensors using the CANN backend. + * + * This templated function takes a binary operator and applies it to two source tensors + * associated with the destination tensor. The function handles broadcasting as needed. + * + * @tparam binary_op A callable object (e.g., lambda or function pointer) representing + * the binary operation to be performed. It must take three arguments: + * (ggml_backend_cann_context&, aclTensor*, aclTensor*, aclTensor*). + * + * @param ctx The CANN backend context used to manage execution and resources. + * @param dst The destination tensor. + */ +template +void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; ggml_tensor* src1 = dst->src[1]; - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); aclTensor* acl_src0; aclTensor* acl_src1; aclTensor* acl_dst; // Need bcast - if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) { - BCAST_SHAPE(src0, src1) - acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0)); - acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1)); - acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0)); - } else { - acl_src0 = ggml_cann_create_tensor(src0); - acl_src1 = ggml_cann_create_tensor(src1); - acl_dst = ggml_cann_create_tensor(dst); - } - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(getWorkspaceSize(acl_src0, acl_src1, acl_dst, &workspaceSize, - &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - aclrtStream main_stream = ctx.stream(); - ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); + bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst); + binary_op(ctx, acl_src0, acl_src1, acl_dst); ACL_CHECK(aclDestroyTensor(acl_src0)); ACL_CHECK(aclDestroyTensor(acl_src1)); ACL_CHECK(aclDestroyTensor(acl_dst)); } -// Activation functions template. -template -void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +/** + * @brief Applies a unary operation to an input tensor using the CANN backend. + * + * This templated function applies a unary operator to the source tensor of `dst` + * and stores the result in the destination tensor. + * + * @tparam unary_op A callable with the signature: + * void(ggml_backend_cann_context&, aclTensor*, aclTensor*) + * where the first aclTensor is the source and the second is the destination. + * + * @param ctx The CANN backend context for managing resources and execution. + * @param dst The destination tensor. Its src[0] is treated as the input tensor. + */ +template + void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - aclTensor* acl_src = ggml_cann_create_tensor(src); aclTensor* acl_dst = ggml_cann_create_tensor(dst); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - aclrtStream main_stream = ctx.stream(); - ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); - + unary_op(ctx, acl_src, acl_dst); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); } -// Activation functions template for const aclTensors. -template -void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; - - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - aclrtStream main_stream = ctx.stream(); - ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); -} +/** + * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op. + * + * This macro defines an inline lambda wrapping a specific ACL operation name, + * and passes it to the templated ggml_cann_unary_op function. It simplifies + * calling unary ops by hiding the lambda boilerplate. + * + * Internally, the lambda will call: + * @code + * GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst); + * @endcode + * + * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP. + * + * @see ggml_cann_unary_op + * @see GGML_CANN_CALL_ACLNN_OP + */ +#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \ + do { \ + auto lambda = [](auto ctx, auto acl_src, auto acl_dst) { \ + GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst); \ + }; \ + ggml_cann_unary_op(ctx, dst); \ + } \ + while (0) #endif // CANN_ACLNN_OPS diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 68cd9920d1ace..bcd05c9ae75de 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1300,47 +1300,59 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_dup(ctx, dst); break; case GGML_OP_ADD: - ggml_cann_add(ctx, dst); + case GGML_OP_ADD1: + ggml_cann_binary_op(ctx, dst); + break; + case GGML_OP_SUB: + ggml_cann_binary_op(ctx, dst); break; case GGML_OP_ACC: ggml_cann_acc(ctx, dst); break; case GGML_OP_MUL: - ggml_cann_mul_div(ctx, dst); + ggml_cann_binary_op(ctx, dst); break; case GGML_OP_DIV: - ggml_cann_mul_div(ctx, dst); + ggml_cann_binary_op(ctx, dst); break; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_ABS: + GGML_CANN_CALL_UNARY_OP(Abs); + break; + case GGML_UNARY_OP_NEG: + GGML_CANN_CALL_UNARY_OP(Neg); + break; case GGML_UNARY_OP_GELU: - ggml_cann_activation( - ctx, dst); + GGML_CANN_CALL_UNARY_OP(Gelu); break; case GGML_UNARY_OP_SILU: - ggml_cann_activation( - ctx, dst); + GGML_CANN_CALL_UNARY_OP(Silu); break; - // TODO: Use faster gelu?? - case GGML_UNARY_OP_GELU_QUICK: - ggml_cann_activation( - ctx, dst); + case GGML_UNARY_OP_GELU_QUICK: { + auto lambda = [](auto ctx, auto acl_src, auto acl_dst) { + GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_unary_op(ctx, dst); + } break; case GGML_UNARY_OP_TANH: - ggml_cann_activation( - ctx, dst); + GGML_CANN_CALL_UNARY_OP(Tanh); break; case GGML_UNARY_OP_RELU: - ggml_cann_activation( - ctx, dst); + GGML_CANN_CALL_UNARY_OP(Relu); + break; + case GGML_UNARY_OP_SIGMOID: + GGML_CANN_CALL_UNARY_OP(Sigmoid); break; case GGML_UNARY_OP_HARDSIGMOID: - ggml_cann_activation(ctx, dst); + GGML_CANN_CALL_UNARY_OP(Hardsigmoid); break; case GGML_UNARY_OP_HARDSWISH: - ggml_cann_activation(ctx, dst); + GGML_CANN_CALL_UNARY_OP(Hardswish); + break; + case GGML_UNARY_OP_EXP: + GGML_CANN_CALL_UNARY_OP(Exp); break; default: return false; @@ -1382,7 +1394,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_scale(ctx, dst); break; case GGML_OP_SQR: - ggml_cann_sqr(ctx, dst); + GGML_ASSERT(dst->src[1] == nullptr); + dst->src[1] = dst->src[0]; + ggml_cann_binary_op(ctx, dst); + break; + case GGML_OP_SQRT: + GGML_CANN_CALL_UNARY_OP(Sqrt); break; case GGML_OP_CLAMP: ggml_cann_clamp(ctx, dst); @@ -1414,12 +1431,24 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_POOL_2D: ggml_cann_pool2d(ctx, dst); break; + case GGML_OP_SUM: + ggml_cann_sum(ctx, dst); + break; case GGML_OP_SUM_ROWS: ggml_cann_sum_rows(ctx, dst); break; case GGML_OP_ARGSORT: ggml_cann_argsort(ctx, dst); break; + case GGML_OP_ARGMAX: + ggml_cann_argmax(ctx, dst); + break; + case GGML_OP_COS: + ggml_cann_unary_op(ctx, dst); + break; + case GGML_OP_SIN: + ggml_cann_unary_op(ctx, dst); + break; default: return false; } @@ -1458,11 +1487,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) { ACL_CHECK(aclrtSynchronizeDevice()); ACL_CHECK(aclrtResetDevice(cann_ctx->device)); - // finalize when last backend freed. - if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) { - ACL_CHECK(aclFinalize()); - } - delete cann_ctx; delete backend; } @@ -1675,24 +1699,31 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_EXP: return true; default: return false; } case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { - case GGML_TYPE_Q8_0: case GGML_TYPE_F16: case GGML_TYPE_F32: - case GGML_TYPE_Q4_0: return true; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + // only support contiguous for quantized types. + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]); default: return false; } @@ -1704,7 +1735,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: return true; default: @@ -1712,16 +1742,21 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } } break; case GGML_OP_CPY: { - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - return true; - default: - return false; + ggml_tensor *src = op->src[0]; + if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || + (src->type != GGML_TYPE_F32 && + src->type != GGML_TYPE_F16)) { + // only support F32 and F16. + return false; } - } + + if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) { + // unsupport dst is not contiguous. + return false; + } + + return true; + } break; case GGML_OP_CONT: { // TODO: support GGML_TYPE_BF16 switch (op->src[0]->type) { @@ -1734,13 +1769,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } case GGML_OP_ROPE: { // TODO: with ops-test v == 1 - float * ext_factor = (float*)((int32_t*)op->op_params + 7); + float ext_factor = 0.0f; + memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float)); // TODO: n_dims <= ne0 if (op->src[0]->ne[0] != op->op_params[1]) { return false; } // TODO: ext_factor != 0 - if (*ext_factor != 0) { + if (ext_factor != 0) { return false; } @@ -1762,9 +1798,20 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } return true; } + case GGML_OP_POOL_2D: { + const int32_t * opts = (const int32_t *) op->op_params; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int p0 = opts[5]; + const int p1 = opts[6]; + // value of paddingH should be at most half of kernelH + // value of paddingW should be at most half of kernelW + return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); + } + case GGML_OP_SUM: + case GGML_OP_DUP: case GGML_OP_IM2COL: case GGML_OP_CONCAT: - case GGML_OP_DUP: case GGML_OP_REPEAT: case GGML_OP_NONE: case GGML_OP_RESHAPE: @@ -1773,15 +1820,17 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_CLAMP: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: - case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: @@ -1790,6 +1839,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_ARGMAX: + case GGML_OP_COS: + case GGML_OP_SIN: return true; default: return false; diff --git a/ggml/src/ggml-cann/kernels/CMakeLists.txt b/ggml/src/ggml-cann/kernels/CMakeLists.txt deleted file mode 100644 index d687220c3c57e..0000000000000 --- a/ggml/src/ggml-cann/kernels/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -file(GLOB SRC_FILES - get_row_f32.cpp - get_row_f16.cpp - get_row_q4_0.cpp - get_row_q8_0.cpp - quantize_f32_q8_0.cpp - quantize_f16_q8_0.cpp - quantize_float_to_q4_0.cpp - dup.cpp -) - -set(ASCEND_CANN_PACKAGE_PATH ${CANN_INSTALL_DIR}) -set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim") - -if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) - set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) -elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake) - set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake) -else() - message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the compiler package is installed.") -endif() -include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) - -ascendc_library(ascendc_kernels STATIC - ${SRC_FILES} -) - -message(STATUS "CANN: compile ascend kernels witch SOC_TYPE:${SOC_TYPE}, SOC_VERSION:${SOC_VERSION}, compile macro:-D${SOC_TYPE_COMPILE_OPTION}.") -ascendc_compile_definitions(ascendc_kernels PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}") -# ascendc_compile_definitions(ascendc_kernels PRIVATE -DASCENDC_DUMP) diff --git a/ggml/src/ggml-cann/kernels/ascendc_kernels.h b/ggml/src/ggml-cann/kernels/ascendc_kernels.h deleted file mode 100644 index 7e153208cfdbc..0000000000000 --- a/ggml/src/ggml-cann/kernels/ascendc_kernels.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef ASCENDC_KERNELS_H -#define ASCENDC_KERNELS_H - -#include "aclrtlaunch_ascendc_get_row_f32.h" -#include "aclrtlaunch_ascendc_get_row_f16.h" -#include "aclrtlaunch_ascendc_get_row_q8_0.h" -#include "aclrtlaunch_ascendc_get_row_q4_0.h" - -#include "aclrtlaunch_ascendc_quantize_f32_q8_0.h" -#include "aclrtlaunch_ascendc_quantize_f16_q8_0.h" -#include "aclrtlaunch_ascendc_quantize_f16_to_q4_0.h" -#include "aclrtlaunch_ascendc_quantize_f32_to_q4_0.h" - -#include "aclrtlaunch_ascendc_dup_by_rows_fp16.h" -#include "aclrtlaunch_ascendc_dup_by_rows_fp32.h" -#include "aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16.h" -#include "aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32.h" - -#endif // ASCENDC_KERNELS_H diff --git a/ggml/src/ggml-cann/kernels/dup.cpp b/ggml/src/ggml-cann/kernels/dup.cpp deleted file mode 100644 index d9b9574494b72..0000000000000 --- a/ggml/src/ggml-cann/kernels/dup.cpp +++ /dev/null @@ -1,234 +0,0 @@ -#include "kernel_operator.h" - -using namespace AscendC; - -#define BUFFER_NUM 2 -const int64_t SUPPORTED_MAX_DIM = 65535; // currently the limit of max block dim supportted by dup kernel is 65535template - -template -class DupByRows { - public: - __aicore__ inline DupByRows() {} - __aicore__ inline void init(GM_ADDR src, GM_ADDR dst, int64_t *input_ne_ub, - size_t *input_nb_ub) { - /* Dup by rows when src is contigous on first dimension and dst is - contiguous, each kernel process one row. - */ - - // Input has four dims. - int64_t op_block_num = GetBlockNum(); - int64_t op_block_idx = GetBlockIdx(); - - // param - num_rows = input_ne_ub[1] * input_ne_ub[2] * input_ne_ub[3]; - num_elem = input_ne_ub[0]; - - // index for (ne[1], ne[2], ne[3]): (idx_ne1, idx_ne2, idx_ne3) - idx_ne3 = op_block_idx / (input_ne_ub[1] * input_ne_ub[2]); - idx_ne2 = (op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2])) - / (input_ne_ub[1]); - idx_ne1 = op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2]) - - idx_ne2 * input_ne_ub[1]; - - // src may not contiguous in dim [1,2,3], so stride decited by ne&nb - src_stride = input_nb_ub[3] * idx_ne3 + input_nb_ub[2] * idx_ne2 - + input_nb_ub[1] * idx_ne1; - - // dst is contiguous - dst_stride = op_block_idx * (input_ne_ub[0] * sizeof(DST_T)); - - src_gm.SetGlobalBuffer(reinterpret_cast<__gm__ SRC_T *>(src + - src_stride)); - dst_gm.SetGlobalBuffer(reinterpret_cast<__gm__ DST_T *>(dst + - dst_stride)); - - pipe.InitBuffer(src_queue, BUFFER_NUM, (sizeof(SRC_T) * num_elem + - 32 - 1) / 32 * 32); - pipe.InitBuffer(dst_queue, BUFFER_NUM, (sizeof(DST_T) * num_elem + - 32 - 1) / 32 * 32); - } - - __aicore__ inline void copy_in() { - LocalTensor src_local = src_queue.AllocTensor(); - const size_t elem_per_block = 32 / sizeof(SRC_T); - size_t tail = num_elem % elem_per_block; - size_t cpy_elements_len = tail > 0 ? num_elem + 1 : num_elem; - DataCopy(src_local, src_gm, cpy_elements_len); - src_queue.EnQue(src_local); - } - - __aicore__ inline void copy_out() { - LocalTensor dst_local = dst_queue.DeQue(); -#ifdef ASCEND_310P - const size_t elem_per_block = 32 / sizeof(DST_T); - size_t tail = num_elem % elem_per_block; - size_t len = num_elem & ~(elem_per_block - 1); - if (len > 0) { - DataCopy(dst_gm, dst_local, len); - } - if(tail != 0) { - for (size_t i = tail; i < elem_per_block; i++) { - dst_local[len + i].SetValue(0, 0); - } - SetAtomicAdd(); - DataCopy(dst_gm[len], dst_local[len], elem_per_block); - SetAtomicNone(); - } -#else - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = num_elem * sizeof(DST_T); - DataCopyPad(dst_gm, dst_local, dataCopyParams); -#endif - dst_queue.FreeTensor(dst_local); - } - - __aicore__ inline void dup() { - // main process, copy one row data from src to dst. - copy_in(); - - LocalTensor src_local = src_queue.DeQue(); - LocalTensor dst_local = dst_queue.AllocTensor(); - - int32_t BLOCK_NUM = 32 / sizeof(DST_T); - DataCopy(dst_local, src_local, (num_elem + BLOCK_NUM - 1) - / BLOCK_NUM * BLOCK_NUM); - dst_queue.EnQue(dst_local); - - src_queue.FreeTensor(src_local); - copy_out(); - } - - __aicore__ inline void dup_with_cast() { - // main process, copy one row data from src to dst. - // cast dtype from src to dst. - copy_in(); - - LocalTensor src_local = src_queue.DeQue(); - LocalTensor dst_local = dst_queue.AllocTensor(); - - Cast(dst_local, src_local, RoundMode::CAST_NONE, num_elem); - dst_queue.EnQue(dst_local); - - src_queue.FreeTensor(src_local); - copy_out(); - } - - private: - - TPipe pipe; - GlobalTensor src_gm; - GlobalTensor dst_gm; - - int64_t num_rows; - int64_t num_elem; - int64_t idx_ne3; - int64_t idx_ne2; - int64_t idx_ne1; - int64_t src_stride; - int64_t dst_stride; - - TQue src_queue; - TQue dst_queue; -}; - -template -__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { - auto gm_ptr = (__gm__ uint8_t *)gm; - auto ub_ptr = (uint8_t *)(ub); - for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { - *ub_ptr = *gm_ptr; - } -} - -extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16( - GM_ADDR src_gm, - GM_ADDR dst_gm, - GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, - GM_ADDR output_ne_gm, - GM_ADDR output_nb_gm) { - - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t output_ne_ub[4]; - size_t output_nb_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - copy_to_ub(output_nb_gm, output_nb_ub, 32); - - DupByRows op; - op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); - op.dup(); -} - -extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32( - GM_ADDR src_gm, - GM_ADDR dst_gm, - GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, - GM_ADDR output_ne_gm, - GM_ADDR output_nb_gm) { - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t output_ne_ub[4]; - size_t output_nb_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - copy_to_ub(output_nb_gm, output_nb_ub, 32); - - DupByRows op; - op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); - op.dup(); -} - -extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32_to_fp16( - GM_ADDR src_gm, - GM_ADDR dst_gm, - GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, - GM_ADDR output_ne_gm, - GM_ADDR output_nb_gm) { - - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t output_ne_ub[4]; - size_t output_nb_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - copy_to_ub(output_nb_gm, output_nb_ub, 32); - - DupByRows op; - op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); - op.dup_with_cast(); -} - -extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16_to_fp32( - GM_ADDR src_gm, - GM_ADDR dst_gm, - GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, - GM_ADDR output_ne_gm, - GM_ADDR output_nb_gm) { - - // copy params from gm to ub. - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t output_ne_ub[4]; - size_t output_nb_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - copy_to_ub(output_nb_gm, output_nb_ub, 32); - - DupByRows op; - op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); - op.dup_with_cast(); -} diff --git a/ggml/src/ggml-cann/kernels/get_row_f16.cpp b/ggml/src/ggml-cann/kernels/get_row_f16.cpp deleted file mode 100644 index 416b45104de5b..0000000000000 --- a/ggml/src/ggml-cann/kernels/get_row_f16.cpp +++ /dev/null @@ -1,197 +0,0 @@ -#include "kernel_operator.h" - -// optimize me. Use template to avoid copy code. -using namespace AscendC; - -#define BUFFER_NUM 2 - -class GET_ROW_F16 { - public: - __aicore__ inline GET_ROW_F16() {} - __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, - int64_t *input_ne_ub, size_t *input_nb_ub, - int64_t *indices_ne_ub, size_t *indices_nb_ub, - int64_t *output_ne_ub, size_t *output_nb_ub) { - // TODO, use template for F16/f32 - int64_t op_block_num = GetBlockNum(); - op_block_idx = GetBlockIdx(); - - for (int i = 0; i < 4; i++) { - input_ne[i] = input_ne_ub[i]; - input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; - - indices_ne[i] = indices_ne_ub[i]; - indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; - - output_ne[i] = output_ne_ub[i]; - output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; - } - - // Indices has two dims. n_elements = all rows should get. - // dr, all rows should this thread get. - uint64_t n_elements = - indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; - dr = n_elements / op_block_num; - - uint64_t tails = n_elements % op_block_num; - if (op_block_idx < tails) { - dr += 1; - ir = dr * op_block_idx; - } else { - ir = dr * op_block_idx + tails; - } - - input_gm.SetGlobalBuffer((__gm__ half *)input); - indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); - output_gm.SetGlobalBuffer((__gm__ float *)output); - - uint64_t input_local_buffer_size = ((input_ne[0] * sizeof(half) + 31) - & ~31); - uint64_t output_local_buffer_size = ((input_ne[0] * sizeof(float) + 31) - & ~31); - - local_buffer_elems = input_local_buffer_size / sizeof(half); - - // TODO, consider long row that can't put in UB. - // All data should asign to 32. It's ok because all data is align to 32. - pipe.InitBuffer(input_queue, BUFFER_NUM, input_local_buffer_size); - pipe.InitBuffer(output_queue, BUFFER_NUM, output_local_buffer_size); - } - - __aicore__ inline void copy_in(uint32_t offset, size_t len) { - size_t origin_len = len; - LocalTensor input_local = input_queue.AllocTensor(); - const size_t elem_per_block = 32 / sizeof(half); - size_t tail = len % elem_per_block; - len = len & ~(elem_per_block - 1); - if(tail != 0) { - len += elem_per_block; - } - DataCopy(input_local, input_gm[offset], len); - input_queue.EnQue(input_local); - } - - __aicore__ inline void copy_out(uint32_t offset, size_t len) { - LocalTensor output_local = output_queue.DeQue(); - const size_t elem_per_block = 32 / sizeof(float); - size_t tail = len % elem_per_block; - len = len & ~(elem_per_block - 1); - if (len > 0) { - DataCopy(output_gm[offset], output_local, len); - } - - if(tail != 0) { -#ifdef ASCEND_310P - for (size_t i = tail; i < elem_per_block; i++) { - output_local[len + i].SetValue(0, 0); - } - SetAtomicAdd(); - DataCopy(output_gm[offset + len], output_local[len], elem_per_block); - SetAtomicNone(); -#else - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = tail * sizeof(float); - DataCopyPad(output_gm[offset + len], output_local[len], - dataCopyParams); -#endif - } - output_queue.FreeTensor(output_local); - } - - __aicore__ inline void calculate_row(int64_t idx) { - const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); - const int64_t indices_ne1_idx = - (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / - indices_ne[0]; - const int64_t indices_ne0_idx = - (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - - indices_ne1_idx * indices_ne[0]); - - const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + - indices_ne1_idx * indices_stride[1] + - indices_ne2_idx * indices_stride[2]; - const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); - - const int64_t input_offset = selected_row_idx * input_stride[1] + - indices_ne1_idx * input_stride[2] + - indices_ne2_idx * input_stride[3]; - - const int64_t output_offset = indices_ne0_idx * output_stride[1] + - indices_ne1_idx * output_stride[2] + - indices_ne2_idx * output_stride[3]; - - copy_in(input_offset, input_ne[0]); - LocalTensor input_local = input_queue.DeQue(); - LocalTensor output_local = output_queue.AllocTensor(); - - Cast(output_local, input_local, RoundMode::CAST_NONE, - local_buffer_elems); - output_queue.EnQue(output_local); - copy_out(output_offset, input_ne[0]); - - input_queue.FreeTensor(input_local); - } - - __aicore__ inline void calculate() { - for (int64_t i = ir; i < ir + dr; i++) { - calculate_row(i); - } - } - - private: - int64_t input_ne[4]; - size_t input_stride[4]; - - int64_t indices_ne[4]; - size_t indices_stride[4]; - - int64_t output_ne[4]; - size_t output_stride[4]; - - size_t local_buffer_elems; - - int64_t ir; - int64_t dr; - - TPipe pipe; - GlobalTensor input_gm; - GlobalTensor indices_gm; - GlobalTensor output_gm; - TQue input_queue; - TQue output_queue; - int64_t op_block_idx; -}; - -template -__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { - auto gm_ptr = (__gm__ uint8_t *)gm; - auto ub_ptr = (uint8_t *)(ub); - for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { - *ub_ptr = *gm_ptr; - } -} - -extern "C" __global__ __aicore__ void ascendc_get_row_f16( - GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, - GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm, - GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t indices_ne_ub[4]; - size_t indices_nb_ub[4]; - int64_t output_ne_ub[4]; - size_t output_nb_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(indices_ne_gm, indices_ne_ub, 32); - copy_to_ub(indices_nb_gm, indices_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - copy_to_ub(output_nb_gm, output_nb_ub, 32); - - GET_ROW_F16 op; - op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub, - indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub); - op.calculate(); -} diff --git a/ggml/src/ggml-cann/kernels/get_row_f32.cpp b/ggml/src/ggml-cann/kernels/get_row_f32.cpp deleted file mode 100644 index 02116905b18e4..0000000000000 --- a/ggml/src/ggml-cann/kernels/get_row_f32.cpp +++ /dev/null @@ -1,190 +0,0 @@ -#include "kernel_operator.h" - -// optimize me. Use template to avoid copy code. -using namespace AscendC; - -#define BUFFER_NUM 2 - -class GET_ROW_F32 { - public: - __aicore__ inline GET_ROW_F32() {} - __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, - int64_t *input_ne_ub, size_t *input_nb_ub, - int64_t *indices_ne_ub, size_t *indices_nb_ub, - int64_t *output_ne_ub, size_t *output_nb_ub) { - int64_t op_block_num = GetBlockNum(); - op_block_idx = GetBlockIdx(); - - for (int i = 0; i < 4; i++) { - input_ne[i] = input_ne_ub[i]; - input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; - - indices_ne[i] = indices_ne_ub[i]; - indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; - - output_ne[i] = output_ne_ub[i]; - output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; - } - - // Indices has two dims. n_elements = all rows should get. - // dr, all rows should this thread get. - uint64_t n_elements = - indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; - dr = n_elements / op_block_num; - - uint64_t tails = n_elements % op_block_num; - if (op_block_idx < tails) { - dr += 1; - ir = dr * op_block_idx; - } else { - ir = dr * op_block_idx + tails; - } - - input_gm.SetGlobalBuffer((__gm__ float *)input); - indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); - output_gm.SetGlobalBuffer((__gm__ float *)output); - - uint64_t local_buffer_size = ((input_ne[0] * sizeof(float) + 31) & ~31); - local_buffer_elems = local_buffer_size / sizeof(float); - - // TODO, consider long row that can't put in UB. - // All data should asign to 32. It's ok because all data is align to 32. - pipe.InitBuffer(input_queue, BUFFER_NUM, local_buffer_size); - pipe.InitBuffer(output_queue, BUFFER_NUM, local_buffer_size); - } - - __aicore__ inline void copy_in(uint32_t offset, size_t len) { - LocalTensor input_local = input_queue.AllocTensor(); - const size_t elem_per_block = 32 / sizeof(float); - size_t tail = len % elem_per_block; - len = len & ~(elem_per_block - 1); - if(tail != 0) { - len += elem_per_block; - } - DataCopy(input_local, input_gm[offset], len); - input_queue.EnQue(input_local); - } - - __aicore__ inline void copy_out(uint32_t offset, size_t len) { - LocalTensor output_local = output_queue.DeQue(); - const size_t elem_per_block = 32 / sizeof(float); - size_t tail = len % elem_per_block; - len = len & ~(elem_per_block - 1); - if (len > 0) { - DataCopy(output_gm[offset], output_local, len); - } - - if(tail != 0) { -#ifdef ASCEND_310P - for (size_t i = tail; i < elem_per_block; i++) { - output_local[len + i].SetValue(0, 0); - } - SetAtomicAdd(); - DataCopy(output_gm[offset + len], output_local[len], elem_per_block); - SetAtomicNone(); -#else - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = tail * sizeof(float); - DataCopyPad(output_gm[offset + len], output_local[len], - dataCopyParams); -#endif - } - output_queue.FreeTensor(output_local); - } - - __aicore__ inline void calculate_row(int64_t idx) { - const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); - const int64_t indices_ne1_idx = - (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / - indices_ne[0]; - const int64_t indices_ne0_idx = - (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - - indices_ne1_idx * indices_ne[0]); - - const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + - indices_ne1_idx * indices_stride[1] + - indices_ne2_idx * indices_stride[2]; - const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); - - const int64_t input_offset = selected_row_idx * input_stride[1] + - indices_ne1_idx * input_stride[2] + - indices_ne2_idx * input_stride[3]; - - const int64_t output_offset = indices_ne0_idx * output_stride[1] + - indices_ne1_idx * output_stride[2] + - indices_ne2_idx * output_stride[3]; - - copy_in(input_offset, input_ne[0]); - LocalTensor input_local = input_queue.DeQue(); - LocalTensor output_local = output_queue.AllocTensor(); - - DataCopy(output_local, input_local, local_buffer_elems); - output_queue.EnQue(output_local); - copy_out(output_offset, input_ne[0]); - - input_queue.FreeTensor(input_local); - } - - __aicore__ inline void calculate() { - for (int64_t i = ir; i < ir + dr; i++) { - calculate_row(i); - } - } - - private: - int64_t input_ne[4]; - size_t input_stride[4]; - - int64_t indices_ne[4]; - size_t indices_stride[4]; - - int64_t output_ne[4]; - size_t output_stride[4]; - - size_t local_buffer_elems; - - int64_t ir; - int64_t dr; - - TPipe pipe; - GlobalTensor input_gm; - GlobalTensor indices_gm; - GlobalTensor output_gm; - TQue input_queue; - TQue output_queue; - int64_t op_block_idx; -}; - -template -__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { - auto gm_ptr = (__gm__ uint8_t *)gm; - auto ub_ptr = (uint8_t *)(ub); - for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { - *ub_ptr = *gm_ptr; - } -} - -extern "C" __global__ __aicore__ void ascendc_get_row_f32( - GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, - GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm, - GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t indices_ne_ub[4]; - size_t indices_nb_ub[4]; - int64_t output_ne_ub[4]; - size_t output_nb_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(indices_ne_gm, indices_ne_ub, 32); - copy_to_ub(indices_nb_gm, indices_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - copy_to_ub(output_nb_gm, output_nb_ub, 32); - - GET_ROW_F32 op; - op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub, - indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub); - op.calculate(); -} diff --git a/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp b/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp deleted file mode 100644 index 4fbe722086cf0..0000000000000 --- a/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +++ /dev/null @@ -1,204 +0,0 @@ -#include "kernel_operator.h" - -// optimize me. Use template to avoid copy code. -using namespace AscendC; -#ifdef ASCEND_310P // 310P not support 4bit get row - extern "C" __global__ __aicore__ void ascendc_get_row_q4_0( - GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, - GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm, - GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { - // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. - printf("Ascend310P not support 4bit get row.\n"); - } -#else - -#define BUFFER_NUM 2 - -#define QK4_0 32 - -class GET_ROW_Q4_0 { - public: - __aicore__ inline GET_ROW_Q4_0() {} - __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, - int64_t *input_ne_ub, int64_t *indices_ne_ub, - size_t *indices_nb_ub, int64_t *output_ne_ub, - size_t *output_nb_ub) { - int64_t op_block_num = GetBlockNum(); - int64_t op_block_idx = GetBlockIdx(); - - for (int i = 0; i < 4; i++) { - input_ne[i] = input_ne_ub[i]; - indices_ne[i] = indices_ne_ub[i]; - indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; - scale_ne[i] = input_ne_ub[i]; - output_ne[i] = output_ne_ub[i]; - output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; - } - - // one scale for a group. - scale_ne[0] /= QK4_0; - - input_stride[0] = 1; - scale_stride[0] = 1; - output_stride[0] = 1; - for (int i = 1; i < 4; i++) { - input_stride[i] = input_stride[i - 1] * input_ne[i - 1]; - scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; - } - - group_size_in_row = input_ne[0] / QK4_0; - int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] * - input_ne[3] / 2; - - // Indices has two dims. n_elements = all rows should get. - // dr, all rows should this thread get. - uint64_t n_elements = - indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; - dr = n_elements / op_block_num; - - uint64_t tails = n_elements % op_block_num; - if (op_block_idx < tails) { - dr += 1; - ir = dr * op_block_idx; - } else { - ir = dr * op_block_idx + tails; - } - - input_gm.SetGlobalBuffer((__gm__ int4b_t *)input); - scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset)); - indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); - output_gm.SetGlobalBuffer((__gm__ float *)output); - - pipe.InitBuffer(input_queue, BUFFER_NUM, QK4_0 * sizeof(int4b_t)); - pipe.InitBuffer(cast_queue, BUFFER_NUM, QK4_0 * sizeof(half)); - pipe.InitBuffer(output_queue, BUFFER_NUM, QK4_0 * sizeof(float)); - } - - __aicore__ inline void copy_in(uint32_t offset) { - LocalTensor input_local = input_queue.AllocTensor(); - // 32 * sizeof(int4b_t) = 16, which is not aligned to 32, why no error? - DataCopy(input_local, input_gm[offset], QK4_0); - input_queue.EnQue(input_local); - } - - __aicore__ inline void copy_out(uint32_t offset) { - LocalTensor output_local = output_queue.DeQue(); - DataCopy(output_gm[offset], output_local, QK4_0); - output_queue.FreeTensor(output_local); - } - - __aicore__ inline void calculate_group(int64_t idx, int64_t group) { - const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); - const int64_t indices_ne1_idx = - (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / - indices_ne[0]; - const int64_t indices_ne0_idx = - (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - - indices_ne1_idx * indices_ne[0]); - - const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + - indices_ne1_idx * indices_stride[1] + - indices_ne2_idx * indices_stride[2]; - const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); - - const int64_t input_offset = selected_row_idx * input_stride[1] + - indices_ne1_idx * input_stride[2] + - indices_ne2_idx * input_stride[3] + - group * QK4_0; - const int64_t scale_offset = selected_row_idx * scale_stride[1] + - indices_ne1_idx * scale_stride[2] + - indices_ne2_idx * scale_stride[3] + group; - const int64_t output_offset = indices_ne0_idx * output_stride[1] + - indices_ne1_idx * output_stride[2] + - indices_ne2_idx * output_stride[3] + - group * QK4_0; - - copy_in(input_offset); - LocalTensor input_local = input_queue.DeQue(); - LocalTensor cast_local = cast_queue.AllocTensor(); - LocalTensor output_local = output_queue.AllocTensor(); - - // TODO: cast more data to speed up. - Cast(cast_local, input_local, RoundMode::CAST_NONE, QK4_0); - Cast(output_local, cast_local, RoundMode::CAST_NONE, QK4_0); - - // Only mul need compile by group. - half scale = scale_gm.GetValue(scale_offset); - - Muls(output_local, output_local, (float)scale, QK4_0); - - input_queue.FreeTensor(input_local); - cast_queue.FreeTensor(cast_local); - output_queue.EnQue(output_local); - - copy_out(output_offset); - } - - __aicore__ inline void calculate() { - for (int64_t i = ir; i < ir + dr; i++) { - for (int64_t j = 0; j < group_size_in_row; j++) { - calculate_group(i, j); - } - } - } - - private: - int64_t input_ne[4]; - size_t input_stride[4]; - - int64_t scale_ne[4]; - size_t scale_stride[4]; - - int64_t indices_ne[4]; - size_t indices_stride[4]; - - int64_t output_ne[4]; - size_t output_stride[4]; - - int64_t ir; - int64_t dr; - - int64_t group_size_in_row; - - TPipe pipe; - GlobalTensor input_gm; - GlobalTensor scale_gm; - GlobalTensor indices_gm; - GlobalTensor output_gm; - TQue input_queue; - TQue output_queue; - TQue cast_queue; -}; - -template -__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { - auto gm_ptr = (__gm__ uint8_t *)gm; - auto ub_ptr = (uint8_t *)(ub); - for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { - *ub_ptr = *gm_ptr; - } -} - -extern "C" __global__ __aicore__ void ascendc_get_row_q4_0( - GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, - GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm, - GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { - int64_t input_ne_ub[4]; - int64_t indices_ne_ub[4]; - size_t indices_nb_ub[4]; - int64_t output_ne_ub[4]; - size_t output_nb_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(indices_ne_gm, indices_ne_ub, 32); - copy_to_ub(indices_nb_gm, indices_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - copy_to_ub(output_nb_gm, output_nb_ub, 32); - - GET_ROW_Q4_0 op; - op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub, - indices_nb_ub, output_ne_ub, output_nb_ub); - op.calculate(); -} - -#endif // #ifdef ASCEND_310P diff --git a/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp b/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp deleted file mode 100644 index ba9ab3c04832f..0000000000000 --- a/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +++ /dev/null @@ -1,191 +0,0 @@ -#include "kernel_operator.h" - -// optimize me. Use template to avoid copy code. -using namespace AscendC; - -#define BUFFER_NUM 2 - -#define QK8_0 32 - -class GET_ROW_Q8_0 { - public: - __aicore__ inline GET_ROW_Q8_0() {} - __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, - int64_t *input_ne_ub, int64_t *indices_ne_ub, - size_t *indices_nb_ub, int64_t *output_ne_ub, - size_t *output_nb_ub) { - int64_t op_block_num = GetBlockNum(); - int64_t op_block_idx = GetBlockIdx(); - - for (int i = 0; i < 4; i++) { - input_ne[i] = input_ne_ub[i]; - indices_ne[i] = indices_ne_ub[i]; - indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; - scale_ne[i] = input_ne_ub[i]; - output_ne[i] = output_ne_ub[i]; - output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; - } - - // one scale for a group. - scale_ne[0] /= QK8_0; - - input_stride[0] = 1; - scale_stride[0] = 1; - output_stride[0] = 1; - for (int i = 1; i < 4; i++) { - input_stride[i] = input_stride[i - 1] * input_ne[i - 1]; - scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; - } - - group_size_in_row = input_ne[0] / QK8_0; - int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] * - input_ne[3] * sizeof(int8_t); - - // Indices has two dims. n_elements = all rows should get. - // dr, all rows should this thread get. - uint64_t n_elements = - indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; - dr = n_elements / op_block_num; - - uint64_t tails = n_elements % op_block_num; - if (op_block_idx < tails) { - dr += 1; - ir = dr * op_block_idx; - } else { - ir = dr * op_block_idx + tails; - } - - input_gm.SetGlobalBuffer((__gm__ int8_t *)input); - scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset)); - indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); - output_gm.SetGlobalBuffer((__gm__ float *)output); - - pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t)); - pipe.InitBuffer(cast_queue, BUFFER_NUM, QK8_0 * sizeof(half)); - pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(float)); - } - - __aicore__ inline void copy_in(uint32_t offset) { - LocalTensor input_local = input_queue.AllocTensor(); - DataCopy(input_local, input_gm[offset], QK8_0); - input_queue.EnQue(input_local); - } - - __aicore__ inline void copy_out(uint32_t offset) { - LocalTensor output_local = output_queue.DeQue(); - DataCopy(output_gm[offset], output_local, QK8_0); - output_queue.FreeTensor(output_local); - } - - __aicore__ inline void calculate_group(int64_t idx, int64_t group) { - const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); - const int64_t indices_ne1_idx = - (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / - indices_ne[0]; - const int64_t indices_ne0_idx = - (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - - indices_ne1_idx * indices_ne[0]); - - const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + - indices_ne1_idx * indices_stride[1] + - indices_ne2_idx * indices_stride[2]; - const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); - - const int64_t input_offset = selected_row_idx * input_stride[1] + - indices_ne1_idx * input_stride[2] + - indices_ne2_idx * input_stride[3] + - group * QK8_0; - const int64_t scale_offset = selected_row_idx * scale_stride[1] + - indices_ne1_idx * scale_stride[2] + - indices_ne2_idx * scale_stride[3] + group; - const int64_t output_offset = indices_ne0_idx * output_stride[1] + - indices_ne1_idx * output_stride[2] + - indices_ne2_idx * output_stride[3] + - group * QK8_0; - - copy_in(input_offset); - LocalTensor input_local = input_queue.DeQue(); - LocalTensor cast_local = cast_queue.AllocTensor(); - LocalTensor output_local = output_queue.AllocTensor(); - - // TODO: cast more data to speed up. - Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0); - Cast(output_local, cast_local, RoundMode::CAST_NONE, QK8_0); - - // Only mul need compile by group. - half scale = scale_gm.GetValue(scale_offset); - Muls(output_local, output_local, (float)scale, QK8_0); - - input_queue.FreeTensor(input_local); - cast_queue.FreeTensor(cast_local); - output_queue.EnQue(output_local); - - copy_out(output_offset); - } - - __aicore__ inline void calculate() { - for (int64_t i = ir; i < ir + dr; i++) { - for (int64_t j = 0; j < group_size_in_row; j++) { - calculate_group(i, j); - } - } - } - - private: - int64_t input_ne[4]; - size_t input_stride[4]; - - int64_t scale_ne[4]; - size_t scale_stride[4]; - - int64_t indices_ne[4]; - size_t indices_stride[4]; - - int64_t output_ne[4]; - size_t output_stride[4]; - - int64_t ir; - int64_t dr; - - int64_t group_size_in_row; - - TPipe pipe; - GlobalTensor input_gm; - GlobalTensor scale_gm; - GlobalTensor indices_gm; - GlobalTensor output_gm; - TQue input_queue; - TQue output_queue; - TQue cast_queue; -}; - -template -__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { - auto gm_ptr = (__gm__ uint8_t *)gm; - auto ub_ptr = (uint8_t *)(ub); - for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { - *ub_ptr = *gm_ptr; - } -} - -extern "C" __global__ __aicore__ void ascendc_get_row_q8_0( - GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, - GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm, - GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { - int64_t input_ne_ub[4]; - int64_t indices_ne_ub[4]; - size_t indices_nb_ub[4]; - int64_t output_ne_ub[4]; - size_t output_nb_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(indices_ne_gm, indices_ne_ub, 32); - copy_to_ub(indices_nb_gm, indices_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - copy_to_ub(output_nb_gm, output_nb_ub, 32); - - GET_ROW_Q8_0 op; - op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub, - indices_nb_ub, output_ne_ub, output_nb_ub); - op.calculate(); -} diff --git a/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp b/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp deleted file mode 100644 index 504b43afaa1f4..0000000000000 --- a/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +++ /dev/null @@ -1,218 +0,0 @@ -#include "kernel_operator.h" - -using namespace AscendC; -#ifdef ASCEND_310P - extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0( - GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { - // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. - printf("Ascend310P not support f16->8bit quantization.\n"); - } -#else - -#define BUFFER_NUM 2 -#define QK8_0 32 - -class QUANTIZE_F16_Q8_0 { - public: - __aicore__ inline QUANTIZE_F16_Q8_0() {} - __aicore__ inline void init(GM_ADDR input, GM_ADDR output, - int64_t *input_ne_ub, size_t *input_nb_ub, - int64_t *output_ne_ub) { - int64_t op_block_num = GetBlockNum(); - int64_t op_block_idx = GetBlockIdx(); - - for (int i = 0; i < 4; i++) { - input_ne[i] = input_ne_ub[i]; - input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; - - output_ne[i] = output_ne_ub[i]; - } - - output_stride[0] = 1; - for (int i = 1; i < 4; i++) { - output_stride[i] = output_stride[i - 1] * output_ne[i - 1]; - } - - scale_ne = input_ne; - scale_stride[0] = 1; - scale_stride[1] = input_ne[0] / QK8_0; - for (int i = 2; i < 4; i++) { - scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; - } - - // split input tensor by rows. - uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3]; - dr = nr / op_block_num; - - uint64_t tails = nr % op_block_num; - if (op_block_idx < tails) { - dr += 1; - ir = dr * op_block_idx; - } else { - ir = dr * op_block_idx + tails; - } - - group_size_in_row = scale_stride[1]; - int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] * - output_ne[3] * sizeof(uint8_t); - - input_gm.SetGlobalBuffer((__gm__ half *)input); - output_gm.SetGlobalBuffer((__gm__ int8_t *)output); - scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size + ir * - group_size_in_row * - sizeof(half))); - - pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(half)); - pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t)); - pipe.InitBuffer(work_queue, 1, 32); - pipe.InitBuffer(max_queue, 1, 32); - pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float)); - pipe.InitBuffer(scale_queue, 1, 32); - pipe.InitBuffer(cast_queue ,1 ,QK8_0 * sizeof(float)); - } - - __aicore__ inline void copy_in(uint32_t offset) { - LocalTensor input_local = input_queue.AllocTensor(); - DataCopy(input_local, input_gm[offset], QK8_0); - input_queue.EnQue(input_local); - } - - __aicore__ inline void copy_out(uint32_t offset) { - LocalTensor output_local = output_queue.DeQue(); - DataCopy(output_gm[offset], output_local, QK8_0); - output_queue.FreeTensor(output_local); - } - - __aicore__ inline half calculate_group(int64_t row, int64_t group) { - const int64_t i3 = row / (input_ne[1] * input_ne[2]); - const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1]; - const int64_t i1 = - row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1]; - - const int64_t input_offset = i1 * input_stride[1] + - i2 * input_stride[2] + - i3 * input_stride[3] + QK8_0 * group; - - const int64_t output_offset = i1 * output_stride[1] + - i2 * output_stride[2] + - i3 * output_stride[3] + QK8_0 * group; - - copy_in(input_offset); - LocalTensor input_local = input_queue.DeQue(); - LocalTensor output_local = output_queue.AllocTensor(); - LocalTensor work_local = work_queue.AllocTensor(); - LocalTensor abs_local = abs_queue.AllocTensor(); - LocalTensor max_local = max_queue.AllocTensor(); - LocalTensor cast_local = cast_queue.AllocTensor(); - - Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0); - Abs(abs_local, cast_local, QK8_0); - ReduceMax(max_local, abs_local, work_local, QK8_0); - - pipe_barrier(PIPE_ALL); - float d = max_local.GetValue(0); - d = d / ((1 << 7) - 1); - if (d != 0) { - Muls(cast_local, cast_local, 1.0f / d, QK8_0); - } - - Cast(cast_local, cast_local, RoundMode::CAST_ROUND, QK8_0); - Cast(input_local, cast_local, RoundMode::CAST_ROUND, QK8_0); - Cast(output_local, input_local, RoundMode::CAST_ROUND, QK8_0); - output_queue.EnQue(output_local); - copy_out(output_offset); - - input_queue.FreeTensor(input_local); - work_queue.FreeTensor(work_local); - abs_queue.FreeTensor(abs_local); - max_queue.FreeTensor(max_local); - cast_queue.FreeTensor(cast_local); - return (half)d; - } - - __aicore__ inline void calculate() { - LocalTensor scale_local = scale_queue.AllocTensor(); - uint32_t scale_local_offset = 0; - uint32_t scale_global_offset = 0; - for (int64_t i = ir; i < ir + dr; i++) { - for (int64_t j = 0; j < group_size_in_row; j++) { - half scale = calculate_group(i, j); - scale_local.SetValue(scale_local_offset++, scale); - if (scale_local_offset == 16) { - scale_local_offset = 0; - // TODO: OPTIMIZE ME - pipe_barrier(PIPE_ALL); - DataCopy(scale_gm[scale_global_offset], scale_local, 16); - pipe_barrier(PIPE_ALL); - scale_global_offset += 16; - } - } - } - - if (scale_local_offset != 0) { - pipe_barrier(PIPE_ALL); - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = scale_local_offset * sizeof(half); - DataCopyPad(scale_gm[scale_global_offset], scale_local, - dataCopyParams); - pipe_barrier(PIPE_ALL); - } - } - - private: - int64_t input_ne[4]; - size_t input_stride[4]; - - int64_t *scale_ne; - size_t scale_stride[4]; - - int64_t output_ne[4]; - size_t output_stride[4]; - - int64_t group_size_in_row; - - int64_t ir; - int64_t dr; - - TPipe pipe; - GlobalTensor input_gm; - GlobalTensor scale_gm; - GlobalTensor output_gm; - TQue input_queue; - TQue output_queue; - TQue work_queue; - TQue max_queue; - TQue abs_queue; - TQue scale_queue; - TQue cast_queue; - -}; - -template -__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { - auto gm_ptr = (__gm__ uint8_t *)gm; - auto ub_ptr = (uint8_t *)(ub); - for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { - *ub_ptr = *gm_ptr; - } -} - -extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0( - GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t output_ne_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - - QUANTIZE_F16_Q8_0 op; - op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); - op.calculate(); -} - -#endif // #ifdef ASCEND_310P diff --git a/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp b/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp deleted file mode 100644 index 05b0bc1df59af..0000000000000 --- a/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +++ /dev/null @@ -1,216 +0,0 @@ -#include "kernel_operator.h" - -using namespace AscendC; -#ifdef ASCEND_310P // 310P not support f32->8bit quantization - extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0( - GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { - // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. - printf("Ascend310P not support f32->8bit quantization.\n"); - } -#else - -#define BUFFER_NUM 2 -#define QK8_0 32 - -class QUANTIZE_F32_Q8_0 { - public: - __aicore__ inline QUANTIZE_F32_Q8_0() {} - __aicore__ inline void init(GM_ADDR input, GM_ADDR output, - int64_t *input_ne_ub, size_t *input_nb_ub, - int64_t *output_ne_ub) { - int64_t op_block_num = GetBlockNum(); - int64_t op_block_idx = GetBlockIdx(); - - for (int i = 0; i < 4; i++) { - input_ne[i] = input_ne_ub[i]; - input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; - - output_ne[i] = output_ne_ub[i]; - } - - output_stride[0] = 1; - for (int i = 1; i < 4; i++) { - output_stride[i] = output_stride[i - 1] * output_ne[i - 1]; - } - - scale_ne = input_ne; - scale_stride[0] = 1; - scale_stride[1] = input_ne[0] / QK8_0; - for (int i = 2; i < 4; i++) { - scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; - } - - // split input tensor by rows. - uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3]; - dr = nr / op_block_num; - - uint64_t tails = nr % op_block_num; - if (op_block_idx < tails) { - dr += 1; - ir = dr * op_block_idx; - } else { - ir = dr * op_block_idx + tails; - } - - group_size_in_row = scale_stride[1]; - int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] * - output_ne[3] * sizeof(uint8_t); - - input_gm.SetGlobalBuffer((__gm__ float *)input); - output_gm.SetGlobalBuffer((__gm__ int8_t *)output); - scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size + - ir * group_size_in_row * - sizeof(half))); - - pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(float)); - pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t)); - pipe.InitBuffer(work_queue, 1, 32); - pipe.InitBuffer(max_queue, 1, 32); - pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float)); - pipe.InitBuffer(cast_queue, 1, QK8_0 * sizeof(half)); - pipe.InitBuffer(scale_queue, 1, 32); - } - - __aicore__ inline void copy_in(uint32_t offset) { - LocalTensor input_local = input_queue.AllocTensor(); - DataCopy(input_local, input_gm[offset], QK8_0); - input_queue.EnQue(input_local); - } - - __aicore__ inline void copy_out(uint32_t offset) { - LocalTensor output_local = output_queue.DeQue(); - DataCopy(output_gm[offset], output_local, QK8_0); - output_queue.FreeTensor(output_local); - } - - __aicore__ inline half calculate_group(int64_t row, int64_t group) { - const int64_t i3 = row / (input_ne[1] * input_ne[2]); - const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1]; - const int64_t i1 = - row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1]; - - const int64_t input_offset = i1 * input_stride[1] + - i2 * input_stride[2] + - i3 * input_stride[3] + QK8_0 * group; - - const int64_t output_offset = i1 * output_stride[1] + - i2 * output_stride[2] + - i3 * output_stride[3] + QK8_0 * group; - - copy_in(input_offset); - LocalTensor input_local = input_queue.DeQue(); - LocalTensor output_local = output_queue.AllocTensor(); - LocalTensor work_local = work_queue.AllocTensor(); - LocalTensor abs_local = abs_queue.AllocTensor(); - LocalTensor max_local = max_queue.AllocTensor(); - LocalTensor cast_local = cast_queue.AllocTensor(); - - Abs(abs_local, input_local, QK8_0); - ReduceMax(max_local, abs_local, work_local, QK8_0); - pipe_barrier(PIPE_ALL); - float d = max_local.GetValue(0); - d = d / ((1 << 7) - 1); - if (d != 0) { - Muls(input_local, input_local, 1.0f / d, QK8_0); - } - - Cast(input_local, input_local, RoundMode::CAST_ROUND, QK8_0); - Cast(cast_local, input_local, RoundMode::CAST_ROUND, QK8_0); - Cast(output_local, cast_local, RoundMode::CAST_ROUND, QK8_0); - output_queue.EnQue(output_local); - copy_out(output_offset); - - input_queue.FreeTensor(input_local); - work_queue.FreeTensor(work_local); - abs_queue.FreeTensor(abs_local); - max_queue.FreeTensor(max_local); - cast_queue.FreeTensor(cast_local); - - return (half)d; - } - - __aicore__ inline void calculate() { - LocalTensor scale_local = scale_queue.AllocTensor(); - uint32_t scale_local_offset = 0; - uint32_t scale_global_offset = 0; - for (int64_t i = ir; i < ir + dr; i++) { - for (int64_t j = 0; j < group_size_in_row; j++) { - half scale = calculate_group(i, j); - scale_local.SetValue(scale_local_offset++, scale); - if (scale_local_offset == 16) { - scale_local_offset = 0; - // TODO: OPTIMIZE ME - pipe_barrier(PIPE_ALL); - DataCopy(scale_gm[scale_global_offset], scale_local, 16); - pipe_barrier(PIPE_ALL); - scale_global_offset += 16; - } - } - } - - if (scale_local_offset != 0) { - pipe_barrier(PIPE_ALL); - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = scale_local_offset * sizeof(half); - DataCopyPad(scale_gm[scale_global_offset], scale_local, - dataCopyParams); - pipe_barrier(PIPE_ALL); - } - } - - private: - int64_t input_ne[4]; - size_t input_stride[4]; - - int64_t *scale_ne; - size_t scale_stride[4]; - - int64_t output_ne[4]; - size_t output_stride[4]; - - int64_t group_size_in_row; - - int64_t ir; - int64_t dr; - - TPipe pipe; - GlobalTensor input_gm; - GlobalTensor scale_gm; - GlobalTensor output_gm; - TQue input_queue; - TQue output_queue; - TQue work_queue; - TQue max_queue; - TQue abs_queue; - TQue cast_queue; - TQue scale_queue; -}; - -template -__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { - auto gm_ptr = (__gm__ uint8_t *)gm; - auto ub_ptr = (uint8_t *)(ub); - for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { - *ub_ptr = *gm_ptr; - } -} - -extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0( - GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t output_ne_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - - QUANTIZE_F32_Q8_0 op; - op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); - op.calculate(); -} - -#endif // #ifdef ASCEND_310P diff --git a/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp b/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp deleted file mode 100644 index 1188937b74461..0000000000000 --- a/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +++ /dev/null @@ -1,295 +0,0 @@ -#include "kernel_operator.h" - -using namespace AscendC; -#ifdef ASCEND_310P // 310P not support float->4bit quantization - extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0( - GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { - // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. - printf("Ascend310P not support f32->4bit quantization.\n"); - } - - extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0( - GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { - // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. - printf("Ascend310P not support f16->4bit quantization.\n"); - } -#else - -#define BUFFER_NUM 2 -#define Group_Size 32 - -template -class QUANTIZE_FLOAT_TO_Q4_0 { - public: - __aicore__ inline QUANTIZE_FLOAT_TO_Q4_0() {} - __aicore__ inline void init(GM_ADDR input, GM_ADDR output, - int64_t *input_ne_ub, size_t *input_nb_ub, - int64_t *output_ne_ub) { - // TODO: fix test_case CPY(type_src=f16,type_dst=q4_0,ne=[256,4,4,4], - // permute=[0,0,0,0]): - // [CPY] NMSE = 0.000008343 > 0.000001000 FAIL - int64_t op_block_num = GetBlockNum(); - int64_t op_block_idx = GetBlockIdx(); - - // input stride of data elements - for (int i = 0; i < 4; i++) { - input_ne[i] = input_ne_ub[i]; - input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; - output_ne[i] = output_ne_ub[i]; - } - - // output stride of data elements - output_stride[0] = 1; - for (int i = 1; i < 4; i++) { - output_stride[i] = output_stride[i - 1] * output_ne[i - 1]; - } - - // scale saved one by one after data:. [group1_scale, group2_scale, ...] - scale_ne = input_ne; - scale_stride[0] = 1; - scale_stride[1] = input_ne[0] / Group_Size; - for (int i = 2; i < 4; i++) { - scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; - } - - // split input tensor by rows. - uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3]; - dr = nr / op_block_num; - - uint64_t tails = nr % op_block_num; - if (op_block_idx < tails) { - dr += 1; - ir = dr * op_block_idx; - } else { - ir = dr * op_block_idx + tails; - } - - group_size_in_row = scale_stride[1]; - int64_t scale_offset = output_ne[0] * output_ne[1] * output_ne[2] * - output_ne[3] * sizeof(uint8_t) / 2; - - input_gm.SetGlobalBuffer((__gm__ SRC_T *)input); - output_gm.SetGlobalBuffer((__gm__ int8_t *)output); - scale_gm.SetGlobalBuffer((__gm__ half *)(output + scale_offset + ir * - group_size_in_row * - sizeof(half))); - - pipe.InitBuffer(input_queue, BUFFER_NUM, Group_Size * sizeof(SRC_T)); - pipe.InitBuffer(output_queue, BUFFER_NUM, - Group_Size * sizeof(int8_t) / 2); - pipe.InitBuffer(cast_queue , 1, Group_Size * sizeof(float)); - pipe.InitBuffer(work_queue, 1, Group_Size * sizeof(float)); - pipe.InitBuffer(max_queue, 1, Group_Size * sizeof(float)); - pipe.InitBuffer(min_queue, 1, Group_Size * sizeof(float)); - pipe.InitBuffer(scale_queue, 1, Group_Size / 2 * sizeof(half)); - pipe.InitBuffer(int8_queue, 1, Group_Size * sizeof(int8_t)); - pipe.InitBuffer(half_queue, 1, Group_Size * sizeof(half)); - } - - __aicore__ inline void copy_in(uint32_t offset) { - LocalTensor input_local = input_queue.AllocTensor(); - DataCopy(input_local, input_gm[offset], Group_Size); - input_queue.EnQue(input_local); - } - - __aicore__ inline void copy_out(uint32_t offset) { - // reinterpretcast Group_Size(32) * int4b_t to Group_Size / 2 * int8_t, - // and using DataCopyPad to avoid 32 bits align. - LocalTensor output_local = output_queue.DeQue(); - LocalTensor output_int8_local = - output_local.ReinterpretCast(); - - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = Group_Size / 2 * sizeof(int8_t); - DataCopyPad(output_gm[offset], output_int8_local, dataCopyParams); - - output_queue.FreeTensor(output_local); - } - - __aicore__ inline void input_to_cast(LocalTensor cast_local, - LocalTensor input_local) { - DataCopy(cast_local, input_local, Group_Size); - } - - __aicore__ inline void input_to_cast(LocalTensor cast_local, - LocalTensor input_local) { - Cast(cast_local, input_local, RoundMode::CAST_NONE, Group_Size); - } - - __aicore__ inline half calculate_group(int64_t row, int64_t group) { - const int64_t i3 = row / (input_ne[1] * input_ne[2]); - const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1]; - const int64_t i1 = - row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1]; - - const int64_t input_offset = i1 * input_stride[1] + - i2 * input_stride[2] + - i3 * input_stride[3] + Group_Size * group; - - // output_offset is stride for output_gm which datatype is int8_t and - // divided by 2 is needed for int4b_t. - const int64_t output_offset = (i1 * output_stride[1] + - i2 * output_stride[2] + - i3 * output_stride[3] + - Group_Size * group) / 2; - copy_in(input_offset); - - LocalTensor input_local = input_queue.DeQue(); - LocalTensor output_local = output_queue.AllocTensor(); - LocalTensor cast_local = cast_queue.AllocTensor(); - LocalTensor work_local = work_queue.AllocTensor(); - LocalTensor max_local = max_queue.AllocTensor(); - LocalTensor min_local = min_queue.AllocTensor(); - LocalTensor int8_local = int8_queue.AllocTensor(); - LocalTensor half_local = half_queue.AllocTensor(); - - input_to_cast(cast_local, input_local); - - ReduceMax(max_local, cast_local, work_local, Group_Size); - ReduceMin(min_local, cast_local, work_local, Group_Size); - const float max_value = max_local.GetValue(0); - const float min_value = min_local.GetValue(0); - float d = max_value; - if (min_value < 0 && (-1 * min_value) > max_value) { - d = min_value; - } - - d = d / (-8); - if (d != 0) { - Muls(cast_local, cast_local, 1.0f / d, Group_Size); - } - - // range: [-8,8] -> [0.5,16.5] -> [0,16] -> [0,15] -> [-8,7] - float scalar = 8.5f; - Adds(cast_local, cast_local, scalar, Group_Size); - Cast(cast_local, cast_local, RoundMode::CAST_FLOOR, Group_Size); - scalar = 15.0f; - Mins(cast_local, cast_local, scalar, Group_Size); - scalar = -8.0f; - Adds(cast_local, cast_local, scalar, Group_Size); - - // float->half->int4b - Cast(half_local, cast_local, RoundMode::CAST_NONE, Group_Size); - Cast(output_local, half_local, RoundMode::CAST_NONE, Group_Size); - - output_queue.EnQue(output_local); - copy_out(output_offset); - - input_queue.FreeTensor(input_local); - work_queue.FreeTensor(work_local); - max_queue.FreeTensor(max_local); - min_queue.FreeTensor(min_local); - int8_queue.FreeTensor(int8_local); - half_queue.FreeTensor(half_local); - cast_queue.FreeTensor(cast_local); - return (half)d; - } - - __aicore__ inline void calculate() { - LocalTensor scale_local = scale_queue.AllocTensor(); - uint32_t scale_local_offset = 0; - uint32_t scale_global_offset = 0; - for (int64_t i = ir; i < ir + dr; i++) { - for (int64_t j = 0; j < group_size_in_row; j++) { - half scale = calculate_group(i, j); - scale_local.SetValue(scale_local_offset++, scale); - // Copy Group_Size/2 length data each time. - if (scale_local_offset == Group_Size / 2) { - scale_local_offset = 0; - // TODO: OPTIMIZE ME - pipe_barrier(PIPE_ALL); - DataCopy(scale_gm[scale_global_offset], scale_local, - Group_Size / 2); - pipe_barrier(PIPE_ALL); - scale_global_offset += Group_Size / 2; - } - } - } - - if (scale_local_offset != 0) { - pipe_barrier(PIPE_ALL); - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = scale_local_offset * sizeof(half); - DataCopyPad(scale_gm[scale_global_offset], scale_local, - dataCopyParams); - pipe_barrier(PIPE_ALL); - } - scale_queue.FreeTensor(scale_local); - } - - private: - int64_t input_ne[4]; - size_t input_stride[4]; - - int64_t *scale_ne; - size_t scale_stride[4]; - - int64_t output_ne[4]; - size_t output_stride[4]; - - int64_t group_size_in_row; - - int64_t ir; - int64_t dr; - - TPipe pipe; - GlobalTensor input_gm; - GlobalTensor scale_gm; - GlobalTensor output_gm; - TQue input_queue; - TQue output_queue; - TQue work_queue; - TQue max_queue; - TQue min_queue; - TQue scale_queue; - TQue cast_queue; - TQue int8_queue; - TQue half_queue; -}; - -template -__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { - auto gm_ptr = (__gm__ uint8_t *)gm; - auto ub_ptr = (uint8_t *)(ub); - for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { - *ub_ptr = *gm_ptr; - } -} - -extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0( - GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t output_ne_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - - QUANTIZE_FLOAT_TO_Q4_0 op; - op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); - op.calculate(); -} - -extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0( - GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, - GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { - int64_t input_ne_ub[4]; - size_t input_nb_ub[4]; - int64_t output_ne_ub[4]; - - copy_to_ub(input_ne_gm, input_ne_ub, 32); - copy_to_ub(input_nb_gm, input_nb_ub, 32); - copy_to_ub(output_ne_gm, output_ne_ub, 32); - - QUANTIZE_FLOAT_TO_Q4_0 op; - op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); - op.calculate(); -} - -#endif // #ifdef ASCEND_310P diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 6c02b69ea239a..086c822d73a89 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -158,6 +158,12 @@ typedef sycl::half2 ggml_half2; #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP +#ifdef _MSC_VER +#define GGML_EXTENSION +#else // _MSC_VER +#define GGML_EXTENSION __extension__ +#endif // _MSC_VER + #define QK4_0 32 typedef struct { ggml_half d; // delta @@ -167,7 +173,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 b #define QK4_1 32 typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // delta ggml_half m; // min @@ -188,7 +194,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 #define QK5_1 32 typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // delta ggml_half m; // min @@ -209,7 +215,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block #define QK8_1 32 typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // delta ggml_half s; // d * sum(qs[i]) @@ -250,7 +256,7 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants - union { + GGML_EXTENSION union { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins @@ -277,7 +283,7 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12 // weight is represented as x = a * q + b // Effectively 4.5 bits per weight typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins @@ -294,7 +300,7 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, // weight is represented as x = a * q + b // Effectively 5.5 bits per weight typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index cb71e9b396089..c8cc32fa5afcc 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -23,6 +23,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/amx/mmq.cpp ggml-cpu/amx/mmq.h ggml-cpu/ggml-cpu-impl.h + ggml-cpu/common.h + ggml-cpu/binary-ops.h + ggml-cpu/binary-ops.cpp + ggml-cpu/unary-ops.h + ggml-cpu/unary-ops.cpp ) target_compile_features(${GGML_CPU_NAME} PRIVATE c_std_11 cxx_std_17) @@ -289,23 +294,29 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ") message(STATUS "PowerPC detected") - if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") - file(READ "/proc/cpuinfo" POWER10_M) - elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc") - execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M) - endif() + if (GGML_NATIVE) + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") + file(READ "/proc/cpuinfo" POWER10_M) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc") + execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M) + endif() - string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}") - string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") + string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}") + string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") - if (EXTRACTED_NUMBER GREATER_EQUAL 10) - list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64) - elseif (EXTRACTED_NUMBER EQUAL 9) - list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64) - elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") - list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) + if (EXTRACTED_NUMBER GREATER_EQUAL 10) + list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64) + elseif (EXTRACTED_NUMBER EQUAL 9) + list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") + list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) + else() + list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64) + endif() else() - list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64) + if (GGML_CPU_POWERPC_CPUTYPE) + list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE}) + endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") message(STATUS "loongarch64 detected") @@ -320,7 +331,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name) elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") message(STATUS "RISC-V detected") if (GGML_RVV) - list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + if (GGML_RV_ZFH) + list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d) + else() + list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") message(STATUS "s390x detected") diff --git a/ggml/src/ggml-cpu/binary-ops.cpp b/ggml/src/ggml-cpu/binary-ops.cpp new file mode 100644 index 0000000000000..14f5b43ae0eb1 --- /dev/null +++ b/ggml/src/ggml-cpu/binary-ops.cpp @@ -0,0 +1,158 @@ +#include "binary-ops.h" + +#if defined(GGML_USE_ACCELERATE) +#include + +using vDSP_fn_t = void (*)(const float *, vDSP_Stride, const float *, vDSP_Stride, float *, vDSP_Stride, vDSP_Length); +#endif + +static inline float op_add(float a, float b) { + return a + b; +} + +static inline float op_sub(float a, float b) { + return a - b; +} + +static inline float op_mul(float a, float b) { + return a * b; +} + +static inline float op_div(float a, float b) { + return a / b; +} + +template +static inline void vec_binary_op_contiguous(const int64_t n, dst_t * z, const src0_t * x, const src1_t * y) { + constexpr auto src0_to_f32 = type_conversion_table::to_f32; + constexpr auto src1_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + for (int i = 0; i < n; i++) { + z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(y[i]))); + } +} + +template +static inline void vec_binary_op_non_contiguous(const int64_t n, const int64_t ne10, const int64_t nb10, dst_t * z, const src0_t * x, const src1_t * y) { + constexpr auto src0_to_f32 = type_conversion_table::to_f32; + constexpr auto src1_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + for (int i = 0; i < n; i++) { + int i10 = i % ne10; + const src1_t * y_ptr = (const src1_t *)((const char *)y + i10*nb10); + z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(*y_ptr))); + } +} + +template +static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(dst_t)); + GGML_ASSERT(nb00 == sizeof(src0_t)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + const bool is_src1_contiguous = (nb10 == sizeof(src1_t)); + + if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + } + +#ifdef GGML_USE_ACCELERATE + vDSP_fn_t vDSP_op = nullptr; + // TODO - avoid the f32-only check using type 'trait' lookup tables and row-based src-to-float conversion functions + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (op == op_add) { + vDSP_op = vDSP_vadd; + } else if (op == op_sub) { + vDSP_op = vDSP_vsub; + } else if (op == op_mul) { + vDSP_op = vDSP_vmul; + } else if (op == op_div) { + vDSP_op = vDSP_vdiv; + } + } +#endif + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + if (is_src1_contiguous) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t nr0 = ne00 / ne10; + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef GGML_USE_ACCELERATE + if constexpr (std::is_same_v && std::is_same_v && std::is_same_v) { + if (vDSP_op != nullptr) { + vDSP_op(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); + continue; + } + } +#endif + vec_binary_op_contiguous(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + } + } else { + vec_binary_op_non_contiguous(ne0, ne10, nb10, dst_ptr, src0_ptr, src1_ptr); + } + } +} + +// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates +template +static void binary_op(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + /* */ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_binary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_binary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_binary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) { + apply_binary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + apply_binary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + apply_binary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + apply_binary_op(params, dst); + } else { + GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + } +} + +void ggml_compute_forward_add_non_quantized(const ggml_compute_params * params, ggml_tensor * dst) { + binary_op(params, dst); +} + +void ggml_compute_forward_sub(const ggml_compute_params * params, ggml_tensor * dst) { + binary_op(params, dst); +} + +void ggml_compute_forward_mul(const ggml_compute_params * params, ggml_tensor * dst) { + binary_op(params, dst); +} + +void ggml_compute_forward_div(const ggml_compute_params * params, ggml_tensor * dst) { + binary_op(params, dst); +} diff --git a/ggml/src/ggml-cpu/binary-ops.h b/ggml/src/ggml-cpu/binary-ops.h new file mode 100644 index 0000000000000..aca1d89be7e53 --- /dev/null +++ b/ggml/src/ggml-cpu/binary-ops.h @@ -0,0 +1,16 @@ +#pragma once + +#include "common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ggml_compute_forward_add_non_quantized(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_sub(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_mul(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_div(const struct ggml_compute_params * params, struct ggml_tensor * dst); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h new file mode 100644 index 0000000000000..3df01c1edffeb --- /dev/null +++ b/ggml/src/ggml-cpu/common.h @@ -0,0 +1,72 @@ +#pragma once + +#include "ggml.h" +#include "ggml-cpu-traits.h" +#include "ggml-cpu-impl.h" +#include "ggml-impl.h" + +#ifdef __cplusplus + +#include + +// convenience functions/macros for use in template calls +// note: these won't be required after the 'traits' lookup table is used. +static inline ggml_fp16_t f32_to_f16(float x) { + return GGML_FP32_TO_FP16(x); +} + +static inline float f16_to_f32(ggml_fp16_t x) { + return GGML_FP16_TO_FP32(x); +} + +static inline ggml_bf16_t f32_to_bf16(float x) { + return GGML_FP32_TO_BF16(x); +} + +static inline float bf16_to_f32(ggml_bf16_t x) { + return GGML_BF16_TO_FP32(x); +} + +static inline float f32_to_f32(float x) { + return x; +} + +// TODO - merge this into the traits table, after using row-based conversions +template +struct type_conversion_table; + +template <> +struct type_conversion_table { + static constexpr float (*to_f32)(ggml_fp16_t) = f16_to_f32; + static constexpr ggml_fp16_t (*from_f32)(float) = f32_to_f16; +}; + +template <> +struct type_conversion_table { + static constexpr float (*to_f32)(float) = f32_to_f32; + static constexpr float (*from_f32)(float) = f32_to_f32; +}; + +template <> +struct type_conversion_table { + static constexpr float (*to_f32)(ggml_bf16_t) = bf16_to_f32; + static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16; +}; + +static std::pair get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) { + const int64_t ith = params->ith; + const int64_t nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + return {ir0, ir1}; +} + +#endif diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index a079ef0249fb8..faacc1e4cbdc8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -250,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -344,7 +344,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC #endif } -static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -559,7 +559,7 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC #endif } -static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); const int nb = k / QK_K; @@ -811,7 +811,7 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on for (int j = 0; j < QK_K * 4; j++) { int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; src_offset += (j % blck_size_interleave); int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3); @@ -823,26 +823,25 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC #endif } -static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { +template +void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row); + +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - if (blck_size_interleave == 4) { - quantize_q8_0_4x4(x, vy, n_per_row); - } else if (blck_size_interleave == 8) { - quantize_q8_0_4x8(x, vy, n_per_row); - } else { - assert(false); - } + ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } -static void quantize_mat_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - if (blck_size_interleave == 8) { - quantize_q8_K_4x8(x, vy, n_per_row); - } else { - assert(false); - } + ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); +} + +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -5276,52 +5275,50 @@ template <> int repack(struct ggml_tensor * t, const void * //} // gemv -template +template void gemv(int, float *, size_t, const void *, const void *, int, int); -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> -void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } // gemm -template +template void gemm(int, float *, size_t, const void *, const void *, int, int); -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> -void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -5335,32 +5332,32 @@ template op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - return true; - case GGML_OP_MUL_MAT_ID: - size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. - size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; - return true; - default: - // GGML_ABORT("fatal error"); - break; + case GGML_OP_MUL_MAT: + size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); + return true; + case GGML_OP_MUL_MAT_ID: + size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); + size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; + return true; + default: + // GGML_ABORT("fatal error"); + break; } return false; } bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { switch (op->op) { - case GGML_OP_MUL_MAT: - forward_mul_mat(params, op); - return true; - case GGML_OP_MUL_MAT_ID: - forward_mul_mat_id(params, op); - return true; - default: - // GGML_ABORT("fatal error"); - break; + case GGML_OP_MUL_MAT: + forward_mul_mat(params, op); + return true; + case GGML_OP_MUL_MAT_ID: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; } return false; } @@ -5399,17 +5396,10 @@ template from_float; int64_t i11_processed = 0; - if(PARAM_TYPE == GGML_TYPE_Q8_K) { - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - quantize_mat_q8_K((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10, - INTER_SIZE); - } - } else { - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10, - INTER_SIZE); - } + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + ggml_quantize_mat_t((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10); } + i11_processed = ne11 - ne11 % 4; for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); @@ -5422,22 +5412,24 @@ template = src0_end) { return; } // If there are more than three rows in src1, use gemm; otherwise, use gemv. if (ne11 > 3) { - gemm(ne00, (float *) ((char *) dst->data) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + gemm(ne00, + (float *) ((char *) dst->data) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); } for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { - gemv(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata + (src1_col_stride * iter), 1, - src0_end - src0_start); + gemv(ne00, + (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); } } @@ -5452,7 +5444,7 @@ template ith; const int nth = params->nth; - const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float; + const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(src0->type)); @@ -5474,7 +5466,7 @@ template ne[0]; // n_expert_used const int n_as = ne02; // n_expert - const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10); + const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; @@ -5486,12 +5478,13 @@ template wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) + n_as * ne12 * sizeof(mmid_row_mapping))); - auto wdata = (char *) params->wdata; - auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + auto * wdata = (char *) params->wdata; + auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); + auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] - // src1: float32 => block_q8_0 + // src1: float32 => param type for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) { from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11), @@ -5530,34 +5523,37 @@ template data + cur_a*nb02; + const auto * src0_cur = (const char *) src0->data + cur_a*nb02; //const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows int64_t src0_cur_start = (ith * ne01) / nth; int64_t src0_cur_end = ((ith + 1) * ne01) / nth; - src0_cur_start = - (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; - src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; - if (src0_cur_start >= src0_cur_end) return; + src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + if (src0_cur_start >= src0_cur_end) { + return; + } for (int ir1 = 0; ir1 < nr1; ir1++) { struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); - const int id = row_mapping.i1; // selected expert index - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row - auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); + const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); - gemv( - ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, - ne01, src0_cur + src0_cur_start * nb01, + gemv(ne00, + (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); } } @@ -5578,7 +5574,7 @@ static const tensor_traits q4_0_8x8_q8_0; static const tensor_traits q4_K_8x8_q8_K; // instance for IQ4 -static const tensor_traits iq4_nl_4x4_q8_0; +static const tensor_traits iq4_nl_4x4_q8_0; } // namespace ggml::cpu::aarch64 diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 4e0ae057244c9..91a81bdc3ccd0 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -891,15 +891,15 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i } #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e32m4(QK8_0); + size_t vl = QK8_0; for (int i = 0; i < nb; i++) { // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); + vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl); - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); const float d = amax / ((1 << 7) - 1); @@ -907,14 +907,14 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_FP32_TO_FP16(d); - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); + vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + __riscv_vse8_v_i8m2(y[i].qs , vs, vl); } #elif defined(__POWER9_VECTOR__) @@ -1229,15 +1229,15 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i } #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e32m4(QK8_1); + size_t vl = QK8_1; for (int i = 0; i < nb; i++) { // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); + vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl); - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); const float d = amax / ((1 << 7) - 1); @@ -1245,18 +1245,18 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_FP32_TO_FP16(d); - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); + vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + __riscv_vse8_v_i8m2(y[i].qs , vs, vl); // compute sum for y[i].s vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); - vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl); // set y[i].s int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); @@ -2391,33 +2391,31 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); + size_t vl = qk / 2; for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); // subtract offset - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); + vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); @@ -2783,29 +2781,27 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); + size_t vl = qk / 2; for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); @@ -3132,65 +3128,33 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // These temporary registers are for masking and shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); - - vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); - vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + size_t vl; + size_t vlenb = __riscv_vlenb(); for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); - - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); - vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + vl = qk / 2; + vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); + vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); + vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); + vint8m2_t v0c; + if (vlenb == 16) { + v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); + } else { + v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); + v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); + } + + vl = qk; + vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); + qh = __riscv_vmnand_mm_b4(qh, qh, vl); + vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); + vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); + vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); + int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); + + sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi; } #elif defined(__POWER9_VECTOR__) @@ -3503,60 +3467,30 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // temporary registers for shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + size_t vl; + size_t vlenb = __riscv_vlenb(); for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); - - // load qh - vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); - - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - - // ((qh >> (j + 12)) ) & 0x10; - vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + vl = qk / 2; + vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); + vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); + vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); + vint8m2_t v0c; + if (vlenb == 16) { + v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); + } else { + v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); + v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); + } + + vl = qk; + vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); + vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); + vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); + vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); + int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); } @@ -3970,17 +3904,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(accum); #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk); + size_t vl = qk; for (; ib < nb; ++ib) { // load elements - vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); - vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl); + vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl); - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); + vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl); vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); @@ -5174,84 +5108,182 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #elif defined __riscv_v_intrinsic + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; + uint8_t atmp[16]; - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - size_t vl = 16; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + size_t vl = 16; - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - vl = 32; + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + vl = 32; - uint8_t is=0; - int isum=0; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - for (int j = 0; j < QK_K/128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + uint8_t is = 0; + int isum = 0; - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); - q2+=32; q8+=128; is=8; + isum += __riscv_vmv_x_s_i32m1_i32(isum1); - } + q2 += 32; + q8 += 128; + is = 8; + } - sumf += dall * isum; + sumf += dall * isum; + } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "vle8.v v1, (%[sc])\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vle16.v v2, (%[bsums])\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v0, (%[q2])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v8, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v15, (%[scale])\n\t" + "vzext.vf4 v12, v15\n\t" + "vmul.vv v10, v10, v12\n\t" + "vredsum.vs v0, v10, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + sumf += dall * isum; + } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -6116,97 +6148,221 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint32_t aux[3]; uint32_t utmp[4]; + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; - size_t vl = 32; - uint8_t m = 1; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + size_t vl = 32; + uint8_t m = 1; - int sum_t = 0; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - for (int j = 0; j < QK_K; j += 128) { + int sum_t = 0; - vl = 32; + for (int j = 0; j < QK_K; j += 128) { - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + vl = 32; - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; - vl = 16; + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + vl = 16; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - q3 += 32; q8 += 128; scale += 8; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - } + q3 += 32; q8 += 128; scale += 8; - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); - sumf += d*sum_t; + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t"\ + "vle8.v v15, (%[scale])\n\t" + "vsext.vf4 v12, v15\n\t" + "vmul.vv v10, v10, v12\n\t" + "vredsum.vs v0, v10, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -6924,69 +7080,181 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - for (int i = 0; i < nb; ++i) { + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { - size_t vl = 8; + size_t vl = 8; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - vl = 32; + vl = 32; - int32_t sum_1 = 0; - int32_t sum_2 = 0; + int32_t sum_1 = 0; + int32_t sum_2 = 0; - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; - q4 += 32; q8 += 64; + q4 += 32; q8 += 64; - } + } + + sumf += d*(sum_1 + sum_2); + + } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int tmp, tmp2, sumi; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} + "vsetivli zero, 4, e32, m1\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "li %[t2], 8\n\t" + "addi %[t1], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v8, (%[utmp]), %[t2]\n\t" + "vsse32.v v1, (%[t1]), %[t2]\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vredsum.vs v0, v6, v0\n\t" + "vmv.x.s %[sumi], v0" + : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) + : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf -= dmin * sumi; - sumf += d*(sum_1 + sum_2); + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + sumi = 0; + const uint8_t * scale = scales; + + for (int j = 0; j < QK_K/128; ++j) { + int vl128 = 128, vl64 = 64, vl32 = 32; + __asm__ __volatile__( + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v8, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v0, (%[q4])\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vle8.v v2, (%[scale])\n\t" + "vmv.v.x v0, zero\n\t" + "vzext.vf4 v1, v2\n\t" + "vsetvli zero, %[vl32], e16, m4\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v6, v7, 1\n\t" + "vslideup.vi v4, v5, 1\n\t" + "vslideup.vi v4, v6, 2\n\t" + "vmul.vv v8, v4, v1\n\t" + "vredsum.vs v0, v8, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[sumi], %[sumi], %[tmp]" + : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi) + : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32) + , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + q4 += 64; q8 += 128; scale += 4; + } + + sumf += d * sumi; + } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -7722,9 +7990,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl); + vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl); + vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl); memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); @@ -7733,11 +8001,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi utmp[2] = uaux; utmp[0] &= kmask1; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl); + vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); vl = 32; @@ -7746,43 +8014,42 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint8_t m = 1; vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl); for (int j = 0; j < QK_K/64; ++j) { // load Q5 and Q8 - vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); - vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl); + vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl); + vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl); // compute mask for addition - vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl); + vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl)); + vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl); + vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl); m <<= 1; - vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl); + vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl)); + vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl); + vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl); m <<= 1; - vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); - vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl); + vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl); - vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); - vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl); + vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl); - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl); - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); q5 += 32; q8 += 64; } - vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); - sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + sums += aux32 * d; } @@ -8667,85 +8934,168 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #elif defined __riscv_v_intrinsic + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const int8_t * GGML_RESTRICT scale = x[i].scales; + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - size_t vl; + const int8_t * GGML_RESTRICT scale = x[i].scales; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + size_t vl; - int sum_t = 0; - int is = 0; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - for (int j = 0; j < QK_K/128; ++j) { + int sum_t = 0; + int is = 0; - vl = 32; + for (int j = 0; j < QK_K/128; ++j) { - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + vl = 32; - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - vl = 16; + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + vl = 16; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - q6 += 64; qh += 32; q8 += 128; is=8; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - } + q6 += 64; qh += 32; q8 += 128; is=8; + + } - sumf += d * sum_t; + sumf += d * sum_t; + + } + break; + case 128: + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int sum_t = 0; + int t0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v8, (%[q6])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v2, (%[scale])\n\t" + "vsext.vf4 v4, v2\n\t" + "vmul.vv v2, v4, v10\n\t" + "vredsum.vs v0, v2, v0\n\t" + "vmv.x.s %[t0], v0\n\t" + "add %[sumi], %[sumi], %[t0]" + : [sumi] "+&r" (sum_t), [t0] "=&r" (t0) + : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q6 += 64; qh += 32; q8 += 128; scale += 8; + } + sumf += d * sum_t; + + } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2dbe835586d4c..f3925e17a2b8e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -9,6 +9,8 @@ #include "ggml-impl.h" #include "ggml-cpu-quants.h" #include "ggml-threading.h" +#include "ggml-cpu/unary-ops.h" +#include "ggml-cpu/binary-ops.h" #include "ggml.h" #if defined(_MSC_VER) || defined(__MINGW32__) @@ -901,24 +903,24 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { #define GGML_F16x4_FMA GGML_F32x4_FMA #define GGML_F16x4_ADD wasm_f32x4_add #define GGML_F16x4_MUL wasm_f32x4_mul -#define GGML_F16x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F16_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = (ggml_float) (wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3)); \ } #define GGML_F16_VEC GGML_F16x4 @@ -4289,340 +4291,6 @@ static void ggml_compute_forward_dup( // ggml_compute_forward_add -static void ggml_compute_forward_add_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { -#ifdef GGML_USE_ACCELERATE - vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); -#else - ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; - } - } - } -} - -static void ggml_compute_forward_add_f16_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - if (dst->type == GGML_TYPE_F32) { - GGML_ASSERT( nb0 == sizeof(float)); - } - else { - GGML_ASSERT(dst->type == GGML_TYPE_F16); - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - } - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - if (dst->type == GGML_TYPE_F16) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); - } - } - } else { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; - } - } - } - } - else { - // src1 is not contiguous - GGML_ABORT("fatal error"); - } -} - -static void ggml_compute_forward_add_bf16_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_BF16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - if (dst->type == GGML_TYPE_F32) { - GGML_ASSERT( nb0 == sizeof(float)); - } - else { - GGML_ASSERT(dst->type == GGML_TYPE_BF16); - GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); - } - - GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - if (dst->type == GGML_TYPE_BF16) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); - } - } - } else { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; - } - } - } - } - else { - // src1 is not contiguous - GGML_ABORT("fatal error"); - } -} - -static void ggml_compute_forward_add_f16_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(ggml_fp16_t)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { - ggml_vec_add_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); - } - } - } - else { - // src1 is not contiguous - GGML_ABORT("fatal error"); - } -} - -static void ggml_compute_forward_add_bf16_bf16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_BF16); - GGML_ASSERT(src1->type == GGML_TYPE_BF16); - GGML_ASSERT(dst->type == GGML_TYPE_BF16); - - GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(ggml_bf16_t)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i])); - } - } - } - else { - // src1 is not contiguous - GGML_ABORT("fatal error"); - } -} - static void ggml_compute_forward_add_q_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -4704,41 +4372,13 @@ static void ggml_compute_forward_add( struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; switch (src0->type) { case GGML_TYPE_F32: - { - if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add_f32(params, dst); - } - else { - GGML_ABORT("fatal error"); - } - } break; case GGML_TYPE_F16: - { - if (src1->type == GGML_TYPE_F16) { - ggml_compute_forward_add_f16_f16(params, dst); - } - else if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add_f16_f32(params, dst); - } - else { - GGML_ABORT("fatal error"); - } - } break; case GGML_TYPE_BF16: { - if (src1->type == GGML_TYPE_BF16) { - ggml_compute_forward_add_bf16_bf16(params, dst); - } - else if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add_bf16_f32(params, dst); - } - else { - GGML_ABORT("fatal error"); - } + ggml_compute_forward_add_non_quantized(params, dst); } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -5272,140 +4912,107 @@ static void ggml_compute_forward_acc( } } -// ggml_compute_forward_sub +// ggml_compute_forward_sum -static void ggml_compute_forward_sub_f32( +static void ggml_compute_forward_sum_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + if (params->ith != 0) { + return; + } - const int ith = params->ith; - const int nth = params->nth; + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - // rows per thread - const int dr = (nr + nth - 1)/nth; + ggml_float sum = 0; + ggml_float row_sum = 0; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32_ggf(ne00, + &row_sum, + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + sum += row_sum; + } + } + } + ((float *) dst->data)[0] = sum; +} - if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); +static void ggml_compute_forward_sum_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; + const struct ggml_tensor * src0 = dst->src[0]; - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + if (params->ith != 0) { + return; + } - for (int64_t r = 0; r < nr0; ++r) { -#ifdef GGML_USE_ACCELERATE - vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); -#else - ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + assert(ggml_is_scalar(dst)); - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + assert(src0->nb[0] == sizeof(ggml_fp16_t)); - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + float sum = 0; + float row_sum = 0; - dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f16_ggf(ne00, + &row_sum, + (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; } } } + ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); } -static void ggml_compute_forward_sub_f16( +static void ggml_compute_forward_sum_bf16( const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - // rows per thread - const int dr = (nr + nth - 1)/nth; + if (params->ith != 0) { + return; + } - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + assert(ggml_is_scalar(dst)); - if (nb10 == sizeof(ggml_fp16_t)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + assert(src0->nb[0] == sizeof(ggml_bf16_t)); - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + float sum = 0; + float row_sum = 0; - for (int64_t r = 0; r < nr0; ++r) { - ggml_vec_sub_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_bf16_ggf(ne00, + &row_sum, + (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; } } - } else { - // src1 is not contiguous - GGML_ABORT("unimplemented error"); } + ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum); } -static void ggml_compute_forward_sub( +static void ggml_compute_forward_sum( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5414,11 +5021,15 @@ static void ggml_compute_forward_sub( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_sub_f32(params, dst); + ggml_compute_forward_sum_f32(params, dst); } break; case GGML_TYPE_F16: { - ggml_compute_forward_sub_f16(params, dst); + ggml_compute_forward_sum_f16(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_sum_bf16(params, dst); } break; default: { @@ -5427,145 +5038,51 @@ static void ggml_compute_forward_sub( } } -// ggml_compute_forward_mul +// ggml_compute_forward_sum_rows -static void ggml_compute_forward_mul_f32( +static void ggml_compute_forward_sum_rows_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0 ; r < nr0; ++r) { -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_mul_f32); - - vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); -#else - ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne00; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); - } - } + if (params->ith != 0) { + return; } -} - -static void ggml_compute_forward_mul_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - if (nb10 == sizeof(ggml_fp16_t)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; + GGML_TENSOR_UNARY_OP_LOCALS - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); - for (int64_t r = 0 ; r < nr0; ++r) { - ggml_vec_mul_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + float row_sum = 0; + ggml_vec_sum_f32(ne00, &row_sum, src_row); + dst_row[0] = row_sum; } } - } else { - // src1 is not contiguous - GGML_ABORT("unimplemented error"); } } -static void ggml_compute_forward_mul( +static void ggml_compute_forward_sum_rows( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT((src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) && "only f32/f16 src1 supported for now"); switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_mul_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_mul_f16(params, dst); + ggml_compute_forward_sum_rows_f32(params, dst); } break; default: { @@ -5574,129 +5091,46 @@ static void ggml_compute_forward_mul( } } -// ggml_compute_forward_div +// ggml_compute_forward_mean -static void ggml_compute_forward_div_f32( +static void ggml_compute_forward_mean_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_div_f32); - - vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); -#else - ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne00; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); - } - } + if (params->ith != 0) { + return; } -} - -static void ggml_compute_forward_div_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - const int64_t nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); + assert(src0->nb[0] == sizeof(float)); - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_TENSOR_UNARY_OP_LOCALS - if (nb10 == sizeof(ggml_fp16_t)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - for (int64_t r = 0; r < nr0; ++r) { - ggml_vec_div_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; } } - } else { - // src1 is not contiguous - GGML_ABORT("unimplemented error"); } } -static void ggml_compute_forward_div( +static void ggml_compute_forward_mean( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5705,11 +5139,7 @@ static void ggml_compute_forward_div( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_div_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_div_f16(params, dst); + ggml_compute_forward_mean_f32(params, dst); } break; default: { @@ -5718,9 +5148,9 @@ static void ggml_compute_forward_div( } } -// ggml_compute_forward_sqr +// ggml_compute_forward_argmax -static void ggml_compute_forward_sqr_f32( +static void ggml_compute_forward_argmax_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5730,47 +5160,25 @@ static void ggml_compute_forward_sqr_f32( return; } - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); assert(src0->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_vec_sqr_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sqr_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; - assert( dst->nb[0] == sizeof(ggml_fp16_t)); - assert(src0->nb[0] == sizeof(ggml_fp16_t)); + const size_t nb01 = src0->nb[1]; + const size_t nb0 = dst->nb[0]; - for (int i = 0; i < n; i++) { - ggml_vec_sqr_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src = (float *) ((char *) src0->data + i1*nb01); + int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); + int v = 0; + ggml_vec_argmax_f32(ne00, &v, src); + dst_[0] = v; } } -static void ggml_compute_forward_sqr( +static void ggml_compute_forward_argmax( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5779,11 +5187,7 @@ static void ggml_compute_forward_sqr( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_sqr_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_sqr_f16(params, dst); + ggml_compute_forward_argmax_f32(params, dst); } break; default: { @@ -5792,72 +5196,78 @@ static void ggml_compute_forward_sqr( } } -// ggml_compute_forward_sqrt +// ggml_compute_forward_count_equal -static void ggml_compute_forward_sqrt_f32( +static void ggml_compute_forward_count_equal_i32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } + GGML_TENSOR_BINARY_OP_LOCALS; - assert(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->type == GGML_TYPE_I32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(dst->type == GGML_TYPE_I64); - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + const int64_t nr = ggml_nrows(src0); - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); + const int ith = params->ith; + const int nth = params->nth; - for (int i = 0; i < n; i++) { - ggml_vec_sqrt_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} + int64_t * sums = (int64_t *) params->wdata; + int64_t sum_thread = 0; -static void ggml_compute_forward_sqrt_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; - const struct ggml_tensor * src0 = dst->src[0]; + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); - if (params->ith != 0) { - return; - } + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02*ne01); + const int64_t i02 = (ir - i03*ne03) / ne01; + const int64_t i01 = ir - i03*ne03 - i02*ne02; - assert(ggml_are_same_shape(src0, dst)); + const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01; + const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11; - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + for (int64_t i00 = 0; i00 < ne00; ++i00) { + const int32_t val0 = *((const int32_t *) (data0 + i00*nb00)); + const int32_t val1 = *((const int32_t *) (data1 + i00*nb10)); - assert( dst->nb[0] == sizeof(ggml_fp16_t)); - assert(src0->nb[0] == sizeof(ggml_fp16_t)); + sum_thread += val0 == val1; + } + } + if (ith != 0) { + sums[ith] = sum_thread; + } + ggml_barrier(params->threadpool); - for (int i = 0; i < n; i++) { - ggml_vec_sqrt_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + if (ith != 0) { + return; + } + + for (int ith_other = 1; ith_other < nth; ++ith_other) { + sum_thread += sums[ith_other]; } + *((int64_t *) dst->data) = sum_thread; } -static void ggml_compute_forward_sqrt( +static void ggml_compute_forward_count_equal( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sqrt_f32(params, dst); - } break; - case GGML_TYPE_F16: + case GGML_TYPE_I32: { - ggml_compute_forward_sqrt_f16(params, dst); + ggml_compute_forward_count_equal_i32(params, dst); } break; default: { @@ -5866,9 +5276,9 @@ static void ggml_compute_forward_sqrt( } } -// ggml_compute_forward_log +// ggml_compute_forward_repeat -static void ggml_compute_forward_log_f32( +static void ggml_compute_forward_repeat_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5878,24 +5288,43 @@ static void ggml_compute_forward_log_f32( return; } - GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat(src0, dst)); - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + GGML_TENSOR_UNARY_OP_LOCALS - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); - for (int i = 0; i < n; i++) { - ggml_vec_log_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_cpy_f32(ne00, + (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), + (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); + } + } + } + } + } + } } } -static void ggml_compute_forward_log_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { +static void ggml_compute_forward_repeat_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; @@ -5903,35 +5332,60 @@ static void ggml_compute_forward_log_f16( return; } - GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat(src0, dst)); - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + GGML_TENSOR_UNARY_OP_LOCALS - GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t)); - GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); - for (int i = 0; i < n; i++) { - ggml_vec_log_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); + ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); + // ggml_vec_cpy_f16(ne00, y, x) + for (int i = 0; i < ne00; ++i) { + y[i] = x[i]; + } + } + } + } + } + } + } } } -static void ggml_compute_forward_log( +static void ggml_compute_forward_repeat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; switch (src0->type) { - case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_I16: { - ggml_compute_forward_log_f32(params, dst); + ggml_compute_forward_repeat_f16(params, dst); } break; - case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_I32: { - ggml_compute_forward_log_f16(params, dst); + ggml_compute_forward_repeat_f32(params, dst); } break; default: { @@ -5940,9 +5394,9 @@ static void ggml_compute_forward_log( } } -// ggml_compute_forward_sin +// ggml_compute_forward_repeat_back -static void ggml_compute_forward_sin_f32( +static void ggml_compute_forward_repeat_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5952,1172 +5406,55 @@ static void ggml_compute_forward_sin_f32( return; } - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(ggml_can_repeat(dst, src0)); - for (int i = 0; i < n; i++) { - ggml_vec_sin_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} + GGML_TENSOR_UNARY_OP_LOCALS -static void ggml_compute_forward_sin_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne00/ne0); + const int nr1 = (int)(ne01/ne1); + const int nr2 = (int)(ne02/ne2); + const int nr3 = (int)(ne03/ne3); - const struct ggml_tensor * src0 = dst->src[0]; + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); - if (params->ith != 0) { - return; + if (ggml_is_contiguous(dst)) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } else { + for (int k3 = 0; k3 < ne3; k3++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int k1 = 0; k1 < ne1; k1++) { + ggml_vec_set_f32(ne0, + (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), + 0); + } + } + } } - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t)); - GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); - - for (int i = 0; i < n; i++) { - ggml_vec_sin_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sin( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sin_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_sin_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_cos - -static void ggml_compute_forward_cos_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_cos_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_cos_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t)); - GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); - - for (int i = 0; i < n; i++) { - ggml_vec_cos_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_cos( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_cos_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_cos_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sum - -static void ggml_compute_forward_sum_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_scalar(dst)); - assert(src0->nb[0] == sizeof(float)); - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - ggml_float sum = 0; - ggml_float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32_ggf(ne00, - &row_sum, - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - sum += row_sum; - } - } - } - ((float *) dst->data)[0] = sum; -} - -static void ggml_compute_forward_sum_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_scalar(dst)); - - assert(src0->nb[0] == sizeof(ggml_fp16_t)); - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - float sum = 0; - float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f16_ggf(ne00, - &row_sum, - (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); - sum += row_sum; - } - } - } - ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); -} - -static void ggml_compute_forward_sum_bf16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_scalar(dst)); - - assert(src0->nb[0] == sizeof(ggml_bf16_t)); - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - float sum = 0; - float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_bf16_ggf(ne00, - &row_sum, - (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); - sum += row_sum; - } - } - } - ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum); -} - -static void ggml_compute_forward_sum( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sum_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_sum_f16(params, dst); - } break; - case GGML_TYPE_BF16: - { - ggml_compute_forward_sum_bf16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sum_rows - -static void ggml_compute_forward_sum_rows_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(dst->nb[0] == sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(ne0 == 1); - GGML_ASSERT(ne1 == ne01); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - - for (int64_t i3 = 0; i3 < ne03; i3++) { - for (int64_t i2 = 0; i2 < ne02; i2++) { - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); - float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); - float row_sum = 0; - ggml_vec_sum_f32(ne00, &row_sum, src_row); - dst_row[0] = row_sum; - } - } - } -} - -static void ggml_compute_forward_sum_rows( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sum_rows_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_mean - -static void ggml_compute_forward_mean_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - assert(ne0 == 1); - assert(ne1 == ne01); - assert(ne2 == ne02); - assert(ne3 == ne03); - - UNUSED(ne0); - UNUSED(ne1); - UNUSED(ne2); - UNUSED(ne3); - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32(ne00, - (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - - *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; - } - } - } -} - -static void ggml_compute_forward_mean( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_mean_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_argmax - -static void ggml_compute_forward_argmax_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - assert(dst->nb[0] == sizeof(float)); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - - const size_t nb01 = src0->nb[1]; - const size_t nb0 = dst->nb[0]; - - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src = (float *) ((char *) src0->data + i1*nb01); - int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); - int v = 0; - ggml_vec_argmax_f32(ne00, &v, src); - dst_[0] = v; - } -} - -static void ggml_compute_forward_argmax( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_argmax_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_count_equal - -static void ggml_compute_forward_count_equal_i32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS; - - GGML_ASSERT(src0->type == GGML_TYPE_I32); - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - GGML_ASSERT(ggml_is_scalar(dst)); - GGML_ASSERT(dst->type == GGML_TYPE_I64); - - const int64_t nr = ggml_nrows(src0); - - const int ith = params->ith; - const int nth = params->nth; - - int64_t * sums = (int64_t *) params->wdata; - int64_t sum_thread = 0; - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - for (int64_t ir = ir0; ir < ir1; ++ir) { - const int64_t i03 = ir / (ne02*ne01); - const int64_t i02 = (ir - i03*ne03) / ne01; - const int64_t i01 = ir - i03*ne03 - i02*ne02; - - const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01; - const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11; - - for (int64_t i00 = 0; i00 < ne00; ++i00) { - const int32_t val0 = *((const int32_t *) (data0 + i00*nb00)); - const int32_t val1 = *((const int32_t *) (data1 + i00*nb10)); - - sum_thread += val0 == val1; - } - } - if (ith != 0) { - sums[ith] = sum_thread; - } - ggml_barrier(params->threadpool); - - if (ith != 0) { - return; - } - - for (int ith_other = 1; ith_other < nth; ++ith_other) { - sum_thread += sums[ith_other]; - } - *((int64_t *) dst->data) = sum_thread; -} - -static void ggml_compute_forward_count_equal( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_I32: - { - ggml_compute_forward_count_equal_i32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_repeat - -static void ggml_compute_forward_repeat_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_can_repeat(src0, dst)); - - GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_vec_cpy_f32(ne00, - (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), - (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_can_repeat(src0, dst)); - - GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); - ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); - // ggml_vec_cpy_f16(ne00, y, x) - for (int i = 0; i < ne00; ++i) { - y[i] = x[i]; - } - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_I16: - { - ggml_compute_forward_repeat_f16(params, dst); - } break; - case GGML_TYPE_F32: - case GGML_TYPE_I32: - { - ggml_compute_forward_repeat_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_repeat_back - -static void ggml_compute_forward_repeat_back_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_can_repeat(dst, src0)); - - GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne00/ne0); - const int nr1 = (int)(ne01/ne1); - const int nr2 = (int)(ne02/ne2); - const int nr3 = (int)(ne03/ne3); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (ggml_is_contiguous(dst)) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - } else { - for (int k3 = 0; k3 < ne3; k3++) { - for (int k2 = 0; k2 < ne2; k2++) { - for (int k1 = 0; k1 < ne1; k1++) { - ggml_vec_set_f32(ne0, - (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), - 0); - } - } - } - } - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne3; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne2; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne1; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_vec_acc_f32(ne0, - (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), - (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_repeat_back_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_concat - -static void ggml_compute_forward_concat_any( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - const size_t len = ggml_type_size(src0->type); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int32_t dim = ggml_get_op_params_i32(dst, 0); - - GGML_ASSERT(dim >= 0 && dim < 4); - - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = src0->ne[dim]; - - const char * x; - - // TODO: smarter multi-theading - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ith; i2 < ne2; i2 += nth) { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03; - } else { - x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13; - } - - char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3; - - memcpy(y, x, len); - } - } - } - } -} - -static void ggml_compute_forward_concat_i8( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int32_t dim = ggml_get_op_params_i32(dst, 0); - - GGML_ASSERT(dim >= 0 && dim < 4); - - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = src0->ne[dim]; - - const int8_t * x; - - // TODO: smarter multi-theading - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ith; i2 < ne2; i2 += nth) { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); - } else { - x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); - } - - int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - - *y = *x; - } - } - } - } -} - -static void ggml_compute_forward_concat_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int32_t dim = ggml_get_op_params_i32(dst, 0); - - GGML_ASSERT(dim >= 0 && dim < 4); - - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = src0->ne[dim]; - - const ggml_fp16_t * x; - - // TODO: smarter multi-theading - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ith; i2 < ne2; i2 += nth) { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); - } else { - x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); - } - - ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - - *y = *x; - } - } - } - } -} - -static void ggml_compute_forward_concat_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int32_t dim = ggml_get_op_params_i32(dst, 0); - - GGML_ASSERT(dim >= 0 && dim < 4); - - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = src0->ne[dim]; - - const float * x; - - // TODO: smarter multi-theading - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ith; i2 < ne2; i2 += nth) { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); - } else { - x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); - } - - float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - - *y = *x; - } - } - } - } -} - -static void ggml_compute_forward_concat( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_I16: - { - ggml_compute_forward_concat_f16(params, dst); - } break; - case GGML_TYPE_I8: - { - ggml_compute_forward_concat_i8(params, dst); - } break; - case GGML_TYPE_F32: - case GGML_TYPE_I32: - { - ggml_compute_forward_concat_f32(params, dst); - } break; - default: - { - ggml_compute_forward_concat_any(params, dst); - } - } -} - -// ggml_compute_forward_abs - -static void ggml_compute_forward_abs_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_abs_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_abs_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_abs_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_abs( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_abs_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_abs_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sgn - -static void ggml_compute_forward_sgn_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_sgn_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sgn_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_sgn_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sgn( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sgn_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_sgn_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_neg - -static void ggml_compute_forward_neg_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_neg_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_neg_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_neg_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_neg( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_neg_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_neg_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_step - -static void ggml_compute_forward_step_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_step_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_step_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_step_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne3; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne1; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_acc_f32(ne0, + (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), + (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); + } + } + } + } + } + } } } -static void ggml_compute_forward_step( +static void ggml_compute_forward_repeat_back( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7126,11 +5463,7 @@ static void ggml_compute_forward_step( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_step_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_step_f16(params, dst); + ggml_compute_forward_repeat_back_f32(params, dst); } break; default: { @@ -7139,290 +5472,205 @@ static void ggml_compute_forward_step( } } -// ggml_compute_forward_tanh - -static void ggml_compute_forward_tanh_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_tanh_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} +// ggml_compute_forward_concat -static void ggml_compute_forward_tanh_f16( +static void ggml_compute_forward_concat_any( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_tanh_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_tanh( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + const size_t len = ggml_type_size(src0->type); - const struct ggml_tensor * src0 = dst->src[0]; + const int ith = params->ith; + const int nth = params->nth; - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_tanh_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_tanh_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} + GGML_TENSOR_BINARY_OP_LOCALS -// ggml_compute_forward_elu + const int32_t dim = ggml_get_op_params_i32(dst, 0); -static void ggml_compute_forward_elu_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + GGML_ASSERT(dim >= 0 && dim < 4); - const struct ggml_tensor * src0 = dst->src[0]; + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; - if (params->ith != 0) { - return; - } + const char * x; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03; + } else { + x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13; + } - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3; - for (int i = 0; i < n; i++) { - ggml_vec_elu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); + memcpy(y, x, len); + } + } + } } } -static void ggml_compute_forward_elu_f16( +static void ggml_compute_forward_concat_i8( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_elu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_elu( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t)); - const struct ggml_tensor * src0 = dst->src[0]; + const int ith = params->ith; + const int nth = params->nth; - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_elu_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_elu_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} + GGML_TENSOR_BINARY_OP_LOCALS -// ggml_compute_forward_relu + const int32_t dim = ggml_get_op_params_i32(dst, 0); -static void ggml_compute_forward_relu_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + GGML_ASSERT(dim >= 0 && dim < 4); - const struct ggml_tensor * src0 = dst->src[0]; + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; - if (params->ith != 0) { - return; - } + const int8_t * x; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - for (int i = 0; i < n; i++) { - ggml_vec_relu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); + *y = *x; + } + } + } } } -static void ggml_compute_forward_relu_f16( +static void ggml_compute_forward_concat_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_relu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_relu( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t)); - const struct ggml_tensor * src0 = dst->src[0]; + const int ith = params->ith; + const int nth = params->nth; - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_relu_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_relu_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} + GGML_TENSOR_BINARY_OP_LOCALS -// ggml_compute_forward_sigmoid + const int32_t dim = ggml_get_op_params_i32(dst, 0); -static void ggml_compute_forward_sigmoid_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + GGML_ASSERT(dim >= 0 && dim < 4); - const struct ggml_tensor * src0 = dst->src[0]; + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; - if (params->ith != 0) { - return; - } + const ggml_fp16_t * x; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - for (int i = 0; i < n; i++) { - ggml_vec_sigmoid_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); + *y = *x; + } + } + } } } -static void ggml_compute_forward_sigmoid_f16( +static void ggml_compute_forward_concat_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); + const int ith = params->ith; + const int nth = params->nth; - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + GGML_TENSOR_BINARY_OP_LOCALS - for (int i = 0; i < n; i++) { - ggml_vec_sigmoid_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const float * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } + + float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } } } -static void ggml_compute_forward_sigmoid( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { +static void ggml_compute_forward_concat( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; switch (src0->type) { - case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_I16: { - ggml_compute_forward_sigmoid_f32(params, dst); + ggml_compute_forward_concat_f16(params, dst); } break; - case GGML_TYPE_F16: + case GGML_TYPE_I8: { - ggml_compute_forward_sigmoid_f16(params, dst); + ggml_compute_forward_concat_i8(params, dst); + } break; + case GGML_TYPE_F32: + case GGML_TYPE_I32: + { + ggml_compute_forward_concat_f32(params, dst); } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_concat_any(params, dst); } } } @@ -7930,217 +6178,6 @@ static void ggml_compute_forward_silu_back( } } -static void ggml_compute_forward_hardswish_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_hardswish_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_hardswish_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_hardswish_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_hardswish( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_hardswish_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_hardswish_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -static void ggml_compute_forward_hardsigmoid_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_hardsigmoid_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_hardsigmoid_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_hardsigmoid_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_hardsigmoid( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_hardsigmoid_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_hardsigmoid_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -static void ggml_compute_forward_exp_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_exp_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_exp_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_exp_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_exp( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_exp_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_exp_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - - // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -12238,10 +10275,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ith = params->ith; const int nth = params->nth; - const int64_t D = neq0; - const int64_t N = neq1; + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == DV); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -12249,12 +10287,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -12320,15 +10357,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16 if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + memset(VKQ16, 0, DV*sizeof(ggml_fp16_t)); } else { - memset(VKQ32, 0, D*sizeof(float)); + memset(VKQ32, 0, DV*sizeof(float)); } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -12342,7 +10379,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iv2 = iq2 / rv2; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); + q_to_vec_dot(pq, Q_q, DK); // online softmax / attention // loop over n_kv and n_head_kv @@ -12356,7 +10393,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( float s; // KQ value const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); s = s*scale; // scale KQ value @@ -12380,14 +10417,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); + ggml_vec_scale_f16(DV, VKQ16, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs); } else { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f @@ -12395,30 +10432,30 @@ static void ggml_compute_forward_flash_attn_ext_f16( ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); + ggml_vec_scale_f32(DV, VKQ32, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } - v_to_float(v_data, V32, D); + v_to_float(v_data, V32, DV); // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); + ggml_vec_mad_f32(DV, VKQ32, V32, vs); } S = S*ms + vs; // scale and increment sum with partial sum } if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < DV; ++d) { VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); } } // V /= S const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); + ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -15277,7 +13314,6 @@ struct ggml_cplan ggml_graph_plan( size_t cur = 0; if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) { - switch (node->op) { case GGML_OP_CPY: case GGML_OP_DUP: @@ -15386,9 +13422,10 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne00 = node->src[0]->ne[0]; // D + const int64_t ne10 = node->src[1]->ne[0]; // DK + const int64_t ne20 = node->src[2]->ne[0]; // DV - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread) } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index e0482c59377fd..f6374f7894a08 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -55,6 +55,7 @@ #include #include +#include #ifdef _MSC_VER #define NOINLINE __declspec(noinline) @@ -1092,13 +1093,403 @@ class tinyBLAS_Q0_PPC { } } - template - void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + template + void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array& comparray) { int64_t i, j; TA *aoffset = NULL; VA *vecOffset = NULL; TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + VB t1, t2, t3, t4, t5, t6, t7, t8; + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + aoffset = const_cast(a); + vecOffset = vec; + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + + j = (rows >> 3); + if (j > 0) { + do { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; + + i = (cols >> 2); + if (i > 0) { + do { + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); + c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); + c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); + c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); + + c1[0] = vec_and(c1[1], lowMask); + c1[1] = vec_sr(c1[1], v4); + c1[0] = vec_sub(c1[0], v8); + c1[1] = vec_sub(c1[1], v8); + vsum = vec_sum4s(c1[0], vsum); + vsum2 = vec_sum4s(c1[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c2[0] = vec_and(c2[1], lowMask); + c2[1] = vec_sr(c2[1], v4); + c2[0] = vec_sub(c2[0], v8); + c2[1] = vec_sub(c2[1], v8); + vsum = vec_sum4s(c2[0], vsum); + vsum2 = vec_sum4s(c2[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c3[0] = vec_and(c3[1], lowMask); + c3[1] = vec_sr(c3[1], v4); + c3[0] = vec_sub(c3[0], v8); + c3[1] = vec_sub(c3[1], v8); + vsum = vec_sum4s(c3[0], vsum); + vsum2 = vec_sum4s(c3[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c4[0] = vec_and(c4[1], lowMask); + c4[1] = vec_sr(c4[1], v4); + c4[0] = vec_sub(c4[0], v8); + c4[1] = vec_sub(c4[1], v8); + vsum = vec_sum4s(c4[0], vsum); + vsum2 = vec_sum4s(c4[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c5[0] = vec_and(c5[1], lowMask); + c5[1] = vec_sr(c5[1], v4); + c5[0] = vec_sub(c5[0], v8); + c5[1] = vec_sub(c5[1], v8); + vsum = vec_sum4s(c5[0], vsum); + vsum2 = vec_sum4s(c5[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c6[0] = vec_and(c6[1], lowMask); + c6[1] = vec_sr(c6[1], v4); + c6[0] = vec_sub(c6[0], v8); + c6[1] = vec_sub(c6[1], v8); + vsum = vec_sum4s(c6[0], vsum); + vsum2 = vec_sum4s(c6[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c7[0] = vec_and(c7[1], lowMask); + c7[1] = vec_sr(c7[1], v4); + c7[0] = vec_sub(c7[0], v8); + c7[1] = vec_sub(c7[1], v8); + vsum = vec_sum4s(c7[0], vsum); + vsum2 = vec_sum4s(c7[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c8[0] = vec_and(c8[1], lowMask); + c8[1] = vec_sr(c8[1], v4); + c8[0] = vec_sub(c8[0], v8); + c8[1] = vec_sub(c8[1], v8); + vsum = vec_sum4s(c8[0], vsum); + vsum2 = vec_sum4s(c8[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + t1 = vec_perm(c5[0], c6[0], swiz1); + t2 = vec_perm(c5[0], c6[0], swiz2); + t3 = vec_perm(c7[0], c8[0], swiz1); + t4 = vec_perm(c7[0], c8[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+128); + vec_xst(t6, 0, vecOffset+144); + vec_xst(t7, 0, vecOffset+160); + vec_xst(t8, 0, vecOffset+176); + + t1 = vec_perm(c5[1], c6[1], swiz1); + t2 = vec_perm(c5[1], c6[1], swiz2); + t3 = vec_perm(c7[1], c8[1], swiz1); + t4 = vec_perm(c7[1], c8[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+192); + vec_xst(t6, 0, vecOffset+208); + vec_xst(t7, 0, vecOffset+224); + vec_xst(t8, 0, vecOffset+240); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + aoffset5 += lda; + aoffset6 += lda; + aoffset7 += lda; + aoffset8 += lda; + vecOffset += 256; + i--; + } while (i > 0); + } + j--; + } while (j > 0); + } + + if (rows & 4) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + + i = (cols >> 2); + if (i > 0) { + do { + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + + c1[0] = vec_and(c1[1], lowMask); + c1[1] = vec_sr(c1[1], v4); + c1[0] = vec_sub(c1[0], v8); + c1[1] = vec_sub(c1[1], v8); + vsum = vec_sum4s(c1[0], vsum); + vsum2 = vec_sum4s(c1[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c2[0] = vec_and(c2[1], lowMask); + c2[1] = vec_sr(c2[1], v4); + c2[0] = vec_sub(c2[0], v8); + c2[1] = vec_sub(c2[1], v8); + vsum = vec_sum4s(c2[0], vsum); + vsum2 = vec_sum4s(c2[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c3[0] = vec_and(c3[1], lowMask); + c3[1] = vec_sr(c3[1], v4); + c3[0] = vec_sub(c3[0], v8); + c3[1] = vec_sub(c3[1], v8); + vsum = vec_sum4s(c3[0], vsum); + vsum2 = vec_sum4s(c3[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c4[0] = vec_and(c4[1], lowMask); + c4[1] = vec_sr(c4[1], v4); + c4[0] = vec_sub(c4[0], v8); + c4[1] = vec_sub(c4[1], v8); + vsum = vec_sum4s(c4[0], vsum); + vsum2 = vec_sum4s(c4[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats( 0); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + vecOffset += 128; + i--; + } while (i > 0); + } + } + + if (rows & 3) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + i = (cols >> 2); + if (i > 0) { + do { + switch(rows) { + case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + break; + } + c1[0] = vec_and(c1[1], lowMask); + c1[1] = vec_sr(c1[1], v4); + c1[0] = vec_sub(c1[0], v8); + c1[1] = vec_sub(c1[1], v8); + vsum = vec_sum4s(c1[0], vsum); + vsum2 = vec_sum4s(c1[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c2[0] = vec_and(c2[1], lowMask); + c2[1] = vec_sr(c2[1], v4); + c2[0] = vec_sub(c2[0], v8); + c2[1] = vec_sub(c2[1], v8); + vsum = vec_sum4s(c2[0], vsum); + vsum2 = vec_sum4s(c2[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c3[0] = vec_and(c3[1], lowMask); + c3[1] = vec_sr(c3[1], v4); + c3[0] = vec_sub(c3[0], v8); + c3[1] = vec_sub(c3[1], v8); + vsum = vec_sum4s(c3[0], vsum); + vsum2 = vec_sum4s(c3[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c4[0] = vec_and(c4[1], lowMask); + c4[1] = vec_sr(c4[1], v4); + c4[0] = vec_sub(c4[0], v8); + c4[1] = vec_sub(c4[1], v8); + vsum = vec_sum4s(c4[0], vsum); + vsum2 = vec_sum4s(c4[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + vecOffset += 128; + i--; + } while(i > 0); + } + } + } + + template + void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + int64_t i, j; + TB *aoffset = NULL; + VA *vecOffset = NULL; + TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0}; VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0}; @@ -1111,24 +1502,24 @@ class tinyBLAS_Q0_PPC { vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset += 8 * lda; + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; - i = (cols >> 3); - if (i > 0) { - do { + i = (cols >> 3); + if (i > 0) { + do { C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); @@ -1156,10 +1547,10 @@ class tinyBLAS_Q0_PPC { t7 = vec_perm(t2, t4, swiz3); t8 = vec_perm(t2, t4, swiz4); if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); } vec_xst(t5, 0, vecOffset); vec_xst(t6, 0, vecOffset+16); @@ -1175,10 +1566,10 @@ class tinyBLAS_Q0_PPC { t7 = vec_perm(t2, t4, swiz3); t8 = vec_perm(t2, t4, swiz4); if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); } vec_xst(t5, 0, vecOffset+64); vec_xst(t6, 0, vecOffset+80); @@ -1194,10 +1585,10 @@ class tinyBLAS_Q0_PPC { t7 = vec_perm(t2, t4, swiz3); t8 = vec_perm(t2, t4, swiz4); if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); } vec_xst(t5, 0, vecOffset+128); vec_xst(t6, 0, vecOffset+144); @@ -1213,10 +1604,10 @@ class tinyBLAS_Q0_PPC { t7 = vec_perm(t2, t4, swiz3); t8 = vec_perm(t2, t4, swiz4); if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); } vec_xst(t5, 0, vecOffset+192); vec_xst(t6, 0, vecOffset+208); @@ -1240,11 +1631,11 @@ class tinyBLAS_Q0_PPC { } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset += 4 * lda; + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { @@ -1311,7 +1702,7 @@ class tinyBLAS_Q0_PPC { aoffset2 = aoffset1 + lda; aoffset3 = aoffset2 + lda; i = (cols >> 3); - if (i > 0) { + if (i > 0) { do { switch(rows) { case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); @@ -1527,13 +1918,18 @@ class tinyBLAS_Q0_PPC { void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; - std::array comparray; + std::array comparray {}; vector float fin_res[8] = {0}; vector float vs[8] = {0}; + bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); - packNormal((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + if (std::is_same_v) { + packNormalInt4((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + } else { + packNormal((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); @@ -1545,15 +1941,17 @@ class tinyBLAS_Q0_PPC { *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); } } - auto aoffset = A+(ii*lda)+l; - for (int i = 0; i < 4; i++) { - comparray[i] = 0; - int ca = 0; - const int8_t *at = aoffset->qs; - for (int j = 0; j < 32; j++) - ca += (int)*at++; - comparray[i] = ca; - aoffset += lda; + if (!isAblock_q4) { + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 4; i++) { + comparray[i] = 0; + int ca = 0; + auto *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } } compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); @@ -1565,13 +1963,18 @@ class tinyBLAS_Q0_PPC { void KERNEL_8x4(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[8] = {0}; acc_t acc_0, acc_1; - std::array comparray; + std::array comparray {}; vector float fin_res[8] = {0}; vector float vs[8] = {0}; + bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); - packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + if (std::is_same_v) { + packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + } else { + packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + } packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); @@ -1582,15 +1985,17 @@ class tinyBLAS_Q0_PPC { *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); } } - auto aoffset = A+(ii*lda)+l; - for (int i = 0; i < 8; i++) { - comparray[i] = 0; - int ca = 0; - const int8_t *at = aoffset->qs; - for (int j = 0; j < 32; j++) - ca += (int)*at++; - comparray[i] = ca; - aoffset += lda; + if (!isAblock_q4) { + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 8; i++) { + comparray[i] = 0; + int ca = 0; + auto *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } } compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); @@ -1602,15 +2007,20 @@ class tinyBLAS_Q0_PPC { void KERNEL_8x8(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[16] = {0}; acc_t acc_0, acc_1, acc_2, acc_3; - std::array comparray; + std::array comparray {}; vector float fin_res[16] = {0}; vector float vs[16] = {0}; + bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); - packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + if (std::is_same_v) { + packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + } else { + packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); @@ -1624,15 +2034,17 @@ class tinyBLAS_Q0_PPC { *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); } } - auto aoffset = A+(ii*lda)+l; - for (int i = 0; i < 8; i++) { - comparray[i] = 0; - int ca = 0; - const int8_t *at = aoffset->qs; - for (int j = 0; j < 32; j++) - ca += (int)*at++; - comparray[i] = ca; - aoffset += lda; + if (!isAblock_q4) { + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 8; i++) { + comparray[i] = 0; + int ca = 0; + auto *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } } compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); @@ -1653,16 +2065,17 @@ class tinyBLAS_Q0_PPC { int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; - vec_t vec_A[8], vec_B[8] = {0}; + vec_t vec_A[8] = {0}, vec_B[8] = {0}; vector signed int vec_C[4]; acc_t acc_0; + bool isAblock_q4 = std::is_same_v; if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - std::array comparray; + std::array comparray{}; vector float res[4] = {0}; vector float fin_res[4] = {0}; vector float vs[4] = {0}; @@ -1673,7 +2086,11 @@ class tinyBLAS_Q0_PPC { __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead __builtin_mma_xxsetaccz(&acc_0); - packNormal((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + if (isAblock_q4) { + packNormalInt4((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + } else { + packNormal((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + } packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x+=4) { __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); @@ -1687,17 +2104,18 @@ class tinyBLAS_Q0_PPC { } } __builtin_mma_disassemble_acc(vec_C, &acc_0); - auto aoffset = A+(ii*lda)+l; - for (int i = 0; i < RM; i++) { - comparray[i] = 0; - int ca = 0; - const int8_t *at = aoffset->qs; - for (int j = 0; j < 32; j++) - ca += (int)*at++; - comparray[i] = ca; - aoffset += lda; + if (!isAblock_q4) { + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < RM; i++) { + comparray[i] = 0; + int ca = 0; + auto *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } } - for (int i = 0; i < RM; i++) { CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0)); res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); @@ -2013,6 +2431,7 @@ class tinyBLAS_PPC { } } } + void KERNEL_4x4(int64_t ii, int64_t jj) { vec_t vec_A[4], vec_B[4], vec_C[4]; acc_t acc_0; @@ -2259,15 +2678,27 @@ class tinyBLAS_PPC { vec_t vec_C[4]; acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); - vec_t vec_A[4], vec_B[4]; + vec_t vec_A[4] {0}, vec_B[4] = {0}; for (int l=0; l= 4 && RM == 1) { + /* 'GEMV Forwarding' concept is used in first two conditional loops. + * when one of the matrix has a single row/column, the elements are + * broadcasted, instead of using packing routine to prepack the + * matrix elements. + */ + if (RM == 1) { TA* a = const_cast(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); + } else if (RN == 1) { + packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); + TB* b = const_cast(B+(jj)*ldb+l); + vec_B[0] = (vec_t)vec_xl(0,b); + vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); } else { packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); @@ -2371,8 +2802,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 assert(params->ith < params->nth); // only enable sgemm for prompt processing +#if !defined(__MMA__) if (n < 2) return false; +#endif if (Ctype != GGML_TYPE_F32) return false; @@ -2503,8 +2936,8 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; - #elif defined(__MMA__) + //TO-DO: Remove this condition once gemv forwarding is enabled. if (n < 8 && n != 4) return false; if (m < 8 && m != 4) @@ -2516,7 +2949,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; - #else return false; #endif @@ -2541,6 +2973,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; +#elif defined(__MMA__) + //TO-DO: Remove this condition once gemv forwarding is enabled. + if (n < 8 && n != 4) + return false; + if (m < 8 && m != 4) + return false; + tinyBLAS_Q0_PPC tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; #else return false; #endif diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp new file mode 100644 index 0000000000000..4fce569b3bfc8 --- /dev/null +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -0,0 +1,186 @@ +#include "unary-ops.h" + +static inline float op_abs(float x) { + return fabsf(x); +} + +static inline float op_sgn(float x) { + return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f); +} + +static inline float op_neg(float x) { + return -x; +} + +static inline float op_step(float x) { + return (x > 0.f) ? 1.f : 0.f; +} + +static inline float op_tanh(float x) { + return tanhf(x); +} + +static inline float op_elu(float x) { + return (x > 0.f) ? x : expm1f(x); +} + +static inline float op_relu(float x) { + return (x > 0.f) ? x : 0.f; +} + +static inline float op_sigmoid(float x) { + return 1.f / (1.f + expf(-x)); +} + +static inline float op_hardsigmoid(float x) { + return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); +} + +static inline float op_exp(float x) { + return expf(x); +} + +static inline float op_hardswish(float x) { + return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); +} + +static inline float op_sqr(float x) { + return x * x; +} + +static inline float op_sqrt(float x) { + return sqrtf(x); +} + +static inline float op_sin(float x) { + return sinf(x); +} + +static inline float op_cos(float x) { + return cosf(x); +} + +static inline float op_log(float x) { + return logf(x); +} + +template +static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) { + constexpr auto src0_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + for (int i = 0; i < n; i++) { + y[i] = f32_to_dst(op(src0_to_f32(x[i]))); + } +} + +template +static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(dst_t)); + GGML_ASSERT(nb00 == sizeof(src0_t)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + vec_unary_op(ne0, dst_ptr, src0_ptr); + } +} + +// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates +template +static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + +void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h new file mode 100644 index 0000000000000..b1ade2c8e341f --- /dev/null +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -0,0 +1,28 @@ +#pragma once + +#include "common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ggml_compute_forward_abs(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_sgn(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_neg(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_step(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_hardswish(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 954ff5f16924b..8284a0017d272 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -52,7 +52,7 @@ #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS) // AMD -// GCN/CNDA, wave size is 64 +// GCN/CDNA, wave size is 64 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a @@ -60,16 +60,18 @@ #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 -// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 +// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 #define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) -#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) @@ -209,9 +211,9 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) #define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE @@ -244,14 +246,14 @@ static bool fp16_mma_available(const int cc) { return false; #else return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc); + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc); + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. @@ -286,6 +288,10 @@ static __device__ void no_device_code( __trap(); GGML_UNUSED(no_device_code); // suppress unused function warning + +#if defined(GGML_USE_MUSA) + __builtin_unreachable(); +#endif // defined(GGML_USE_MUSA) } #ifdef __CUDA_ARCH__ @@ -409,7 +415,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__) c = __builtin_amdgcn_sdot4(a, b, c, false); -#elif defined(RDNA3) +#elif defined(RDNA3) || defined(RDNA4) c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); #elif defined(RDNA1) || defined(__gfx900__) int tmp1; @@ -723,7 +729,13 @@ struct ggml_cuda_graph { bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; std::vector ggml_graph_properties; - std::vector updated_kernel_arg; + bool use_cpy_indirection = false; + std::vector cpy_dest_ptrs; + char ** dest_ptrs_d; + int dest_ptrs_size = 0; + // Index to allow each cpy kernel to be aware of it's position within the graph + // relative to other cpy nodes. + int graph_cpynode_index = -1; #endif }; diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index aafbaf803b40c..e9ffd274b9966 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -38,7 +38,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float * blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; - if (blockIdx.y < ne01) { // src0 + if (blockIdx.y < (unsigned)ne01) { // src0 int offset_src = nidx + blockIdx.y * ne0 + @@ -64,7 +64,7 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float * blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; - if (blockIdx.z < ne02) { // src0 + if (blockIdx.z < (unsigned)ne02) { // src0 int offset_src = nidx + blockIdx.y * ne0 + diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ggml/src/ggml-cuda/conv-transpose-1d.cu index b1e94d6f770aa..fe4caf674d4d9 100644 --- a/ggml/src/ggml-cuda/conv-transpose-1d.cu +++ b/ggml/src/ggml-cuda/conv-transpose-1d.cu @@ -34,6 +34,10 @@ static __global__ void conv_transpose_1d_kernel( } } dst[global_index] = accumulator; + GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3); + GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3); + GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1); + GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2); } static void conv_transpose_1d_f32_f32_cuda( @@ -75,8 +79,6 @@ void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor const int p0 = 0;//opts[3]; const int d0 = 1;//opts[4]; - const int64_t kernel_size = ggml_nelements(src0); - const int64_t input_size = ggml_nelements(src1); const int64_t output_size = ggml_nelements(dst); conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 795b720d60bcb..2997e2b4d5bd5 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -577,7 +577,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res return; } - const src_t * x = (src_t *) vx; + const src_t * x = (const src_t *) vx; y[i] = x[i]; } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index cca2bee0b2791..ed853ee6c15a2 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -32,16 +32,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, +static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets const int64_t i03 = i/(ne00 * ne01 * ne02); @@ -288,16 +290,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, +static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -314,16 +318,18 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, +static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -339,66 +345,87 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, cpy_blck(cx + x_offset, cdst + dst_offset); } +// Copy destination pointers to GPU to be available when pointer indirection is in use + +void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) + if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers + CUDA_CHECK(cudaStreamSynchronize(stream)); + if (cuda_graph->dest_ptrs_d != nullptr) { + CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d)); + } + CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *))); + cuda_graph->dest_ptrs_size = host_dest_ptrs_size; + } + // copy destination pointers to GPU + CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); + cuda_graph->graph_cpynode_index = 0; // reset index +#else + GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs); + GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream); +#endif +} + static void ggml_cpy_f16_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q8_0_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q4_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q4_0_f32_cuda( @@ -407,22 +434,22 @@ static void ggml_cpy_q4_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream) { + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q4_1_f32_cuda( @@ -431,22 +458,22 @@ static void ggml_cpy_q4_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream) { + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK5_0 == 0); const int num_blocks = ne / QK5_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q5_0_f32_cuda( @@ -455,22 +482,22 @@ static void ggml_cpy_q5_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream) { + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK5_1 == 0); const int num_blocks = ne / QK5_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q5_1_f32_cuda( @@ -479,32 +506,32 @@ static void ggml_cpy_q5_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream) { + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_NL == 0); const int num_blocks = ne / QK4_NL; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f16_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { @@ -541,46 +568,60 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; + char ** dest_ptrs_d = nullptr; + int graph_cpynode_index = -1; +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection) { + dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; + graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; + } +#endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); } +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection) { + ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; + } +#endif + } void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 28b06cddaa87b..6bed0564df27a 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -7,3 +7,5 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); + +void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 1c2a2a138f9b3..56121705bdf3f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -146,7 +146,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -193,7 +193,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -244,7 +244,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; @@ -315,14 +315,14 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( float vals[sizeof(int)] = {0.0f}; #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { vals[l] = scale * x[4*threadIdx.x + l]; } float amax = fabsf(vals[0]); float sum = vals[0]; #pragma unroll - for (int l = 1; l < sizeof(int); ++l) { + for (int l = 1; l < int(sizeof(int)); ++l) { amax = fmaxf(amax, fabsf(vals[l])); sum += vals[l]; } @@ -338,7 +338,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( if (d != 0.0f) { #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { q8[l] = roundf(vals[l] / d); } } @@ -638,7 +638,7 @@ static __global__ void flash_attn_combine_results( float VKQ_denominator = 0.0f; for (int l = 0; l < parallel_blocks; ++l) { const float diff = meta[l].x - kqmax; - const float KQ_max_scale = expf(diff); + float KQ_max_scale = expf(diff); const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; @@ -649,6 +649,7 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; } +[[noreturn]] static void on_no_fattn_vec_case(const int D) { if (D == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 024032f6221bc..04804a15c9d51 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // CP_ASYNC_AVAILABLE #else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); + GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); + GGML_UNUSED(kb0); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } #else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } @@ -985,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8); - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16); - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32); - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) // Kernels with ncols == 128 are only 4% faster due to register pressure. -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory. +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 77455d8e4f1a2..e0039e1755c2b 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -282,7 +282,19 @@ static __global__ void flash_attn_tile_ext_f16( } } #else - NO_DEVICE_CODE; + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 85fea4404d07c..fcb6f848fe003 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -52,6 +52,18 @@ static __global__ void flash_attn_tile_ext_f32( return; #endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; return; } @@ -281,6 +293,18 @@ static __global__ void flash_attn_tile_ext_f32( } } #else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 32c52ebe33e9c..e17d2d0e4fb39 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -292,7 +292,19 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else - NO_DEVICE_CODE; + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 336c136d19d7d..d42ddca49f66d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -45,6 +45,18 @@ static __global__ void flash_attn_vec_ext_f32( // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; return; } @@ -114,7 +126,7 @@ static __global__ void flash_attn_vec_ext_f32( // Set memory to zero if out of bounds: if (ncols > 2 && ic0 + j >= ne01) { #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; tmp_q_i32[i] = 0; @@ -127,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f32( const float * Q_f = (const float *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); } } @@ -140,7 +152,7 @@ static __global__ void flash_attn_vec_ext_f32( float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; @@ -277,6 +289,16 @@ static __global__ void flash_attn_vec_ext_f32( dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 5c214ea3109d2..bc21b27a0cc07 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -430,7 +430,17 @@ static __global__ void flash_attn_ext_f16( dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; } #else - NO_DEVICE_CODE; + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8edc12649aa63..7a2d1e45365af 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -299,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128); + const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6dd5dcb85e15c..57319bafdb70a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -31,6 +31,8 @@ #include "ggml-cuda/rope.cuh" #include "ggml-cuda/scale.cuh" #include "ggml-cuda/softmax.cuh" +#include "ggml-cuda/ssm-conv.cuh" +#include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/tsembd.cuh" @@ -1216,7 +1218,7 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - if (GGML_CUDA_CC_IS_CDNA(cc)) { + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { const float alpha = 1.0f; const float beta = 0.0f; CUBLAS_CHECK( @@ -1759,7 +1761,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } - if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) { + int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { cu_compute_type = CUBLAS_COMPUTE_32F; alpha = &alpha_f32; beta = &beta_f32; @@ -1836,7 +1840,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } #endif - if (dst->op_params[0] == GGML_PREC_DEFAULT) { + if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream); } @@ -2294,6 +2298,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; + case GGML_OP_SSM_CONV: + ggml_cuda_op_ssm_conv(ctx, dst); + break; + case GGML_OP_SSM_SCAN: + ggml_cuda_op_ssm_scan(ctx, dst); + break; case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; @@ -2431,10 +2441,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { #ifdef USE_CUDA_GRAPH static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) { + bool use_cuda_graph) { // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->updated_kernel_arg.clear(); + cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2466,8 +2477,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud } if (node->op == GGML_OP_CPY) { - // store the copy op parameter which changes with each token. - cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); + + // Store the pointers which are updated for each token, such that these can be sent + // to the device and accessed using indirection from CUDA graph + cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); + // store a pointer to each copy op CUDA kernel to identify it later void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); if (!ptr) { @@ -2475,10 +2489,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); #endif - } else { - if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { - ggml_cuda_cpy_fn_ptrs.push_back(ptr); - } } } @@ -2487,6 +2497,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud } } + if (use_cuda_graph) { + cuda_ctx->cuda_graph->use_cpy_indirection = true; + // copy pointers to GPU so they can be accessed via indirection within CUDA graph + ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream()); + } + return use_cuda_graph; } @@ -2541,51 +2557,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } -static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { - - if (cuda_graph_update_required) { - // Extract nodes from graph - // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); - // Subsequent call with non-null argument gets nodes - cuda_ctx->cuda_graph->nodes.clear(); - cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); - cuda_ctx->cuda_graph->params.clear(); - cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); - if (cuda_ctx->cuda_graph->num_nodes > 0) { - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); - - // Loop over nodes, and extract kernel parameters from each node - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); - if (node_type == cudaGraphNodeTypeKernel) { - cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime - if (stat == cudaErrorInvalidDeviceFunction) { - // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. - // We don't need to update blas nodes, so clear error and move on. - (void)cudaGetLastError(); - } else { - GGML_ASSERT(stat == cudaSuccess); - } - } - } - } - } else { - // One of the arguments to the copy kernel is updated for each token, hence we need to - // replace that argument with the updated value in the CUDA graph - // on update steps, the live parameters will already be captured - int k = 0; - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { - char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); - *(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr; - CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); - } - } - } -} - static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { bool cuda_graph_update_required = false; @@ -2645,8 +2616,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { #endif static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph, - bool & cuda_graph_update_required) { + bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. @@ -2696,13 +2666,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } - - // Perform update to graph (if required for this token), and change copy parameter (required for every token) - maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); - - // Update graph executable - update_cuda_graph_executable(cuda_ctx); - + if (cuda_graph_update_required) { // Update graph executable + update_cuda_graph_executable(cuda_ctx); + } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); #else @@ -2716,10 +2682,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cuda_set_device(cuda_ctx->device); - // vector of pointers to CUDA cpy kernels, which are required to identify - // kernel parameters which need updated in the graph for each token - std::vector ggml_cuda_cpy_fn_ptrs; - #ifdef USE_CUDA_GRAPH static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); @@ -2753,8 +2715,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, if (use_cuda_graph) { cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); - use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, - ggml_cuda_cpy_fn_ptrs, use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph); // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. if (use_cuda_graph && cuda_graph_update_required) { @@ -2775,6 +2736,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } + if (!use_cuda_graph) { + cuda_ctx->cuda_graph->use_cpy_indirection = false; + } + #else bool use_cuda_graph = false; bool cuda_graph_update_required = false; @@ -2782,7 +2747,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool graph_evaluated_or_captured = false; - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); return GGML_STATUS_SUCCESS; } @@ -3191,6 +3156,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_LOG: + case GGML_OP_SSM_SCAN: + case GGML_OP_SSM_CONV: return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; @@ -3230,6 +3197,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #ifndef FLASH_ATTN_AVAILABLE return false; #endif // FLASH_ATTN_AVAILABLE + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } + if (op->src[0]->ne[0] == 192) { + return false; + } if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 9206bfeba3de4..2af63355a195e 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -26,6 +26,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(ret) : "r"(x)); #else + GGML_UNUSED(x); NO_DEVICE_CODE; #endif // defined(NEW_MMA_AVAILABLE) return ret; @@ -178,6 +179,7 @@ namespace ggml_cuda_mma { : "l"(xs)); #else load_generic(xs0, stride); + GGML_UNUSED(t); #endif // NEW_MMA_AVAILABLE } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 2c19485d51a92..b36b43d5417ba 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index ee01154254e3d..532358018f410 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -945,7 +945,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -1024,7 +1024,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1035,19 +1035,34 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (k01 < WARP_SIZE/2) { - constexpr int ns = 2; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, - &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); - } else { - constexpr int ns = 1; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, - &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); - } + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } + + // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. + // As a workaround 2 separate loops are used instead. +#pragma unroll + for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } } @@ -1176,7 +1191,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -1253,7 +1268,7 @@ template static __device__ __forceinlin const float d = bxi->d; #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l]; } #else @@ -1376,7 +1391,7 @@ template static __device__ __forceinlin const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); } } @@ -1517,7 +1532,7 @@ template static __device__ __forceinlin const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); } } @@ -1810,7 +1825,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -2570,6 +2585,8 @@ static __device__ void mul_mat_q_process_tile( } else { write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j); } + + GGML_UNUSED(ne00); GGML_UNUSED(ne10); } @@ -2577,9 +2594,9 @@ static __device__ void mul_mat_q_process_tile( template #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) __launch_bounds__(WARP_SIZE*nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #else #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA __launch_bounds__(WARP_SIZE*nwarps, 1) @@ -2695,7 +2712,7 @@ static __global__ void mul_mat_q_stream_k_fixup( const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block: - if (it != blockIdx.x || jt != blockIdx.y) { + if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) { continue; } @@ -2825,7 +2842,6 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); - const int nsm = ggml_cuda_info().devices[id].nsm; const int cc = ggml_cuda_info().devices[id].cc; const int smpbo = ggml_cuda_info().devices[id].smpbo; diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index f89ed03b578d1..b39961cd1154d 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -29,7 +29,7 @@ static __global__ void mul_mat_vec( __syncthreads(); } - float sumf; + float sumf = 0.0f; if constexpr (std::is_same::value) { const half2 * x2 = (const half2 *) x; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index a7d518a574ddc..eef8585a7380a 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -54,7 +54,7 @@ enum mmvq_parameter_table_id { }; static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) +#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; @@ -64,7 +64,7 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { } static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { return MMVQ_PARAMETERS_RDNA2; } if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { @@ -151,7 +151,7 @@ static __global__ void mul_mat_vec_q( constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; // partial sum for each thread - float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; + float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = (const block_q8_1 *) vy; @@ -197,10 +197,12 @@ static __global__ void mul_mat_vec_q( tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) { dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; } } + + GGML_UNUSED(nrows_x); } static std::pair calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) { diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index aba539e8dad10..77432b04689be 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -14,7 +14,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons nidx + blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; - if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) { + if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) { int offset_src = nidx + blockIdx.y * ne00 + diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu new file mode 100644 index 0000000000000..f637571963730 --- /dev/null +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -0,0 +1,148 @@ +#include "ssm-conv.cuh" + +template +static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, + float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, + const int64_t n_t) { + GGML_UNUSED(src0_nb0); + const int tid = threadIdx.x; + const int bidx = blockIdx.x; + const int bidy = blockIdx.y; + + const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); + const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1); + float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0); + + const int stride_x = src0_nb1 / sizeof(float); + const int stride_w = src1_nb1 / sizeof(float); + const int stride_y = dst_nb1 / sizeof(float); + + float x[d_conv] = { 0.0f }; + float w[d_conv] = { 0.0f }; + +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + + for (int64_t i = 0; i < n_t; i++) { + float sumf = 0.0f; + + if (i == 0) { + for (size_t j = 0; j < d_conv; j++) { + x[j] = x_block[tid * stride_x + j]; + } + } else { + x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; + } + +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + sumf += x[(i + j) % d_conv] * w[j]; + } + y_block[i * stride_y + tid] = sumf; + } +} + +template +static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, float * __restrict__ dst, const int dst_nb0, + const int dst_nb1, const int dst_nb2, const int64_t n_t) { + const int tid = threadIdx.x; + const int bidx = blockIdx.x; + const int bidy = blockIdx.y; + const int bidz = blockIdx.z; + + const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 + + bidz * split_n_t * src0_nb0); + const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1); + float * y_block = + (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0); + + const int stride_x = src0_nb1 / sizeof(float); + const int stride_w = src1_nb1 / sizeof(float); + const int stride_y = dst_nb1 / sizeof(float); + + float x[d_conv] = { 0.0f }; + float w[d_conv] = { 0.0f }; + +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + +#pragma unroll + for (int64_t i = 0; i < split_n_t; i++) { + if (bidz * split_n_t + i < n_t) { + float sumf = 0.0f; + + if (i == 0) { + for (size_t j = 0; j < d_conv; j++) { + x[j] = x_block[tid * stride_x + j]; + } + } else { + x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; + } + +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + sumf += x[(i + j) % d_conv] * w[j]; + } + y_block[i * stride_y + tid] = sumf; + } + } +} + +static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, + const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, + const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, + const int64_t n_s, cudaStream_t stream) { + const int threads = 128; + GGML_ASSERT(nr % threads == 0); + + if (n_t <= 32) { + const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); + if (nc == 4) { + ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, n_t); + } else { + GGML_ABORT("Only support kernel size = 4 now."); + } + } else { + if (nc == 4) { + const int64_t split_n_t = 32; + dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); + ssm_conv_long_token_f32<<>>( + src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + } else { + GGML_ABORT("Only support kernel size = 4 right now."); + } + } +} + +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + + const int64_t nc = src1->ne[0]; // d_conv + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = dst->ne[1]; // tokens per sequence + const int64_t n_s = dst->ne[2]; // number of sequences in the batch + + GGML_ASSERT(dst->ne[0] == nr); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], + dst->nb[2], nc, nr, n_t, n_s, stream); +} diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh new file mode 100644 index 0000000000000..8e6c1f00bfa03 --- /dev/null +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu new file mode 100644 index 0000000000000..37ee208c09d46 --- /dev/null +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -0,0 +1,153 @@ +#include "ssm-scan.cuh" + +template +__global__ void __launch_bounds__(splitD, 2) + ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, + const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, + const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, + float * __restrict__ dst, const int64_t L) { + GGML_UNUSED(src1_nb0); + GGML_UNUSED(src2_nb0); + const int bidx = blockIdx.x; // split along B + const int bidy = blockIdx.y; // split along D + const int tid = threadIdx.x; + const int wid = tid / 32; + const int wtid = tid % 32; + + extern __shared__ float smem[]; + const int stride_sA = N + 1; + const int stride_ss0 = N + 1; + float * smem_A = smem; + float * smem_s0 = smem_A + splitD * stride_sA; + + const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); + const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); + const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); + const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); + const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); + const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); + float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); + float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); + + const int stride_s0 = src0_nb1 / sizeof(float); + const int stride_x = src1_nb1 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_A = src3_nb1 / sizeof(float); + const int stride_B = src4_nb1 / sizeof(float); + const int stride_C = src5_nb1 / sizeof(float); + const int stride_s = stride_s0; + const int stride_y = stride_x; + + // can N not be 16? for example 32? + if (N == 16) { +#pragma unroll + for (size_t i = 0; i < splitD / 4; i += 2) { + float value = A_block[(wid * warpSize + i) * stride_A + wtid]; + // todo: bank conflict + // I am always confused with how to use the swizzling method to solve + // bank conflit. Hoping somebody can tell me. + smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; + } +#pragma unroll + for (size_t i = 0; i < splitD / 4; i += 2) { + float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; + smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; + } + } + + __syncthreads(); + + for (int64_t i = 0; i < L; i++) { + float dt_soft_plus = dt_block[i * stride_dt + tid]; + if (dt_soft_plus <= 20.0f) { + dt_soft_plus = log1pf(exp(dt_soft_plus)); + } + float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; + float sumf = 0.0f; +#pragma unroll + for (size_t j = 0; j < N; j++) { + float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + + (B_block[i * stride_B + j] * x_dt); + sumf += state * C_block[i * stride_C + j]; + if (i == L - 1) { + s_block[tid * stride_s + j] = state; + } else { + smem_s0[tid * stride_ss0 + j] = state; + } + } + __syncthreads(); + y_block[i * stride_y + tid] = sumf; + } +} + +static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, + const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, + const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, + const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, + float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, + cudaStream_t stream) { + const int threads = 128; + // todo: consider D cannot be divided,does this situation exist? + GGML_ASSERT(D % threads == 0); + const dim3 blocks(B, (D + threads - 1) / threads, 1); + const int smem_size = (threads * (N + 1) * 2) * sizeof(float); + if (N == 16) { + ssm_scan_f32<128, 16><<>>( + src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, + src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); + } else { + GGML_ABORT("doesn't support N!=16."); + } +} + +void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // s + const struct ggml_tensor * src1 = dst->src[1]; // x + const struct ggml_tensor * src2 = dst->src[2]; // dt + const struct ggml_tensor * src3 = dst->src[3]; // A + const struct ggml_tensor * src4 = dst->src[4]; // B + const struct ggml_tensor * src5 = dst->src[5]; // C + + // const int64_t d_state = src0->ne[0]; + // const int64_t d_inner = src0->ne[1]; + // const int64_t l = src1->ne[1]; + // const int64_t b = src0->ne[2]; + + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch + + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(float)); + GGML_ASSERT(src4->nb[0] == sizeof(float)); + GGML_ASSERT(src5->nb[0] == sizeof(float)); + // required for the dot product between s and C + GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float)); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + const float * src2_d = (const float *) src2->data; + const float * src3_d = (const float *) src3->data; + const float * src4_d = (const float *) src4->data; + const float * src5_d = (const float *) src5->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], + src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], + src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); +} diff --git a/ggml/src/ggml-cuda/ssm-scan.cuh b/ggml/src/ggml-cuda/ssm-scan.cuh new file mode 100644 index 0000000000000..ee078f5ebb8c0 --- /dev/null +++ b/ggml/src/ggml-cuda/ssm-scan.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index cf513c3ade7c4..524e979574266 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -19,7 +19,7 @@ static __global__ void upscale_f32(const float * x, float * dst, int i02 = i12 / sf2; int i03 = i13 / sf3; - dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); + dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) ); } static void upscale_f32_cuda(const float * x, float * dst, diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index a4c717a321cfb..3983ce5b423c0 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -151,6 +151,10 @@ #define CDNA #endif +#if defined(__GFX12__) +#define RDNA4 +#endif + #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ defined(__gfx1150__) || defined(__gfx1151__) #define RDNA3 diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 1fbcbd0456e99..be2e3fc915551 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -381,6 +381,35 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); return r; } +#elif defined(__riscv) && defined(GGML_RV_ZFH) + + static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + float f; + __asm__( + "fmv.h.x %[f], %[h]\n\t" + "fcvt.s.h %[f], %[f]" + : [f] "=&f" (f) + : [h] "r" (h) + ); + return f; + } + + static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + ggml_fp16_t res; + __asm__( + "fcvt.h.s %[f], %[f]\n\t" + "fmv.x.h %[h], %[f]" + : [h] "=&r" (res) + : [f] "f" (f) + ); + return res; + } + + #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) + #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + #else // FP16 <-> FP32 diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 1e954b4ceabd7..8721b272de6a9 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1,6 +1,70 @@ #ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL +// kernel parameters for mat-vec threadgroups +// +// N_R0: number of src0 rows to process per simdgroup +// N_SG: number of simdgroups per threadgroup +// +// TODO: for optimal performance, become function of the device and work size + +#define N_R0_Q4_0 4 +#define N_SG_Q4_0 2 + +#define N_R0_Q4_1 4 +#define N_SG_Q4_1 2 + +#define N_R0_Q5_0 4 +#define N_SG_Q5_0 2 + +#define N_R0_Q5_1 4 +#define N_SG_Q5_1 2 + +#define N_R0_Q8_0 4 +#define N_SG_Q8_0 2 + +#define N_R0_Q2_K 4 +#define N_SG_Q2_K 2 + +#define N_R0_Q3_K 2 +#define N_SG_Q3_K 2 + +#define N_R0_Q4_K 4 +#define N_SG_Q4_K 2 + +#define N_R0_Q5_K 2 +#define N_SG_Q5_K 2 + +#define N_R0_Q6_K 1 +#define N_SG_Q6_K 2 + +#define N_R0_IQ1_S 4 +#define N_SG_IQ1_S 2 + +#define N_R0_IQ1_M 4 +#define N_SG_IQ1_M 2 + +#define N_R0_IQ2_XXS 4 +#define N_SG_IQ2_XXS 2 + +#define N_R0_IQ2_XS 4 +#define N_SG_IQ2_XS 2 + +#define N_R0_IQ2_S 4 +#define N_SG_IQ2_S 2 + +#define N_R0_IQ3_XXS 4 +#define N_SG_IQ3_XXS 2 + +#define N_R0_IQ3_S 4 +#define N_SG_IQ3_S 2 + +#define N_R0_IQ4_NL 2 +#define N_SG_IQ4_NL 2 + +#define N_R0_IQ4_XS 2 +#define N_SG_IQ4_XS 2 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -155,9 +219,12 @@ typedef struct { int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; - uint64_t nb_12_1; - uint64_t nb_12_2; - uint64_t nb_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; uint64_t nb31; int32_t ne1; int32_t ne2; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index af65e7d9f53d4..456e1fd994c40 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -351,42 +351,56 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, @@ -395,6 +409,20 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, @@ -758,313 +786,341 @@ @implementation GGMLMetalClass // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } return ctx; @@ -2561,171 +2617,180 @@ static void ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - id pipeline = nil; + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; + nr1 = 4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; - nrows = 4; } break; case GGML_TYPE_F16: { - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; - nrows = ne11; + nr1 = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; - nrows = 4; + nr1 = 4; } } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; - nrows = 4; + nr1 = 4; } } break; case GGML_TYPE_BF16: { - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; - nrows = ne11; + nr1 = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; - nrows = 4; + nr1 = 4; } } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; - nrows = 4; + nr1 = 4; } } break; case GGML_TYPE_Q4_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { - nth0 = 4; //1; - nth1 = 8; //32; + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; } break; case GGML_TYPE_IQ3_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; } break; case GGML_TYPE_IQ3_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; } break; case GGML_TYPE_IQ2_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; } break; case GGML_TYPE_IQ1_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; } break; case GGML_TYPE_IQ1_M: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; } break; case GGML_TYPE_IQ4_NL: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; } break; case GGML_TYPE_IQ4_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; default: @@ -2762,41 +2827,10 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_MUL_MAT_ID: @@ -2822,20 +2856,19 @@ static void ggml_metal_encode_node( // ne21 = n_rows const int dst_rows = ne20*ne21; const int dst_rows_min = n_as; - const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4; + const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4; // max size of the rowids array in the kernel shared buffer - GGML_ASSERT(dst_rows <= dst_rows_max); + //GGML_ASSERT(dst_rows <= dst_rows_max); // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! if ([device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && - dst_rows > dst_rows_min) { + //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments + dst_rows > dst_rows_min && + dst_rows <= dst_rows_max) { + // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) switch (src0->type) { @@ -2902,146 +2935,155 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - id pipeline = nil; + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; } break; case GGML_TYPE_F16: { GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; case GGML_TYPE_BF16: { GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; } break; case GGML_TYPE_Q4_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { - nth0 = 4; //1; - nth1 = 8; //32; + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; } break; case GGML_TYPE_IQ3_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; } break; case GGML_TYPE_IQ3_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; } break; case GGML_TYPE_IQ2_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; } break; case GGML_TYPE_IQ1_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; } break; case GGML_TYPE_IQ1_M: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; } break; case GGML_TYPE_IQ4_NL: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; } break; case GGML_TYPE_IQ4_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; default: @@ -3052,7 +3094,7 @@ static void ggml_metal_encode_node( }; if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); + GGML_ASSERT(ne00 >= nsg*nr0); } ggml_metal_kargs_mul_mv_id args = { @@ -3085,43 +3127,12 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; - const int tgz = dst_rows; + const int64_t ne123 = dst_rows; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_GET_ROWS: @@ -3776,7 +3787,9 @@ static void ggml_metal_encode_node( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == src2->type); - GGML_ASSERT(ggml_are_same_shape (src1, src2)); + //GGML_ASSERT(ggml_are_same_shape (src1, src2)); + GGML_ASSERT(ne11 == ne21); + GGML_ASSERT(ne12 == ne22); struct ggml_tensor * src3 = node->src[3]; @@ -3823,125 +3836,161 @@ static void ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower - if (ne01 >= 4 || (ne00%128 != 0)) { + // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 + if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) { switch (src1->type) { case GGML_TYPE_F16: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_BF16: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q4_0: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q4_1: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q5_0: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q5_1: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q8_0: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; default: @@ -3973,6 +4022,42 @@ static void ggml_metal_encode_node( } } } break; + case 192: + { + if (ne20 == 128) { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } + } break; case 256: { switch (src1->type) { @@ -4010,9 +4095,12 @@ static void ggml_metal_encode_node( /*.ne11 =*/ ne11, /*.ne_12_2 =*/ ne12, /*.ne_12_3 =*/ ne13, - /*.nb_12_1 =*/ nb11, - /*.nb_12_2 =*/ nb12, - /*.nb_12_3 =*/ nb13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, /*.nb31 =*/ nb31, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, @@ -4091,10 +4179,9 @@ static void ggml_metal_encode_node( // ne00*(nsg) // each simdgroup has a full f16 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; - while (true) { const size_t smem = FATTN_SMEM(nsgmax); if (smem > device.maxThreadgroupMemoryLength) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3cef81b797197..b08666e27991f 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { - reg = (type4)(*(src + il)); + reg = (type4)(*(src)); } #if defined(GGML_METAL_USE_BF16) @@ -56,6 +56,11 @@ template void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); } + +template +void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src)); +} #endif template @@ -1439,7 +1444,7 @@ kernel void kernel_rwkv_wkv7_f32( float4 sa_vec(0.0); - for (int j = 0; j < head_size; j += 4) { + for (uint j = 0; j < head_size; j += 4) { float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); sa_vec += a_vec * s_vec; @@ -1853,14 +1858,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// guard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -1876,7 +1874,7 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -1888,15 +1886,15 @@ void mul_vec_q_n_f32_impl( device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr]; - for (int row = 0; row < nr; ++row) { + device const block_q_type * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; + float sumf[nr0] = {0.f}; const short ix = (tiisg/2); const short il = (tiisg%2)*8; @@ -1908,7 +1906,7 @@ void mul_vec_q_n_f32_impl( float sumy[2] = { 0.f, 0.f }; #pragma unroll - for (int i = 0; i < 8; i += 2) { + for (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -1919,7 +1917,7 @@ void mul_vec_q_n_f32_impl( } #pragma unroll - for (int row = 0; row < nr; row++) { + for (short row = 0; row < nr0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } @@ -1928,7 +1926,7 @@ void mul_vec_q_n_f32_impl( device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -1945,7 +1943,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1956,7 +1954,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1967,7 +1965,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1978,12 +1976,12 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -1993,16 +1991,13 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - const int nb = args.ne00/QK8_0; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0*nsg + sgitg)*nr; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -2014,15 +2009,15 @@ void kernel_mul_mv_q8_0_f32_impl( device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr]; - for (int row = 0; row < nr; ++row) { + device const block_q8_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } float yl[NB_Q8_0]; - float sumf[nr] = { 0.f }; + float sumf[nr0] = { 0.f }; const short ix = tiisg/4; const short il = tiisg%4; @@ -2035,7 +2030,7 @@ void kernel_mul_mv_q8_0_f32_impl( yl[i] = yb[i]; } - for (int row = 0; row < nr; row++) { + for (short row = 0; row < nr0; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; for (short iq = 0; iq < NB_Q8_0; ++iq) { @@ -2049,7 +2044,7 @@ void kernel_mul_mv_q8_0_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -2067,7 +2062,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 @@ -2404,9 +2399,9 @@ void kernel_mul_mv_impl( sumf += (T0) x[i] * (T1) y[i]; } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } else { @@ -2427,10 +2422,10 @@ void kernel_mul_mv_impl( sumf += dot((float4) x4[i], (float4) y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } @@ -2492,9 +2487,9 @@ kernel void kernel_mul_mv_1row( for (int i = tiisg; i < args.ne00; i += 32) { sumf += (float) x[i] * (float) y[i]; } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[r0] = all_sum; + dst_f32[r0] = sum_all; } } else { device const T4 * x4 = (device const T4 *) x; @@ -2504,11 +2499,11 @@ kernel void kernel_mul_mv_1row( sumf += dot((float4) x4[i], y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[r0] = sum_all; } } } @@ -2553,9 +2548,9 @@ kernel void kernel_mul_mv_l4( sumf += dot((float4) x4[i], y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } @@ -3110,7 +3105,8 @@ template< typename vd4x4_t, // key type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short D, // head size + short DK, // K head size + short DV, // V head size short Q = 8, // queries per threadgroup short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup @@ -3132,20 +3128,24 @@ kernel void kernel_flash_attn_ext( const int iq2 = tgpig[1]; const int iq1 = tgpig[0]*Q; - const short D4 = D/4; - const short D8 = D/8; - const short D16 = D/16; - const short NW = N_SIMDWIDTH; - const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + constexpr short DK4 = DK/4; + constexpr short DK8 = DK/8; + constexpr short DK16 = DK/16; + constexpr short DV4 = DV/4; + constexpr short DV8 = DV/8; + constexpr short DV16 = DV/16; + + constexpr short NW = N_SIMDWIDTH; + constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = D + 2*TS; // shared memory size per query in (half) + const short T = DK + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t @@ -3154,23 +3154,23 @@ kernel void kernel_flash_attn_ext( threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o8x8_t lo[D8]; + o8x8_t lo[DV8]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); - for (short i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < DK4; i += NW) { if (iq1 + j < args.ne01) { - sq4[j*D4 + i] = (q4_t) q4[i]; + sq4[j*DK4 + i] = (q4_t) q4[i]; } else { - sq4[j*D4 + i] = (q4_t) 0.0f; + sq4[j*DK4 + i] = (q4_t) 0.0f; } } } // zero out lo - for (short i = 0; i < D8; ++i) { + for (short i = 0; i < DV8; ++i) { lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); } @@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { [0 ... Q-1] = 0.0f }; - half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; + float S[Q] = { [0 ... Q-1] = 0.0f }; + float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; // thread indices inside the simdgroup // TODO: see if we can utilize quad-group functions for better performance @@ -3200,22 +3200,15 @@ kernel void kernel_flash_attn_ext( const short ikv2 = iq2/(args.ne02/args.ne_12_2); const short ikv3 = iq3/(args.ne03/args.ne_12_3); - // load the queries from shared memory into local memory - q8x8_t mq[D8]; - - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, D); - } - const bool has_mask = mask != q; - half slope = 1.0f; + float slope = 1.0f; // ALiBi if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); @@ -3231,14 +3224,14 @@ kernel void kernel_flash_attn_ext( if (has_mask) { // used to detect blocks full of -INF - half smax = -INFINITY; + float smax = -INFINITY; // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); - const half m = pm[ic + tiisg]; + const float m = pm[ic + tiisg]; ss[j*TS + C + tiisg] = m; smax = max(smax, m); @@ -3259,20 +3252,22 @@ kernel void kernel_flash_attn_ext( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); - #pragma unroll(D8) - for (short i = 0; i < D8; ++i) { + #pragma unroll(DK8) + for (short i = 0; i < DK8; ++i) { k8x8_t mk; - simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + q8x8_t mq; + simdgroup_load(mq, sq + i*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } else { - for (short ii = 0; ii < D16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); - if (D16%4 == 0) { + if (DK16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks { k4x4_t tmp; @@ -3285,15 +3280,18 @@ kernel void kernel_flash_attn_ext( #pragma unroll(4) for (short k = 0; k < 4; ++k) { k8x8_t mk; + q8x8_t mq; simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } else { - if (ii + tx < D16) { + if (ii + tx < DK16) { k4x4_t tmp; deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); sk4x4[4*ty + tx] = tmp; @@ -3301,14 +3299,17 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - for (short k = 0; k < 4 && ii + k < D16; ++k) { + for (short k = 0; k < 4 && ii + k < DK16; ++k) { k8x8_t mk; + q8x8_t mq; simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } } @@ -3326,10 +3327,10 @@ kernel void kernel_flash_attn_ext( // online softmax { for (ushort j = 0; j < Q; ++j) { - const half m = M[j]; + const float m = M[j]; // scale and apply the logitcap / mask - half s = ss[j*TS + tiisg]*args.scale; + float s = ss[j*TS + tiisg]*args.scale; if (args.logit_softcap != 0.0f) { s = args.logit_softcap*precise::tanh(s); @@ -3340,8 +3341,8 @@ kernel void kernel_flash_attn_ext( M[j] = simd_max(max(M[j], s)); - const half ms = exp(m - M[j]); - const half vs = exp(s - M[j]); + const float ms = exp(m - M[j]); + const float vs = exp(s - M[j]); S[j] = S[j]*ms + simd_sum(vs); @@ -3360,8 +3361,8 @@ kernel void kernel_flash_attn_ext( s8x8_t mm; simdgroup_load(mm, ss + 2*C, TS, 0, false); - #pragma unroll(D8) - for (short i = 0; i < D8; ++i) { + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { simdgroup_multiply(lo[i], mm, lo[i]); } } @@ -3374,20 +3375,20 @@ kernel void kernel_flash_attn_ext( if (is_same::value) { // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); - #pragma unroll(D8) - for (short i = 0; i < D8; ++i) { + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { v8x8_t mv; - simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 + simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); } } else { - for (short ii = 0; ii < D16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + for (short ii = 0; ii < DV16; ii += 4) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); - if (D16%4 == 0) { + if (DV16%4 == 0) { // no need for bound checks { v4x4_t tmp; @@ -3408,7 +3409,7 @@ kernel void kernel_flash_attn_ext( simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); } } else { - if (ii + tx < D16) { + if (ii + tx < DV16) { v4x4_t tmp; deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); sv4x4[4*ty + tx] = tmp; @@ -3416,7 +3417,7 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - for (short k = 0; k < 4 && ii + k < D16; ++k) { + for (short k = 0; k < 4 && ii + k < DV16; ++k) { v8x8_t mv; simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); @@ -3443,15 +3444,15 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { - half S = { 0.0f }; - half M = { -__FLT16_MAX__/2 }; + float S = { 0.0f }; + float M = { -__FLT16_MAX__/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], so + i*8, D, 0, false); + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); } } @@ -3460,16 +3461,16 @@ kernel void kernel_flash_attn_ext( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*TS + 0]; - const half S1 = ss[j*TS + sg*SH + 0]; + const float S0 = ss[j*TS + 0]; + const float S1 = ss[j*TS + sg*SH + 0]; - const half M0 = ss[j*TS + 1]; - const half M1 = ss[j*TS + sg*SH + 1]; + const float M0 = ss[j*TS + 1]; + const float M1 = ss[j*TS + sg*SH + 1]; M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); S = S0*ms0 + S1*ms1; @@ -3490,11 +3491,11 @@ kernel void kernel_flash_attn_ext( simdgroup_load(ms0, ss + 2*C, TS, 0, false); simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); - #pragma unroll(D8) - for (short i = 0; i < D8; ++i) { + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { o8x8_t t; - simdgroup_load (t, so + i*8, D, 0, false); + simdgroup_load (t, so + i*8, DV, 0, false); simdgroup_multiply(t, ms1, t); simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); @@ -3505,8 +3506,8 @@ kernel void kernel_flash_attn_ext( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], so + i*8, D, 0, false); + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); } } @@ -3517,8 +3518,8 @@ kernel void kernel_flash_attn_ext( for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { const float S = ss[j*TS + 0]; - for (short i = tiisg; i < D4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; } } } @@ -3535,80 +3536,94 @@ kernel void kernel_flash_attn_ext( float, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 -typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES template< - typename q4_t, // query types in shared memory - typename q4x4_t, - typename k4x4_t, // key types in shared memory - typename v4x4_t, // value types in shared memory - typename qk_t, // Q*K types - typename s_t, // soft-max types + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types typename s4_t, - typename s4x4_t, - typename o4x4_t, // attention accumulation types - typename kd4x4_t, // key type in device memory + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory short nl_k, - void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), - typename vd4x4_t, // key type in device memory + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // key type in device memory short nl_v, - void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short D, // head size - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, @@ -3627,29 +3642,28 @@ kernel void kernel_flash_attn_ext_vec( const int iq2 = tgpig[1]; const int iq1 = tgpig[0]; - const short D4 = D/4; - const short D16 = D/16; - const short NW = N_SIMDWIDTH; - const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0 - const short SH = 2*C; // shared memory per simdgroup + constexpr short DK4 = DK/4; + constexpr short DV4 = DV/4; + constexpr short NW = N_SIMDWIDTH; + constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads + constexpr short SH = 4*C; // shared memory per simdgroup - const short T = D + nsg*SH; // shared memory size per query in (half) + const short T = DK + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask - threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o4x4_t lo[D16/NL]; + // store the result for all queries in local memory (the O matrix from the paper) + o4_t lo[DV4/NL]; // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); - for (short i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < DK4; i += NW) { if (iq1 < args.ne01) { sq4[i] = (q4_t) q4[i]; } else { @@ -3658,8 +3672,8 @@ kernel void kernel_flash_attn_ext_vec( } // zero out lo - for (short i = 0; i < D16/NL; ++i) { - lo[i] = (o4x4_t) 0.0f; + for (short i = 0; i < DV4/NL; ++i) { + lo[i] = (o4_t) 0.0f; } // zero out shared memory SH @@ -3670,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S = 0.0f; - half M = -__FLT16_MAX__/2; + float S = 0.0f; + float M = -__FLT16_MAX__/2; // thread indices inside the simdgroup const short tx = tiisg%NL; @@ -3684,26 +3698,18 @@ kernel void kernel_flash_attn_ext_vec( const short ikv2 = iq2/(args.ne02/args.ne_12_2); const short ikv3 = iq3/(args.ne03/args.ne_12_3); - // load the queries from shared memory into local memory - q4x4_t mq[D16/NL]; - - #pragma unroll(D16/NL) - for (short ii = 0; ii < D16; ii += NL) { - mq[ii/NL] = sq4x4[ii + tx]; - } - const bool has_mask = mask != q; // pointer to the mask device const half * pm = (device const half *) (mask + iq1*args.nb31); - half slope = 1.0f; + float slope = 1.0f; // ALiBi if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); @@ -3723,43 +3729,56 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and 4 (NW/NL) keys - for (short cc = 0; cc < C/4; ++cc) { - qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; + // each simdgroup processes 1 query and NE (NW/NL) head elements + for (short cc = 0; cc < C/NE; ++cc) { + qk_t mqk = 0.0f; - device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); - #pragma unroll(D16/NL) - for (short ii = 0; ii < D16; ii += NL) { + #pragma unroll(DK4/NL) + for (short ii = 0; ii < DK4; ii += NL) { const short i = ii + tx; - k4x4_t mk; - deq_k(pk + i/nl_k, i%nl_k, mk); + k4_t mk; + deq_k_t4(pk + i/nl_k, i%nl_k, mk); // note: this is less precise than the version below - //mqka[0] += dot(mq[ii/NL][0], mk[0]); - //mqka[1] += dot(mq[ii/NL][1], mk[1]); - //mqka[2] += dot(mq[ii/NL][2], mk[2]); - //mqka[3] += dot(mq[ii/NL][3], mk[3]); - - mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]); - mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]); - mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]); - mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]); + //mqka[0] += dot(mq[0], mk[0]); + //mqka[1] += dot(mq[1], mk[1]); + //mqka[2] += dot(mq[2], mk[2]); + //mqka[3] += dot(mq[3], mk[3]); + + //q4x4_t mq = sq4x4[i]; + //mqka[0] += dot((float4) mq[0], (float4) mk[0]); + //mqka[1] += dot((float4) mq[1], (float4) mk[1]); + //mqka[2] += dot((float4) mq[2], (float4) mk[2]); + //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + + mqk += dot((float4) mk, (float4) sq4[i]); } - qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3]; + static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails - // simdgroup reduce + // simdgroup reduce (NE = 4) // [ 0 .. 7] -> [ 0] // [ 8 .. 15] -> [ 8] // [16 .. 23] -> [16] // [24 .. 31] -> [24] - //mqk += simd_shuffle_down(mqk, 16); - //mqk += simd_shuffle_down(mqk, 8); - mqk += simd_shuffle_down(mqk, 4); - mqk += simd_shuffle_down(mqk, 2); - mqk += simd_shuffle_down(mqk, 1); + if (NE <= 1) { + mqk += simd_shuffle_down(mqk, 16); + } + if (NE <= 2) { + mqk += simd_shuffle_down(mqk, 8); + } + if (NE <= 4) { + mqk += simd_shuffle_down(mqk, 4); + } + if (NE <= 8) { + mqk += simd_shuffle_down(mqk, 2); + } + if (NE <= 16) { + mqk += simd_shuffle_down(mqk, 1); + } // mqk = mqk*scale + mask*slope if (tx == 0) { @@ -3769,9 +3788,9 @@ kernel void kernel_flash_attn_ext_vec( mqk = args.logit_softcap*precise::tanh(mqk); } - mqk += sm[4*cc + ty]*slope; + mqk += sm[NE*cc + ty]*slope; - ss[4*cc + ty] = mqk; + ss[NE*cc + ty] = mqk; } } } @@ -3780,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec( // online softmax { - const half m = M; - const half s = ss[tiisg]; + const float m = M; + const float s = ss[tiisg]; M = simd_max(max(M, s)); - const half ms = exp(m - M); - const half vs = exp(s - M); + const float ms = exp(m - M); + const float vs = exp(s - M); S = S*ms + simd_sum(vs); @@ -3794,8 +3813,8 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O - #pragma unroll(D16/NL) - for (short ii = 0; ii < D16; ii += NL) { + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { lo[ii/NL] *= ms; } } @@ -3804,19 +3823,20 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - for (short cc = 0; cc < C/4; ++cc) { - device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + //#pragma unroll(C/NE) + for (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); - const s4x4_t ms(ss[4*cc + ty]); + const s4_t ms(ss[NE*cc + ty]); - #pragma unroll(D16/NL) - for (short ii = 0; ii < D16; ii += NL) { + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { const short i = ii + tx; - v4x4_t mv; - deq_v(pv4 + i/nl_v, i%nl_v, mv); + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); - lo[ii/NL] += mv*ms; + lo[ii/NL] += o4_t(float4(mv)*float4(ms)); } } } @@ -3829,7 +3849,7 @@ kernel void kernel_flash_attn_ext_vec( } } - // simdgroup reduce + // simdgroup reduce (NE = 4) // [ 0, 8, 16, 24] -> [ 0] // [ 1, 9, 17, 25] -> [ 1] // [ 2, 10, 18, 26] -> [ 2] @@ -3838,37 +3858,48 @@ kernel void kernel_flash_attn_ext_vec( // [ 5, 13, 21, 29] -> [ 5] // [ 6, 14, 22, 30] -> [ 6] // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < D16; ii += NL) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); - //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); - //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); - //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); - - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); - //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); - //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); - //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); - - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); - //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); - //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); - //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); - - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); - //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); - //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); - //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + for (short ii = 0; ii < DV4; ii += NL) { + if (NE > 1) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); + } + + if (NE > 2) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); + } + + if (NE > 4) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); + } + + if (NE > 8) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); + } + + if (NE > 16) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + } } threadgroup_barrier(mem_flags::mem_threadgroup); // store results to shared memory - for (short i = tiisg; i < D16; i += NL) { - sr4x4[i] = lo[i/NL]; + for (short i = tiisg; i < DV4; i += NL) { + sr4[i] = lo[i/NL]; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3876,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - const half S0 = ss[ 0]; - const half S1 = ss[r*SH + 0]; + const float S0 = ss[ 0]; + const float S1 = ss[r*(SH/2) + 0]; - const half M0 = ss[ 1]; - const half M1 = ss[r*SH + 1]; + const float M0 = ss[ 1]; + const float M1 = ss[r*(SH/2) + 1]; - const half M = max(M0, M1); + const float M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); - const half S = S0*ms0 + S1*ms1; + const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { ss[0] = S; @@ -3895,22 +3926,22 @@ kernel void kernel_flash_attn_ext_vec( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short i = tiisg; i < D16; i += NW) { - sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1; + for (short i = tiisg; i < DV4; i += NW) { + sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; } } threadgroup_barrier(mem_flags::mem_threadgroup); } - device float4x4 * dst44 = (device float4x4 *) dst; + device float4 * dst4 = (device float4 *) dst; // final rescale with 1/S and store to global memory if (sgitg == 0) { const float S = ss[0]; - for (short i = tiisg; i < D16; i += NW) { - dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; } } } @@ -3919,34 +3950,54 @@ kernel void kernel_flash_attn_ext_vec( // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // #define FA_TYPES \ - half4, half4x4, \ - half4x4, \ - half4x4, \ - float, \ - half, half4, half4x4, \ - half4x4 + half4, \ + half4, \ + half4, \ + float, \ + float, float4, \ + half4 + +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES @@ -4321,7 +4372,7 @@ kernel void kernel_cpy_f32_iq4_nl( float amax = 0.0f; // absolute max float max = 0.0f; - for (int j = 0; j < QK4_0; j++) { + for (int j = 0; j < QK4_NL; j++) { const float v = src[j]; if (amax < fabs(v)) { amax = fabs(v); @@ -4429,7 +4480,7 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( args_t args, device const char * src0, @@ -4445,7 +4496,7 @@ void kernel_mul_mv_q2_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4457,20 +4508,19 @@ void kernel_mul_mv_q2_K_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 + const short is = (8*ir)/16;// 0 or 1 device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; @@ -4481,7 +4531,7 @@ void kernel_mul_mv_q2_K_f32_impl( device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { @@ -4512,10 +4562,10 @@ void kernel_mul_mv_q2_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -4530,10 +4580,10 @@ kernel void kernel_mul_mv_q2_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( args_t args, device const char * src0, @@ -4550,7 +4600,7 @@ void kernel_mul_mv_q3_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4566,13 +4616,12 @@ void kernel_mul_mv_q3_K_f32_impl( //const uint16_t kmask1 = 0x3030; //const uint16_t kmask2 = 0x0f0f; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short ip = tid/4; // 0 or 1 + const short il = 2*((tid%4)/2); // 0 or 2 + const short ir = tid%2; + const short l0 = 8*ir; // One would think that the Metal compiler would figure out that ip and il can only have // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it @@ -4597,8 +4646,8 @@ void kernel_mul_mv_q3_K_f32_impl( const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; + const short q_offset = 32*ip + l0; + const short y_offset = 128*ip + 32*il + l0; device const float * y1 = yy + ix*QK_K + y_offset; @@ -4606,10 +4655,11 @@ void kernel_mul_mv_q3_K_f32_impl( thread uint16_t * scales16 = (thread uint16_t *)&scales32; thread const int8_t * scales = (thread const int8_t *)&scales32; - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; + float sumf1[nr0] = {0.f}; + float sumf2[nr0] = {0.f}; + for (int i = ix; i < nb; i += 4) { - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+ 0] = y1[l+ 0]; yl[l+ 8] = y1[l+16]; yl[l+16] = y1[l+32]; @@ -4621,7 +4671,7 @@ void kernel_mul_mv_q3_K_f32_impl( device const uint16_t * a = (device const uint16_t *)(x[i].scales); device const half * dh = &x[i].d; - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { const float d_all = (float)dh[0]; scales16[0] = a[4]; @@ -4632,7 +4682,7 @@ void kernel_mul_mv_q3_K_f32_impl( scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2]; s1 += yl[l+0] * (qs & qm[il/2][0]); s2 += yl[l+1] * (qs & qm[il/2][1]); @@ -4647,7 +4697,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf2[row] += d2 * (scales[2] - 32); s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2+8]; s1 += yl[l+8] * (qs & qm[il/2][0]); s2 += yl[l+9] * (qs & qm[il/2][1]); @@ -4670,7 +4720,7 @@ void kernel_mul_mv_q3_K_f32_impl( y1 += 4 * QK_K; } - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < nr0; ++row) { const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); sumf1[row] = simd_sum(sumf); } @@ -4678,7 +4728,7 @@ void kernel_mul_mv_q3_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { dst_f32[first_row + row] = sumf1[row]; } } @@ -4694,10 +4744,10 @@ kernel void kernel_mul_mv_q3_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( args_t args, device const char * src0, @@ -4707,22 +4757,22 @@ void kernel_mul_mv_q4_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; + + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4735,7 +4785,8 @@ void kernel_mul_mv_q4_K_f32_impl( float yl[16]; float yh[16]; - float sumf[N_DST]={0.f}, all_sum; + + float sumf[nr0]={0.f}; device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; @@ -4744,7 +4795,8 @@ void kernel_mul_mv_q4_K_f32_impl( for (int ib = ix; ib < nb; ib += 4) { float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + + for (short i = 0; i < 8; ++i) { yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; @@ -4755,7 +4807,7 @@ void kernel_mul_mv_q4_K_f32_impl( device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { sc16[0] = sc[0] & kmask1; sc16[1] = sc[2] & kmask1; sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); @@ -4765,19 +4817,21 @@ void kernel_mul_mv_q4_K_f32_impl( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + + for (short i = 0; i < 4; ++i) { + acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); + acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); + acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); + acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000); + acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F); + acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00); + acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0); + acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); } float dall = dh[0]; float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + @@ -4794,10 +4848,10 @@ void kernel_mul_mv_q4_K_f32_impl( device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -4812,10 +4866,10 @@ kernel void kernel_mul_mv_q4_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( args_t args, device const char * src0, @@ -4832,7 +4886,7 @@ void kernel_mul_mv_q5_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4843,7 +4897,7 @@ void kernel_mul_mv_q5_K_f32_impl( device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); - float sumf[2]={0.f}; + float sumf[nr0]={0.f}; float yl[16], yh[16]; @@ -4851,15 +4905,14 @@ void kernel_mul_mv_q5_K_f32_impl( const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short iq = tid/4; + const short ir = tid%4; - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; + const short l0 = 8*ir; + const short q_offset = 32*iq + l0; + const short y_offset = 64*iq + l0; const uint8_t hm1 = 1u << (2*iq); const uint8_t hm2 = hm1 << 1; @@ -4879,14 +4932,14 @@ void kernel_mul_mv_q5_K_f32_impl( device const float * y2 = y1 + 128; float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; } - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { device const uint8_t * q2 = q1 + 64; sc16[0] = a[0] & kmask1; @@ -4896,7 +4949,7 @@ void kernel_mul_mv_q5_K_f32_impl( float4 acc1 = {0.f}; float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { + for (short l = 0; l < 8; ++l) { uint8_t h = qh[l]; acc1[0] += yl[l+0] * (q1[l] & 0x0F); acc1[1] += yl[l+8] * (q1[l] & 0xF0); @@ -4926,7 +4979,7 @@ void kernel_mul_mv_q5_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = tot; @@ -4944,10 +4997,10 @@ kernel void kernel_mul_mv_q5_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( args_t args, device const char * src0, @@ -4969,62 +5022,77 @@ void kernel_mul_mv_q6_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int row = 2*r0 + sgitg; - - if (row >= args.ne0) { - return; - } + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); - float sumf = 0; + float sumf[nr0] = { 0.f }; - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; + float yl[16]; + + const short tid = tiisg/2; + const short ix = tiisg%2; + const short ip = tid/8; // 0 or 1 + const short il = tid%8; + const short l0 = 4*il; + const short is = 8*ip + l0/16; - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; + const short y_offset = 128*ip + l0; + const short q_offset_l = 64*ip + l0; + const short q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; + device const half * dh = &x[i].d; device const float * y = yy + i * QK_K + y_offset; - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + for (short l = 0; l < 4; ++l) { + yl[4*l + 0] = y[l + 0]; + yl[4*l + 1] = y[l + 32]; + yl[4*l + 2] = y[l + 64]; + yl[4*l + 3] = y[l + 96]; } - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + for (short row = 0; row < nr0; ++row) { + const float dall = dh[0]; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (short l = 0; l < 4; ++l) { + sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + q1 += args.nb01; + q2 += args.nb01; + qh += args.nb01; + sc += args.nb01; + dh += args.nb01/2; + } } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[row] = tot; + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } } } @@ -5038,12 +5106,12 @@ kernel void kernel_mul_mv_q6_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, device const char * src0, @@ -5059,7 +5127,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5071,7 +5139,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5092,8 +5160,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5104,18 +5171,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const uint16_t * q2 = xr->qs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; device const uint8_t * aux8 = (device const uint8_t *)q2; const uint32_t aux32 = q2[2] | (q2[3] << 16); const float d = db * (0.5f + (aux32 >> 28)); float sum = 0; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5130,10 +5196,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5148,10 +5214,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, device const char * src0, @@ -5167,7 +5233,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5179,7 +5245,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5200,8 +5266,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5213,8 +5278,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const uint8_t * sc = xr->scales + ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint8_t ls1 = sc[0] & 0xf; const uint8_t ls2 = sc[0] >> 4; @@ -5222,17 +5286,17 @@ void kernel_mul_mv_iq2_xs_f32_impl( const float d2 = db * (0.5f + ls2); float sum1 = 0, sum2 = 0; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } - for (int l = 2; l < 4; ++l) { + for (short l = 2; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5248,10 +5312,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5267,10 +5331,10 @@ kernel void kernel_mul_mv_iq2_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, device const char * src0, @@ -5286,7 +5350,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5298,7 +5362,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5319,7 +5383,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5331,17 +5395,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } @@ -5358,10 +5422,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.5f; + dst_f32[first_row + row] = sum_all * 0.5f; } } } @@ -5377,10 +5441,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( args_t args, device const char * src0, @@ -5396,7 +5460,7 @@ void kernel_mul_mv_iq3_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5408,7 +5472,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5425,8 +5489,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5440,18 +5503,17 @@ void kernel_mul_mv_iq3_s_f32_impl( device const uint8_t * signs = xr->signs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); } @@ -5470,10 +5532,10 @@ void kernel_mul_mv_iq3_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -5489,10 +5551,10 @@ kernel void kernel_mul_mv_iq3_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( args_t args, device const char * src0, @@ -5508,7 +5570,7 @@ void kernel_mul_mv_iq2_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5520,7 +5582,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5532,13 +5594,12 @@ void kernel_mul_mv_iq2_s_f32_impl( // threadgroup_barrier(mem_flags::mem_threadgroup); //} - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5552,19 +5613,18 @@ void kernel_mul_mv_iq2_s_f32_impl( device const uint8_t * signs = qs + QK_K/8; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d1 = db * (0.5f + (sc[0] & 0xf)); const float d2 = db * (0.5f + (sc[0] >> 4)); float2 sum = {0}; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); } @@ -5583,10 +5643,10 @@ void kernel_mul_mv_iq2_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5602,10 +5662,10 @@ kernel void kernel_mul_mv_iq2_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( args_t args, device const char * src0, @@ -5621,7 +5681,7 @@ void kernel_mul_mv_iq1_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5633,18 +5693,17 @@ void kernel_mul_mv_iq1_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float sumy = 0; - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; sumy += yl[i]; } @@ -5657,15 +5716,14 @@ void kernel_mul_mv_iq1_s_f32_impl( device const uint16_t * qh = xr->qh + ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); float sum = 0; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5683,15 +5741,28 @@ void kernel_mul_mv_iq1_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq1_m_f32_impl( args_t args, device const char * src0, @@ -5703,11 +5774,12 @@ void kernel_mul_mv_iq1_m_f32_impl( ushort sgitg) { const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5719,20 +5791,19 @@ void kernel_mul_mv_iq1_m_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; iq1m_scale_t scale; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float4 sumy = {0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; @@ -5747,7 +5818,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const uint8_t * qh = xr->qh + 2 * ib; device const uint16_t * sc = (device const uint16_t *)xr->scales; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); @@ -5756,7 +5827,7 @@ void kernel_mul_mv_iq1_m_f32_impl( constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); float2 sum = {0.f}; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5778,15 +5849,28 @@ void kernel_mul_mv_iq1_m_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -5799,10 +5883,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5813,14 +5899,14 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float sumf[nr0]={0.f}; device const float * yb = y + ix * QK4_NL + it * 8; @@ -5830,12 +5916,13 @@ void kernel_mul_mv_iq4_nl_f32_impl( float4 qf1, qf2; for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; + for (short row = 0; row < nr0; row++) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5860,7 +5947,6 @@ void kernel_mul_mv_iq4_nl_f32_impl( acc1 += acc2; sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 16 * QK4_NL; @@ -5868,15 +5954,29 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -5892,7 +5992,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5903,16 +6003,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; + const short ix = tiisg/16; // 0 or 1 + const short it = tiisg%16; // 0...15 + const short ib = it/2; + const short il = it%2; shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float sumf[nr0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -5923,9 +6023,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); @@ -5949,7 +6052,6 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 2 * QK_K; @@ -5957,54 +6059,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -[[host_name("kernel_mul_mv_iq1_s_f32")]] -kernel void kernel_mul_mv_iq1_s_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq1_m_f32")]] -kernel void kernel_mul_mv_iq1_m_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_nl_f32")]] -kernel void kernel_mul_mv_iq4_nl_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); -} - [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( constant ggml_metal_kargs_mul_mv & args, @@ -6016,7 +6078,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6660,25 +6722,27 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; kernel void kernel_pool_2d_max_f32( device const float * src0, diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index cece8e37d48df..0a981031f5f70 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -64,6 +64,7 @@ set(GGML_OPENCL_KERNELS ggml-opencl_transpose_16 ggml-opencl_transpose_32 ggml-opencl_transpose_32_16 + ggml-opencl_im2col ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 38c3f5c02b5cf..ca95c2d84f05c 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -225,12 +225,14 @@ struct ggml_backend_opencl_context { cl_program program; cl_program program_1; cl_program program_2; + cl_program program_im2col; cl_kernel kernel_add, kernel_add_row; cl_kernel kernel_mul, kernel_mul_row; cl_kernel kernel_scale; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; + cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_clamp; cl_kernel kernel_norm; @@ -240,6 +242,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; + cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; @@ -253,6 +256,7 @@ struct ggml_backend_opencl_context { kernel_mul_mat_q4_0_f32_flat_img_v0; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; + cl_kernel kernel_im2col_f32, kernel_im2col_f16; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Transpose kernels @@ -724,6 +728,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err)); CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); @@ -738,6 +744,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); @@ -785,6 +795,19 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); + // im2col kernels +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_im2col { + #include "ggml-opencl_im2col.cl.h" + }; +#else + const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl"); +#endif + backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err)); + // Kernels for Adreno #ifdef GGML_OPENCL_USE_ADRENO_KERNELS #ifdef GGML_OPENCL_EMBED_KERNELS @@ -914,10 +937,30 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts); CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + // TODO: fixme: these sizes are hardcoded for now. + // they should be allocated based on the model's size + // and the device's max alloc size // Allocate intermediate buffers and images - size_t max_A_q_d_bytes = 311164928; - size_t max_A_s_d_bytes = 38895616; - size_t max_B_d_bytes = 45088768; + size_t required_A_q_d_bytes = 311164928; + size_t required_A_s_d_bytes = 38895616; + size_t required_B_d_bytes = 45088768; + + // Ensure buffer sizes do not exceed the maximum allocation size + size_t max_A_q_d_bytes = MIN(required_A_q_d_bytes, backend_ctx->max_alloc_size); + size_t max_A_s_d_bytes = MIN(required_A_s_d_bytes, backend_ctx->max_alloc_size); + size_t max_B_d_bytes = MIN(required_B_d_bytes, backend_ctx->max_alloc_size); + if (required_A_q_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: A_q_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_A_q_d_bytes, max_A_q_d_bytes); + } + if (required_A_s_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: A_s_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_A_s_d_bytes, max_A_s_d_bytes); + } + if (required_B_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: B_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_B_d_bytes, max_B_d_bytes); + } CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err)); CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err)); @@ -1203,6 +1246,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU_QUICK: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -1232,14 +1276,26 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->ne[3] == 1; case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + if (is_mrope && !is_vision) { + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16) { + return true; + } return false; } - if (mode & GGML_ROPE_TYPE_VISION) { + if (is_vision) { + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16) { + return true; + } return false; } return true; } + case GGML_OP_IM2COL: + return true; default: return false; } @@ -2598,6 +2654,53 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const #endif } +static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_quick_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu_quick; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL); +#endif +} + static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -3996,6 +4099,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const float attn_factor; float beta_fast; float beta_slow; + int32_t sections[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -4003,29 +4107,62 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int32_t)*4); const bool is_neox = mode & 2; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } cl_kernel kernel; - if (!is_neox) { + if (is_neox) { switch (src0->type) { case GGML_TYPE_F32: - kernel = backend_ctx->kernel_rope_norm_f32; + kernel = backend_ctx->kernel_rope_neox_f32; break; case GGML_TYPE_F16: - kernel = backend_ctx->kernel_rope_norm_f16; + kernel = backend_ctx->kernel_rope_neox_f16; break; default: GGML_ASSERT(false); }; + } else if (is_mrope && !is_vision) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_multi_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_multi_f16; + break; + default: + GGML_ASSERT(false); + }; + } else if (is_vision) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_vision_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_vision_f16; + break; + default: + GGML_ASSERT(false); + } } else { switch (src0->type) { case GGML_TYPE_F32: - kernel = backend_ctx->kernel_rope_neox_f32; + kernel = backend_ctx->kernel_rope_norm_f32; break; case GGML_TYPE_F16: - kernel = backend_ctx->kernel_rope_neox_f16; + kernel = backend_ctx->kernel_rope_norm_f16; break; default: GGML_ASSERT(false); @@ -4065,6 +4202,9 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor)); CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast)); CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow)); + if (is_mrope || is_vision) { + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions)); + } size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; @@ -4080,6 +4220,98 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const #endif } +static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + // src0 - filter, src1 - input + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const cl_long IC = src1->ne[is_2D ? 2 : 1]; + const cl_long IH = is_2D ? src1->ne[1] : 1; + const cl_long IW = src1->ne[0]; + + const cl_long KH = is_2D ? src0->ne[1] : 1; + const cl_long KW = src0->ne[0]; + + const cl_long OH = is_2D ? dst->ne[2] : 1; + const cl_long OW = dst->ne[1]; + + // nb is byte offset, src is type float32 + const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4; + const cl_long batch = src1->ne[is_2D ? 3 : 2]; + const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4; + + const cl_long pelements = OW*KW*KH; + const cl_long CHW = IC*KH*KW; + + cl_kernel kernel; + + if(dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_im2col_f16; + } else { + kernel = backend_ctx->kernel_im2col_f32; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &batch_offset)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &delta_offset)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &IW)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &IH)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &IC)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &OW)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_long), &OH)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_long), &KW)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_long), &KH)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_long), &pelements)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_long), &CHW)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &s0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &s1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &p0)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &p1)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &d0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &d1)); + + const int num_blocks = (pelements + 256 - 1) / 256; + size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC}; + size_t local_work_size[] = {256, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -4138,6 +4370,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_gelu; break; + case GGML_UNARY_OP_GELU_QUICK: + if (!any_on_device) { + return false; + } + func = ggml_cl_gelu_quick; + break; case GGML_UNARY_OP_SILU: if (!any_on_device) { return false; @@ -4210,6 +4448,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_rope; break; + case GGML_OP_IM2COL: + if (!any_on_device) { + return false; + } + func = ggml_cl_im2col; + break; default: return false; } diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl index 1d43642a983be..b88792887937a 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -404,6 +404,7 @@ kernel void kernel_scale( // gelu //------------------------------------------------------------------------------ #define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f kernel void kernel_gelu( @@ -434,6 +435,32 @@ kernel void kernel_gelu_4( dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +kernel void kernel_gelu_quick( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + //------------------------------------------------------------------------------ // silu //------------------------------------------------------------------------------ @@ -1325,6 +1352,368 @@ kernel void kernel_rope_neox_f16( } } +kernel void kernel_rope_multi_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_vision_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + +kernel void kernel_rope_vision_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + //------------------------------------------------------------------------------ // cpy //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl new file mode 100644 index 0000000000000..9b41dfb255510 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl @@ -0,0 +1,146 @@ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#elif defined(cl_amd_fp16) +#pragma OPENCL EXTENSION cl_amd_fp16 : enable +#else +#error "Half precision floating point not supportedby OpenCL implementation on your device." +#endif + +#ifdef cl_khr_subgroups +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#elif defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#error "Subgroup not supported on your device." +#endif + +#ifdef cl_intel_required_subgroup_size +// Always use subgroup size of 32 on Intel. +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +// Always use subgroups size of 64 on Adreno. +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +// TODO: do not know how to choose subgroup size on other GPUs. +#error "Selecting subgroup size is not supported on your device." +#endif + +kernel void kernel_im2col_f32( + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + // threadIdx.x + blockIdx.x * blockDim.x + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} + +kernel void kernel_im2col_f16( + global float * src1, + ulong offset1, + global half * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global half*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 6c3b80b0883c9..862b9b666175d 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -26,6 +26,10 @@ # include #endif #include +#include +#include + +namespace fs = std::filesystem; #ifdef _WIN32 typedef SOCKET sockfd_t; @@ -80,6 +84,7 @@ enum rpc_cmd { RPC_CMD_FREE_BUFFER, RPC_CMD_BUFFER_CLEAR, RPC_CMD_SET_TENSOR, + RPC_CMD_SET_TENSOR_HASH, RPC_CMD_GET_TENSOR, RPC_CMD_COPY_TENSOR, RPC_CMD_GRAPH_COMPUTE, @@ -89,6 +94,9 @@ enum rpc_cmd { RPC_CMD_COUNT, }; +// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold +const size_t HASH_THRESHOLD = 10 * 1024 * 1024; + struct rpc_msg_get_alloc_size_req { rpc_tensor tensor; }; @@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req { uint8_t value; }; +struct rpc_msg_set_tensor_hash_rsp { + uint8_t result; +}; + struct rpc_msg_get_tensor_req { rpc_tensor tensor; uint64_t offset; @@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context { // RPC helper functions +// Computes FNV-1a hash of the data +static uint64_t fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return hash; +} + static std::shared_ptr make_socket(sockfd_t fd) { #ifdef _WIN32 if (fd == INVALID_SOCKET) { @@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | + rpc_tensor rpc_tensor = serialize_tensor(tensor); + if (size > HASH_THRESHOLD) { + // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) + size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t); + std::vector input(input_size, 0); + uint64_t hash = fnv_hash((const uint8_t*)data, size); + memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash)); + rpc_msg_set_tensor_hash_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response)); + GGML_ASSERT(status); + if (response.result) { + // the server has the same data, no need to send it + return; + } + } + // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size; std::vector input(input_size, 0); - rpc_tensor rpc_tensor = serialize_tensor(tensor); memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); @@ -772,7 +812,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si class rpc_server { public: - rpc_server(ggml_backend_t backend) : backend(backend) {} + rpc_server(ggml_backend_t backend, const char * cache_dir) + : backend(backend), cache_dir(cache_dir) { + } ~rpc_server(); void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); @@ -782,6 +824,7 @@ class rpc_server { bool free_buffer(const rpc_msg_free_buffer_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request); bool set_tensor(const std::vector & input); + bool set_tensor_hash(const std::vector & input, rpc_msg_set_tensor_hash_rsp & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); @@ -789,6 +832,7 @@ class rpc_server { bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); private: + bool get_cached_file(uint64_t hash, std::vector & data); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); ggml_tensor * create_node(uint64_t id, struct ggml_context * ctx, @@ -797,6 +841,7 @@ class rpc_server { ggml_backend_t backend; + const char * cache_dir; std::unordered_set buffers; }; @@ -960,11 +1005,85 @@ bool rpc_server::set_tensor(const std::vector & input) { } const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset); + if (cache_dir && size > HASH_THRESHOLD) { + uint64_t hash = fnv_hash((const uint8_t*)data, size); + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + // save to cache_dir/hash_str + fs::path cache_file = fs::path(cache_dir) / hash_str; + std::ofstream ofs(cache_file, std::ios::binary); + ofs.write((const char *)data, size); + printf("[%s] saved to '%s'\n", __func__, cache_file.c_str()); + } ggml_backend_tensor_set(tensor, data, offset, size); ggml_free(ctx); return true; } +bool rpc_server::get_cached_file(uint64_t hash, std::vector & data) { + if (!cache_dir) { + return false; + } + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + fs::path cache_file = fs::path(cache_dir) / hash_str; + if (!fs::exists(cache_file)) { + return false; + } + std::ifstream ifs(cache_file, std::ios::binary); + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + data.resize(size); + ifs.read((char *)data.data(), size); + return true; +} + +bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set_tensor_hash_rsp & response) +{ + // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) | + if (input.size() != sizeof(rpc_tensor) + 16) { + return false; + } + const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); + uint64_t offset; + memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); + const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset)); + std::vector cached_file; + if (!get_cached_file(*hash, cached_file)) { + response.result = 0; + return true; + } + size_t size = cached_file.size(); + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + ggml_free(ctx); + return false; + } + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { + GGML_ABORT("[%s] tensor->data out of bounds\n", __func__); + } + } + ggml_backend_tensor_set(tensor, cached_file.data(), offset, size); + response.result = 1; + ggml_free(ctx); + return true; +} + bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), @@ -1148,8 +1267,9 @@ rpc_server::~rpc_server() { } } -static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) { - rpc_server server(backend); +static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, + sockfd_t sockfd, size_t free_mem, size_t total_mem) { + rpc_server server(backend, cache_dir); while (true) { uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { @@ -1260,6 +1380,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } break; } + case RPC_CMD_SET_TENSOR_HASH: { + std::vector input; + if (!recv_msg(sockfd, input)) { + return; + } + rpc_msg_set_tensor_hash_rsp response; + if (!server.set_tensor_hash(input, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_INIT_TENSOR: { rpc_msg_init_tensor_req request; if (!recv_msg(sockfd, &request,sizeof(request))) { @@ -1335,7 +1469,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } } -void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { +void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, + const char * cache_dir, + size_t free_mem, size_t total_mem) { std::string host; int port; if (!parse_endpoint(endpoint, host, port)) { @@ -1364,7 +1500,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint } printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); fflush(stdout); - rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); + rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem); printf("Client connection closed\n"); fflush(stdout); } diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index f713fbe46e012..6699b70bad0d7 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -23,6 +23,32 @@ ggml_add_backend_library(ggml-sycl ../../include/ggml-sycl.h ) +file(GLOB GGML_HEADERS_SYCL "*.hpp") +file(GLOB GGML_SOURCES_SYCL "*.cpp") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) + +if (WIN32) + # To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory + if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C")) + set_target_properties(ggml-sycl PROPERTIES VS_PLATFORM_TOOLSET "Intel C++ Compiler 2025") + set(CMAKE_CXX_COMPILER "icx") + set(CMAKE_CXX_COMPILER_ID "IntelLLVM") + endif() +endif() + +find_package(IntelSYCL) +if (IntelSYCL_FOUND) + # Use oneAPI CMake when possible + target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX) +else() + # Fallback to the simplest way of enabling SYCL when using intel/llvm nightly for instance + target_compile_options(ggml-sycl PRIVATE "-fsycl") + target_link_options(ggml-sycl PRIVATE "-fsycl") +endif() + +target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") + +# Link against oneDNN find_package(DNNL) set(GGML_SYCL_DNNL 0) if(DNNL_FOUND) @@ -62,8 +88,6 @@ if (GGML_SYCL_F16) add_compile_definitions(GGML_SYCL_F16) endif() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl") - if (GGML_SYCL_TARGET STREQUAL "NVIDIA") add_compile_definitions(GGML_SYCL_WARP_SIZE=32) elseif (GGML_SYCL_TARGET STREQUAL "AMD") @@ -76,34 +100,84 @@ else() add_compile_definitions(GGML_SYCL_WARP_SIZE=16) endif() -file(GLOB GGML_HEADERS_SYCL "*.hpp") -file(GLOB GGML_SOURCES_SYCL "*.cpp") -target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) - +if (GGML_SYCL_GRAPH) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) +endif() -if (WIN32) - find_package(IntelSYCL REQUIRED) +# Link against Intel oneMKL or oneMath +if (GGML_SYCL_TARGET STREQUAL "INTEL") + # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically + # See https://github.com/uxlfoundation/oneMath/issues/654 find_package(MKL REQUIRED) - target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) + target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL) else() - if (GGML_SYCL_GRAPH) - add_compile_definitions(GGML_SYCL_GRAPH) + find_package(oneMath QUIET) + if (NOT oneMath_FOUND) + message(STATUS "oneMath not found: oneMath will be automatically downloaded") + # Use FetchContent to automatically pull and build oneMath + include(FetchContent) + set(BUILD_FUNCTIONAL_TESTS False) + set(BUILD_EXAMPLES False) + set(TARGET_DOMAINS blas) + if (GGML_SYCL_TARGET STREQUAL "NVIDIA") + set(ENABLE_MKLCPU_BACKEND False) + set(ENABLE_MKLGPU_BACKEND False) + set(ENABLE_CUBLAS_BACKEND True) + elseif (GGML_SYCL_TARGET STREQUAL "AMD") + set(ENABLE_MKLCPU_BACKEND False) + set(ENABLE_MKLGPU_BACKEND False) + set(ENABLE_ROCBLAS_BACKEND True) + # Ensure setting a string variable here is not overriden by oneMath CACHE variables + cmake_policy(SET CMP0126 NEW) + # Setting the device architecture is only needed and useful for AMD devices in oneMath + set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE) + endif() + FetchContent_Declare( + ONEMATH + GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git + GIT_TAG c255b1b4c41e2ee3059455c1f96a965d6a62568a + ) + FetchContent_MakeAvailable(ONEMATH) + # Create alias to match with find_package targets name + function(onemath_alias target) + if (TARGET ${target}_obj) + # Silence verbose warnings from external libraries + target_compile_options(${target}_obj PRIVATE -w) + endif() + if (TARGET ${target}) + add_library(ONEMATH::${target} ALIAS ${target}) + endif() + endfunction() + onemath_alias(onemath) + onemath_alias(onemath_blas_mklcpu) + onemath_alias(onemath_blas_mklgpu) + onemath_alias(onemath_blas_cublas) + onemath_alias(onemath_blas_rocblas) endif() - if (GGML_SYCL_TARGET STREQUAL "INTEL") - target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) - elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") - add_compile_definitions(GGML_SYCL_NVIDIA) - target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas) + + # Below oneMath compile-time dispatching is used for better performance + if (GGML_SYCL_TARGET STREQUAL "NVIDIA") + target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas) + target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") + target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA) elseif (GGML_SYCL_TARGET STREQUAL "AMD") if (NOT GGML_SYCL_DEVICE_ARCH) message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa") - target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl) + target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas) + target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") + target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD) + else() + # Fallback to oneMath runtime dispatcher + target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC) endif() +endif() - if (GGML_SYCL_DEVICE_ARCH) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}") - endif() +if (GGML_SYCL_DEVICE_ARCH) + target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) + target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) endif() diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 9069c47865f68..05fd5ef46c76a 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -66,41 +66,6 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block return sycl_down_blk_size; } -void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const ggml_sycl_op_flatten_t op) try { - - const bool use_src1 = src1 != nullptr; - if(use_src1) - GGML_ASSERT(strcmp(src1->buffer->buft->iface.get_name(src1->buffer->buft), GGML_SYCL_NAME "_Split") != 0); - GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0); - - // dd = data device - float * src0_ddf = (float *) src0->data; - float * src1_ddf = use_src1 ? (float *) src1->data : nullptr; - float * dst_ddf = (float *) dst->data; - - ggml_sycl_pool_alloc src0_f(ctx.pool()); - ggml_sycl_pool_alloc src1_f(ctx.pool()); - ggml_sycl_pool_alloc dst_f(ctx.pool()); - - ggml_sycl_set_device(ctx.device); - queue_ptr main_stream = ctx.stream(); - // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n", - // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device); - - // do the computation - op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); - // print_ggml_tensor("tensor", dst); -} -catch (sycl::exception const &exc) { - - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - - void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams) { for (int i = 0; i < ggml_sycl_info().device_count; ++i) { for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 27b447ce30d18..3e1ceeaa49486 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -494,12 +494,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); -typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream); - template static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3, @@ -757,24 +751,22 @@ struct bin_bcast_sycl { template inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { + const ggml_tensor *src1, ggml_tensor *dst) { + dpct::queue_ptr main_stream = ctx.stream(); if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, - (sycl::half *)dst_dd, main_stream); + op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, + (sycl::half *)dst->data, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd, + op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream); } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { - op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, + op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data, main_stream); } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { - op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, + op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data, main_stream); } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, @@ -784,8 +776,4 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t } bool gpu_has_xmx(sycl::device &dev); - -void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const ggml_sycl_op_flatten_t op); #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index c96395be61312..d538965b096bf 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -16,9 +16,18 @@ #include #include #include -#include #include +#ifdef GGML_SYCL_USE_INTEL_ONEMKL +#include +// Allow to use the same namespace for Intel oneMKL and oneMath +namespace oneapi { + namespace math = mkl; +} +#else +#include +#endif + #include "ggml.h" #if defined(__linux__) @@ -83,13 +92,32 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { } template struct matrix_info_t { - oneapi::mkl::transpose transpose_info[2]; + oneapi::math::transpose transpose_info[2]; Ts value_info[2]; std::int64_t size_info[3]; std::int64_t ld_info[3]; std::int64_t groupsize_info; }; +inline auto get_onemath_backend(sycl::queue& queue) +#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) + -> sycl::queue& +#endif +{ +// If the backend is known at compile-time, use oneMath backend_selector to use +// compile-time dispatching and avoid the need to dlopen libraries. Otherwise +// fallback to runtime dispatching. +#if defined(GGML_SYCL_NVIDIA) + return oneapi::math::backend_selector{ queue }; +#elif defined(GGML_SYCL_AMD) + return oneapi::math::backend_selector{ queue }; +#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) + return queue; +#else + static_assert(false, "Unsupported backend"); +#endif +} + namespace dpct { typedef sycl::queue *queue_ptr; @@ -1686,26 +1714,18 @@ namespace dpct namespace detail { - template - inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, int lda, const void *b, - int ldb, const void *beta, void *c, int ldc) - { - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - auto data_a = get_memory(a); - auto data_b = get_memory(b); - auto data_c = get_memory(c); -#ifdef GGML_SYCL_NVIDIA - oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector{ q }, - a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, - beta_value, data_c, ldc); -#else - oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, - beta_value, data_c, ldc); -#endif - } + template + inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb, + const void * beta, void * c, int ldc) { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, + lda, data_b, ldb, beta_value, data_c, ldc); + } template class vectorized_binary @@ -1735,7 +1755,7 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, + inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, int ldb, const void * beta, void ** c, int ldc, int batch_size, matrix_info_t * matrix_info) { @@ -1754,48 +1774,28 @@ namespace dpct matrix_info->ld_info[2] = ldc; matrix_info->groupsize_info = batch_size; -#ifdef GGML_SYCL_NVIDIA - sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, - matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), - reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), - reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); -#else - sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, - matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), - reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), - reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); -#endif + sycl::event e = oneapi::math::blas::column_major::gemm_batch( + get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1, + matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2, + reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, + reinterpret_cast(b), matrix_info->ld_info + 1, + reinterpret_cast(matrix_info->value_info + 1), reinterpret_cast(c), + matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); } template - inline void - gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, - int k, const void *alpha, const void *a, int lda, - long long int stride_a, const void *b, int ldb, - long long int stride_b, const void *beta, void *c, - int ldc, long long int stride_c, int batch_size) - { + inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + int m, int n, int k, const void * alpha, const void * a, int lda, + long long int stride_a, const void * b, int ldb, long long int stride_b, + const void * beta, void * c, int ldc, long long int stride_c, int batch_size) { Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); -#ifdef GGML_SYCL_NVIDIA - oneapi::mkl::blas::column_major::gemm_batch( - oneapi::mkl::backend_selector{ q }, a_trans, b_trans, m, n, k, - alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, - batch_size); -#else - oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, - stride_c, batch_size); -#endif + oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, + data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, + data_c, ldc, stride_c, batch_size); } } // namespace detail @@ -2259,13 +2259,10 @@ namespace dpct sycl::range<3>(x, y, 1), direction); } - inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, const void *b, library_data_t b_type, int ldb, - const void *beta, void *c, library_data_t c_type, int ldc, - library_data_t scaling_type) - { + inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n, + int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b, + library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc, + library_data_t scaling_type) { if (scaling_type == library_data_t::real_float && c_type == library_data_t::complex_float) { @@ -2329,9 +2326,8 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } case detail::get_type_combination_id( @@ -2369,8 +2365,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_impl( + detail::gemm_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } @@ -2390,7 +2385,7 @@ namespace dpct default: throw std::runtime_error("the combination of data type is unsupported"); } - } // gemm() + } // gemm() /// Computes a batch of matrix-matrix product with general matrices. /// \param [in] q The queue where the routine should be executed. @@ -2412,7 +2407,7 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, + inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, @@ -2450,7 +2445,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } @@ -2458,7 +2453,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } @@ -2534,15 +2529,11 @@ namespace dpct /// \param [in] stride_c Stride between the different C matrices. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, long long int stride_a, const void *b, - library_data_t b_type, int ldb, long long int stride_b, - const void *beta, void *c, library_data_t c_type, - int ldc, long long int stride_c, int batch_size, - library_data_t scaling_type) - { + inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda, + long long int stride_a, const void * b, library_data_t b_type, int ldb, + long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc, + long long int stride_c, int batch_size, library_data_t scaling_type) { if (scaling_type == library_data_t::real_float && c_type == library_data_t::complex_float) { @@ -2611,20 +2602,18 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size); break; } case detail::get_type_combination_id( library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - stride_a, b, ldb, stride_b, beta, c, ldc, - stride_c, batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size); break; } #endif diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 1e12cb220e4c1..0423305bb4016 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -509,497 +509,409 @@ static void pad_f32_sycl(const float *x, float *dst, const int ne00, }); } -inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); +inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + log_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); +inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); float negative_slope; memcpy(&negative_slope, dst->op_params, sizeof(float)); - leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream); } -inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); } -inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - const float sf0 = (float)dst->ne[0]/src0->ne[0]; - const float sf1 = (float)dst->ne[1]/src0->ne[1]; - const float sf2 = (float)dst->ne[2]/src0->ne[2]; - const float sf3 = (float)dst->ne[3]/src0->ne[3]; + const float sf0 = (float)dst->ne[0]/dst->src[0]->ne[0]; + const float sf1 = (float)dst->ne[1]/dst->src[0]->ne[1]; + const float sf2 = (float)dst->ne[2]/dst->src[0]->ne[2]; + const float sf3 = (float)dst->ne[3]/dst->src[0]->ne[3]; - upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } -inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); pad_f32_sycl(src0_dd, dst_dd, - src0->ne[0], src0->ne[1], src0->ne[2], + dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } -inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + const float * src1_dd = static_cast(dst->src[1]->data); + float * dst_dd = static_cast(dst->data); int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused int offset = dst->op_params[3] / 4; // offset in bytes - acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream); - - GGML_UNUSED(dst); - GGML_UNUSED(ctx); + acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream); } -inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } -inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } -inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } -inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqrt); + ggml_sycl_op_sqrt(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sin); + ggml_sycl_op_sin(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_cos); + ggml_sycl_op_cos(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_acc); + ggml_sycl_op_acc(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu); + ggml_sycl_op_gelu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_silu); + ggml_sycl_op_silu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu_quick); + ggml_sycl_op_gelu_quick(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_tanh); + ggml_sycl_op_tanh(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_relu); + ggml_sycl_op_relu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sigmoid); + ggml_sycl_op_sigmoid(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardsigmoid); + ggml_sycl_op_hardsigmoid(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardswish); + ggml_sycl_op_hardswish(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_exp); + ggml_sycl_op_exp(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_log); + ggml_sycl_op_log(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_neg); + ggml_sycl_op_neg(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_step); + ggml_sycl_op_step(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_leaky_relu); + ggml_sycl_op_leaky_relu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqr); + ggml_sycl_op_sqr(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_upscale); + ggml_sycl_op_upscale(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pad); + ggml_sycl_op_pad(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } @@ -1007,24 +919,24 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_add); + ggml_sycl_op_add(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sub); + ggml_sycl_op_sub(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_mul); + ggml_sycl_op_mul(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_div); + ggml_sycl_op_div(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index b9cf8767cbab2..64665be464762 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -257,50 +257,54 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens GGML_UNUSED(ctx); } -void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_d, const float *src1_d, - float *dst_d, const queue_ptr &stream) { +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type)); + GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type)); GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - const int32_t * src1_i32 = (const int32_t *) src1_d; - - switch (src0->type) { + const int32_t * src1_i32 = (const int32_t *) dst->src[1]->data; + /* TODO: Refactor and remove duplicates */ + switch (dst->src[0]->type) { case GGML_TYPE_F16: - get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d, - src1_i32, dst_d, stream); + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); break; case GGML_TYPE_F32: - get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); break; case GGML_TYPE_Q4_0: if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) { - get_rows_sycl_reorder(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl_reorder(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); } else { - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); } break; case GGML_TYPE_Q4_1: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); break; case GGML_TYPE_Q5_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); break; case GGML_TYPE_Q5_1: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); break; case GGML_TYPE_Q8_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); break; default: // TODO: k-quants - GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); + GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(dst->src[0]->type)); GGML_ABORT("fatal error"); } } diff --git a/ggml/src/ggml-sycl/getrows.hpp b/ggml/src/ggml-sycl/getrows.hpp index cdbe6c2f41bc3..1c560cd9f8941 100644 --- a/ggml/src/ggml-sycl/getrows.hpp +++ b/ggml/src/ggml-sycl/getrows.hpp @@ -15,9 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_d, const float *src1_d, - float *dst_d, const queue_ptr &stream); +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst); #endif // GGML_SYCL_GETROWS_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 9fa24b98098f2..dff9f8d4c4ac2 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -37,6 +37,7 @@ #include "ggml-backend-impl.h" #include "ggml-sycl/backend.hpp" +#include "ggml-sycl/common.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/sycl_hw.hpp" @@ -490,6 +491,23 @@ catch (sycl::exception const &exc) { std::exit(1); } +static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, + size_t offset, size_t size) { + GGML_SYCL_DEBUG(" [SYCL] call %s\n", __func__); + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; + SYCL_CHECK(ggml_sycl_set_device(ctx->device)); + auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); + if (size == 0) { + return; // Nothing to do + } + if (tensor->data == nullptr) { + GGML_ABORT("Error: Tensor data pointer is null.\n"); + } + void * target_ptr = static_cast(tensor->data) + offset; + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size))); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait())); +} + static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) { GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); if (buffer == nullptr) { @@ -510,7 +528,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer, /* .get_base = */ ggml_backend_sycl_buffer_get_base, /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, @@ -1970,16 +1988,8 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_d, const float *src1_d, - float *dst_d, - const queue_ptr &main_stream) { - - ggml_sycl_op_bin_bcast>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(src1_d); +static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst); } @@ -2049,8 +2059,8 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( - *stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, + *stream, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, dst_f16.get(), dpct::library_data_t::real_half, ldc, @@ -2087,17 +2097,10 @@ inline void ggml_sycl_op_mul_mat_sycl( #if !GGML_SYCL_DNNL const float alpha = 1.0f; const float beta = 0.0f; -# ifdef GGML_SYCL_NVIDIA - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - oneapi::mkl::backend_selector{ *stream }, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, - ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); -# else - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, - dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), - dst_dd_i, ldc))); -# endif + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( + get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); #else DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), @@ -2114,13 +2117,14 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { +static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); const int32_t * opts = (const int32_t *)dst->op_params; enum ggml_op_pool op = static_cast(opts[0]); @@ -2131,8 +2135,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens const int p0 = opts[5]; const int p1 = opts[6]; - const int64_t IH = src0->ne[1]; - const int64_t IW = src0->ne[0]; + const int64_t IH = dst->src[0]->ne[1]; + const int64_t IW = dst->src[0]->ne[0]; const int64_t N = dst->ne[3]; const int64_t OC = dst->ne[2]; @@ -2151,163 +2155,125 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens parallel_elements, src0_dd, dst_dd, op, item_ct1); }); - - GGML_UNUSED(src1); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } -inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); +inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - const int64_t ne = ggml_nelements(src0); + const int64_t ne = ggml_nelements(dst->src[0]); sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } -inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - const int64_t ncols = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } -inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + int32_t * dst_dd = static_cast(dst->data); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_I32); - const int64_t ncols = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream); } -inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_I32); - const int64_t ncols = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + int32_t * dst_dd = static_cast(dst->data); - argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream); + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); + argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); } -inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int nrows0 = ggml_nrows(src0); + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t ne01 = dst->src[0]->ne[1]; + const int nrows0 = ggml_nrows(dst->src[0]); const int n_past = ((int32_t *) dst->op_params)[0]; diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } -inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); float scale; memcpy(&scale, dst->op_params, sizeof(float)); - scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream); + scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream); /* DPCT1010:87: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ SYCL_CHECK(0); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } -inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); float min; float max; memcpy(&min, dst->op_params, sizeof(float)); memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); + clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream()); /* DPCT1010:88: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ SYCL_CHECK(0); - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { @@ -2677,37 +2643,37 @@ catch (sycl::exception const &exc) { static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat); + ggml_sycl_op_repeat(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows); + ggml_sycl_op_get_rows(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm); + ggml_sycl_op_norm(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm); + ggml_sycl_op_rms_norm(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm); + ggml_sycl_op_l2_norm(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm); + ggml_sycl_op_group_norm(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } @@ -2863,14 +2829,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, - (const char *)src0_as_f16, dpct::library_data_t::real_half, - nb01 / nb00, nb02 / nb00, - (const char *)src1_f16, dpct::library_data_t::real_half, - nb11 / nb10, nb12 / nb10, beta, - (char *)dst_t, cu_data_type, ne01, nb2 / nb0, - ne12 * ne13, cu_compute_type))); + *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, + (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t, + cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type))); } else { const int ne23 = ne12*ne13; @@ -2905,7 +2867,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, }); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, + *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); @@ -3251,48 +3213,48 @@ catch (sycl::exception const &exc) { } static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale); + ggml_sycl_op_scale(ctx, dst); } static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp); + ggml_sycl_op_clamp(ctx, dst); } static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf); + ggml_sycl_op_diag_mask_inf(ctx, dst); } static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope); + ggml_sycl_op_rope(ctx, dst); } static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d); + ggml_sycl_op_pool2d(ctx, dst); } static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col); + ggml_sycl_op_im2col(ctx, dst); } static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum); + ggml_sycl_op_sum(ctx, dst); } static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows); + ggml_sycl_op_sum_rows(ctx, dst); } static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort); + ggml_sycl_op_argsort(ctx, dst); } static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax); + ggml_sycl_op_argmax(ctx, dst); } @@ -3317,7 +3279,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) { +static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try { if (!g_sycl_loaded) return false; if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) { @@ -3510,6 +3472,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg } return true; +} catch (sycl::exception & e) { + std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); } GGML_API void ggml_backend_sycl_get_device_description(int device, char *description, diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 6146a99edbe77..009b42035d026 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -82,10 +82,9 @@ static void im2col_sycl( } } -void ggml_sycl_op_im2col( - ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { +void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -115,12 +114,8 @@ void ggml_sycl_op_im2col( const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 if (dst->type == GGML_TYPE_F16) { - im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + im2col_sycl((const float *) src1->data, (sycl::half *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream()); } else { - im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + im2col_sycl((const float *) src1->data, (float *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream()); } - - GGML_UNUSED(src0); - GGML_UNUSED(src0_dd); - GGML_UNUSED(ctx); } diff --git a/ggml/src/ggml-sycl/im2col.hpp b/ggml/src/ggml-sycl/im2col.hpp index 7db144fbbe524..dbbb248ddb4fc 100644 --- a/ggml/src/ggml-sycl/im2col.hpp +++ b/ggml/src/ggml-sycl/im2col.hpp @@ -16,8 +16,6 @@ #include "common.hpp" void ggml_sycl_op_im2col( - ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream); + ggml_backend_sycl_context & ctx, ggml_tensor *dst); #endif // GGML_SYCL_IM2COL_HPP diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index d9678da8f042e..4e9f438b46ba6 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -367,7 +367,7 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE); }); @@ -389,7 +389,7 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); @@ -397,90 +397,78 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, } } -void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, - ggml_tensor* dst, const float* src0_dd, - const float* src1_dd, float* dst_dd, - const queue_ptr& main_stream) { +void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); float eps; memcpy(&eps, dst->op_params, sizeof(float)); norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); - - (void)src1; - (void)dst; - (void)src1_dd; } -void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, - const ggml_tensor* src1, ggml_tensor* dst, - const float* src0_dd, const float* src1_dd, - float* dst_dd, - const queue_ptr& main_stream) { +void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); int num_groups = dst->op_params[0]; + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); - int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device); - - (void)src1; - (void)dst; - (void)src1_dd; - GGML_UNUSED(ctx); + int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups); + group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device); } -void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, - const ggml_tensor* src1, ggml_tensor* dst, - const float* src0_dd, const float* src1_dd, - float* dst_dd, - const queue_ptr& main_stream) { +void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); float eps; memcpy(&eps, dst->op_params, sizeof(float)); rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); - - (void)src1; - (void)dst; - (void)src1_dd; } -void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, - const ggml_tensor* src1, ggml_tensor* dst, - const float* src0_dd, const float* src1_dd, - float* dst_dd, - const queue_ptr& main_stream) { +void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); float eps; memcpy(&eps, dst->op_params, sizeof(float)); l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); - (void)src1; - (void)dst; - (void)src1_dd; } diff --git a/ggml/src/ggml-sycl/norm.hpp b/ggml/src/ggml-sycl/norm.hpp index 11e91680cc496..612cd67cf9183 100644 --- a/ggml/src/ggml-sycl/norm.hpp +++ b/ggml/src/ggml-sycl/norm.hpp @@ -15,27 +15,12 @@ #include "common.hpp" -void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, - ggml_tensor* dst, const float* src0_dd, - const float* src1_dd, float* dst_dd, - const queue_ptr& main_stream); - -void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, - const ggml_tensor* src1, ggml_tensor* dst, - const float* src0_dd, const float* src1_dd, - float* dst_dd, - const queue_ptr& main_stream); - -void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, - const ggml_tensor* src1, ggml_tensor* dst, - const float* src0_dd, const float* src1_dd, - float* dst_dd, - const queue_ptr& main_stream); - -void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, - const ggml_tensor* src1, ggml_tensor* dst, - const float* src0_dd, const float* src1_dd, - float* dst_dd, - const queue_ptr& main_stream); +void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + +void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + +void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + +void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); #endif // GGML_SYCL_NORM_HPP diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index 8e8347ff4f95e..b60415784f32d 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -1,8 +1,5 @@ -#include -#include #include "outprod.hpp" - void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { const ggml_tensor *src0 = dst->src[0]; const ggml_tensor *src1 = dst->src[1]; @@ -34,20 +31,13 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { // Handle transposition of src1 const bool src1_T = ggml_is_transposed(src1); - const oneapi::mkl::transpose src1_op = - src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans; + const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); try { - // Perform matrix multiplication using oneMKL GEMM -#ifdef GGML_SYCL_NVIDIA - oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector{ *stream }, - oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, - ne00, src1_d, ldb, beta, dst_d, ne0); -#else - oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, - src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); -#endif + // Perform matrix multiplication using oneMath GEMM + oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op, + ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); } catch (sycl::exception const& exc) { std::cerr << exc.what() << std::endl; diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index 1244b231af738..bbcb356e97992 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -192,18 +192,15 @@ static void rope_neox_sycl( } } -void ggml_sycl_op_rope( - ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { - const ggml_tensor * src2 = dst->src[2]; +void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->src[0]->type == dst->type); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t nr = ggml_nrows(src0); + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t ne01 = dst->src[0]->ne[1]; + const int64_t nr = ggml_nrows(dst->src[0]); //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -228,49 +225,47 @@ void ggml_sycl_op_rope( const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - const int32_t * pos = (const int32_t *) src1_dd; + const int32_t * pos = (const int32_t *) dst->src[1]->data; const float * freq_factors = nullptr; - if (src2 != nullptr) { - freq_factors = (const float *) src2->data; + if (dst->src[2] != nullptr) { + freq_factors = (const float *) dst->src[2]->data; } rope_corr_dims corr_dims; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + // compute if (is_neox) { - if (src0->type == GGML_TYPE_F32) { + if (dst->src[0]->type == GGML_TYPE_F32) { rope_neox_sycl( - (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream ); - } else if (src0->type == GGML_TYPE_F16) { + } else if (dst->src[0]->type == GGML_TYPE_F16) { rope_neox_sycl( - (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream ); } else { GGML_ABORT("fatal error"); } } else { - if (src0->type == GGML_TYPE_F32) { + if (dst->src[0]->type == GGML_TYPE_F32) { rope_norm_sycl( - (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream ); - } else if (src0->type == GGML_TYPE_F16) { + } else if (dst->src[0]->type == GGML_TYPE_F16) { rope_norm_sycl( - (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream ); } else { GGML_ABORT("fatal error"); } } - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_dd); - GGML_UNUSED(ctx); } diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp index 00354c3131bd7..a399bddb8a07b 100644 --- a/ggml/src/ggml-sycl/rope.hpp +++ b/ggml/src/ggml-sycl/rope.hpp @@ -15,8 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_rope( - ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream); +void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); #endif // GGML_SYCL_ROPE_HPP diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 22df7328eeff1..bcf3706edee64 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -38,9 +38,11 @@ if (Vulkan_FOUND) ERROR_VARIABLE glslc_error) if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") - message( "GL_KHR_cooperative_matrix not supported by glslc") + message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) else() - message( "GL_KHR_cooperative_matrix supported by glslc") + message(STATUS "GL_KHR_cooperative_matrix supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) endif() @@ -52,12 +54,30 @@ if (Vulkan_FOUND) ERROR_VARIABLE glslc_error) if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") - message( "GL_NV_cooperative_matrix2 not supported by glslc") + message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) else() - message( "GL_NV_cooperative_matrix2 supported by glslc") + message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) endif() + # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") + message(STATUS "GL_EXT_integer_dot_product not supported by glslc") + set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) + else() + message(STATUS "GL_EXT_integer_dot_product supported by glslc") + set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + endif() + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -122,13 +142,16 @@ if (Vulkan_FOUND) include(ExternalProject) # Native build through ExternalProject_Add ExternalProject_Add( - vulkan-shaders-gen - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders - CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} - -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} - BUILD_COMMAND ${CMAKE_COMMAND} --build . - INSTALL_COMMAND ${CMAKE_COMMAND} --install . - INSTALL_DIR ${CMAKE_BINARY_DIR} + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} + -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} + -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} + BUILD_COMMAND ${CMAKE_COMMAND} --build . + INSTALL_COMMAND ${CMAKE_COMMAND} --install . + INSTALL_DIR ${CMAKE_BINARY_DIR} ) ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) endif() diff --git a/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in index b6af747a500ed..2d8a85696d374 100644 --- a/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +++ b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in @@ -4,8 +4,8 @@ set(CMAKE_CXX_FLAGS -O2) set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) -set(CMAKE_C_COMPILER @HOST_C_COMPILER@) -set(CMAKE_CXX_COMPILER @HOST_CXX_COMPILER@) +set(CMAKE_C_COMPILER "@HOST_C_COMPILER@") +set(CMAKE_CXX_COMPILER "@HOST_CXX_COMPILER@") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f0f3cae97ffe3..ad2a0f57b3da3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -24,6 +24,28 @@ #include #include +#if defined(_MSC_VER) +# define NOMINMAX 1 +# include +# define YIELD() YieldProcessor() +#elif defined(__clang__) || defined(__GNUC__) +# if defined(__x86_64__) ||defined(__i386__) +# include +# define YIELD() _mm_pause() +# elif defined(__arm__) || defined(__aarch64__) +# if defined(__clang__) +# include +# define YIELD() __yield() +# else +# define YIELD() asm volatile("yield") +# endif +# endif +#endif + +#if !defined(YIELD) +#define YIELD() +#endif + #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -31,6 +53,7 @@ #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_AMD 0x1002 #define VK_VENDOR_ID_APPLE 0x106b @@ -234,6 +257,8 @@ struct vk_device_struct { bool float_controls_rte_fp16; bool subgroup_add; + bool integer_dot_product; + bool subgroup_size_control; uint32_t subgroup_min_size; uint32_t subgroup_max_size; @@ -245,6 +270,12 @@ struct vk_device_struct { uint32_t coopmat_m; uint32_t coopmat_n; uint32_t coopmat_k; + + bool coopmat_int_support; + uint32_t coopmat_int_m; + uint32_t coopmat_int_n; + uint32_t coopmat_int_k; + bool coopmat2; size_t idx; @@ -263,10 +294,10 @@ struct vk_device_struct { vk_matmul_pipeline pipeline_matmul_f32_f16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; - vk_pipeline pipeline_matmul_split_k_reduce; - vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; vk_matmul_pipeline pipeline_matmul_id_f32 {}; vk_matmul_pipeline2 pipeline_matmul_id_f16; @@ -274,6 +305,9 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + vk_pipeline pipeline_matmul_split_k_reduce; + vk_pipeline pipeline_quantize_q8_1; + vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; @@ -341,6 +375,7 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_split_k_reduce; std::unordered_map pipelines; std::unordered_map pipeline_descriptor_set_requirements; @@ -490,6 +525,10 @@ struct vk_flash_attn_push_constants { uint32_t n_head_log2; float m0; float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; }; struct vk_op_push_constants { @@ -640,6 +679,13 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t H; }; +struct vk_op_upscale_push_constants { + uint32_t ne; uint32_t a_offset; uint32_t d_offset; + uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; + float sf0; float sf1; float sf2; float sf3; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -649,13 +695,6 @@ struct vk_staging_memcpy { size_t n; }; -struct vk_op_upscale_push_constants { - uint32_t ne; uint32_t a_offset; uint32_t d_offset; - uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; - uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; - float sf0; float sf1; float sf2; float sf3; -}; - struct vk_context_struct { vk_submission * s; std::vector seqs; @@ -770,7 +809,8 @@ struct ggml_backend_vk_context { ggml_vk_garbage_collector gc; size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; vk_buffer prealloc_x, prealloc_y, prealloc_split_k; - vk::Fence fence; + vk::Fence fence, almost_ready_fence; + bool almost_ready_fence_pending {}; vk_buffer buffer_pool[MAX_VK_BUFFERS]; @@ -861,6 +901,39 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_backend_vk_free(ggml_backend_t backend); +// Wait for ctx->fence to be signaled. +static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { + // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep + // during this wait. + if (ctx->almost_ready_fence_pending) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence"); + ctx->device->device.resetFences({ ctx->almost_ready_fence }); + ctx->almost_ready_fence_pending = false; + } + + // Spin (w/pause) waiting for the graph to finish executing. + vk::Result result; + while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) { + if (result != vk::Result::eNotReady) { + fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__); + exit(1); + } + for (uint32_t i = 0; i < 100; ++i) { + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + } + } + ctx->device->device.resetFences({ ctx->fence }); +} + // variables to track number of compiles in progress static uint32_t compile_count = 0; static std::mutex compile_count_mutex; @@ -1462,7 +1535,7 @@ static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ // small rows, large cols if (small_rows) { - return {flash_attention_num_small_rows, 128}; + return {flash_attention_num_small_rows, 64}; } // small cols to reduce register count if (ggml_is_quantized(type) || D == 256) { @@ -1598,6 +1671,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // mulmat std::vector l_warptile, m_warptile, s_warptile, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, @@ -1662,6 +1736,10 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; @@ -1755,6 +1833,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // can't use 256 for D==80. uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; }; @@ -2000,6 +2080,14 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -2031,6 +2119,16 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); @@ -2056,6 +2154,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 +#undef CREATE_MMQ #undef CREATE_MM } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} @@ -2073,6 +2172,14 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); @@ -2099,6 +2206,16 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); @@ -2132,7 +2249,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t rm_stdq = 1; uint32_t rm_kq = 2; if (device->vendor_id == VK_VENDOR_ID_AMD) { - if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN + if (device->architecture == AMD_GCN) { rm_stdq = 2; rm_kq = 4; } @@ -2266,6 +2383,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { if (device->subgroup_add && device->subgroup_require_full_support) { @@ -2452,6 +2571,7 @@ static vk_device ggml_vk_get_device(size_t idx) { bool pipeline_robustness = false; bool coopmat2_support = false; device->coopmat_support = false; + device->integer_dot_product = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -2477,6 +2597,11 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + device->integer_dot_product = true; +#endif } } @@ -2490,6 +2615,7 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceVulkan11Properties vk11_props; vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; props2.pNext = &props3; props3.pNext = &subgroup_props; @@ -2524,6 +2650,11 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; @@ -2570,6 +2701,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->coopmat_support = false; } + device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -2661,7 +2794,15 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct = (VkBaseOutStructure *)&maint4_features; device_extensions.push_back("VK_KHR_maintenance4"); } - + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + device_extensions.push_back("VK_KHR_shader_integer_dot_product"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -2831,6 +2972,17 @@ static vk_device ggml_vk_get_device(size_t idx) { device->coopmat_acc_f16_support = true; } } + } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup && + device->coopmat_int_m == 0 + ) { + device->coopmat_int_support = true; + device->coopmat_int_m = prop.MSize; + device->coopmat_int_n = prop.NSize; + device->coopmat_int_k = prop.KSize; } } @@ -2935,25 +3087,11 @@ static void ggml_vk_print_gpu_info(size_t idx) { vk::PhysicalDevice physical_device = devices[dev_num]; std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); - vk::PhysicalDeviceProperties2 props2; - vk::PhysicalDeviceMaintenance3Properties props3; - vk::PhysicalDeviceSubgroupProperties subgroup_props; - vk::PhysicalDeviceDriverProperties driver_props; - props2.pNext = &props3; - props3.pNext = &subgroup_props; - subgroup_props.pNext = &driver_props; - physical_device.getProperties2(&props2); - - vk_device_architecture arch = get_device_architecture(physical_device); - uint32_t default_subgroup_size = get_subgroup_size("", arch); - const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; - - const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; - bool fp16_storage = false; bool fp16_compute = false; bool coopmat_support = false; bool coopmat2_support = false; + bool integer_dot_product = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -2969,27 +3107,44 @@ static void ggml_vk_print_gpu_info(size_t idx) { } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + integer_dot_product = true; #endif } } const vk_device_architecture device_architecture = get_device_architecture(physical_device); - if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) { - coopmat_support = false; - } - const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; - vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures(); + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props; + + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + physical_device.getProperties2(&props2); VkPhysicalDeviceFeatures2 device_features2; device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; device_features2.pNext = nullptr; - device_features2.features = (VkPhysicalDeviceFeatures)device_features; VkPhysicalDeviceVulkan11Features vk11_features; vk11_features.pNext = nullptr; @@ -3002,7 +3157,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { vk11_features.pNext = &vk12_features; // Pointer to the last chain element - VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; + last_struct = (VkBaseOutStructure *)&vk12_features; #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; @@ -3014,20 +3169,39 @@ static void ggml_vk_print_gpu_info(size_t idx) { last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; last_struct = (VkBaseOutStructure *)&coopmat_features; } +#endif + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + } vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; - coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); + const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; + const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + + integer_dot_product = integer_dot_product + && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated + && shader_integer_dot_product_features.shaderIntegerDotProduct; + + coopmat_support = coopmat_support +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + && coopmat_features.cooperativeMatrix #endif + && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n", + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, - props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str()); + props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); @@ -3229,6 +3403,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->prealloc_size_split_k = 0; ctx->fence = ctx->device->device.createFence({}); + ctx->almost_ready_fence = ctx->device->device.createFence({}); #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); @@ -3293,6 +3468,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte } } + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; + } + + return pipelines; + } + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { return nullptr; } @@ -3585,8 +3771,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo return s; } - - static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); @@ -4016,8 +4200,8 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int return split_k; } -static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); if (ctx->device->coopmat2) { // Use large shader when the N dimension is greater than the medium shader's tile size @@ -4042,9 +4226,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul( @@ -4054,7 +4238,7 @@ static void ggml_vk_matmul( uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t padded_n) { - VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); ggml_vk_sync_buffers(subctx); if (split_k == 1) { const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; @@ -4072,7 +4256,7 @@ static void ggml_vk_matmul( ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { @@ -4214,6 +4398,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); } +static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { + switch(type) { + case GGML_TYPE_Q8_1: + return ctx->device->pipeline_quantize_q8_1; + default: + std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; + GGML_ABORT("fatal error"); + } +} + +static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) { + VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")"); + + vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 }); +} + static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -4265,10 +4468,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig); if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat @@ -4278,13 +4490,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); - const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking - uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; const int x_ne = ne01 * ne00; const int y_ne = padded_n * ne10; const int d_ne = ne11 * ne01; @@ -4294,11 +4506,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; - const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t d_sz = sizeof(float) * d_ne; vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; if (x_non_contig) { to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); @@ -4313,6 +4526,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + } + if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; @@ -4326,7 +4543,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ctx->prealloc_size_x = x_sz_upd; } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { ctx->prealloc_size_y = y_sz_upd; } if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { @@ -4341,6 +4558,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1); + } if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); } @@ -4376,6 +4596,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qy_needs_dequant) { d_Y = ctx->prealloc_y; GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -4392,6 +4615,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (y_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); } + if (quantize_y) { + ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13); + } uint32_t stride_batch_x = ne00*ne01; uint32_t stride_batch_y = ne10*ne11; @@ -4400,7 +4626,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); } - if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } @@ -5232,7 +5458,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t nbm1 = mask ? mask->nb[1] : 0; const uint32_t D = neq0; - const uint32_t N = neq1; + uint32_t N = neq1; const uint32_t KV = nek1; GGML_ASSERT(ne0 == D); @@ -5287,12 +5513,60 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); + vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows && + qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + + uint32_t split_kv = KV; + uint32_t split_k = 1; + + if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) { + GGML_ASSERT(workgroups_x == 1); + // Try to run two workgroups per SM. + split_k = ctx->device->shader_core_count * 2 / workgroups_y; + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_k = CEIL_DIV(KV, split_kv); + workgroups_x = split_k; + } + } + + // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). + const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + if (split_k_size > ctx->device->max_memory_allocation_size) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + if (dryrun) { // Request descriptor sets ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + } return; } @@ -5313,8 +5587,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_sync_buffers(subctx); - vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; @@ -5379,16 +5651,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, - mask != nullptr, n_head_log2, m0, m1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, - }, - sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); + mask != nullptr, n_head_log2, m0, m1, + gqa_ratio, split_kv, split_k }; + + ggml_vk_sync_buffers(subctx); + + if (split_k > 1) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + }, + // We only use split_k when group query attention is enabled, which means + // there's no more than one tile of rows (i.e. workgroups_x would have been + // one). We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { D, (uint32_t)ne1, split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, + { + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 }); + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z }); + } } static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { @@ -6929,6 +7230,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_pipeline_allocate_descriptor_sets(ctx->device); vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -7177,6 +7482,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_pipeline_allocate_descriptor_sets(ctx->device); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); @@ -7236,66 +7545,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ free(x_chk); } -static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) { +// This does not work without ggml q8_1 quantization support +// +// typedef uint16_t ggml_half; +// typedef uint32_t ggml_half2; +// +// #define QK8_1 32 +// typedef struct { +// union { +// struct { +// ggml_half d; // delta +// ggml_half s; // d * sum(qs[i]) +// } GGML_COMMON_AGGR_S; +// ggml_half2 ds; +// } GGML_COMMON_AGGR_U; +// int8_t qs[QK8_1]; // quants +// } block_q8_1; +// +// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { +// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")"); +// GGML_ASSERT(quant == GGML_TYPE_Q8_1); +// +// const size_t x_sz = sizeof(float) * ne; +// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); +// float * x = (float *) malloc(x_sz); +// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); +// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); +// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// +// for (size_t i = 0; i < ne; i++) { +// x[i] = rand() / (float)RAND_MAX; +// } +// +// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); +// +// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); +// +// if (ctx->device->need_compiles) { +// ggml_vk_load_shaders(ctx->device); +// } +// +// ggml_pipeline_allocate_descriptor_sets(ctx->device); +// +// ggml_vk_buffer_write(x_buf, 0, x, x_sz); +// +// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); +// ggml_vk_ctx_begin(ctx->device, subctx); +// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); +// ggml_vk_ctx_end(subctx); +// +// auto begin = std::chrono::high_resolution_clock::now(); +// +// ggml_vk_submit(subctx, ctx->fence); +// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); +// ctx->device->device.resetFences({ ctx->fence }); +// +// auto end = std::chrono::high_resolution_clock::now(); +// +// double ms_quant = std::chrono::duration_cast(end-begin).count() / 1000.0; +// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz); +// +// ggml_vk_quantize_data(x, qx_res, ne, quant); +// +// int first_err = -1; +// +// for (size_t i = 0; i < ne / 32; i++) { +// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// for (size_t j = 0; j < 32; j++) { +// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]); +// +// if (first_err < 0 && error > 1) { +// first_err = i; +// } +// } +// } +// +// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl; +// +// if (first_err != -1) { +// std::cerr << "first_error = " << first_err << std::endl; +// std::cerr << "Actual result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " "; +// } +// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " "; +// } +// std::cerr << std::endl; +// } +// +// ggml_vk_destroy_buffer(x_buf); +// ggml_vk_destroy_buffer(qx_buf); +// +// free(x); +// free(qx); +// free(qx_res); +// } + +static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) { VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); const size_t x_ne = m * k * batch; const size_t y_ne = k * n * batch; const size_t d_ne = m * n * batch; + vk_matmul_pipeline2 * pipelines; + + if (mmq) { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1; + } else { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat; + } + + const bool fp16acc = ctx->device->fp16; + vk_pipeline p; std::string shname; if (shader_size == 0) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; + p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; } else if (shader_size == 1) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; + p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; } else if (shader_size == 2) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; + p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; } else { GGML_ASSERT(0); } - const size_t kpad = ggml_vk_align_size(k, p->align); + const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align); - if (k != kpad) { + if (mmq || k != kpad) { if (shader_size == 0) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; + p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s; shname = std::string(ggml_type_name(quant)) + "_S"; } else if (shader_size == 1) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; + p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m; shname = std::string(ggml_type_name(quant)) + "_M"; } else if (shader_size == 2) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; + p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l; shname = std::string(ggml_type_name(quant)) + "_L"; } else { GGML_ASSERT(0); } } + if (p == nullptr) { + std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl; + return; + } + const size_t x_sz = sizeof(float) * x_ne; const size_t y_sz = sizeof(float) * y_ne; const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); + const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz; const size_t d_sz = sizeof(float) * d_ne; float * x = (float *) malloc(x_sz); float * y = (float *) malloc(y_sz); void * qx = malloc(qx_sz); vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); float * d = (float *) malloc(d_sz); float * d_chk = (float *) malloc(d_sz); for (size_t i = 0; i < x_ne; i++) { x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + // x[i] = i % k; } ggml_vk_quantize_data(x, qx, x_ne, quant); for (size_t i = 0; i < y_ne; i++) { - // y[i] = rand() / (float)RAND_MAX; - y[i] = (i % k == i / k) ? 1.0f : 0.0f; + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i % k; } ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); @@ -7310,6 +7751,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); } } + if (mmq) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it); + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } ggml_pipeline_allocate_descriptor_sets(ctx->device); @@ -7318,13 +7766,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ggml_vk_ctx_begin(ctx->device, subctx); - for (size_t i = 0; i < num_it; i++) { - ggml_vk_matmul( - ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), - m, n, k, - k, k, m, k*m, k*n, m*n, - split_k, batch, batch, batch, 1, 1, n - ); + if (mmq) { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne); + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } else { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } } ggml_vk_ctx_end(subctx); @@ -7382,7 +7842,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); - std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + std::cerr << "TEST dequant matmul " << shname; + if (mmq) { + std::cerr << " mmq"; + } + std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; @@ -7392,6 +7856,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, std::cerr << "Expected result: " << std::endl << std::endl; ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << "src0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "src1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b); + if (split_k > 1) { float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); @@ -7414,6 +7884,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ggml_vk_destroy_buffer(qx_buf); ggml_vk_destroy_buffer(y_buf); + ggml_vk_destroy_buffer(qy_buf); ggml_vk_destroy_buffer(d_buf); free(x); @@ -7448,6 +7919,24 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { }; const size_t num_it = 100; + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true); + + abort(); + for (size_t i = 0; i < vals.size(); i += 3) { ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); @@ -7522,11 +8011,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. -static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ if (ggml_is_empty(node) || !node->buffer) { return false; } @@ -7898,7 +8387,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); + bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; @@ -7912,7 +8401,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { ggml_backend_buffer * buf = nullptr; switch (tensor->op) { @@ -8015,12 +8504,15 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { + ggml_vk_submit(subctx, ctx->almost_ready_fence); + ctx->almost_ready_fence_pending = true; + } else { + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + } if (use_fence) { - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); - - ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_wait_for_fence(ctx); } #ifdef GGML_VULKAN_CHECK_RESULTS ggml_vk_check_results_1(tensor); @@ -8106,6 +8598,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->gc.events.clear(); ctx->device->device.destroyFence(ctx->fence); + ctx->device->device.destroyFence(ctx->almost_ready_fence); } static int ggml_vk_get_device_count() { @@ -8452,8 +8945,7 @@ static void ggml_backend_vk_synchronize(ggml_backend_t backend) { } ggml_vk_submit(transfer_ctx, ctx->fence); - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); - ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_wait_for_fence(ctx); for (auto& cpy : transfer_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); @@ -8472,7 +8964,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } @@ -8514,11 +9006,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) + bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || (mul_mat_bytes >= mul_mat_bytes_per_submit) || - (i == last_node); + (i == last_node) || + (almost_ready && !ctx->almost_ready_fence_pending); - bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); + bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); if (enqueued) { ++submitted_nodes; @@ -8530,7 +9025,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg #endif } - if (submit) { + if (submit && enqueued) { first_node_in_batch = true; submitted_nodes = 0; mul_mat_bytes = 0; @@ -8764,6 +9259,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } if (op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -9254,7 +9753,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { - const float *params = (const float *)tensor->op_params; + const float * params = (const float *)tensor->op_params; tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); } else if (tensor->op == GGML_OP_MUL_MAT) { tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); @@ -9271,7 +9770,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_UPSCALE) { tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_SCALE) { - tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { @@ -9279,7 +9779,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_COS) { tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { - tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_PAD) { tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); } else if (tensor->op == GGML_OP_REPEAT) { @@ -9293,7 +9794,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { - tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); + const float * float_params = (const float *)tensor->op_params; + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { @@ -9306,14 +9808,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { - tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); } else { tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); } } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { - tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index 72e63bc70fcea..d6e0b2a5a5dd6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,5 +1,19 @@ +cmake_minimum_required(VERSION 3.19) +project("vulkan-shaders-gen" C CXX) + find_package (Threads REQUIRED) -add_executable(vulkan-shaders-gen vulkan-shaders-gen.cpp) -target_compile_features(vulkan-shaders-gen PRIVATE cxx_std_17) +if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) +endif() +set(TARGET vulkan-shaders-gen) +add_executable(${TARGET} vulkan-shaders-gen.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index df30355f635b8..8ddadb8a15de5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -61,6 +61,10 @@ layout (push_constant) uniform parameter { uint32_t n_head_log2; float m0; float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; } p; layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; @@ -103,6 +107,38 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #define DECODEFUNC #endif +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c < D) { + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r & (p.gqa_ratio - 1)); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -111,12 +147,22 @@ void main() { const uint32_t N = p.N; const uint32_t KV = p.KV; + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + const uint32_t Tr = CEIL_DIV(N, Br); - const uint32_t Tc = CEIL_DIV(KV, Bc); - const uint32_t i = gl_WorkGroupID.x; + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - const uint32_t iq2 = gl_WorkGroupID.y; + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; const uint32_t iq3 = gl_WorkGroupID.z; // broadcast factors @@ -149,8 +195,10 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - // nb?1 are already divided by the type size and are in units of elements - uint32_t q_stride = p.nb01; + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; uint32_t k_stride = p.nb11; uint32_t v_stride = p.nb21; // hint to the compiler that strides are aligned for the aligned variant of the shader @@ -179,23 +227,21 @@ void main() { coopmat L, M; + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + L = coopmat(0); - M = coopmat(-1.0/0.0); + M = coopmat(NEG_FLT_MAX_OVER_2); - ACC_TYPE slope = ACC_TYPE(1.0); + coopmat slopeMat = coopmat(1.0); // ALiBi if (p.max_bias > 0.0f) { - const uint32_t h = iq2; - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - slope = pow(base, ACC_TYPE(exph)); + coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } [[dont_unroll]] - for (uint32_t j = 0; j < Tc; ++j) { + for (uint32_t j = start_j; j < end_j; ++j) { coopmat S = coopmat(0); @@ -213,14 +259,18 @@ void main() { } if (p.mask != 0) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + // When using grouped query attention, all rows use the same mask. + if (p.gqa_ratio > 1) { + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1); + } coopmat mv; coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - S += slope*coopmat(mv); + S += slopeMat*coopmat(mv); } // Clear padding elements to -inf, so they don't contribute to rowmax @@ -231,7 +281,7 @@ void main() { uint R = ((i + 1) * Br > N) ? (N % Br) : Br; uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; - coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C); } coopmat rowmax, P, rowsum, eM; @@ -285,6 +335,20 @@ void main() { O = coopMatMulAdd(P_A, V, O); } + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + coopmat O_D = coopmat(O); + + uint32_t o_offset = D * p.ne1 * split_k_index; + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + return; + } + coopmat Ldiag; // resize L by using smear/reduce @@ -297,13 +361,18 @@ void main() { O = Ldiag*O; - tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); - - // permute dimensions - tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); uint32_t o_offset = iq3*p.ne2*p.ne1; coopmat O_D = coopmat(O); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); + if (p.gqa_ratio > 1) { + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + } else { + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); + } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp new file mode 100644 index 0000000000000..a7e3956854c44 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -0,0 +1,59 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#define BLOCK_SIZE 32 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; + +layout (push_constant) uniform parameter { + uint D; + uint N; + uint k_num; +} p; + +void main() { + // Each workgroup handles a row + const uint n = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + uint D = p.D; + uint N = p.N; + uint k_num = p.k_num; + + uint l_offset = D * N * k_num + n; + uint m_offset = D * N * k_num + N + n; + uint lm_stride = N * 2; + + // Compute the max m value for the row + float m_max = -1.0/0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float m = data_a[m_offset + k * lm_stride]; + m_max = max(m_max, m); + } + + // Compute L based on m_max + float L = 0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float l = data_a[l_offset + k * lm_stride]; + float m = data_a[m_offset + k * lm_stride]; + L += exp(m - m_max) * l; + } + + L = 1.0 / L; + + // Scale and sum the O contributions based on m_max and store the result to memory + for (uint d = tid; d < D; d += BLOCK_SIZE) { + float O = 0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + uint o_offset = D * N * k + D * n + d; + float m = data_a[m_offset + k * lm_stride]; + O += exp(m - m_max) * data_a[o_offset]; + } + O *= L; + data_d[D * n + d] = O; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 5a0054bac336c..23ce8ceec332b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -212,7 +212,7 @@ void main() { #else ACC_TYPE sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE cache_a[WMITER * TM]; - FLOAT_TYPE cache_b[WNITER * TN]; + FLOAT_TYPE cache_b[TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); @@ -744,16 +744,14 @@ void main() { } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; + cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; } - } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp new file mode 100644 index 0000000000000..284a35caa68ad --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -0,0 +1,442 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#extension GL_EXT_integer_dot_product : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif +layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +// layout (constant_id = 3) const uint BK = 32; +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#define BK 32 + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK / 4 + 4) +#else +#define SHMEM_STRIDE (BK / 4 + 1) +#endif + +shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; + +#ifndef COOPMAT +#if QUANT_AUXF == 1 +shared FLOAT_TYPE buf_a_dm[BM]; +#else +shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; +#endif +#endif + +shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; +#ifndef COOPMAT +shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; +#endif + +#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_B 4 + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[3072]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +#include "mul_mmq_funcs.comp" + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK; + const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a_ib = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / BK; +#ifdef MUL_MAT_ID + uint pos_b_ib = 0; +#else + uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat cm_result; + + coopmat factors[cms_per_row * cms_per_col]; + + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + int32_t cache_a_qs[WMITER * TM * BK / 4]; + + int32_t cache_b_qs[TN * BK / 4]; + + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + +#if QUANT_AUXF == 1 + FLOAT_TYPE cache_a_dm[WMITER * TM]; +#else + FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; +#endif + + FLOAT_TYPE_VEC2 cache_b_ds[TN]; + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { + const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; + const uint iqs = loadr_a; + const uint buf_ib = loadc_a + l; + + if (iqs == 0) { +#if QUANT_AUXF == 1 + buf_a_dm[buf_ib] = get_d(ib); +#else + buf_a_dm[buf_ib] = get_dm(ib); +#endif + } +#if QUANT_R == 1 + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); +#else + const i32vec2 vals = repack(ib, iqs); + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; +#endif + } + [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; + const uint ib = idx / 8; + const uint iqs = idx & 0x7; +#else + const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; + const uint iqs = loadr_b; +#endif + + const uint buf_ib = loadc_b + l; + + if (iqs == 0) { + buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); + } + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs]; + } + + barrier(); + + pos_a_ib += 1; + pos_b_ib += 1; + +#ifdef COOPMAT + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint ib_a = warp_r * WM + cm_row * TM; + // Load from shared into cache + coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { + cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; + } + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint ib_b = warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { + cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; + } + + cm_result = coopmat(0); + cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); + } + + coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); + } + } +#else + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + cache_b_ds[cc] = buf_b_ds[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + int32_t q_sum = 0; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], + cache_b_qs[cc * (BK / 4) + idx_k]); + } + + sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]); + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp new file mode 100644 index 0000000000000..63b15471bd3aa --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -0,0 +1,99 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.comp" + +// Each iqs value maps to a 32-bit integer + +#if defined(DATA_A_Q4_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +} +#endif + +#if defined(DATA_A_Q5_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t repack(uint ib, uint iqs) { + // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 + return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1])); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp new file mode 100644 index 0000000000000..e2e020fec2c6a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -0,0 +1,77 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint ne; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint GROUP_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {vec4 data_a[];}; +layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; + +shared float shmem[GROUP_SIZE]; + +void quantize() { + const uint wgid = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Each thread handles a vec4, so 8 threads handle a block + const uint blocks_per_group = GROUP_SIZE / 8; + + const uint block_in_wg = tid / 8; + + const uint ib = wgid * blocks_per_group + block_in_wg; + const uint iqs = tid % 8; + + if (ib >= gl_NumWorkGroups.x * blocks_per_group) { + return; + } + + const uint a_idx = ib * 8 + iqs; + + vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f); + const vec4 abs_vals = abs(vals); + + // Find absolute max for each block + shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] = max(shmem[tid], shmem[tid + s]); + } + barrier(); + } + + const float amax = shmem[block_in_wg * 8]; + const float d = amax / 127.0; + const float d_inv = d != 0.0 ? 1.0 / d : 0.0; + vals = round(vals * d_inv); + data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); + barrier(); + + // Calculate the sum for each block + shmem[tid] = vals.x + vals.y + vals.z + vals.w; + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] += shmem[tid + s]; + } + barrier(); + } + if (iqs == 0) { + const float sum = shmem[tid]; + + data_b[ib].ds = f16vec2(vec2(d, sum * d)); + } +} + +void main() { + quantize(); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp new file mode 100644 index 0000000000000..470e3074d938a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_EXT_integer_dot_product : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 789776816b75a..f5b29bfb13a66 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -1,4 +1,3 @@ - #if !defined(GGML_TYPES_COMP) #define GGML_TYPES_COMP @@ -51,6 +50,7 @@ struct block_q4_0_packed16 #if defined(DATA_A_Q4_0) #define QUANT_K QUANT_K_Q4_0 #define QUANT_R QUANT_R_Q4_0 +#define QUANT_AUXF 1 #define A_TYPE block_q4_0 #define A_TYPE_PACKED16 block_q4_0_packed16 #endif @@ -72,11 +72,19 @@ struct block_q4_1_packed16 uint16_t qs[16/2]; }; +struct block_q4_1_packed32 +{ + f16vec2 dm; + uint32_t qs[16/4]; +}; + #if defined(DATA_A_Q4_1) #define QUANT_K QUANT_K_Q4_1 #define QUANT_R QUANT_R_Q4_1 +#define QUANT_AUXF 2 #define A_TYPE block_q4_1 #define A_TYPE_PACKED16 block_q4_1_packed16 +#define A_TYPE_PACKED32 block_q4_1_packed32 #endif #define QUANT_K_Q5_0 32 @@ -99,6 +107,7 @@ struct block_q5_0_packed16 #if defined(DATA_A_Q5_0) #define QUANT_K QUANT_K_Q5_0 #define QUANT_R QUANT_R_Q5_0 +#define QUANT_AUXF 1 #define A_TYPE block_q5_0 #define A_TYPE_PACKED16 block_q5_0_packed16 #endif @@ -122,11 +131,20 @@ struct block_q5_1_packed16 uint16_t qs[16/2]; }; +struct block_q5_1_packed32 +{ + f16vec2 dm; + uint qh; + uint32_t qs[16/4]; +}; + #if defined(DATA_A_Q5_1) #define QUANT_K QUANT_K_Q5_1 #define QUANT_R QUANT_R_Q5_1 +#define QUANT_AUXF 2 #define A_TYPE block_q5_1 #define A_TYPE_PACKED16 block_q5_1_packed16 +#define A_TYPE_PACKED32 block_q5_1_packed32 #endif #define QUANT_K_Q8_0 32 @@ -142,14 +160,40 @@ struct block_q8_0_packed16 float16_t d; int16_t qs[32/2]; }; +struct block_q8_0_packed32 +{ + float16_t d; + int32_t qs[32/4]; +}; #if defined(DATA_A_Q8_0) #define QUANT_K QUANT_K_Q8_0 #define QUANT_R QUANT_R_Q8_0 +#define QUANT_AUXF 1 #define A_TYPE block_q8_0 #define A_TYPE_PACKED16 block_q8_0_packed16 +#define A_TYPE_PACKED32 block_q8_0_packed32 #endif +#define QUANT_K_Q8_1 32 +#define QUANT_R_Q8_1 1 + +struct block_q8_1 +{ + f16vec2 ds; + int8_t qs[32]; +}; +struct block_q8_1_packed16 +{ + f16vec2 ds; + int16_t qs[16]; +}; +struct block_q8_1_packed32 +{ + f16vec2 ds; + int32_t qs[8]; +}; + // K-quants #define QUANT_K_Q2_K 256 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 1ded4a321f138..d839ed54c3bfc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -296,7 +296,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; - std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; + std::map base_dict = { + {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}, + {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"}, + }; std::string shader_name = "matmul"; if (matmul_id) { @@ -314,9 +317,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool base_dict["COOPMAT"] = "1"; } - base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; - - std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; // Shaders with f16 B_TYPE string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); @@ -340,14 +341,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + } +#endif } } @@ -459,6 +466,8 @@ void process_shaders() { string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8ebe91892ac99..cc62356b80b1f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1220,6 +1220,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { } size_t ggml_nbytes(const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) { + return 0; + } + } + size_t nbytes; const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { @@ -4430,7 +4436,7 @@ struct ggml_tensor * ggml_flash_attn_ext( } // permute(0, 2, 1, 3) - int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); float params[] = { scale, max_bias, logit_softcap }; diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index ab13669c567fe..381a9c7dcbe8f 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -932,6 +932,7 @@ static void gguf_check_reserved_keys(const std::string & key, const T val) { if constexpr (std::is_same::value) { GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2"); } else { + GGML_UNUSED(val); GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32"); } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 13cca7ab009bf..3a52cfd1e39ac 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -286,6 +286,8 @@ class MODEL_ARCH(IntEnum): GRANITE_MOE = auto() CHAMELEON = auto() WAVTOKENIZER_DEC = auto() + PLM = auto() + BAILINGMOE = auto() class MODEL_TENSOR(IntEnum): @@ -488,6 +490,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.PLM: "plm", + MODEL_ARCH.BAILINGMOE: "bailingmoe", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1464,6 +1468,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, ], + MODEL_ARCH.PLM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.ROPE_FREQS, @@ -1651,6 +1669,25 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_V, MODEL_TENSOR.POSNET_ATTN_OUT, ], + MODEL_ARCH.BAILINGMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], # TODO } @@ -1703,6 +1740,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.BAILINGMOE: [ + MODEL_TENSOR.ROPE_FREQS, + ], } # diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8d4a2b0320183..50bef12e3dbe7 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -29,6 +29,7 @@ class TensorNameMap: "shared", # t5 "rwkv.embeddings", # rwkv6 "model.embeddings", # rwkv7 + "model.word_embeddings", # bailingmoe ), # Token type embeddings diff --git a/include/llama.h b/include/llama.h index 8325a62f06da5..4962b7bb2d666 100644 --- a/include/llama.h +++ b/include/llama.h @@ -108,6 +108,8 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, }; enum llama_rope_type { @@ -278,10 +280,18 @@ extern "C" { }; }; + struct llama_model_tensor_buft_override { + const char * pattern; + ggml_backend_buffer_type_t buft; + }; + struct llama_model_params { // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) ggml_backend_dev_t * devices; + // NULL-terminated list of buffer types to use for tensors that match a pattern + const struct llama_model_tensor_buft_override * tensor_buft_overrides; + int32_t n_gpu_layers; // number of layers to store in VRAM enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -1265,6 +1275,10 @@ extern "C" { float tau, float eta); + /// @details Intializes a GBNF grammar, see grammars/README.md for details. + /// @param vocab The vocabulary that this grammar will be used with. + /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. + /// @param grammar_root The name of the start symbol for the grammar. LLAMA_API struct llama_sampler * llama_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, diff --git a/media/llama1-logo.svg b/media/llama1-logo.svg new file mode 100644 index 0000000000000..e080481fa67c3 --- /dev/null +++ b/media/llama1-logo.svg @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 33e8c6414300a..914ff7c55356f 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -69,7 +69,11 @@ while read c; do git format-patch -U${ctx} -k $c~1..$c --stdout -- \ CMakeLists.txt \ src/CMakeLists.txt \ - cmake/FindSIMD.cmake \ + cmake/BuildTypes.cmake \ + cmake/GitVars.cmake \ + cmake/common.cmake \ + cmake/ggml-config.cmake.in \ + src/ggml-cpu/cmake/FindSIMD.cmake \ src/ggml*.h \ src/ggml*.c \ src/ggml*.cpp \ @@ -121,7 +125,12 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # # CMakelists.txt -> ggml/CMakeLists.txt # src/CMakeLists.txt -> ggml/src/CMakeLists.txt - # cmake/FindSIMD.cmake -> ggml/cmake/FindSIMD.cmake + + # cmake/BuildTypes.cmake -> ggml/cmake/BuildTypes.cmake + # cmake/GitVars.cmake -> ggml/cmake/GitVars.cmake + # cmake/common.cmake -> ggml/cmake/common.cmake + # cmake/ggml-config.cmake.in -> ggml/cmake/ggml-config.cmake.in + # src/ggml-cpu/cmake/FindSIMD.cmake -> ggml/src/ggml-cpu/cmake/FindSIMD.cmake # # src/ggml*.c -> ggml/src/ggml*.c # src/ggml*.cpp -> ggml/src/ggml*.cpp @@ -151,7 +160,11 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then cat ggml-src.patch | sed -E \ -e 's/(^[[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \ -e 's/(^[[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \ - -e 's/(^[[:space:]]| [ab]\/)cmake\/FindSIMD.cmake/\1ggml\/cmake\/FindSIMD.cmake/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \ + -e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index c7944d1d429c6..88b3d0eed7d59 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c7dfe3d174f98b14801f9ed12f129179d3e7b638 +f06264eda2e2bf6e814db5a32bbf42e0b2b1ed98 diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index e83d415c04657..aa1a46b4bfccd 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -2,7 +2,9 @@ cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt -cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake + +cp -rpv ../ggml/cmake/* ./ggml/cmake/ +cp -rpv ../ggml/src/ggml-cpu/cmake/* ./ggml/src/ggml-cpu/cmake/ cp -rpv ../ggml/src/ggml*.c ./ggml/src/ cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/ diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index b448614e471d6..7ac54d2391fd0 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -247,6 +247,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } } + // get extra buffer types of the CPU + // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future + // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948 + std::vector buft_extra; + { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_extra.emplace_back(*extra_bufts); + ++extra_bufts; + } + } + } + // add tensors for (auto & it : ab_map) { const std::string & name = it.first; @@ -263,7 +283,23 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); } - ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); + auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer); + + // do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case + for (auto & ex : buft_extra) { + if (ex == buft) { + LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); + + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + buft = ggml_backend_dev_buffer_type(cpu_dev); + + break; + } + } + + LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); + + ggml_context * dev_ctx = ctx_for_buft(buft); // validate tensor shape if (is_token_embd) { // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8664f8963cc18..047782e7d0fc8 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -65,6 +65,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_PLM, "plm" }, + { LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -73,6 +75,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, { LLM_KV_GENERAL_NAME, "general.name" }, { LLM_KV_GENERAL_AUTHOR, "general.author" }, { LLM_KV_GENERAL_VERSION, "general.version" }, @@ -1043,6 +1046,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_PLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_CHATGLM, { @@ -1392,6 +1411,29 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + { + LLM_ARCH_BAILINGMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama-arch.h b/src/llama-arch.h index a28815d8a14c7..297cfa4dae571 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -69,6 +69,8 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_PLM, + LLM_ARCH_BAILINGMOE, LLM_ARCH_UNKNOWN, }; @@ -77,6 +79,7 @@ enum llm_kv { LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_FILE_TYPE, LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_VERSION, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index de44586967d71..dd27a381423df 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -59,6 +59,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, + { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, + { "bailing", LLM_CHAT_TEMPLATE_BAILING }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -168,6 +170,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { return LLM_CHAT_TEMPLATE_MEGREZ; + } else if (tmpl_contains(" Ассистент:")) { + return LLM_CHAT_TEMPLATE_YANDEX; + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("'HUMAN'")) { + return LLM_CHAT_TEMPLATE_BAILING; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -567,6 +573,41 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|role_start|>assistant<|role_end|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) { + // Yandex template ("\n\n" is defined as EOT token) + + ss << ""; + + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "user") { + ss << " Пользователь: " << chat[i]->content << "\n\n"; + } else if (role == "assistant") { + ss << " Ассистент: " << chat[i]->content << "\n\n"; + } + } + + // Add generation prompt if needed + if (add_ass) { + ss << " Ассистент:[SEP]"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) { + // Bailing (Ling) template + for (auto message : chat) { + std::string role(message->role); + + if (role == "user") { + role = "HUMAN"; + } else { + std::transform(role.begin(), role.end(), role.begin(), ::toupper); + } + + ss << "" << role << "" << message->content; + } + + if (add_ass) { + ss << "ASSISTANT"; + } } else { // template not supported return -1; @@ -585,4 +626,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) { } return (int32_t) LLM_CHAT_TEMPLATES.size(); } - diff --git a/src/llama-chat.h b/src/llama-chat.h index 2f6a0e3e28266..0e0bd772c2eac 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -38,6 +38,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_YANDEX, + LLM_CHAT_TEMPLATE_BAILING, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index aa363df6356ea..4735e98ea040f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -255,7 +255,8 @@ llama_context::llama_context( model.n_devices() > 1 && model.params.n_gpu_layers > (int) model.hparams.n_layer && model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && - cparams.offload_kqv; + cparams.offload_kqv && + !model.has_tensor_overrides(); // pipeline parallelism requires support for async compute and events in all devices if (pipeline_parallel) { @@ -1201,33 +1202,7 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - // TODO: remove this stuff - class batch_guard { - public: - batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) { - } - - ~batch_guard() { - if (!is_done) { - kv_slot_restorer.restore(); - } - } - - void done() { - is_done = true; - } - - void save(const llama_kv_cache_slot_info & slot_info) { - kv_slot_restorer.save(slot_info); - } - - private: - bool is_done = false; - - llama_kv_slot_restorer kv_slot_restorer; - }; - - batch_guard bg(*kv_self); + llama_kv_cache_guard kv_guard(kv_self.get()); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -1280,6 +1255,9 @@ int llama_context::decode(llama_batch & inp_batch) { return -2; }; + // handle any pending defrags/shifts + kv_self_update(); + int64_t n_outputs_prev = 0; while (sbatch.n_tokens > 0) { @@ -1317,24 +1295,14 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_new; } - // non-causal masks do not use the KV cache - if (hparams.causal_attn) { - kv_self_update(); - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) { - kv_self->head = 0; - } + // find KV slot + { + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - const auto slot_info = kv_self->find_slot(ubatch); - if (!slot_info) { - LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); - return -3; + return 1; } - bg.save(slot_info); - if (!kv_self->recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -1371,16 +1339,6 @@ int llama_context::decode(llama_batch & inp_batch) { } } - // update the kv ring buffer - { - kv_self->head += ubatch.n_tokens; - - // Ensure kv cache head points to a valid index. - if (kv_self->head >= kv_self->size) { - kv_self->head = 0; - } - } - // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); @@ -1467,7 +1425,7 @@ int llama_context::decode(llama_batch & inp_batch) { } // finalize the batch processing - bg.done(); + kv_guard.commit(); // set output mappings { @@ -2316,11 +2274,6 @@ llama_context * llama_init_from_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { - LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); - params.flash_attn = false; - } - if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; @@ -2521,7 +2474,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { } int32_t llama_kv_self_n_tokens(const llama_context * ctx) { - return llama_kv_cache_n_tokens(ctx->get_kv_self()); + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return 0; + } + + return kv->get_n_tokens(); } // deprecated @@ -2530,7 +2488,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { } int32_t llama_kv_self_used_cells(const llama_context * ctx) { - return llama_kv_cache_used_cells(ctx->get_kv_self()); + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return 0; + } + + return kv->get_used_cells(); } // deprecated @@ -2539,7 +2502,12 @@ void llama_kv_cache_clear(llama_context * ctx) { } void llama_kv_self_clear(llama_context * ctx) { - llama_kv_cache_clear(ctx->get_kv_self()); + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + kv->clear(); } // deprecated @@ -2556,7 +2524,12 @@ bool llama_kv_self_seq_rm( llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1); + auto * kv = ctx->get_kv_self(); + if (!kv) { + return true; + } + + return kv->seq_rm(seq_id, p0, p1); } // deprecated @@ -2575,7 +2548,12 @@ void llama_kv_self_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1); + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); } // deprecated @@ -2586,7 +2564,12 @@ void llama_kv_cache_seq_keep( } void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id); + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->seq_keep(seq_id); } // deprecated @@ -2605,7 +2588,12 @@ void llama_kv_self_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { - return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta); + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->seq_add(seq_id, p0, p1, delta); } // deprecated @@ -2624,7 +2612,12 @@ void llama_kv_self_seq_div( llama_pos p0, llama_pos p1, int d) { - return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d); + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->seq_div(seq_id, p0, p1, d); } // deprecated @@ -2633,7 +2626,12 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { } llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id); + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return 0; + } + + return kv->seq_pos_max(seq_id); } // deprecated @@ -2642,7 +2640,12 @@ void llama_kv_cache_defrag(llama_context * ctx) { } void llama_kv_self_defrag(llama_context * ctx) { - llama_kv_cache_defrag(ctx->get_kv_self()); + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->defrag(); } // deprecated @@ -2651,7 +2654,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) { } bool llama_kv_self_can_shift(const llama_context * ctx) { - return llama_kv_cache_can_shift(ctx->get_kv_self()); + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return false; + } + + return kv->get_can_shift(); } // deprecated diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0bd40174438cc..cec203df49268 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -402,120 +402,86 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { if (self_kq_mask || self_kq_mask_swa) { - // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. - if (cparams.causal_attn) { - const int64_t n_kv = kv_self->n; - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - - float * data = nullptr; - float * data_swa = nullptr; - - if (self_kq_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - data = (float *) self_kq_mask->data; - } + const int64_t n_kv = kv_self->n; + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; - if (self_kq_mask_swa) { - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - data_swa = (float *) self_kq_mask_swa->data; - } + float * data = nullptr; + float * data_swa = nullptr; - // For causal attention, use only the previous KV cells - // of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[s][0]; + if (self_kq_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + data = (float *) self_kq_mask->data; + } - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos pos = ubatch->pos[s*n_seq_tokens + j]; + if (self_kq_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); + data_swa = (float *) self_kq_mask_swa->data; + } - for (int i = 0; i < n_kv; ++i) { - float f; - if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) { - f = -INFINITY; + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = ubatch->pos[s*n_seq_tokens + j]; + for (int i = 0; i < n_kv; ++i) { + float f; + // mask the token if: + if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence + || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens + ) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(kv_self->cells[i].pos - pos); } else { - if (hparams.use_alibi) { - f = -std::abs(kv_self->cells[i].pos - pos); - } else { - f = 0.0f; - } - } - - if (data) { - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + f = 0.0f; } + } - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; - } - data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } - } - } - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } } } + } - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } + // mask padded tokens + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } } - } else { - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - // when using kv cache, the mask needs to match the kv cache size - const int64_t n_stride = n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - - float * data = (float *) self_kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch->seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { - if (ubatch->seq_id[s0][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); - } else { - f = 0.0f; - } - break; - } - } - - data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; - } - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; - } + // mask padded tokens + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 14c8933b4d6c4..dbf5f1187d9e5 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -11,8 +11,6 @@ #include #include -static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; - llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { } @@ -133,7 +131,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const { return result; } -uint32_t llama_kv_cache_unified::get_used_cells() const { +int32_t llama_kv_cache_unified::get_used_cells() const { return used; } @@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos return false; } } + + return true; } for (uint32_t i = 0; i < size; ++i) { @@ -428,7 +428,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } } -llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) { +llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { llama_pos result = 0; for (uint32_t i = 0; i < size; ++i) { @@ -446,16 +446,71 @@ void llama_kv_cache_unified::defrag() { } } +void llama_kv_cache_unified::restore() { + if (pending.ranges.empty()) { + return; + } + + // TODO: tmp - move to llama_kv_cache_recurrent + if (recurrent) { + seq_rm(-1, -1, -1); + return; + } + + uint32_t new_head = size; + + for (auto & range : pending.ranges) { + for (uint32_t i = range.c0; i < range.c1; ++i) { + cells[i].seq_id.clear(); + + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + } + + new_head = std::min(new_head, range.c0); + } + + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_unified::commit() { + // TODO: tmp - move to llama_kv_cache_recurrent + if (recurrent) { + return; + } + + if (pending.ranges.empty()) { + LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); + return; + } + + pending.ranges.clear(); +} + bool llama_kv_cache_unified::get_can_shift() const { return can_shift; } -llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( +bool llama_kv_cache_unified::find_slot( const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*ubatch.n_tokens) { + head = 0; + } + if (recurrent) { // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. @@ -477,7 +532,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( // too big seq_id // TODO: would it be possible to resize the cache instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); - return llama_kv_cache_slot_info_failed; + return false; } if (j > 0) { llama_kv_cell & seq = cells[seq_id]; @@ -616,14 +671,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( [](const llama_kv_cell& cell){ return !cell.is_empty(); }); // sanity check - return llama_kv_cache_slot_info(n >= n_seqs); + return n >= n_seqs; } // otherwise, one cell per token. if (n_tokens > size) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); - return llama_kv_cache_slot_info_failed; + return false; } uint32_t n_tested = 0; @@ -651,7 +706,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( if (n_tested >= size) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return llama_kv_cache_slot_info_failed; + return false; } } @@ -668,7 +723,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( used += n_tokens; - return llama_kv_cache_slot_info(head, head + n_tokens); + pending.ranges.push_back({head, head + n_tokens}); + + return true; } uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { @@ -1033,6 +1090,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } + commit(); // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells @@ -1220,117 +1278,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell return true; } -// -// interface implementation -// - -int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) { - if (!kv) { - return 0; - } - - return kv->get_n_tokens(); -} - -int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) { - if (!kv) { - return 0; - } - - return kv->get_used_cells(); -} - -void llama_kv_cache_clear(llama_kv_cache * kv) { - if (!kv) { - return; - } - - kv->clear(); -} - -bool llama_kv_cache_seq_rm( - llama_kv_cache * kv, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - if (!kv) { - return true; - } - - return kv->seq_rm(seq_id, p0, p1); -} - -void llama_kv_cache_seq_cp( - llama_kv_cache * kv, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - if (!kv) { - return; - } - - kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); -} - -void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) { - if (!kv) { - return; - } - - kv->seq_keep(seq_id); -} - -void llama_kv_cache_seq_add( - llama_kv_cache * kv, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - if (!kv) { - return; - } - - kv->seq_add(seq_id, p0, p1, delta); -} - -void llama_kv_cache_seq_div( - llama_kv_cache * kv, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - if (!kv) { - return; - } - - kv->seq_div(seq_id, p0, p1, d); -} - -llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) { - if (!kv) { - return 0; - } - - return kv->seq_pos_max(seq_id); -} - -void llama_kv_cache_defrag(llama_kv_cache * kv) { - if (!kv) { - return; - } - - kv->defrag(); -} - -bool llama_kv_cache_can_shift(const llama_kv_cache * kv) { - if (!kv) { - return false; - } - - return kv->get_can_shift(); -} - // // kv cache view // @@ -1340,7 +1287,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t /*.n_cells = */ 0, /*.n_seq_max = */ n_seq_max, /*.token_count = */ 0, - /*.used_cells = */ llama_kv_cache_used_cells(&kv), + /*.used_cells = */ kv.get_used_cells(), /*.max_contiguous = */ 0, /*.max_contiguous_idx = */ -1, /*.cells = */ nullptr, diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0a7ff8a4ea3e6..56c74035ae1b9 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -17,17 +17,35 @@ struct llama_ubatch; struct llama_kv_cache : public llama_memory_i { using llama_memory_i::llama_memory_i; - virtual int32_t get_n_tokens() const = 0; - virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + virtual void restore() = 0; // call if batch processing fails - restores the cache state + virtual void commit() = 0; // call after successful batch processing - clears any pending state + + virtual int32_t get_n_tokens() const = 0; + virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache virtual bool get_can_shift() const = 0; bool get_can_edit() const override { return get_can_shift(); } }; +struct llama_kv_cache_guard { + llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} + + ~llama_kv_cache_guard() { + kv->restore(); + } + + void commit() { + kv->commit(); + } + +private: + llama_kv_cache * kv; +}; + struct llama_kv_cell { llama_pos pos = -1; - llama_pos delta = 0; + llama_pos delta = 0; int32_t src = -1; // used by recurrent state models to copy states int32_t tail = -1; @@ -46,17 +64,6 @@ struct llama_kv_cell { } }; -// a structure holds information about the slot found in llama_kv_cache_find_slot -struct llama_kv_cache_slot_info { - std::pair boundaries; // slot boundaries [begin, end) - bool found = false; // the slot was found - - explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} - llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} - - operator bool() const { return found; } -}; - // ring-buffer of cached KV data // TODO: pimpl // TODO: add notion of max sequences @@ -82,8 +89,8 @@ class llama_kv_cache_unified : public llama_kv_cache { uint32_t kv_size, bool offload); - int32_t get_n_tokens() const override; - uint32_t get_used_cells() const override; + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; size_t total_size() const; @@ -93,22 +100,24 @@ class llama_kv_cache_unified : public llama_kv_cache { void clear() override; void defrag() override; + virtual void restore() override; + virtual void commit() override; + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - llama_pos seq_pos_max(llama_seq_id seq_id) override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; bool get_can_shift() const override; // find an empty slot of size "n_tokens" in the cache // updates the cache head - // returns a structure holding information about the slot found // Note: On success, it's important that cache.head points // to the first cell of the slot. - llama_kv_cache_slot_info find_slot(const llama_ubatch & batch); + bool find_slot(const llama_ubatch & batch); // TODO: maybe not needed uint32_t get_padding(const llama_cparams & cparams) const; @@ -128,7 +137,19 @@ class llama_kv_cache_unified : public llama_kv_cache { // return true if cells have been moved bool defrag_prepare(int32_t n_max_nodes); - // state save/load + // commit/restore cache + + struct slot_range { + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; + }; + + // pending cell updates that are not yet committed + struct { + std::vector ranges; + } pending; + + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); @@ -183,101 +204,6 @@ class llama_kv_cache_unified : public llama_kv_cache { // using llama_kv_cache_unified::llama_kv_cache_unified; //}; -// -// kv cache restore -// - -// saves the kv_cache state for future recovery. -// used to rollback llama_kv_cache_find_slot changes. -struct llama_kv_slot_restorer { - struct llama_kv_cache_state { - uint32_t head = 0; - uint32_t n = 0; - } old_state; - - // for non-recurrent models only - // list of slots to restore - std::vector> slot_boundaries; - - bool do_restore = false; - - llama_kv_cache_unified & cache; - - explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) { - old_state.head = cache.head; - old_state.n = cache.n; - } - - // saves a slot information for future restoration - void save(const llama_kv_cache_slot_info & slot) { - if (slot) { - do_restore = true; - if (slot.boundaries.first != slot.boundaries.second) { - slot_boundaries.push_back(slot.boundaries); - } - } - } - - // must be explicitly called to restore the kv_cache state - // and rollback changes from all llama_kv_cache_find_slot calls - void restore() { - if (do_restore) { - cache.head = old_state.head; - cache.n = old_state.n; - - if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased - cache.seq_rm(-1, -1, -1); - } else { - for (auto & slot : slot_boundaries) { - cache.seq_rm(-1, slot.first, slot.second); - } - } - } - } -}; - -// TODO: maybe become part of the public llama_kv_cache in the future -int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv); - -int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv); - -void llama_kv_cache_clear(llama_kv_cache * kv); - -bool llama_kv_cache_seq_rm( - llama_kv_cache * kv, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1); - -void llama_kv_cache_seq_cp( - llama_kv_cache * kv, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1); - -void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id); - -void llama_kv_cache_seq_add( - llama_kv_cache * kv, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta); - -void llama_kv_cache_seq_div( - llama_kv_cache * kv, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d); - -llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id); - -void llama_kv_cache_defrag(llama_kv_cache * kv); - -bool llama_kv_cache_can_shift(const llama_kv_cache * kv); - // // kv cache view // diff --git a/src/llama-memory.h b/src/llama-memory.h index 69e6e34ca4516..dfa8c4e90fc2a 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -15,7 +15,7 @@ class llama_memory_i { virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; - virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0; + virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; virtual bool get_can_edit() const = 0; }; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 05d58ad90eba9..ea73a8a7ba944 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader( std::vector & splits, bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p) { + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader( } } + tensor_buft_overrides = param_tensor_buft_overrides_p; + // Load the main GGUF struct ggml_context * ctx = NULL; struct gguf_init_params params = { @@ -600,7 +603,9 @@ llama_model_loader::llama_model_loader( if (trace > 0) { const uint16_t sid = w.idx; - LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); + LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ] %8.2f MiB\n", __func__, + sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str(), + ggml_nbytes(tensor)/1024.0f/1024.0f); } } @@ -640,9 +645,9 @@ llama_model_loader::llama_model_loader( ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); { - const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV - if (kid >= 0) { - ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid); + uint32_t ftype_val = 0; + if (get_key(LLM_KV_GENERAL_FILE_TYPE, ftype_val, false)) { + ftype = (llama_ftype) ftype_val; } } diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index fe35404b26889..0f52b011b6986 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -77,8 +77,9 @@ struct llama_model_loader { llama_mmaps mappings; - std::map weights_map; - std::unordered_map kv_overrides; + std::map weights_map; + std::unordered_map kv_overrides; + const llama_model_tensor_buft_override * tensor_buft_overrides; gguf_context_ptr meta; std::vector contexts; @@ -95,7 +96,8 @@ struct llama_model_loader { std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p); + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); template typename std::enable_if::value, bool>::type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0ae754154b069..ca6e3ab2caeb1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -47,6 +48,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_1_4B: return "1.4B"; case LLM_TYPE_1_5B: return "1.5B"; case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_1_8B: return "1.8B"; case LLM_TYPE_2B: return "2B"; case LLM_TYPE_2_8B: return "2.8B"; case LLM_TYPE_2_9B: return "2.9B"; @@ -87,6 +89,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; case LLM_TYPE_57B_A14B: return "57B.A14B"; case LLM_TYPE_27B: return "27B"; + case LLM_TYPE_290B: return "290B"; default: return "?B"; } } @@ -255,7 +258,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara return nullptr; } -// CPU: ACCEL -> CPU extra -> GPU host -> CPU +// CPU: ACCEL -> GPU host -> CPU extra -> CPU static buft_list_t make_cpu_buft_list(const std::vector & devices) { buft_list_t buft_list; @@ -271,32 +274,6 @@ static buft_list_t make_cpu_buft_list(const std::vector & de } } - bool has_gpu_device = false; - for (auto * dev : devices) { - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { - has_gpu_device = true; - break; - } - } - - // add extra buffer types, only if no GPU device is present - // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 - if (!has_gpu_device) { - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(cpu_dev, *extra_bufts); - ++extra_bufts; - } - } - } else { - LLAMA_LOG_WARN("%s: disabling extra buffer types (i.e. repacking) since a GPU device is available\n", __func__); - } - // add a host buffer type // storing the tensors in a host buffer is useful when the processing of large batches // is offloaded to a GPU device, since it reduces the time spent on data transfers @@ -311,6 +288,20 @@ static buft_list_t make_cpu_buft_list(const std::vector & de } } + // add extra buffer types, only if no GPU device is present + // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + // add the CPU buffer type for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); @@ -388,9 +379,12 @@ struct llama_model::impl { layer_dev dev_input = {}; layer_dev dev_output = {}; std::vector dev_layer; + + bool has_tensor_overrides; }; llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique()) { + pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } llama_model::~llama_model() {} @@ -1144,6 +1138,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_CHATGLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1330,6 +1333,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); } break; + case LLM_ARCH_BAILINGMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_16B; break; + case 88: type = LLM_TYPE_290B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -1557,9 +1575,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); } - ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); + ggml_backend_buffer_type_t buft = nullptr; + + // check overrides + if (ml.tensor_buft_overrides) { + std::string tensor_name = tn.str(); + for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + std::regex pattern(overrides->pattern); + if (std::regex_search(tensor_name, pattern)) { + LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft)); + buft = overrides->buft; + break; + } + } + } + if (!buft) { - throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + buft = select_weight_buft(hparams, t_meta, op, *buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } } // avoid using a host buffer when using mmap @@ -3068,6 +3103,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_PLM: + { + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_BITNET: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3712,6 +3776,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); } break; + case LLM_ARCH_BAILINGMOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -3999,6 +4103,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); } + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } + vocab.print_info(); } @@ -4060,6 +4172,10 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const { }); } +bool llama_model::has_tensor_overrides() const { + return pimpl->has_tensor_overrides; +} + const ggml_tensor * llama_model::get_tensor(const char * name) const { auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), [name](const std::pair & it) { @@ -6284,7 +6400,7 @@ struct llm_build_qwen2moe : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - cb(cur, "ffn_moe_out", il); + cb(moe_out, "ffn_moe_out", il); // FFN shared expert { @@ -11615,6 +11731,322 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; +struct llm_build_plm : public llm_graph_context { + llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = NULL; + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + kv_compressed = build_norm(kv_compressed, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bailingmoe : public llm_graph_context { + llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + false, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory() const { llama_memory_i * res; @@ -11846,10 +12278,11 @@ llm_graph_result_ptr llama_model::build_graph( GGML_ABORT("invalid graph type"); }; } break; - //case LLM_ARCH_T5ENCODER: - // { - // llm.build_t5_enc(gf); - // } break; + case LLM_ARCH_T5ENCODER: + { + llm = std::make_unique(*this, params, gf); + } + break; case LLM_ARCH_JAIS: { llm = std::make_unique(*this, params, gf); @@ -11886,6 +12319,14 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_PLM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BAILINGMOE: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -11903,6 +12344,7 @@ llm_graph_result_ptr llama_model::build_graph( llama_model_params llama_model_default_params() { llama_model_params result = { /*.devices =*/ nullptr, + /*.tensor_buft_overrides =*/ nullptr, /*.n_gpu_layers =*/ 0, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, @@ -12012,10 +12454,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: + case LLM_ARCH_BAILINGMOE: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-model.h b/src/llama-model.h index a9da1215abbfd..91e6e8725acd2 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -44,6 +44,7 @@ enum llm_type { LLM_TYPE_1_4B, LLM_TYPE_1_5B, LLM_TYPE_1_6B, + LLM_TYPE_1_8B, LLM_TYPE_2B, LLM_TYPE_2_8B, LLM_TYPE_2_9B, @@ -84,6 +85,7 @@ enum llm_type { LLM_TYPE_10B_128x3_66B, LLM_TYPE_57B_A14B, LLM_TYPE_27B, + LLM_TYPE_290B, }; struct llama_layer_posnet { @@ -380,6 +382,8 @@ struct llama_model { ggml_backend_buffer_type_t select_buft(int il) const; + bool has_tensor_overrides() const; + const struct ggml_tensor * get_tensor(const char * name) const; // TODO: move this to new llm_arch_model_i interface diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 09eb570779ce5..e3e10fa6cf77f 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -527,7 +527,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides); + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 452e284659deb..3923f6a71d802 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1477,6 +1477,7 @@ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sam const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0); + GGML_ASSERT(result); // copy the state { @@ -1548,6 +1549,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( /* .grammar_root = */ grammar_root, /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), }; + if (!ctx->grammar) { + delete ctx; + return nullptr; + } } else { *ctx = { /* .vocab = */ vocab, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 2ddc8108f4cb4..5d5cafbea1f1b 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -342,6 +342,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_MPT: case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: + case LLAMA_VOCAB_PRE_TYPE_TRILLION: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -406,6 +407,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?=(\\d{3})+(?!\\d))", }; break; + case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE: + regex_exprs = { + // original regex from tokenizer.json + // "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + // FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?) + "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1614,6 +1623,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "superbpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; clean_spaces = false; + } else if ( + tokenizer_pre == "trillion") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; + clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -1791,6 +1808,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" || t.first == "<|endoftext|>" || t.first == "" + || t.first == "_" || t.first == "<|end▁of▁sentence|>" // DeepSeek ) { special_eot_id = t.second; @@ -1823,6 +1841,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
+                        || t.first == "▁
"          // CodeLlama
                         ) {
                     special_fim_pre_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1840,6 +1859,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
+                        || t.first == "▁"         // CodeLlama
                         ) {
                     special_fim_suf_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1857,6 +1877,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
+                        || t.first == "▁"         // CodeLlama
                         ) {
                     special_fim_mid_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1941,6 +1962,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|endoftext|>"
                     || t.first == "<|eom_id|>"
                     || t.first == ""
+                    || t.first == "_"
                ) {
                 special_eog_ids.insert(t.second);
                 if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2199,14 +2221,12 @@ void llama_vocab::impl::tokenizer_st_partition(std::forward_list length
-                    if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
-
 #ifdef PRETOKENIZERDEBUG
                     LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
 #endif
diff --git a/src/llama.cpp b/src/llama.cpp
index 81e1dd1d0873a..d5164720b2196 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -92,7 +92,7 @@ static int llama_model_load(const std::string & fname, std::vector
     model.t_start_us = tm.t_start_us;
 
     try {
-        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);
 
         ml.print_info();
 
diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp
index 69604b87ceec4..537fc63a4c975 100644
--- a/tests/test-arg-parser.cpp
+++ b/tests/test-arg-parser.cpp
@@ -77,7 +77,7 @@ int main(void) {
 
     argv = {"binary_name", "-m", "model_file.gguf"};
     assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
-    assert(params.model == "model_file.gguf");
+    assert(params.model.path == "model_file.gguf");
 
     argv = {"binary_name", "-t", "1234"};
     assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
@@ -89,7 +89,7 @@ int main(void) {
 
     argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"};
     assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
-    assert(params.model == "abc.gguf");
+    assert(params.model.path == "abc.gguf");
     assert(params.n_predict == 6789);
     assert(params.n_batch == 9090);
 
@@ -112,7 +112,7 @@ int main(void) {
     setenv("LLAMA_ARG_THREADS", "1010", true);
     argv = {"binary_name"};
     assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
-    assert(params.model == "blah.gguf");
+    assert(params.model.path == "blah.gguf");
     assert(params.cpuparams.n_threads == 1010);
 
 
@@ -122,7 +122,7 @@ int main(void) {
     setenv("LLAMA_ARG_THREADS", "1010", true);
     argv = {"binary_name", "-m", "overwritten.gguf"};
     assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
-    assert(params.model == "overwritten.gguf");
+    assert(params.model.path == "overwritten.gguf");
     assert(params.cpuparams.n_threads == 1010);
 #endif // _WIN32
 
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 28f860a7f2969..e61a126cf5b2f 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -3217,7 +3217,8 @@ struct test_leaky_relu : public test_case {
 
 // GGML_OP_FLASH_ATTN_EXT
 struct test_flash_attn_ext : public test_case {
-    const int64_t hs; // head size
+    const int64_t hsk; // K head size
+    const int64_t hsv; // V head size
     const int64_t nh; // num heads
     const int64_t nr; // repeat in Q, tests for grouped-query attention
     const int64_t kv; // kv size
@@ -3233,7 +3234,7 @@ struct test_flash_attn_ext : public test_case {
     std::array permute;
 
     std::string vars() override {
-        return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+        return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
     }
 
     double max_nmse_err() override {
@@ -3243,17 +3244,18 @@ struct test_flash_attn_ext : public test_case {
     uint64_t op_flops(ggml_tensor * t) override {
         GGML_UNUSED(t);
         // Just counting matmul costs:
-        // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
-        return 2 * 2 * nh*nr * nb * hs * kv;
+        // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
+        return 2 * nh*nr * nb * (hsk + hsv) * kv;
     }
 
-    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
+    test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
                         bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
                         ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3})
-        : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+        : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
+        const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
+        const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
 
         auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
             int64_t ne[4] = {ne0, ne1, ne2, ne3};
@@ -3268,13 +3270,13 @@ struct test_flash_attn_ext : public test_case {
             return t;
         };
 
-        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
+        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
         ggml_set_name(q, "q");
 
-        ggml_tensor * k = create_permuted(type_KV,       hs_padded, kv, nh,    1);
+        ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,    1);
         ggml_set_name(k, "k");
 
-        ggml_tensor * v = create_permuted(type_KV,       hs_padded, kv, nh,    1);
+        ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,    1);
         ggml_set_name(v, "v");
 
         ggml_tensor * m = nullptr;
@@ -3283,7 +3285,7 @@ struct test_flash_attn_ext : public test_case {
             ggml_set_name(m, "m");
         }
 
-        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
+        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
         ggml_flash_attn_ext_set_prec(out, prec);
         ggml_set_name(out, "out");
 
@@ -4412,27 +4414,32 @@ static std::vector> make_test_cases_eval() {
     test_cases.emplace_back(new test_timestep_embedding());
     test_cases.emplace_back(new test_leaky_relu());
 
-    for (int hs : { 64, 80, 128, 256, }) {
-        for (bool mask : { true, false } ) {
-            for (float max_bias : { 0.0f, 8.0f }) {
-                if (!mask && max_bias > 0.0f) continue;
-                for (float logit_softcap : {0.0f, 10.0f}) {
-                    if (hs != 128 && logit_softcap != 0.0f) continue;
-                    for (int nh : { 4, }) {
-                        for (int nr : { 1, 4, 16 }) {
-                            if (nr == 16 && hs != 128) continue;
-                            for (int kv : { 512, 1024, }) {
-                                if (nr != 1 && kv != 512) continue;
-                                for (int nb : { 1, 3, 32, 35, }) {
-                                    for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
-                                        if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
-                                        for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
-                                            test_cases.emplace_back(new test_flash_attn_ext(
-                                                hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
-                                            // run fewer test cases permuted
-                                            if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+    for (int hsk : { 64, 80, 128, 192, 256, }) {
+        for (int hsv : { 64, 80, 128, 192, 256, }) {
+            if (hsk != 192 && hsk != hsv) continue;
+            if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
+
+            for (bool mask : { true, false } ) {
+                for (float max_bias : { 0.0f, 8.0f }) {
+                    if (!mask && max_bias > 0.0f) continue;
+                    for (float logit_softcap : {0.0f, 10.0f}) {
+                        if (hsk != 128 && logit_softcap != 0.0f) continue;
+                        for (int nh : { 4, }) {
+                            for (int nr : { 1, 4, 16 }) {
+                                if (nr == 16 && hsk != 128) continue;
+                                for (int kv : { 512, 1024, }) {
+                                    if (nr != 1 && kv != 512) continue;
+                                    for (int nb : { 1, 3, 32, 35, }) {
+                                        for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+                                            if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+                                            for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                                 test_cases.emplace_back(new test_flash_attn_ext(
-                                                    hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                    hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
+                                                // run fewer test cases permuted
+                                                if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                                    test_cases.emplace_back(new test_flash_attn_ext(
+                                                        hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                }
                                             }
                                         }
                                     }
@@ -4509,6 +4516,12 @@ static std::vector> make_test_cases_perf() {
         }
     }
 
+    for (int kv : { 4096, 8192, 16384, }) {
+        for (int hs : { 64, 128, }) {
+            test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, 4, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+        }
+    }
+
     return test_cases;
 }
 
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
index 9231c517afb0b..a9627df684922 100644
--- a/tests/test-chat-template.cpp
+++ b/tests/test-chat-template.cpp
@@ -270,6 +270,22 @@ int main(void) {
             /* .bos_token= */ "",
             /* .eos_token= */ "",
         },
+        {
+            /* .name= */ "yandex/YandexGPT-5-Lite-8B-instruct",
+            /* .template_str= */ "{%- set names = {'assistant': ' Ассистент:', 'user': ' Пользователь:'} %}\n{%- set tools_prefix = 'Тебе доступны следующие функции:' %}\n{%- macro __render_tool(tool) %}\n    {%- set name = tool.function.name %}\n    {%- set description = tool.function.description|default('') %}\n    {%- set parameters = tool.function.parameters|tojson %}\n    {{- '\\n' }}function {{ '{' }}'name':'{{ name }}',\n    {%- if tool.function.description %}'description':'{{ description }}',{% endif %}\n'parameters':{{ parameters }}\n    {{- '}' }}\n{%- endmacro %}\n{%- macro __render_tools(tools) %}\n    {{- tools_prefix }}\n    {%- for tool in tools %}\n        {{- __render_tool(tool) }}\n    {%- endfor %}\n    {{- '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_tool_message(message) %}\n    {{- '\\n\\nРезультат вызова' }} {{ message.name }}: {{ message.content }} {{ '\\n\\n' }}\n{%- endmacro %}\n{%- if tools -%}\n    {{- __render_tools(tools) }}\n{%- endif -%}\n{%- macro __render_user_message(message) %}\n{{ names.user }} {{ message.content + '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_assistant_message(message) %}\n    {{- names.assistant }}\n    {%- set call = message['function_call'] %}\n    {%- if call %}\n        {{- '\\n[TOOL_CALL_START]' }}{{ call.name }}{{ '\\n' }}{{ call.arguments|tojson }}\n    {%- else %}\n        {{- ' ' + message.content + '\\n\\n' }}\n    {%- endif %}\n{%- endmacro %}\n{%- if not add_generation_prompt is defined %}\n{%- set add_generation_prompt = false %}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'user' %}\n        {{- __render_user_message(message) }}\n    {%- endif %}\n    {%- if message.role == 'assistant' and not loop.last %}\n        {{- __render_assistant_message(message) }}\n    {%- endif %}\n    {%- if message.role == 'tool' %}\n        {{- __render_tool_message(message) }}\n    {%- endif %}\n    {%- if loop.last %}\n        {{- ' Ассистент:[SEP]' }}\n    {%- endif %}\n{%- endfor %}\n",
+            /* .expected_output= */ " Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент:    I am an assistant   \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
+            /* .expected_output_jinja= */ " Пользователь: You are a helpful assistant\nHello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент:    I am an assistant   \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "inclusionAI/Ling-lite",
+            /* .template_str */ "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '' + role + '' + message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT' }}{% endif %}",
+            /* .expected_output= */ "SYSTEMYou are a helpful assistantHUMANHelloASSISTANTHi thereHUMANWho are youASSISTANT   I am an assistant   HUMANAnother questionASSISTANT",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
     };
     std::vector formatted_chat(1024);
     int32_t res;
diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp
index 8b696006be714..3c19220e11964 100644
--- a/tests/test-grammar-llguidance.cpp
+++ b/tests/test-grammar-llguidance.cpp
@@ -1086,6 +1086,65 @@ static void test_json_schema() {
         });
 }
 
+static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
+    auto n_vocab = tok_arr.size;
+
+    tok_arr.selected = -1;
+    tok_arr.sorted   = false;
+    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+        tok_arr.data[token_id].id    = token_id;
+        tok_arr.data[token_id].logit = 0.0f;
+    }
+
+    tok_arr.data[selected].logit = 100.0f;
+}
+
+static void test_sampler_chain(void) {
+    auto sparams            = llama_sampler_chain_default_params();
+    sparams.no_perf         = false;
+    llama_sampler * sampler = llama_sampler_chain_init(sparams);
+
+    const auto grammar_data = R"(%llguidance {}
+start: /[A-Z ]*/)";
+
+    llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
+    llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
+
+    auto input  = "ALL YOUR BASE ARE BELONG TO US";
+    auto tokens = common_tokenize(vocab, input, false, false);
+
+    auto n_vocab = llama_vocab_n_tokens(vocab);
+
+    std::vector cur;
+    cur.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+        cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
+    }
+    auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
+
+    for (const auto token : tokens) {
+        one_hot(tok_arr, token);
+
+        fprintf(stderr, "applying token: %d\n", token);
+        llama_sampler_apply(sampler, &tok_arr);
+
+        auto idx = tok_arr.selected;
+        fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
+        assert(cur[tok_arr.selected].id == token);
+        llama_sampler_accept(sampler, token);
+    }
+
+    auto tok_eos = llama_vocab_eot(vocab);
+    if (tok_eos == LLAMA_TOKEN_NULL) {
+        tok_eos = llama_vocab_eos(vocab);
+    }
+
+    one_hot(tok_arr, tok_eos);
+
+    llama_sampler_apply(sampler, &tok_arr);
+    assert(cur[tok_arr.selected].id == tok_eos);
+}
+
 int main(int argc, const char ** argv) {
     fprintf(stdout, "Running llguidance integration tests...\n");
 
@@ -1135,6 +1194,9 @@ int main(int argc, const char ** argv) {
     test_special_chars();
     test_quantifiers();
     test_json_schema();
+
+    test_sampler_chain();
+
     fprintf(stdout, "All tests passed.\n");
     return 0;
 }