diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 15e11330952..36084c55078 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -161,15 +161,16 @@ jobs: - name: Dawn Dependency id: dawn-depends run: | - DAWN_VERSION="v1.0.0" + DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip" echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" - curl -L -o artifact.tar.gz \ + curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar -xvf artifact.tar.gz -C dawn --strip-components=1 + unzip artifact.zip + tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -521,15 +522,16 @@ jobs: id: dawn-depends run: | sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev - DAWN_VERSION="v1.0.0" + DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip" echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" - curl -L -o artifact.tar.gz \ + curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar -xvf artifact.tar.gz -C dawn --strip-components=1 + unzip artifact.zip + tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build diff --git a/common/arg.cpp b/common/arg.cpp index 5597de121c1..a5708102814 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -740,6 +740,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex exit(0); } )); + add_opt(common_arg( + {"-cl", "--cache-list"}, + "show list of models in cache", + [](common_params &) { + printf("model cache directory: %s\n", fs_get_cache_directory().c_str()); + auto models = common_list_cached_models(); + printf("number of models in cache: %zu\n", models.size()); + for (size_t i = 0; i < models.size(); i++) { + auto & model = models[i]; + printf("%4d. %s\n", (int) i + 1, model.to_string().c_str()); + } + exit(0); + } + )); add_opt(common_arg( {"--completion-bash"}, "print source-able bash completion script for llama.cpp", diff --git a/common/common.cpp b/common/common.cpp index b0591e84b06..a8d709ab1d0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -908,6 +908,39 @@ std::string fs_get_cache_file(const std::string & filename) { return cache_directory + filename; } +std::vector fs_list_files(const std::string & path) { + std::vector files; + if (path.empty()) return files; + + std::filesystem::path dir(path); + if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) { + return files; + } + + for (const auto & entry : std::filesystem::directory_iterator(dir)) { + try { + // Only include regular files (skip directories) + const auto & p = entry.path(); + if (std::filesystem::is_regular_file(p)) { + common_file_info info; + info.path = p.string(); + info.name = p.filename().string(); + try { + info.size = static_cast(std::filesystem::file_size(p)); + } catch (const std::filesystem::filesystem_error &) { + info.size = 0; + } + files.push_back(std::move(info)); + } + } catch (const std::filesystem::filesystem_error &) { + // skip entries we cannot inspect + continue; + } + } + + return files; +} + // // Model utils diff --git a/common/common.h b/common/common.h index 54b7849b174..8540725aaa4 100644 --- a/common/common.h +++ b/common/common.h @@ -611,6 +611,13 @@ bool fs_create_directory_with_parents(const std::string & path); std::string fs_get_cache_directory(); std::string fs_get_cache_file(const std::string & filename); +struct common_file_info { + std::string path; + std::string name; + size_t size = 0; // in bytes +}; +std::vector fs_list_files(const std::string & path); + // // Model utils // diff --git a/common/download.cpp b/common/download.cpp index 02d75fc0d09..57308a5c6d5 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -50,6 +50,22 @@ using json = nlohmann::ordered_json; // downloader // +// validate repo name format: owner/repo +static bool validate_repo_name(const std::string & repo) { + static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)"); + return std::regex_match(repo, repo_regex); +} + +static std::string get_manifest_path(const std::string & repo, const std::string & tag) { + // we use "=" to avoid clashing with other component, while still being allowed on windows + std::string fname = "manifest=" + repo + "=" + tag + ".json"; + if (!validate_repo_name(repo)) { + throw std::runtime_error("error: repo name must be in the format 'owner/repo'"); + } + string_replace_all(fname, "/", "="); + return fs_get_cache_file(fname); +} + static std::string read_file(const std::string & fname) { std::ifstream file(fname); if (!file) { @@ -829,17 +845,13 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response // User-Agent header is already set in common_remote_get_content, no need to set it here - // we use "=" to avoid clashing with other component, while still being allowed on windows - std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json"; - string_replace_all(cached_response_fname, "/", "_"); - std::string cached_response_path = fs_get_cache_file(cached_response_fname); - // make the request common_remote_params params; params.headers = headers; long res_code = 0; std::string res_str; bool use_cache = false; + std::string cached_response_path = get_manifest_path(hf_repo, tag); if (!offline) { try { auto res = common_remote_get_content(url, params); @@ -895,6 +907,33 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons return { hf_repo, ggufFile, mmprojFile }; } +std::vector common_list_cached_models() { + std::vector models; + const std::string cache_dir = fs_get_cache_directory(); + const std::vector files = fs_list_files(cache_dir); + for (const auto & file : files) { + if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) { + common_cached_model_info model_info; + model_info.manifest_path = file.path; + std::string fname = file.name; + string_replace_all(fname, ".json", ""); // remove extension + auto parts = string_split(fname, '='); + if (parts.size() == 4) { + // expect format: manifest==== + model_info.user = parts[1]; + model_info.model = parts[2]; + model_info.tag = parts[3]; + } else { + // invalid format + continue; + } + model_info.size = 0; // TODO: get GGUF size, not manifest size + models.push_back(model_info); + } + } + return models; +} + // // Docker registry functions // @@ -959,6 +998,7 @@ std::string common_docker_resolve_model(const std::string & docker) { std::string token = common_docker_get_token(repo); // Get authentication token // Get manifest + // TODO: cache the manifest response so that it appears in the model list const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; std::string manifest_url = url_prefix + "/manifests/" + tag; common_remote_params manifest_params; diff --git a/common/download.h b/common/download.h index ddf36155ef8..45a6bd6bba8 100644 --- a/common/download.h +++ b/common/download.h @@ -8,16 +8,23 @@ struct common_params_model; // download functionalities // +struct common_cached_model_info { + std::string manifest_path; + std::string user; + std::string model; + std::string tag; + size_t size = 0; // GGUF size in bytes + std::string to_string() const { + return user + "/" + model + ":" + tag; + } +}; + struct common_hf_file_res { std::string repo; // repo name with ":tag" removed std::string ggufFile; std::string mmprojFile; }; -// resolve and download model from Docker registry -// return local path to downloaded model file -std::string common_docker_resolve_model(const std::string & docker); - /** * Allow getting the HF file from the HF repo with tag (like ollama), for example: * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 @@ -39,3 +46,10 @@ bool common_download_model( const common_params_model & model, const std::string & bearer_token, bool offline); + +// returns list of cached models +std::vector common_list_cached_models(); + +// resolve and download model from Docker registry +// return local path to downloaded model file +std::string common_docker_resolve_model(const std::string & docker); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 181f179ed17..869796f0e3b 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -168,7 +168,7 @@ option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) -option(GGML_VXE "ggml: enable vxe" ON) +option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 30247751359..67af1d8ccc1 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -124,6 +124,7 @@ if (CUDAToolkit_FOUND) if (GGML_CUDA_DEBUG) list(APPEND CUDA_FLAGS -lineinfo) + add_compile_definitions(GGML_CUDA_DEBUG) endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 049aece1b52..68dc57843e4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -27,7 +27,6 @@ #include "ggml-cuda/mmq.cuh" #include "ggml-cuda/mmvf.cuh" #include "ggml-cuda/mmvq.cuh" -#include "ggml-cuda/moe-expert-reduce.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" #include "ggml-cuda/opt-step-sgd.cuh" @@ -3152,8 +3151,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - - #ifdef GGML_CUDA_DEBUG const int nodes_fused = i - prev_i - 1; prev_i = i; @@ -3199,31 +3196,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } - if (node->op == GGML_OP_MUL) { - int current_node = i + 1; - int num_views = 0; - int num_adds = 0; - while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) { - num_views++; - current_node++; - } - - while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD && - num_adds < num_views - 1) { - num_adds++; - current_node++; - } - - if (num_adds == num_views - 1 && num_views > 0) { - ggml_tensor * dst_node = cgraph->nodes[current_node - 1]; - if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) { - ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node); - i += num_views + num_adds; - continue; - } - } - } - if (node->op == GGML_OP_ADD) { int n_fuse = 0; ggml_op ops[8]; @@ -3302,6 +3274,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + // we don't support repeating adds + if (bias_op == GGML_OP_ADD && + (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || + !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { + continue; + } + const ggml_tensor * src0 = up_n->src[0]; const ggml_tensor * src1 = up_n->src[1]; const ggml_tensor * ids = up_n->src[2]; @@ -3411,6 +3390,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { + continue; + } + ggml_cuda_mm_fusion_args_host fusion_data{}; fusion_data.x_bias = bias_tensor; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index c9a07e82fed..2e133b6bda8 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3494,7 +3494,7 @@ static __global__ void mul_mat_q_stream_k_fixup( const int col_diff = col_high - col_low; for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { - ids_dst_shared[j] = ids_dst[col_low + j]; + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; } __syncthreads(); diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cu b/ggml/src/ggml-cuda/moe-expert-reduce.cu deleted file mode 100644 index a97c5d573bb..00000000000 --- a/ggml/src/ggml-cuda/moe-expert-reduce.cu +++ /dev/null @@ -1,168 +0,0 @@ -#include "moe-expert-reduce.cuh" - -// This kernel is a fusion of the expert weight reduce, common in MoE models - -template -__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts, - const float * __restrict__ weights, - float * __restrict__ dst, - const int n_expert_used, - const int n_cols) { - const int row = blockIdx.x; - const int col = blockIdx.y * blockDim.x + threadIdx.x; - if (col >= n_cols) { - return; - } - - experts += row * n_cols * n_expert_used; - weights += row * n_expert_used; - dst += row * n_cols; - - float acc = 0.f; - if constexpr (n_expert_used_template == 0) { - for (int expert = 0; expert < n_expert_used; ++expert) { - ggml_cuda_mad(acc, experts[col], weights[expert]); - experts += n_cols; - } - dst[col] = acc; - } else { -#pragma unroll - for (int i = 0; i < n_expert_used_template; ++i) { - ggml_cuda_mad(acc, experts[col], weights[i]); - experts += n_cols; - } - dst[col] = acc; - } -} - -static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const float * experts, - const float * weights, - float * dst, - const int n_expert_used, - const int n_cols, - const int n_rows) { - const int block_size = 32; - - const int n_blocks_x = n_rows; - const int n_blocks_y = (n_cols + block_size - 1) / block_size; - - dim3 block_dims(block_size); - dim3 grid_dims(n_blocks_x, n_blocks_y); - - cudaStream_t stream = ctx.stream(); - switch (n_expert_used) { - case 1: - moe_expert_reduce_cuda<1> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 2: - moe_expert_reduce_cuda<2> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 4: - moe_expert_reduce_cuda<4> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 6: - moe_expert_reduce_cuda<6> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 8: - moe_expert_reduce_cuda<8> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 16: - moe_expert_reduce_cuda<16> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 32: - moe_expert_reduce_cuda<32> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 64: - moe_expert_reduce_cuda<64> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 128: - moe_expert_reduce_cuda<128> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - default: - moe_expert_reduce_cuda<0> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - } -} - -bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) { - const ggml_tensor * mul = cgraph->nodes[start_index]; - - if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) { - return false; - } - - int current_node = start_index + 1; - size_t current_offset = 0; - - std::vector view_nodes; - //check if all are views of the expert in increasing order - while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) { - const ggml_tensor * node = cgraph->nodes[current_node]; - if (node->view_src != mul) { - return false; - } - if (node->view_offs < current_offset) { - return false; - } - current_offset = node->view_offs; - current_node++; - view_nodes.push_back(node); - } - - //check if all the adds are in increasing order - const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0]; - int num_adds = 0; - int num_views = view_nodes.size(); - while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) { - const ggml_tensor * add_node = cgraph->nodes[current_node]; - - bool is_first_op_ok = num_views > num_adds ? add_node->src[0] == prev_add_src : false; - bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false; - - if (!is_first_op_ok || !is_second_op_ok) { - return false; - } - prev_add_src = add_node; - - num_adds++; - current_node++; - } - - if (num_views != num_adds + 1) { - return false; - } - - return true; -} - -void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const ggml_tensor * experts, - const ggml_tensor * weights, - ggml_tensor * dst) { - const int n_rows = experts->ne[2]; - const int n_expert_used = experts->ne[1]; - const int n_cols = experts->ne[0]; - - GGML_ASSERT(experts->type == GGML_TYPE_F32); - GGML_ASSERT(weights->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(experts)); - GGML_ASSERT(ggml_is_contiguous(weights)); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - const float * experts_d = (const float *) experts->data; - const float * weights_d = (const float *) weights->data; - float * dst_d = (float *) dst->data; - - launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows); -} diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cuh b/ggml/src/ggml-cuda/moe-expert-reduce.cuh deleted file mode 100644 index cafc50e104a..00000000000 --- a/ggml/src/ggml-cuda/moe-expert-reduce.cuh +++ /dev/null @@ -1,11 +0,0 @@ -#include "common.cuh" -#include "ggml.h" - -#include - -void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const ggml_tensor * experts, - const ggml_tensor * weights, - ggml_tensor * dst); - -bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a0a05f2e5b2..6da7bbd2f61 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -130,9 +130,9 @@ struct vk_pipeline_struct { // true if fields have been set by ggml_vk_create_pipeline bool initialized {}; // set to true to request the pipeline is compiled - bool needed {}; + std::atomic needed {}; // set to true when the shader has been compiled - bool compiled {}; + std::atomic compiled {}; // number of registers used, extracted from pipeline executable properties uint32_t register_count {}; }; @@ -351,6 +351,12 @@ enum vk_conv_shapes { CONV_SHAPE_COUNT, }; +uint32_t conv_shapes_wg_denoms[][3] = { + { 128, 128, 1 }, + { 64, 32, 1 }, + { 32, 256, 1 }, +}; + enum dmmv_wg_sizes { DMMV_WG_SIZE_SUBGROUP, DMMV_WG_SIZE_LARGE, @@ -379,6 +385,18 @@ struct vk_fa_pipeline_state { } }; +struct vk_conv2d_pipeline_state { + vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH) + : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {} + + uint32_t s0, s1, p0, p1, d0, d1, KW, KH; + + bool operator<(const vk_conv2d_pipeline_state &b) const { + return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) < + std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH); + } +}; + enum shader_reduction_mode { SHADER_REDUCTION_MODE_SHMEM, SHADER_REDUCTION_MODE_HYBRID, @@ -466,6 +484,14 @@ static constexpr std::initializer_list> rope_view_set_rows_ed { 2, 0, 1 }, // set_rows->src[0] == view }; +static constexpr std::initializer_list> rms_norm_mul_rope_view_set_rows_edges { + { 1, 0, 0 }, // mul->src[0] == rms + { 2, 0, 1 }, // rope->src[0] == mul + { 3, 0, 2 }, // view->src[0] == rope + { 4, 0, 3 }, // set_rows->src[0] == view +}; + + struct vk_device_struct { std::recursive_mutex mutex; @@ -617,6 +643,8 @@ struct vk_device_struct { vk_pipeline pipeline_rms_norm_mul_f32; vk_pipeline pipeline_rms_norm_partials_f32; vk_pipeline pipeline_rms_norm_mul_partials_f32; + vk_pipeline pipeline_rms_norm_mul_rope_f32_f32; + vk_pipeline pipeline_rms_norm_mul_rope_f32_f16; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; @@ -665,10 +693,10 @@ struct vk_device_struct { vk_pipeline pipeline_ssm_conv_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; - vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; - vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; - vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; - vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv2d_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; @@ -1060,6 +1088,7 @@ struct vk_op_diag_mask_push_constants { }; struct vk_op_rope_push_constants { + uint32_t rope_mode; uint32_t ncols; uint32_t n_dims; float freq_scale; @@ -1079,6 +1108,12 @@ struct vk_op_rope_push_constants { uint32_t set_rows_stride; }; +// For fused rms_norm+mul+rope(+view+set_rows) +struct vk_op_rms_norm_mul_rope_push_constants { + vk_op_binary_push_constants bin; + vk_op_rope_push_constants rope; +}; + struct vk_op_soft_max_push_constants { uint32_t KX; uint32_t KY; @@ -1241,17 +1276,13 @@ struct vk_op_conv2d_push_constants { uint32_t nb2; uint32_t nb3; - // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH - uint32_t KWmp; uint32_t KWL; - uint32_t KWKHmp; uint32_t KWKHL; + // init_fastdiv_values constants for dividing by OW, OW*OH uint32_t OWmp; uint32_t OWL; uint32_t OWOHmp; uint32_t OWOHL; }; template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { - // Compute magic values to divide by KW, KW*KH, OW, OW*OH - init_fastdiv_values(p.KW, p.KWmp, p.KWL); - init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + // Compute magic values to divide by OW, OW*OH init_fastdiv_values(p.OW, p.OWmp, p.OWL); init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); } @@ -1287,23 +1318,15 @@ struct vk_op_conv_transpose_2d_push_constants { uint32_t nb2; uint32_t nb3; - // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1 - uint32_t KWmp; uint32_t KWL; - uint32_t KWKHmp; uint32_t KWKHL; + // init_fastdiv_values constants for dividing by OW, OW*OH uint32_t OWmp; uint32_t OWL; uint32_t OWOHmp; uint32_t OWOHL; - uint32_t s0mp; uint32_t s0L; - uint32_t s1mp; uint32_t s1L; }; template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) { - // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1 - init_fastdiv_values(p.KW, p.KWmp, p.KWL); - init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + // Compute magic values to divide by OW, OW*OH init_fastdiv_values(p.OW, p.OWmp, p.OWL); init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); - init_fastdiv_values(p.s0, p.s0mp, p.s0L); - init_fastdiv_values(p.s1, p.s1mp, p.s1L); } struct vk_op_conv2d_dw_push_constants { @@ -1842,10 +1865,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } } - { - std::lock_guard guard(device->mutex); - device->all_pipelines.push_back(pipeline); - } + device->all_pipelines.push_back(pipeline); { std::lock_guard guard(compile_count_mutex); @@ -2536,6 +2556,7 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); + std::lock_guard guard(device->mutex); // some shaders have a minimum subgroup size const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); @@ -2729,6 +2750,8 @@ static void ggml_vk_load_shaders(vk_device& device) { if (!pipeline->needed || pipeline->compiled) { return; } + // TODO: We're no longer benefitting from the async compiles (shaders are + // compiled individually, as needed) and this complexity can be removed. { // wait until fewer than N compiles are in progress uint32_t N = std::max(1u, std::thread::hardware_concurrency()); @@ -3557,6 +3580,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + if (device->float_controls_rte_fp16 && + sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + } + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -3835,22 +3864,22 @@ static void ggml_vk_load_shaders(vk_device& device) { switch (s) { default: case CONV_SHAPE_128x128: - conv2d_BS_K = 128; - conv2d_BS_NPQ = 128; + conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_128x128][0]; + conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_128x128][1]; conv2d_BS_CRS = 16; if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) { conv2d_UNROLL = false; } break; case CONV_SHAPE_64x32: - conv2d_BS_K = 64; - conv2d_BS_NPQ = 32; + conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_64x32][0]; + conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_64x32][1]; conv2d_BS_CRS = 32; conv2d_TS_K = 4; break; case CONV_SHAPE_32x256: - conv2d_BS_K = 32; - conv2d_BS_NPQ = 256; + conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_32x256][0]; + conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_32x256][1]; conv2d_BS_CRS = 16; break; } @@ -3884,10 +3913,22 @@ static void ggml_vk_load_shaders(vk_device& device) { std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; #define CREATE_CONV(name, type_suffix, spv_suffix) \ - ggml_vk_create_pipeline( \ - device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \ - name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ - sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + for (auto &c : device->pipeline_##name##type_suffix[s]) { \ + const vk_conv2d_pipeline_state &state = c.first; \ + std::vector spec_constants_cpy = spec_constants; \ + spec_constants_cpy.push_back(state.s0); \ + spec_constants_cpy.push_back(state.s1); \ + spec_constants_cpy.push_back(state.p0); \ + spec_constants_cpy.push_back(state.p1); \ + spec_constants_cpy.push_back(state.d0); \ + spec_constants_cpy.push_back(state.d1); \ + spec_constants_cpy.push_back(state.KW); \ + spec_constants_cpy.push_back(state.KH); \ + ggml_vk_create_pipeline( \ + device, c.second, #name #type_suffix, \ + name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ + sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \ + } #define CREATE_CONVS(spv_suffix) \ CREATE_CONV(conv2d, _f32, spv_suffix) \ CREATE_CONV(conv2d, _f16_f32, spv_suffix) \ @@ -7914,12 +7955,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = nullptr; - auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; - auto it = pipelines.find(fa_pipeline_state); - if (it != pipelines.end()) { - pipeline = it->second; - } else { - pipelines[fa_pipeline_state] = pipeline = std::make_shared(); + { + std::lock_guard guard(ctx->device->mutex); + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto it = pipelines.find(fa_pipeline_state); + if (it != pipelines.end()) { + pipeline = it->second; + } else { + pipelines[fa_pipeline_state] = pipeline = std::make_shared(); + } } assert(pipeline); @@ -8510,7 +8554,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t tiles[CONV_SHAPE_COUNT]; for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) { - tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]); + tiles[i] = CEIL_DIV(elements[0], conv_shapes_wg_denoms[i][0]) * CEIL_DIV(elements[1], conv_shapes_wg_denoms[i][1]); } // We can't query number of shader cores on Intel, use 32 as a placeholder @@ -8525,19 +8569,45 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const shape = CONV_SHAPE_64x32; } + uint32_t KW = static_cast(src0->ne[0]); + uint32_t KH = static_cast(src0->ne[1]); + uint32_t s0 = static_cast(dst->op_params[0]); + uint32_t s1 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[1]) : static_cast(dst->op_params[0]); + uint32_t p0 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[2]) : 0; + uint32_t p1 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[3]) : 0; + uint32_t d0 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[4]) : 1; + uint32_t d1 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[5]) : 1; + + vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH); + + std::map *pipelines = nullptr; if (op == GGML_OP_CONV_2D) { if (src0->type == GGML_TYPE_F32) { - return ctx->device->pipeline_conv2d_f32[shape]; + pipelines = &ctx->device->pipeline_conv2d_f32[shape]; } else if (src0->type == GGML_TYPE_F16) { - return ctx->device->pipeline_conv2d_f16_f32[shape]; + pipelines = &ctx->device->pipeline_conv2d_f16_f32[shape]; } } else if (op == GGML_OP_CONV_TRANSPOSE_2D) { if (src0->type == GGML_TYPE_F32) { - return ctx->device->pipeline_conv_transpose_2d_f32[shape]; + pipelines = &ctx->device->pipeline_conv_transpose_2d_f32[shape]; } else if (src0->type == GGML_TYPE_F16) { - return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape]; + pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape]; } } + + vk_pipeline pipeline = nullptr; + + { + std::lock_guard guard(ctx->device->mutex); + auto it = pipelines->find(conv2d_pipeline_state); + if (it != pipelines->end()) { + pipeline = it->second; + } else { + (*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared(); + } + } + + return pipeline; } return nullptr; case GGML_OP_CONV_2D_DW: @@ -9587,21 +9657,149 @@ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const g return num_bytes; } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params) { +static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) { + const int n_dims = ((const int32_t *) dst->op_params)[1]; + const int mode = ((const int32_t *) dst->op_params)[2]; + // const int n_ctx = ((const int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; + const float freq_base = ((const float *) dst->op_params)[5]; + const float freq_scale = ((const float *) dst->op_params)[6]; + const float ext_factor = ((const float *) dst->op_params)[7]; + const float attn_factor = ((const float *) dst->op_params)[8]; + const float beta_fast = ((const float *) dst->op_params)[9]; + const float beta_slow = ((const float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4); + } + + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); + + vk_op_rope_push_constants rope { + (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + has_ff, (uint32_t)src0->ne[2], nb01, nb02, + { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, + }; + + return rope; +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) { + ggml_tensor * dst; + const ggml_tensor * src0; + const ggml_tensor * src1; + + if (ctx->num_additional_fused_ops > 0) { + // fused rms_norm + mul + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0]; + dst = mul; + src0 = cgraph->nodes[node_idx]->src[0]; + src1 = other_src; + } else { + dst = cgraph->nodes[node_idx]; + src0 = src1 = dst->src[0]; + } + const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + vk_op_binary_push_constants bin { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, op_params[0], 0.0f, (int32_t)param3, - }); + }; + + // more than one fused op means rms_norm+mul+rope + if (ctx->num_additional_fused_ops > 1) { + static constexpr uint32_t max_tensors = 7; + const ggml_tensor *tensors[max_tensors] {}; + + ggml_tensor *rms = cgraph->nodes[node_idx + 0]; + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *rope = cgraph->nodes[node_idx + 2]; + + ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; + + bool do_set_rows = ctx->num_additional_fused_ops == 4; + + tensors[0] = rms->src[0]; + tensors[1] = other_src; + tensors[2] = mul; + tensors[3] = rope->src[1]; // pos + tensors[4] = rope->src[2]; // ff + tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst + tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr; + const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0; + + vk_op_rms_norm_mul_rope_push_constants pc; + pc.bin = bin; + pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride); + + vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + ggml_backend_vk_buffer_context * buf_ctx[max_tensors]; + vk_buffer buf[max_tensors]; + size_t offset[max_tensors]; + bool uma[max_tensors]; + + for (uint32_t i = 0; i < max_tensors; ++i) { + if (!tensors[i]) { + // If any remaining descriptors are unused, just point them at src[0] + buf[i] = buf[0]; + offset[i] = 0; + continue; + } + buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context; + buf[i] = nullptr; + offset[i] = 0; + uma[i] = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]); + uma[i] = buf[i] != nullptr; + } + if (!uma[i]) { + buf[i] = buf_ctx[i]->dev_buffer; + offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs; + } + GGML_ASSERT(buf[i] != nullptr); + } + + std::array elements; + elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; + + static_assert(max_tensors == 7); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, buf[0], offset[0]), + ggml_vk_subbuffer(ctx, buf[1], offset[1]), + ggml_vk_subbuffer(ctx, buf[2], offset[2]), + ggml_vk_subbuffer(ctx, buf[3], offset[3]), + ggml_vk_subbuffer(ctx, buf[4], offset[4]), + ggml_vk_subbuffer(ctx, buf[5], offset[5]), + ggml_vk_subbuffer(ctx, buf[6], offset[6]), + }, pc, elements); + } else { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin)); + } if (ctx->do_add_rms_partials_offset_calculation) { ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0); @@ -9755,9 +9953,6 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons // const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const float freq_base = ((float *) dst->op_params)[5]; - const float freq_scale = ((float *) dst->op_params)[6]; - const float ext_factor = ((float *) dst->op_params)[7]; - const float attn_factor = ((float *) dst->op_params)[8]; const float beta_fast = ((float *) dst->op_params)[9]; const float beta_slow = ((float *) dst->op_params)[10]; int sections[4] {}; @@ -9765,16 +9960,9 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); } - const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; - float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); - uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); - uint32_t set_rows_stride = 0; // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride // and overrides the dst and sets src3=row_indices @@ -9784,12 +9972,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons dst = cgraph->nodes[node_idx + 2]; } - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, { - (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], - freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, - { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, - }); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, + ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride)); } static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -11304,6 +11488,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (n->op == GGML_OP_GLU) { std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; } + if (n->op == GGML_OP_ROPE) { + const int mode = ((const int32_t *) n->op_params)[2]; + std::cerr << " rope mode: " << mode; + } std::cerr << std::endl; } #endif @@ -11411,14 +11599,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_RMS_NORM: - if (ctx->num_additional_fused_ops > 0) { - // fused rms_norm + mul - ggml_tensor *mul = cgraph->nodes[node_idx + 1]; - ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; - ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params); - } else { - ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params); - } + ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params); break; case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node); @@ -12404,6 +12585,70 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } +// Check whether the tensors overlap in memory but are not equal. +// Fusions can potenitally overwrite src tensors in ways that are not prevented +// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them +// to overlap if they are exactly equal. +// XXX TODO this check is probably missing from several fusion optimizations. +static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) { + ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context; + vk_buffer a_buf = a_buf_ctx->dev_buffer; + ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context; + vk_buffer b_buf = b_buf_ctx->dev_buffer; + if (a_buf == b_buf) { + auto a_base = vk_tensor_offset(a) + a->view_offs; + auto a_size = ggml_nbytes(a); + auto b_base = vk_tensor_offset(b) + b->view_offs; + auto b_size = ggml_nbytes(b); + + if (a_base == b_base && a_size == b_size) { + return false; + } + + if ((b_base <= a_base && a_base < b_base + b_size) || + (a_base <= b_base && b_base < a_base + a_size)) { + return true; + } + } + return false; +} + +static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx) { + GGML_UNUSED(ctx); + const ggml_tensor *rms = cgraph->nodes[node_idx + 0]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + const ggml_tensor *rope = cgraph->nodes[node_idx + 2]; + + const int mode = ((const int32_t *) rope->op_params)[2]; + + // noncontig tensors aren't tested, and don't seem common in practice + if (!ggml_is_contiguous(rms) || + !ggml_is_contiguous(mul) || + !ggml_is_contiguous(rope)) { + return false; + } + + // only norm/neox are handled in the shader + if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) { + return false; + } + + // shared memory size for passing data from mul->rope + if (mul->ne[0] > 1024) { + return false; + } + + // must not overwrite srcs in a way that's not elementwise + ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; + if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) || + ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) { + return false; + } + + return true; +} + static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { const ggml_tensor *first_node = cgraph->nodes[node_idx]; @@ -12549,12 +12794,20 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; - } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && + ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && + ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && + ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) { + ctx->num_additional_fused_ops = 4; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& + ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 2; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { @@ -12787,14 +13040,34 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } if (ok) { current_set.push_back(j); + + int rope_idx = j; + + // When we've found RMS_NORM + MUL, try to find a ROPE that uses it + if (j > 0 && + graph->nodes[j]->op == GGML_OP_MUL && + graph->nodes[j-1]->op == GGML_OP_RMS_NORM) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_ROPE && + graph->nodes[k]->src[0] == graph->nodes[j] && + // Check that other srcs are already valid + graph->nodes[k]->src[1]->op == GGML_OP_NONE && + (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) { + rope_idx = k; + current_set.push_back(rope_idx); + used[rope_idx] = true; + break; + } + } + } // Look for ROPE + VIEW + SET_ROWS and make them consecutive - if (graph->nodes[j]->op == GGML_OP_ROPE) { + if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) { int view_idx = -1; int set_rows_idx = -1; - for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) { + for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) { if (view_idx == -1 && graph->nodes[k]->op == GGML_OP_VIEW && - graph->nodes[k]->src[0] == graph->nodes[j]) { + graph->nodes[k]->src[0] == graph->nodes[rope_idx]) { view_idx = k; continue; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 0367e80bbfa..e9bdbf7db5e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -62,14 +62,8 @@ layout(push_constant) uniform parameter { uint32_t nb3; // fastdiv helper values - uint32_t KWmp; uint32_t KWL; - uint32_t KWKHmp; uint32_t KWKHL; uint32_t OWmp; uint32_t OWL; uint32_t OWOHmp; uint32_t OWOHL; -#ifdef TRANSPOSE - uint32_t s0mp; uint32_t s0L; - uint32_t s1mp; uint32_t s1L; -#endif } p; @@ -84,6 +78,15 @@ layout(constant_id = 4) const uint TS_K = 8; layout(constant_id = 5) const uint use_collectives = 1; layout(constant_id = 6) const uint SHMEM_PAD = 4; +layout(constant_id = 7) const uint s0 = 1; +layout(constant_id = 8) const uint s1 = 1; +layout(constant_id = 9) const uint p0 = 0; +layout(constant_id = 10) const uint p1 = 0; +layout(constant_id = 11) const uint d0 = 1; +layout(constant_id = 12) const uint d1 = 1; +layout(constant_id = 13) const uint KW = 1; +layout(constant_id = 14) const uint KH = 1; + uint32_t tid = gl_LocalInvocationID.x; const uint32_t WG_SIZE = gl_WorkGroupSize.x; @@ -92,7 +95,7 @@ uint splitWork(uint work_size, uint block_size) { } uint32_t K = p.Cout; -uint32_t CRS = p.Cin * p.KH * p.KW; +uint32_t CRS = p.Cin * KH * KW; uint32_t NPQ = p.N * p.OH * p.OW; uint32_t n_elems_out = K * NPQ; @@ -187,7 +190,7 @@ void main() { } #endif /* Advance block in CRS dim */ - for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + [[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { uint32_t CRS_idx_a; uint32_t Cin_idx_a; uint32_t KH_idx_a; @@ -200,10 +203,10 @@ void main() { uint32_t cached_KW_idx; if (use_collectives == 1) { cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; - cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); - uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); - cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; + cached_Cin_idx = cached_CRS_idx / (KW * KH); + uint32_t cached_CRS_remainder = cached_CRS_idx % (KW * KH); + cached_KH_idx = cached_CRS_remainder / KW; + cached_KW_idx = cached_CRS_remainder % KW; CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); @@ -211,21 +214,21 @@ void main() { KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); } else { CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) - Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); - uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; - KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - KW_idx_a = CRS_remainder - KH_idx_a * p.KW; + Cin_idx_a = CRS_idx_a / (KW * KH); + uint32_t CRS_remainder = CRS_idx_a % (KW * KH); + KH_idx_a = CRS_remainder / KW; + KW_idx_a = CRS_remainder % KW; } #else CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) - Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); - CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; - KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - KW_idx_a = CRS_remainder - KH_idx_a * p.KW; + Cin_idx_a = CRS_idx_a / (KW * KH); + CRS_remainder = CRS_idx_a % (KW * KH); + KH_idx_a = CRS_remainder / KW; + KW_idx_a = CRS_remainder % KW; #endif /* Load kernel to A_block: (BS_K x BS_CRS)*/ - for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { uint32_t B_ly = r_offset + Ar; uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ @@ -262,27 +265,27 @@ void main() { KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); } else { CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ - Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); - uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; - KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - KW_idx_b = CRS_remainder - KH_idx_b * p.KW; + Cin_idx_b = CRS_idx_b / (KW * KH); + uint32_t CRS_remainder = CRS_idx_b % (KW * KH); + KH_idx_b = CRS_remainder / KW; + KW_idx_b = CRS_remainder % KW; } #else CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ - Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); - uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; - KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - KW_idx_b = CRS_remainder - KH_idx_b * p.KW; + Cin_idx_b = CRS_idx_b / (KW * KH); + uint32_t CRS_remainder = CRS_idx_b % (KW * KH); + KH_idx_b = CRS_remainder / KW; + KW_idx_b = CRS_remainder % KW; #endif #ifdef TRANSPOSE - uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; - uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; - uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L); - uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L); + uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * d1 + p1; + uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * d0 + p0; + uint32_t H_idx = H_idx_x_s1 / s1; + uint32_t W_idx = W_idx_x_s0 / s0; #else - uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; - uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; + uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1; + uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0; #endif uint32_t src_idx = min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); @@ -290,7 +293,7 @@ void main() { if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) #ifdef TRANSPOSE - || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0) + || (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0) #endif ) { val = 0.0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index 99595fc688c..c1ad5172562 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -3,6 +3,9 @@ #include "rte.glsl" #include "utils.glsl" +#if RMS_NORM_ROPE_FUSION +#include "rope_params.glsl" +#endif layout (push_constant) uniform parameter { @@ -12,11 +15,16 @@ layout (push_constant) uniform parameter uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; uint misalign_offsets; float param1; float param2; int param3; +#if RMS_NORM_ROPE_FUSION + rope_params rope; +#endif } p; +#if !RMS_NORM_ROPE_FUSION layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#endif // true if src0/src1 are the same shape and the indices can be reused without additional modulus layout(constant_id = 0) const bool norepeat = false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index d260969f07e..5c5251da39b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -100,7 +100,6 @@ layout (push_constant) uniform parameter layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; -layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; @@ -109,6 +108,14 @@ layout (constant_id = 8) const uint TN = 2; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat layout (constant_id = 10) const uint WARP = 32; +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#define BK 32 +#define BK_STEP 4 +#else +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +#define BK_STEP 2 +#endif + #ifdef COOPMAT #define SHMEM_STRIDE (BK / 2 + 4) #else @@ -244,8 +251,13 @@ void main() { } #else ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; +#if defined(DATA_A_F32) || defined(DATA_A_F16) + FLOAT_TYPE_VEC4 cache_a[WMITER * TM]; + FLOAT_TYPE_VEC4 cache_b; +#else FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; FLOAT_TYPE_VEC2 cache_b; +#endif [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); @@ -283,24 +295,41 @@ void main() { } } #else - [[unroll]] for (uint i = 0; i < BK / 2; i++) { + [[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { + #if defined(DATA_A_F32) || defined(DATA_A_F16) + cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i ]; + cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1]; + #else cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; + #endif } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { + #if defined(DATA_A_F32) || defined(DATA_A_F16) + cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i ]; + cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1]; + #else cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i]; + #endif [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; + #if defined(DATA_A_F32) || defined(DATA_A_F16) + sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), + fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x)))); + sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), + fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y)))); + #else sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); + #endif } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index d5b211ffaa7..3a47949d5a6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -3,6 +3,32 @@ #include "generic_binary_head.glsl" #include "types.glsl" +#if RMS_NORM_ROPE_FUSION + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; + +// data is passed from rms_norm -> rope through shared memory. +// rms_norm calls this data_d, rope calls this rope_data_a. +// Binding 2 is not used +shared FLOAT_TYPE rope_data_a[1024]; +#define data_d rope_data_a + +layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];}; +layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];}; +layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];}; +layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows + +#include "rope_params.glsl" +#include "rope_funcs.glsl" + +#define GGML_ROPE_TYPE_NORMAL 0 +#define GGML_ROPE_TYPE_NEOX 2 +#define GGML_ROPE_TYPE_MROPE 8 +#define GGML_ROPE_TYPE_VISION 24 + +#endif + #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 @@ -28,8 +54,12 @@ void rms_norm(uint num_iters) { uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); +#if RMS_NORM_ROPE_FUSION + // Per-row offset in shared memory + uint32_t d_offset = 0; +#else uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); - +#endif FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { @@ -79,6 +109,18 @@ void rms_norm(uint num_iters) { data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } +#if RMS_NORM_ROPE_FUSION + barrier(); + rope_params rp = p.rope; + uint rope_row = (samp*nchannels + channel)*nrows + row; + for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) { + if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) { + rope_neox(t, rope_row, rp); + } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) { + rope_norm(t, rope_row, rp); + } + } +#endif } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl new file mode 100644 index 00000000000..9726b722d1e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -0,0 +1,227 @@ + +float rope_yarn_ramp(const float low, const float high, const uint i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) { +#if RMS_NORM_ROPE_FUSION + // Per-row offset in shared memory + const uint ix = i0; +#else + const uint ix = i02*p.nb02 + i01*p.nb01 + i0; +#endif + return ix; +} + +void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) { + float mscale = p.attn_factor; + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = p.freq_scale * theta_extrap; + float theta = theta_interp; + if (p.ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); + } + // Backprogagation uses inverted rotation + if (p.is_back != 0) { + theta = -theta; + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} + +void rope_norm(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + uint idst = i1*ne0 + i0; + const uint ix = rope_a_coord(i0, i01, i02, p); + + // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. + if (p.set_rows_stride != 0) { + idst = i01*ne0 + i0; + idst += rope_data_i[i02].x * p.set_rows_stride; + } + + if (i0 >= p.n_dims) { + rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); + rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]); + + return; + } + + const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + 1]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + +void rope_neox(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. + if (p.set_rows_stride != 0) { + idst = i01*ne0 + i0/2; + idst += rope_data_i[i02].x * p.set_rows_stride; + } + + if (i0 >= p.n_dims) { + rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); + rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]); + + return; + } + + const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims/2]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + + +void rope_multi(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + const uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + if (i0 >= p.n_dims) { + rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); + rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]); + + return; + } + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (p.is_imrope != 0) { + if (sector % 3 == 1 && sector < 3 * p.sections[1]) { + theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { + theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + } else { + theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + } else { + if (sector < p.sections[0]) { + theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + } + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims/2]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + +void rope_vision(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + const uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index fa2bb33394c..d9b4d4c03f3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -3,56 +3,18 @@ #extension GL_EXT_shader_16bit_storage : require #include "rte.glsl" +#include "rope_params.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {int data_pos[];}; -layout (binding = 2) readonly buffer Z {float data_ff[];}; -layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; -layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows +layout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];}; +layout (binding = 1) readonly buffer Y {int rope_data_pos[];}; +layout (binding = 2) readonly buffer Z {float rope_data_ff[];}; +layout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];}; +layout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows -layout (push_constant) uniform parameter { - uint ncols; - uint n_dims; - float freq_scale; - uint p_delta_rows; - float freq_base; - float ext_factor; - float attn_factor; - float corr_dims[2]; - float theta_scale; - uint has_ff; - uint ne02; - uint s1; - uint s2; - int sections[4]; - uint is_imrope; - uint is_back; - uint set_rows_stride; -} p; - -float rope_yarn_ramp(const float low, const float high, const uint i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} -void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { - float mscale = p.attn_factor; - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = p.freq_scale * theta_extrap; - float theta = theta_interp; - if (p.ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; +layout (push_constant) uniform parameter { + rope_params pc; +}; - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); - } - // Backprogagation uses inverted rotation - if (p.is_back != 0) { - theta = -theta; - } - cos_theta = cos(theta) * mscale; - sin_theta = sin(theta) * mscale; -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 54aabcf2228..7c1fb1cd224 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -1,70 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - const uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - if (i0 >= p.n_dims) { - data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; - data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; - - return; - } - - const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; - const int sec_w = p.sections[1] + p.sections[0]; - const uint sector = (i0 / 2) % sect_dims; - - float theta_base = 0.0; - if (p.is_imrope != 0) { - if (sector % 3 == 1 && sector < 3 * p.sections[1]) { - theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } else { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); - } - } else { - if (sector < p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); - } - } - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims/2]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_multi(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 9f4538155a0..68f00c180bb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -1,48 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - // Fusion optimization: ROPE + VIEW + SET_ROWS.. - // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. - if (p.set_rows_stride != 0) { - idst = row_x*ne0 + i0/2; - idst += data_i[channel_x].x * p.set_rows_stride; - } - - if (i0 >= p.n_dims) { - data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]); - data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]); - - return; - } - - const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims/2]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_neox(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index f4209ed9582..28a939ec6ad 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -1,48 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - uint idst = row_dst*ne0 + i0; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; - - // Fusion optimization: ROPE + VIEW + SET_ROWS.. - // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. - if (p.set_rows_stride != 0) { - idst = row_x*ne0 + i0; - idst += data_i[channel_x].x * p.set_rows_stride; - } - - if (i0 >= p.n_dims) { - data_d[idst + 0] = D_TYPE(data_a[ix + 0]); - data_d[idst + 1] = D_TYPE(data_a[ix + 1]); - - return; - } - - const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + 1]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_norm(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl new file mode 100644 index 00000000000..82f39cee349 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -0,0 +1,27 @@ +#if !defined(GGML_ROPE_PARAMS) +#define GGML_ROPE_PARAMS + +#include "rte.glsl" + +struct rope_params { + uint rope_mode; + uint ncols; + uint n_dims; + float freq_scale; + uint p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint has_ff; + uint ne02; + uint nb01; + uint nb02; + int sections[4]; + uint is_imrope; + uint is_back; + uint set_rows_stride; +}; + +#endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index d37d1c1043f..ea1e0fdb416 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -1,47 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - const uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - const int sect_dims = p.sections[0] + p.sections[1]; - const int sec_w = p.sections[1] + p.sections[0]; - const uint sector = (i0 / 2) % sect_dims; - - float theta_base = 0.0; - if (sector < p.sections[0]) { - const uint p0 = sector; - theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); - } - else if (sector >= p.sections[0] && sector < sec_w) { - const uint p0 = sector - p.sections[0]; - theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); - } - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_vision(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bd178875d55..c2e42cf006e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -695,6 +695,8 @@ void process_shaders() { string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}})); + string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -840,25 +842,25 @@ void process_shaders() { string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - - string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - - string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - - string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1a157567315..9e8cbc477ed 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -73,6 +74,30 @@ // For operations which process a row in parallel, this seems like a reasonable default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 +// Matrix multiplication parameters + +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 8 +#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_K 32 + +// Subgroup matrix parameters +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 + +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_TILE_K 256 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -236,6 +261,10 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + bool supports_subgroup_matrix = false; + uint32_t subgroup_size; + wgpu::SubgroupMatrixConfig subgroup_matrix_config; + // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. uint32_t max_wg_size_x; @@ -247,6 +276,11 @@ struct webgpu_context_struct { webgpu_buf_pool set_rows_error_buf_pool; webgpu_pipeline memset_pipeline; + + std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + std::map>> + mul_mat_vec_pipelines; // src0_type, src1_type, vectorized + webgpu_pipeline mul_mat_pipeline[30][2]; webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized webgpu_pipeline get_rows_pipeline[30]; @@ -321,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ +// Process a WGSL shader string, replacing tokens of the form {{KEY}} with +// the corresponding values provided in `repls`. +static std::string ggml_webgpu_process_shader_repls(const char * src, + const std::map & repls) { + if (!src) { + return std::string(); + } + std::string s = src; + for (const auto & kv : repls) { + std::string token = "{{" + kv.first + "}}"; + size_t pos = 0; + while ((pos = s.find(token, pos)) != std::string::npos) { + s.replace(pos, token.length(), kv.second); + pos += kv.second.length(); + } + } + return s; +} + static void ggml_webgpu_create_pipeline(wgpu::Device & device, webgpu_pipeline & pipeline, const char * shader_code, @@ -346,6 +399,30 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } +static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code; + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + if (constants.size() > 0) { + pipeline_desc.compute.constants = constants.data(); + pipeline_desc.compute.constantCount = constants.size(); + } + return { device.CreateComputePipeline(&pipeline_desc), label }; +} + static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::Buffer & buffer, size_t size, @@ -512,6 +589,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, + uint32_t wg_y = 1, std::optional set_rows_error_bufs = std::nullopt) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); @@ -557,7 +635,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & #endif pass.SetPipeline(pipeline.pipeline); pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups(wg_x, 1, 1); + pass.DispatchWorkgroups(wg_x, wg_y, 1); pass.End(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -779,7 +857,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -835,8 +913,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[1], // number of rows in result (M) - (uint32_t) dst->ne[0], // number of columns in result (N) + (uint32_t) dst->ne[0], // number of rows in result (M, transposed) + (uint32_t) dst->ne[1], // number of columns in result (N) (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 @@ -865,9 +943,67 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; + webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type]; + uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); + uint32_t wg_y = 1; + + bool use_fast = false; + switch (src1->type) { + case GGML_TYPE_F16: + use_fast = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + use_fast = true; + break; + default: + break; + } + break; + default: + break; + } + + if (use_fast) { + int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; + if (dst->ne[1] == 1) { + // We don't support vectorized mul_mat_vec for quantized types + vectorized = vectorized && (src0->type < 2); + pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = + (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; + wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / + ctx->limits.maxComputeWorkgroupsPerDimension; + } else { + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; + uint32_t wg_m; + uint32_t wg_n; + if (ctx->supports_subgroup_matrix) { + // The total number of subgroups/workgroups needed per matrix. + uint32_t wg_m_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; + wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile; + uint32_t wg_n_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; + wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; + } else { + uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; + uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; + wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; + wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + } + } + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1583,12 +1719,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], - wgsl_mul_mat_f32_f32, "mul_mat_f32_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], - wgsl_mul_mat_f16_f16, "mul_mat_f16_f16"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], - wgsl_mul_mat_f16_f32, "mul_mat_f16_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], @@ -1627,6 +1757,136 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + + if (webgpu_ctx->supports_subgroup_matrix) { + std::map sg_matrix_repls; + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); + sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); + sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); + sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); + sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); + sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); + + std::string proc_mul_mat_subgroup_matrix_f32_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f16 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), + "mul_mat_subgroup_matrix_f32_f32_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), + "mul_mat_subgroup_matrix_f16_f32_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), + "mul_mat_subgroup_matrix_f16_f16_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), + "mul_mat_subgroup_matrix_q4_0_f32_vec"); + } else { + std::vector mul_mat_reg_tile_constants(3); + mul_mat_reg_tile_constants[0].key = "TILE_K"; + mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; + mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; + mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; + mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; + mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; + + std::map reg_repls; + reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); + reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); + + // Process each reg-tile shader with tile replacements. + // Keep the processed strings in-scope so .c_str() remains valid. + std::string proc_mul_mat_reg_tile_f32_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + std::string proc_mul_mat_reg_tile_f32_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f16 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f16_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + std::string proc_mul_mat_reg_tile_q4_0_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + std::string proc_mul_mat_reg_tile_q4_0_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), + "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), + "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), + "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), + "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), + "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), + "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), + "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), + "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + } + + std::vector mul_mat_vec_constants(3); + mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; + mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; + mul_mat_vec_constants[1].key = "TILE_K"; + mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; + mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; + mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { @@ -2124,7 +2384,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; - wgpu::RequestAdapterOptions options = {}; + // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + wgpu::DawnTogglesDescriptor adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; + wgpu::RequestAdapterOptions options = {}; + options.nextInChain = &adapterTogglesDesc; ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { @@ -2140,12 +2406,46 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; + if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + info.nextInChain = &subgroup_matrix_configs; + } ctx->adapter.GetInfo(&info); + wgpu::SupportedFeatures features; + ctx->adapter.GetFeatures(&features); + // we require f16 support + GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + + // Only support square f16 matrices of size 8 or 16 for now + bool valid_subgroup_matrix_config = false; + if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && + config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->subgroup_matrix_config = config; + valid_subgroup_matrix_config = true; + break; + } + } + } + + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + ctx->subgroup_size = info.subgroupMaxSize; + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; + // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; + if (ctx->supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::Subgroups); + required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 251051eaeca..ed8068d416e 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile): except ValueError: decls_map = {} - with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f: - common_decls = f.read() - decls_map.update(parse_decls(common_decls)) + for fname in sorted(os.listdir(input_dir)): + if fname.endswith(".tmpl"): + tmpl_path = os.path.join(input_dir, fname) + with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: + decls = f_tmpl.read() + decls_map.update(parse_decls(decls)) shader_template = extract_block(text, "SHADER") for variant in variants: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl index 141db9b39d9..0f8e6e5ac3d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -864,8 +864,8 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; @@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { let dst2_rem = dst3_rem % dst2_stride; - let row = dst2_rem / params.n; // output row - let col = dst2_rem % params.n; // output column + let row = dst2_rem / params.m; // output row + let col = dst2_rem % params.m; // output column let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; @@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { sum += multiply_add(src0_idx_base, src1_idx_base, i); } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; + dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; } #end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl new file mode 100644 index 00000000000..109ff8d6159 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -0,0 +1,97 @@ +#decl(SHMEM_VEC) +fn store_shmem(val: vec4, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; +} +#enddecl(SHMEM_VEC) + +#decl(SHMEM_SCALAR) +fn store_shmem(val: f16, idx: u32) { + shmem[idx] = val; +} +#enddecl(SHMEM_SCALAR) + +#decl(INIT_SRC0_SHMEM_FLOAT) + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let src0_val = select( // taking a slight performance hit to avoid oob + {{SRC0_TYPE}}(0.0), + src0[src0_idx/{{VEC_SIZE}}], + global_m < params.m && global_k < params.k); + store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + } +} + +#enddecl(INIT_SRC0_SHMEM_FLOAT) + +#decl(INIT_SRC1_SHMEM) + +fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_n = offset_n + tile_n; + let global_k = k_outer + tile_k; + let src1_idx = batch_offset + global_n * params.stride_11 + global_k; + let src1_val = select( + {{SRC1_TYPE}}(0.0), + src1[src1_idx/{{VEC_SIZE}}], + global_n < params.n && global_k < params.k); + store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + } +} + +#enddecl(INIT_SRC1_SHMEM) + +#decl(INIT_SRC0_SHMEM_Q4_0) + +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1u + block_offset + j]; + let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + shmem[shmem_idx + j * 2 + k] = q_lo; + shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + } + } + } + } +} + +#enddecl(INIT_SRC0_SHMEM_Q4_0) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl new file mode 100644 index 00000000000..6b1dd26cd9e --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -0,0 +1,247 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32_vec", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +} +#enddecl(VEC) + +#decl(SCALAR) +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return f32(acc[tm][tn]); +} +#enddecl(SCALAR) + +#end(DECLS) + +#define(SHADER) +enable f16; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +DECLS + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +// TILE_M must be multiple of 4 for vec4 loads +const TILE_M = {{WEBGPU_TILE_M}}u; +const TILE_N = {{WEBGPU_TILE_N}}u; + +override WORKGROUP_SIZE_M: u32; +override WORKGROUP_SIZE_N: u32; +override TILE_K: u32; + +override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var shmem: array; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_id.x / wg_per_matrix; + + let wg_in_batch = wg_id.x % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; + let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + var acc: array, TILE_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tm][tn] += src0_tile[tm] * src1_val; + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tn = 0u; tn < TILE_N; tn++) { + let global_col = output_col_base + tn; + if (global_col < params.n) { + for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); + } + } + } + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl new file mode 100644 index 00000000000..47c8ce36ab3 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -0,0 +1,302 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32_vec", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = vec4( + f32(shmem[shmem_idx]), + f32(shmem[shmem_idx + 1]), + f32(shmem[shmem_idx + 2]), + f32(shmem[shmem_idx + 3]) + ); +} +#enddecl(VEC) + +#decl(SCALAR) +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = f32(shmem[shmem_idx]); +} +#enddecl(SCALAR) + +#end(DECLS) + +#define(SHADER) +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +DECLS + +// Note: These are string interpolated at build time, cannot use override constants due to limitations in +// current Dawn version type definitions/matrix load requirements for constant memory sizes. +const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; +const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; +// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the +// runtime subgroup size is smaller. +const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; + +const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; + +const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; +const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; +const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; + +const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; +const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; + +const TILE_K = {{WEBGPU_TILE_K}}u; + +const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + +const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; +const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + +const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; + +// We reuse shmem for accumulation matrices +const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); + +var shmem: array; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32) { + + let thread_id = local_id.x; + let subgroup_m = subgroup_id % SUBGROUP_M; + let subgroup_n = subgroup_id / SUBGROUP_M; + + let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE; + let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_id.x / wg_per_matrix; + + let wg_in_batch = wg_id.x % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + + var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + if (subgroup_id < EXPECTED_SUBGROUPS) { + + for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { + + let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; + var src0_sg_mats: array, SUBGROUP_MATRIX_M>; + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + src0_sg_mats[m] = subgroupMatrixLoad>( + &shmem, + src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, + false, + TILE_K + ); + } + + let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + let src1_sg_mat = subgroupMatrixLoad>( + &shmem, + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, + true, + TILE_K + ); + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); + } + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + // Stage the subgroup matrix tiles into shared memory + // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile). + let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE; + let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + + if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :( + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; + let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; + let out_base = local_row * WG_TILE_STRIDE + local_col; + subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); + } + } + } + + workgroupBarrier(); + + // Cooperative write: iterate over the entire workgroup tile + let tile_rows = WG_N_SG_TILE_SIZE; + let tile_cols = WG_M_SG_TILE_SIZE; + let total_tile_elems = tile_rows * tile_cols; + let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + + for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let local_row = idx % WG_TILE_STRIDE; + let local_col = idx / WG_TILE_STRIDE; + + let global_row = tile_dst_row_base + local_row; + let global_col = tile_dst_col_base + local_col; + + if (global_col < params.n && global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + store_dst(idx, dst_idx/{{VEC_SIZE}}); + } + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl new file mode 100644 index 00000000000..ffbb6403285 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl @@ -0,0 +1,267 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { + return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); +} + +fn store_val(group_base: u32) -> vec4 { + return vec4(partial_sums[group_base], + partial_sums[group_base + THREADS_PER_OUTPUT], + partial_sums[group_base + THREADS_PER_OUTPUT * 2], + partial_sums[group_base + THREADS_PER_OUTPUT * 3]); +} +#enddecl(VEC) + +#decl(SCALAR) +fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { + return f32(src0_val) * f32(src1_val); +} + +fn store_val(group_base: u32) -> f32 { + return partial_sums[group_base]; +} +#enddecl(SCALAR) + +#decl(MUL_ACC_FLOAT) + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { + let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; + let b = shared_vector[i / {{VEC_SIZE}}]; + local_sum += inner_dot(a, b); + } + return local_sum; +} + +#enddecl(MUL_ACC_FLOAT) + +#decl(MUL_ACC_Q4_0) + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1 + block_offset + j]; + let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0) * d; + local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + } + } + } + return local_sum; +} + +#enddecl(MUL_ACC_Q4_0) + +#end(DECLS) + +#define(SHADER) +enable f16; + +DECLS + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +override WORKGROUP_SIZE: u32; +override TILE_K: u32; +override OUTPUTS_PER_WG: u32; +override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; + +// Shared memory for collaborative loading and reduction +var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile +var partial_sums: array; // For reduction + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let thread_id = local_id.x; + + // Handle batch dimensions + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + // Which of the outputs does this thread belong to? + let thread_group = thread_id / THREADS_PER_OUTPUT; + let thread_in_group = thread_id % THREADS_PER_OUTPUT; + + // Each workgroup computes OUTPUTS_PER_WG consecutive outputs + let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + + var local_sum = 0.0; + + // Each thread processes multiple K elements and accumulates + for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { + let tile_size = min(TILE_K, params.k - k_tile); + + // Cooperatively load vector tile into shared memory (all threads) + for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { + shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; + } + + workgroupBarrier(); + + if (output_row < params.m) { + local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + } + + workgroupBarrier(); + } + + // Store partial sums and reduce within each partition + partial_sums[thread_id] = local_sum; + workgroupBarrier(); + let group_base = thread_group * THREADS_PER_OUTPUT; + let thread_base = group_base + thread_in_group; + var offset = THREADS_PER_OUTPUT / 2; + while (offset > 0) { + if (thread_in_group < offset) { + partial_sums[thread_base] += partial_sums[thread_base + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + + // Store back to global memory + if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { + dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); + } +} +#end(SHADER) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b9ae82eeddd..2470c148d66 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2294,6 +2294,79 @@ struct test_rope_set_rows : public test_case { } }; +// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS) +struct test_rms_norm_mul_rope : public test_case { + const std::array ne; + const float eps; + const bool multi_add; // test a sequence of adds feeding into rms_norm + const bool set_rows; + int mode; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "RMS_NORM_MUL_ROPE"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode); + } + + test_rms_norm_mul_rope(std::array ne, float eps = 1e-6f, bool multi_add = false, + bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL) + : ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1); + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1); + ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1); + + if (multi_add) { + a = ggml_add(ctx, ggml_add(ctx, a, b), c); + } + + a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b); + + ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]); + + ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode); + + ggml_tensor * out; + + if (set_rows) { + ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0); + + ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1); + ggml_set_name(dst, "dst"); + + ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1); + ggml_set_name(row_idxs, "row_idxs"); + + out = ggml_set_rows(ctx, dst, view, row_idxs); + ggml_set_name(out, "out"); + } else { + out = rope; + } + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { + continue; + } + + init_set_rows_row_ids(t, ne[2]); + } else { + init_tensor_uniform(t); + } + } + } +}; + // GGML_OP_ARGMAX struct test_argmax : public test_case { const ggml_type type; @@ -4809,60 +4882,6 @@ struct test_topk_moe: public test_case { } }; -struct test_moe_expert_reduce : public test_case { - const int64_t n_embd; - const int64_t n_tokens; - const int64_t n_expert_used; - - test_moe_expert_reduce(int64_t n_embd = 64, int64_t n_tokens = 5, int64_t n_expert_used = 4) - : n_embd(n_embd), n_tokens(n_tokens), n_expert_used(n_expert_used) { - GGML_ASSERT(n_expert_used > 1); - } - - std::string vars() override { - return VARS_TO_STR3(n_embd, n_tokens, n_expert_used); - } - - std::string op_desc(ggml_tensor * t) override { - GGML_UNUSED(t); - return "MOE_EXPERT_REDUCE"; - } - - bool run_whole_graph() override { return true; } - - ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * experts = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_expert_used, n_tokens); - ggml_set_name(experts, "experts"); - - ggml_tensor * weights = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, n_expert_used, n_tokens); - ggml_set_name(weights, "weights"); - - ggml_tensor * weighted = ggml_mul(ctx, experts, weights); - ggml_set_name(weighted, "weighted_experts"); - - std::vector expert_views(n_expert_used); - for (int64_t i = 0; i < n_expert_used; ++i) { - expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[2], i * weighted->nb[1]); - - std::string name = "expert_view_" + std::to_string(i); - ggml_set_name(expert_views[i], name.c_str()); - ggml_build_forward_expand(gf, expert_views[i]); - } - - ggml_tensor * moe_out = expert_views[0]; - for (int64_t i = 1; i < n_expert_used; ++i) { - moe_out = ggml_add(ctx, moe_out, expert_views[i]); - - std::string name = "expert_add_" + std::to_string(i - 1); - ggml_set_name(moe_out, name.c_str()); - } - - ggml_set_name(moe_out, "moe_out"); - - return moe_out; - } -}; - struct test_mul_mat_vec_fusion : public test_case { const ggml_type type; const ggml_glu_op glu_op; @@ -4911,8 +4930,10 @@ struct test_mul_mat_vec_fusion : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { if (!use_id) { - std::array ne = {k, m, 1, 1}; - std::array ne0 = {k, n, 1, 1}; + const int channels = 4; + const int samples = 2; + std::array ne = { k, m, channels, samples }; + std::array ne0 = { k, n, channels, samples }; ggml_tensor * cur = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr; @@ -4920,14 +4941,14 @@ struct test_mul_mat_vec_fusion : public test_case { ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur); if (with_bias) { - std::array bias_ne = {ffn_up->ne[0], 1, 1, 1}; + std::array bias_ne = { ffn_up->ne[0], 1, channels, samples }; ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data()); ffn_up = ggml_add(ctx, ffn_up, up_bias); } ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr; if (with_bias && with_gate) { - std::array bias_ne = {ffn_gate->ne[0], 1, 1, 1}; + std::array bias_ne = { ffn_gate->ne[0], 1, channels, samples }; ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data()); ffn_gate = ggml_add(ctx, ffn_gate, gate_bias); } @@ -6751,6 +6772,22 @@ static std::vector> make_test_cases_eval() { } } + for (auto multi_add : {false, true}) { + for (auto set_rows : {false, true}) { + for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) { + test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope)); + } + } + } + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); for (int64_t d_conv : {3, 4}) { @@ -7324,10 +7361,6 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true)); test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true)); - test_cases.emplace_back(new test_moe_expert_reduce(1024, 5, 4)); - test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 6)); - test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 7)); - #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true)); diff --git a/tools/server/README.md b/tools/server/README.md index 6828ef73824..8fd478eb328 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -512,7 +512,7 @@ These words will not be included in the completion, so make sure to add them to `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` -`return_progress`: Include prompt processing progress in `stream` mode. The progress will be contained inside `prompt_progress` with 3 values: `total`, `cache` and `processed`. The overall progress is `processed/total`, while the actual timed progress is `(processed-cache)/(total-cache)`. Default: `false` +`return_progress`: Include prompt processing progress in `stream` mode. The progress will be contained inside `prompt_progress` with 4 values: `total`, `cache`, `processed`, and `time_ms`. The overall progress is `processed/total`, while the actual timed progress is `(processed-cache)/(total-cache)`. The `time_ms` field contains the elapsed time in milliseconds since prompt processing started. Default: `false` `post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain. diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index a796c255c18..976d6585da6 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 164e8cf4e70..9d91e32d1fb 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3078,7 +3078,7 @@ struct server_context { res->progress.total = slot.task->n_tokens(); res->progress.cache = slot.n_prompt_tokens_cache; res->progress.processed = slot.prompt.tokens.size(); - res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); + res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; } else { res->content = tkn.text_to_send; res->tokens = { tkn.tok }; diff --git a/tools/server/webui/src/routes/+layout.svelte b/tools/server/webui/src/routes/+layout.svelte index 075bdd356bc..b08bd59c15e 100644 --- a/tools/server/webui/src/routes/+layout.svelte +++ b/tools/server/webui/src/routes/+layout.svelte @@ -44,12 +44,12 @@ } } - if (isCtrlOrCmd && event.shiftKey && event.key === 'o') { + if (isCtrlOrCmd && event.shiftKey && event.key === 'O') { event.preventDefault(); goto('?new_chat=true#/'); } - if (event.shiftKey && isCtrlOrCmd && event.key === 'e') { + if (event.shiftKey && isCtrlOrCmd && event.key === 'E') { event.preventDefault(); if (chatSidebar?.editActiveConversation) {