From 53af4dba425768fd507ce6b45873cfbb3df2295f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 25 Mar 2025 23:03:10 +0100 Subject: [PATCH 1/8] convert: fix Mistral3/Gemma3 model hparams init (#12571) * Fix Mistral3/Gemma3 model hparams init * set positional args correctly * use existing hparams if passed --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d9fa57027b771..76ab4233ef2c1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1752,7 +1752,7 @@ class Mistral3Model(LlamaModel): # we need to merge the text_config into the root level of hparams def __init__(self, *args, **kwargs): - hparams = Model.load_hparams(kwargs["dir_model"]) + hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0]) if "text_config" in hparams: hparams = {**hparams, **hparams["text_config"]} kwargs["hparams"] = hparams @@ -3385,7 +3385,7 @@ class Gemma3Model(Model): # we need to merge the text_config into the root level of hparams def __init__(self, *args, **kwargs): - hparams = Model.load_hparams(kwargs["dir_model"]) + hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0]) if "text_config" in hparams: hparams = {**hparams, **hparams["text_config"]} kwargs["hparams"] = hparams @@ -5358,7 +5358,7 @@ def main() -> None: logger.error(f"Model {model_architecture} is not supported") sys.exit(1) - model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out, + model_instance = model_class(dir_model, output_type, fname_out, is_big_endian=args.bigendian, use_temp_file=args.use_temp_file, eager=args.no_lazy, metadata_override=args.metadata, model_name=args.model_name, From fd7855f8f522ab26681b25ae506ff99c654e2dd5 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Wed, 26 Mar 2025 15:09:48 +0800 Subject: [PATCH 2/8] doc: [MUSA] minor changes (#12583) Signed-off-by: Xiaodong Ye --- ci/README.md | 2 +- docs/build.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/README.md b/ci/README.md index db4d9066816e8..2d0004b18c622 100644 --- a/ci/README.md +++ b/ci/README.md @@ -60,7 +60,7 @@ docker run --privileged -it \ Inside the container, execute the following commands: ```bash -apt update -y && apt install -y cmake git python3.10-venv wget +apt update -y && apt install -y bc cmake git python3.10-venv time unzip wget git config --global --add safe.directory /ws GG_BUILD_MUSA=1 bash ./ci/run.sh /ci-results /ci-cache ``` diff --git a/docs/build.md b/docs/build.md index fbf12c7664d50..aa1db9a047130 100644 --- a/docs/build.md +++ b/docs/build.md @@ -218,6 +218,7 @@ By default, all supported compute capabilities are enabled. To customize this be ```bash cmake -B build -DGGML_MUSA=ON -DMUSA_ARCHITECTURES="21" +cmake --build build --config Release ``` This configuration enables only compute capability `2.1` (MTT S80) during compilation, which can help reduce compilation time. From 5ed38b6852bd509d56acfdae54bceec8ab3cc396 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Mar 2025 13:02:00 +0200 Subject: [PATCH 3/8] ggml : fix MUL_MAT_ID repack with Q8_K (#12544) * ggml : fix MUL_MAT_ID repack with Q8_K ggml-ci * ggml : improve repack templates ggml-ci --- ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp | 178 ++++++++++++------------- 1 file changed, 87 insertions(+), 91 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index e852c8253b44e..74a31abb2d6fc 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -250,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -344,7 +344,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC #endif } -static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -559,7 +559,7 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC #endif } -static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); const int nb = k / QK_K; @@ -811,7 +811,7 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on for (int j = 0; j < QK_K * 4; j++) { int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; src_offset += (j % blck_size_interleave); int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3); @@ -823,26 +823,25 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC #endif } -static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { +template +void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row); + +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - if (blck_size_interleave == 4) { - quantize_q8_0_4x4(x, vy, n_per_row); - } else if (blck_size_interleave == 8) { - quantize_q8_0_4x8(x, vy, n_per_row); - } else { - assert(false); - } + ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } -static void quantize_mat_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - if (blck_size_interleave == 8) { - quantize_q8_K_4x8(x, vy, n_per_row); - } else { - assert(false); - } + ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); +} + +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -5276,52 +5275,50 @@ template <> int repack(struct ggml_tensor * t, const void * //} // gemv -template +template void gemv(int, float *, size_t, const void *, const void *, int, int); -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> -void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } // gemm -template +template void gemm(int, float *, size_t, const void *, const void *, int, int); -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> -void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -5335,32 +5332,32 @@ template op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - return true; - case GGML_OP_MUL_MAT_ID: - size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. - size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; - return true; - default: - // GGML_ABORT("fatal error"); - break; + case GGML_OP_MUL_MAT: + size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); + return true; + case GGML_OP_MUL_MAT_ID: + size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); + size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; + return true; + default: + // GGML_ABORT("fatal error"); + break; } return false; } bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { switch (op->op) { - case GGML_OP_MUL_MAT: - forward_mul_mat(params, op); - return true; - case GGML_OP_MUL_MAT_ID: - forward_mul_mat_id(params, op); - return true; - default: - // GGML_ABORT("fatal error"); - break; + case GGML_OP_MUL_MAT: + forward_mul_mat(params, op); + return true; + case GGML_OP_MUL_MAT_ID: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; } return false; } @@ -5399,17 +5396,10 @@ template from_float; int64_t i11_processed = 0; - if(PARAM_TYPE == GGML_TYPE_Q8_K) { - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - quantize_mat_q8_K((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10, - INTER_SIZE); - } - } else { - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10, - INTER_SIZE); - } + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + ggml_quantize_mat_t((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10); } + i11_processed = ne11 - ne11 % 4; for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); @@ -5422,22 +5412,24 @@ template = src0_end) { return; } // If there are more than three rows in src1, use gemm; otherwise, use gemv. if (ne11 > 3) { - gemm(ne00, (float *) ((char *) dst->data) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + gemm(ne00, + (float *) ((char *) dst->data) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); } for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { - gemv(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata + (src1_col_stride * iter), 1, - src0_end - src0_start); + gemv(ne00, + (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); } } @@ -5452,7 +5444,7 @@ template ith; const int nth = params->nth; - const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float; + const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(src0->type)); @@ -5474,7 +5466,7 @@ template ne[0]; // n_expert_used const int n_as = ne02; // n_expert - const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10); + const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; @@ -5486,12 +5478,13 @@ template wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) + n_as * ne12 * sizeof(mmid_row_mapping))); - auto wdata = (char *) params->wdata; - auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + auto * wdata = (char *) params->wdata; + auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); + auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] - // src1: float32 => block_q8_0 + // src1: float32 => param type for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) { from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11), @@ -5530,34 +5523,37 @@ template data + cur_a*nb02; + const auto * src0_cur = (const char *) src0->data + cur_a*nb02; //const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows int64_t src0_cur_start = (ith * ne01) / nth; int64_t src0_cur_end = ((ith + 1) * ne01) / nth; - src0_cur_start = - (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; - src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; - if (src0_cur_start >= src0_cur_end) return; + src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + if (src0_cur_start >= src0_cur_end) { + return; + } for (int ir1 = 0; ir1 < nr1; ir1++) { struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); - const int id = row_mapping.i1; // selected expert index - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row - auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); + const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); - gemv( - ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, - ne01, src0_cur + src0_cur_start * nb01, + gemv(ne00, + (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); } } @@ -5578,7 +5574,7 @@ static const tensor_traits q4_0_8x8_q8_0; static const tensor_traits q4_K_8x8_q8_K; // instance for IQ4 -static const tensor_traits iq4_nl_4x4_q8_0; +static const tensor_traits iq4_nl_4x4_q8_0; } // namespace ggml::cpu::aarch64 From df4d20cd53d5bb6fc21c1dc65f026d53b566d097 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Mar 2025 14:21:05 +0200 Subject: [PATCH 4/8] convert : fix squeeze for ssm_conv tensors (#12573) * convert : fix squeeze for ssm_conv tensors * convert : match ssm_conv tensors by type --------- Co-authored-by: Francis Couture-Harpin --- convert_hf_to_gguf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 76ab4233ef2c1..52637c42f723a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3803,8 +3803,6 @@ def set_gguf_parameters(self): _tok_embd = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD) @@ -3814,6 +3812,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter logger.debug("A_log --> A ==> " + new_name) data_torch = -torch.exp(data_torch) + # [4 1 8192 1] -> [4 8192 1 1] + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + # assuming token_embd.weight is seen before output.weight if self._tok_embd is not None and new_name == output_name: if torch.equal(self._tok_embd, data_torch): From 02082f1519565fc7b49de211b28bc5404a69209b Mon Sep 17 00:00:00 2001 From: Ivy233 <952254420@qq.com> Date: Wed, 26 Mar 2025 22:06:04 +0800 Subject: [PATCH 5/8] clip: Fix llama-llava-clip-quantize-cli quantization error under CUDA backend (#12566) * [Fix] Compiling clip-quantize-cli and running it in a CUDA environment will cause ggml_fp16_to_fp32 to report an error when trying to access video memory. You need to switch to the CPU backend to run quantize. After the fix, it will automatically run in the CPU backend and will no longer be bound to CUDA. * [Fix]Roll back the signature and implementation of clip_model_load, and change the call in clip_model_quantize to clip_init. --- examples/llava/clip.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index a1f050e39a094..58ee5cf0174b2 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -2989,7 +2989,10 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i assert(itype < GGML_TYPE_COUNT); ggml_type type = static_cast(itype); - auto * ctx_clip = clip_model_load(fname_inp, 2); + auto * ctx_clip = clip_init(fname_inp, clip_context_params{ + /* use_gpu */ false, + /* verbosity */ 2, + }); const auto & ctx_src = ctx_clip->ctx_gguf; const auto & ctx_data = ctx_clip->ctx_data; From 2447ad8a981253a2b8e9f4b31cc8e7fdff83423e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Wed, 26 Mar 2025 11:06:09 -0700 Subject: [PATCH 6/8] upgrade to llguidance 0.7.10 (#12576) --- common/CMakeLists.txt | 4 +- common/llguidance.cpp | 77 ++++++++++++------------------- tests/test-grammar-llguidance.cpp | 62 +++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 49 deletions(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 17146fffc1168..829eb5b7238b9 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -114,8 +114,8 @@ if (LLAMA_LLGUIDANCE) ExternalProject_Add(llguidance_ext GIT_REPOSITORY https://github.com/guidance-ai/llguidance - # v0.6.12: - GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09 + # v0.7.10: + GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8 PREFIX ${CMAKE_BINARY_DIR}/llguidance SOURCE_DIR ${LLGUIDANCE_SRC} BUILD_IN_SOURCE TRUE diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 2feeb93c87e30..8bff89ea4aa9e 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -11,25 +11,24 @@ struct llama_sampler_llg { std::string grammar_kind; std::string grammar_data; LlgTokenizer * tokenizer; - LlgConstraint * grammar; - LlgMaskResult llg_res; - bool has_llg_res; + LlgMatcher * grammar; }; -static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, - const char * grammar_data) { +static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, + const char * grammar_data) { LlgConstraintInit cinit; llg_constraint_init_set_defaults(&cinit, tokenizer); const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL"); if (log_level && *log_level) { cinit.log_stderr_level = atoi(log_level); } - auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data); - if (llg_get_error(c)) { - LOG_ERR("llg error: %s\n", llg_get_error(c)); - llg_free_constraint(c); + auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data); + if (llg_matcher_get_error(c)) { + LOG_ERR("llg error: %s\n", llg_matcher_get_error(c)); + llg_free_matcher(c); return nullptr; } + return c; } @@ -40,39 +39,29 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) { static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) { auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - LlgCommitResult res; - llg_commit_token(ctx->grammar, token, &res); - ctx->has_llg_res = false; + llg_matcher_consume_token(ctx->grammar, token); } } static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - if (!ctx->has_llg_res) { - if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) { - ctx->has_llg_res = true; + const uint32_t * mask = llg_matcher_get_mask(ctx->grammar); + if (mask == nullptr) { + if (llg_matcher_compute_mask(ctx->grammar) == 0) { + mask = llg_matcher_get_mask(ctx->grammar); } else { - LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar)); - llg_free_constraint(ctx->grammar); + LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar)); + llg_free_matcher(ctx->grammar); ctx->grammar = nullptr; + return; } } - if (ctx->has_llg_res) { - if (ctx->llg_res.is_stop) { - for (size_t i = 0; i < cur_p->size; ++i) { - if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) { - cur_p->data[i].logit = -INFINITY; - } - } - } else { - const uint32_t * mask = ctx->llg_res.sample_mask; - for (size_t i = 0; i < cur_p->size; ++i) { - auto token = cur_p->data[i].id; - if ((mask[token / 32] & (1 << (token % 32))) == 0) { - cur_p->data[i].logit = -INFINITY; - } - } + + for (size_t i = 0; i < cur_p->size; ++i) { + auto token = cur_p->data[i].id; + if ((mask[token / 32] & (1 << (token % 32))) == 0) { + cur_p->data[i].logit = -INFINITY; } } } @@ -80,14 +69,9 @@ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array static void llama_sampler_llg_reset(llama_sampler * smpl) { auto * ctx = (llama_sampler_llg *) smpl->ctx; - if (!ctx->grammar) { - return; + if (ctx->grammar) { + llg_matcher_reset(ctx->grammar); } - - auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str()); - llg_free_constraint(ctx->grammar); - ctx->grammar = grammar_new; - ctx->has_llg_res = false; } static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { @@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { if (ctx->grammar) { result_ctx->grammar_kind = ctx->grammar_kind; result_ctx->grammar_data = ctx->grammar_data; - result_ctx->grammar = llg_clone_constraint(ctx->grammar); + result_ctx->grammar = llg_clone_matcher(ctx->grammar); result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer); } } @@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) { const auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - llg_free_constraint(ctx->grammar); + llg_free_matcher(ctx->grammar); llg_free_tokenizer(ctx->tokenizer); } @@ -239,9 +223,11 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g /* .grammar_data = */ grammar_data, /* .tokenizer = */ tokenizer, /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data), - /* .llg_res = */ {}, - /* .has_llg_res = */ false, }; + if (ctx->grammar) { + GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 == + llg_matcher_get_mask_byte_size(ctx->grammar)); + } } else { *ctx = { /* .vocab = */ vocab, @@ -249,15 +235,12 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g /* .grammar_data = */ {}, /* .tokenizer = */ nullptr, /* .grammar = */ nullptr, - /* .llg_res = */ {}, - /* .has_llg_res = */ false, }; } return llama_sampler_init( /* .iface = */ &llama_sampler_llg_i, - /* .ctx = */ ctx - ); + /* .ctx = */ ctx); } #else diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp index 8b696006be714..3c19220e11964 100644 --- a/tests/test-grammar-llguidance.cpp +++ b/tests/test-grammar-llguidance.cpp @@ -1086,6 +1086,65 @@ static void test_json_schema() { }); } +static void one_hot(llama_token_data_array & tok_arr, llama_token selected) { + auto n_vocab = tok_arr.size; + + tok_arr.selected = -1; + tok_arr.sorted = false; + for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) { + tok_arr.data[token_id].id = token_id; + tok_arr.data[token_id].logit = 0.0f; + } + + tok_arr.data[selected].logit = 100.0f; +} + +static void test_sampler_chain(void) { + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + llama_sampler * sampler = llama_sampler_chain_init(sparams); + + const auto grammar_data = R"(%llguidance {} +start: /[A-Z ]*/)"; + + llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data)); + llama_sampler_chain_add(sampler, llama_sampler_init_dist(42)); + + auto input = "ALL YOUR BASE ARE BELONG TO US"; + auto tokens = common_tokenize(vocab, input, false, false); + + auto n_vocab = llama_vocab_n_tokens(vocab); + + std::vector cur; + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) { + cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f }); + } + auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false }; + + for (const auto token : tokens) { + one_hot(tok_arr, token); + + fprintf(stderr, "applying token: %d\n", token); + llama_sampler_apply(sampler, &tok_arr); + + auto idx = tok_arr.selected; + fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit); + assert(cur[tok_arr.selected].id == token); + llama_sampler_accept(sampler, token); + } + + auto tok_eos = llama_vocab_eot(vocab); + if (tok_eos == LLAMA_TOKEN_NULL) { + tok_eos = llama_vocab_eos(vocab); + } + + one_hot(tok_arr, tok_eos); + + llama_sampler_apply(sampler, &tok_arr); + assert(cur[tok_arr.selected].id == tok_eos); +} + int main(int argc, const char ** argv) { fprintf(stdout, "Running llguidance integration tests...\n"); @@ -1135,6 +1194,9 @@ int main(int argc, const char ** argv) { test_special_chars(); test_quantifiers(); test_json_schema(); + + test_sampler_chain(); + fprintf(stdout, "All tests passed.\n"); return 0; } From b3298fa47a2d56ae892127ea038942ab1cada190 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Mar 2025 21:38:38 +0200 Subject: [PATCH 7/8] metal : refactor mat-vec code (#12569) * metal : refactor mat-vec code ggml-ci * metal : rename all_sum -> sum_all ggml-ci * metal : fix comments [no ci] * metal : fix nr constant [no ci] * metal : mv q6_K support nr0 > 1 ggml-ci * metal : reduce register pressure ggml-ci * metal : fix typo [no ci] * metal : reduce register pressure ggml-ci --- ggml/src/ggml-metal/ggml-metal-impl.h | 64 +++ ggml/src/ggml-metal/ggml-metal.m | 298 +++++------- ggml/src/ggml-metal/ggml-metal.metal | 657 +++++++++++++------------- 3 files changed, 521 insertions(+), 498 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 1e954b4ceabd7..ca5a00b0322e1 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1,6 +1,70 @@ #ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL +// kernel parameters for mat-vec threadgroups +// +// N_R0: number of src0 rows to process per simdgroup +// N_SG: number of simdgroups per threadgroup +// +// TODO: for optimal performance, become function of the device and work size + +#define N_R0_Q4_0 4 +#define N_SG_Q4_0 2 + +#define N_R0_Q4_1 4 +#define N_SG_Q4_1 2 + +#define N_R0_Q5_0 4 +#define N_SG_Q5_0 2 + +#define N_R0_Q5_1 4 +#define N_SG_Q5_1 2 + +#define N_R0_Q8_0 4 +#define N_SG_Q8_0 2 + +#define N_R0_Q2_K 4 +#define N_SG_Q2_K 2 + +#define N_R0_Q3_K 2 +#define N_SG_Q3_K 2 + +#define N_R0_Q4_K 4 +#define N_SG_Q4_K 2 + +#define N_R0_Q5_K 2 +#define N_SG_Q5_K 2 + +#define N_R0_Q6_K 1 +#define N_SG_Q6_K 2 + +#define N_R0_IQ1_S 4 +#define N_SG_IQ1_S 2 + +#define N_R0_IQ1_M 4 +#define N_SG_IQ1_M 2 + +#define N_R0_IQ2_XXS 4 +#define N_SG_IQ2_XXS 2 + +#define N_R0_IQ2_XS 4 +#define N_SG_IQ2_XS 2 + +#define N_R0_IQ2_S 4 +#define N_SG_IQ2_S 2 + +#define N_R0_IQ3_XXS 4 +#define N_SG_IQ3_XXS 2 + +#define N_R0_IQ3_S 4 +#define N_SG_IQ3_S 2 + +#define N_R0_IQ4_NL 2 +#define N_SG_IQ4_NL 2 + +#define N_R0_IQ4_XS 2 +#define N_SG_IQ4_XS 2 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index af65e7d9f53d4..195d9678275c9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - id pipeline = nil; + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; + nr1 = 4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; - nrows = 4; } break; case GGML_TYPE_F16: { - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; - nrows = ne11; + nr1 = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; - nrows = 4; + nr1 = 4; } } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; - nrows = 4; + nr1 = 4; } } break; case GGML_TYPE_BF16: { - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; - nrows = ne11; + nr1 = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; - nrows = 4; + nr1 = 4; } } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; - nrows = 4; + nr1 = 4; } } break; case GGML_TYPE_Q4_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { - nth0 = 4; //1; - nth1 = 8; //32; + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; } break; case GGML_TYPE_IQ3_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; } break; case GGML_TYPE_IQ3_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; } break; case GGML_TYPE_IQ2_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; } break; case GGML_TYPE_IQ1_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; } break; case GGML_TYPE_IQ1_M: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; } break; case GGML_TYPE_IQ4_NL: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; } break; case GGML_TYPE_IQ4_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; default: @@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_MUL_MAT_ID: @@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - id pipeline = nil; + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; } break; case GGML_TYPE_F16: { GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; case GGML_TYPE_BF16: { GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; } break; case GGML_TYPE_Q4_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { - nth0 = 4; //1; - nth1 = 8; //32; + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; } break; case GGML_TYPE_IQ3_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; } break; case GGML_TYPE_IQ3_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; } break; case GGML_TYPE_IQ2_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; } break; case GGML_TYPE_IQ1_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; } break; case GGML_TYPE_IQ1_M: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; } break; case GGML_TYPE_IQ4_NL: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; } break; case GGML_TYPE_IQ4_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; default: @@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node( }; if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); + GGML_ASSERT(ne00 >= nsg*nr0); } ggml_metal_kargs_mul_mv_id args = { @@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; - const int tgz = dst_rows; + const int64_t ne123 = dst_rows; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_GET_ROWS: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3cef81b797197..38f03efbae273 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1439,7 +1439,7 @@ kernel void kernel_rwkv_wkv7_f32( float4 sa_vec(0.0); - for (int j = 0; j < head_size; j += 4) { + for (uint j = 0; j < head_size; j += 4) { float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); sa_vec += a_vec * s_vec; @@ -1853,14 +1853,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// guard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -1876,7 +1869,7 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -1888,15 +1881,15 @@ void mul_vec_q_n_f32_impl( device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr]; - for (int row = 0; row < nr; ++row) { + device const block_q_type * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; + float sumf[nr0] = {0.f}; const short ix = (tiisg/2); const short il = (tiisg%2)*8; @@ -1908,7 +1901,7 @@ void mul_vec_q_n_f32_impl( float sumy[2] = { 0.f, 0.f }; #pragma unroll - for (int i = 0; i < 8; i += 2) { + for (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -1919,7 +1912,7 @@ void mul_vec_q_n_f32_impl( } #pragma unroll - for (int row = 0; row < nr; row++) { + for (short row = 0; row < nr0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } @@ -1928,7 +1921,7 @@ void mul_vec_q_n_f32_impl( device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -1945,7 +1938,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1956,7 +1949,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1967,7 +1960,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1978,12 +1971,12 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -1993,16 +1986,13 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - const int nb = args.ne00/QK8_0; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0*nsg + sgitg)*nr; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -2014,15 +2004,15 @@ void kernel_mul_mv_q8_0_f32_impl( device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr]; - for (int row = 0; row < nr; ++row) { + device const block_q8_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } float yl[NB_Q8_0]; - float sumf[nr] = { 0.f }; + float sumf[nr0] = { 0.f }; const short ix = tiisg/4; const short il = tiisg%4; @@ -2035,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl( yl[i] = yb[i]; } - for (int row = 0; row < nr; row++) { + for (short row = 0; row < nr0; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; for (short iq = 0; iq < NB_Q8_0; ++iq) { @@ -2049,7 +2039,7 @@ void kernel_mul_mv_q8_0_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -2067,7 +2057,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 @@ -2404,9 +2394,9 @@ void kernel_mul_mv_impl( sumf += (T0) x[i] * (T1) y[i]; } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } else { @@ -2427,10 +2417,10 @@ void kernel_mul_mv_impl( sumf += dot((float4) x4[i], (float4) y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } @@ -2492,9 +2482,9 @@ kernel void kernel_mul_mv_1row( for (int i = tiisg; i < args.ne00; i += 32) { sumf += (float) x[i] * (float) y[i]; } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[r0] = all_sum; + dst_f32[r0] = sum_all; } } else { device const T4 * x4 = (device const T4 *) x; @@ -2504,11 +2494,11 @@ kernel void kernel_mul_mv_1row( sumf += dot((float4) x4[i], y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[r0] = sum_all; } } } @@ -2553,9 +2543,9 @@ kernel void kernel_mul_mv_l4( sumf += dot((float4) x4[i], y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } @@ -4321,7 +4311,7 @@ kernel void kernel_cpy_f32_iq4_nl( float amax = 0.0f; // absolute max float max = 0.0f; - for (int j = 0; j < QK4_0; j++) { + for (int j = 0; j < QK4_NL; j++) { const float v = src[j]; if (amax < fabs(v)) { amax = fabs(v); @@ -4429,7 +4419,7 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( args_t args, device const char * src0, @@ -4445,7 +4435,7 @@ void kernel_mul_mv_q2_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4457,20 +4447,19 @@ void kernel_mul_mv_q2_K_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 + const short is = (8*ir)/16;// 0 or 1 device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; @@ -4481,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl( device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { @@ -4512,10 +4501,10 @@ void kernel_mul_mv_q2_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -4530,10 +4519,10 @@ kernel void kernel_mul_mv_q2_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( args_t args, device const char * src0, @@ -4550,7 +4539,7 @@ void kernel_mul_mv_q3_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4566,13 +4555,12 @@ void kernel_mul_mv_q3_K_f32_impl( //const uint16_t kmask1 = 0x3030; //const uint16_t kmask2 = 0x0f0f; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short ip = tid/4; // 0 or 1 + const short il = 2*((tid%4)/2); // 0 or 2 + const short ir = tid%2; + const short l0 = 8*ir; // One would think that the Metal compiler would figure out that ip and il can only have // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it @@ -4597,8 +4585,8 @@ void kernel_mul_mv_q3_K_f32_impl( const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; + const short q_offset = 32*ip + l0; + const short y_offset = 128*ip + 32*il + l0; device const float * y1 = yy + ix*QK_K + y_offset; @@ -4606,10 +4594,11 @@ void kernel_mul_mv_q3_K_f32_impl( thread uint16_t * scales16 = (thread uint16_t *)&scales32; thread const int8_t * scales = (thread const int8_t *)&scales32; - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; + float sumf1[nr0] = {0.f}; + float sumf2[nr0] = {0.f}; + for (int i = ix; i < nb; i += 4) { - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+ 0] = y1[l+ 0]; yl[l+ 8] = y1[l+16]; yl[l+16] = y1[l+32]; @@ -4621,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl( device const uint16_t * a = (device const uint16_t *)(x[i].scales); device const half * dh = &x[i].d; - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { const float d_all = (float)dh[0]; scales16[0] = a[4]; @@ -4632,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl( scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2]; s1 += yl[l+0] * (qs & qm[il/2][0]); s2 += yl[l+1] * (qs & qm[il/2][1]); @@ -4647,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf2[row] += d2 * (scales[2] - 32); s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2+8]; s1 += yl[l+8] * (qs & qm[il/2][0]); s2 += yl[l+9] * (qs & qm[il/2][1]); @@ -4670,7 +4659,7 @@ void kernel_mul_mv_q3_K_f32_impl( y1 += 4 * QK_K; } - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < nr0; ++row) { const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); sumf1[row] = simd_sum(sumf); } @@ -4678,7 +4667,7 @@ void kernel_mul_mv_q3_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { dst_f32[first_row + row] = sumf1[row]; } } @@ -4694,10 +4683,10 @@ kernel void kernel_mul_mv_q3_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( args_t args, device const char * src0, @@ -4707,22 +4696,22 @@ void kernel_mul_mv_q4_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; + + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4735,7 +4724,8 @@ void kernel_mul_mv_q4_K_f32_impl( float yl[16]; float yh[16]; - float sumf[N_DST]={0.f}, all_sum; + + float sumf[nr0]={0.f}; device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; @@ -4744,7 +4734,8 @@ void kernel_mul_mv_q4_K_f32_impl( for (int ib = ix; ib < nb; ib += 4) { float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + + for (short i = 0; i < 8; ++i) { yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; @@ -4755,7 +4746,7 @@ void kernel_mul_mv_q4_K_f32_impl( device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { sc16[0] = sc[0] & kmask1; sc16[1] = sc[2] & kmask1; sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); @@ -4765,19 +4756,21 @@ void kernel_mul_mv_q4_K_f32_impl( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + + for (short i = 0; i < 4; ++i) { + acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); + acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); + acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); + acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000); + acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F); + acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00); + acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0); + acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); } float dall = dh[0]; float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + @@ -4794,10 +4787,10 @@ void kernel_mul_mv_q4_K_f32_impl( device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -4812,10 +4805,10 @@ kernel void kernel_mul_mv_q4_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( args_t args, device const char * src0, @@ -4832,7 +4825,7 @@ void kernel_mul_mv_q5_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4843,7 +4836,7 @@ void kernel_mul_mv_q5_K_f32_impl( device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); - float sumf[2]={0.f}; + float sumf[nr0]={0.f}; float yl[16], yh[16]; @@ -4851,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl( const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short iq = tid/4; + const short ir = tid%4; - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; + const short l0 = 8*ir; + const short q_offset = 32*iq + l0; + const short y_offset = 64*iq + l0; const uint8_t hm1 = 1u << (2*iq); const uint8_t hm2 = hm1 << 1; @@ -4879,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl( device const float * y2 = y1 + 128; float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; } - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { device const uint8_t * q2 = q1 + 64; sc16[0] = a[0] & kmask1; @@ -4896,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl( float4 acc1 = {0.f}; float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { + for (short l = 0; l < 8; ++l) { uint8_t h = qh[l]; acc1[0] += yl[l+0] * (q1[l] & 0x0F); acc1[1] += yl[l+8] * (q1[l] & 0xF0); @@ -4926,7 +4918,7 @@ void kernel_mul_mv_q5_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = tot; @@ -4944,10 +4936,10 @@ kernel void kernel_mul_mv_q5_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( args_t args, device const char * src0, @@ -4969,62 +4961,77 @@ void kernel_mul_mv_q6_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int row = 2*r0 + sgitg; - - if (row >= args.ne0) { - return; - } + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); - float sumf = 0; + float sumf[nr0] = { 0.f }; + + float yl[16]; - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; + const short tid = tiisg/2; + const short ix = tiisg%2; + const short ip = tid/8; // 0 or 1 + const short il = tid%8; + const short l0 = 4*il; + const short is = 8*ip + l0/16; - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; + const short y_offset = 128*ip + l0; + const short q_offset_l = 64*ip + l0; + const short q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; + device const half * dh = &x[i].d; device const float * y = yy + i * QK_K + y_offset; - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + for (short l = 0; l < 4; ++l) { + yl[4*l + 0] = y[l + 0]; + yl[4*l + 1] = y[l + 32]; + yl[4*l + 2] = y[l + 64]; + yl[4*l + 3] = y[l + 96]; } - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + for (short row = 0; row < nr0; ++row) { + const float dall = dh[0]; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (short l = 0; l < 4; ++l) { + sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + q1 += args.nb01; + q2 += args.nb01; + qh += args.nb01; + sc += args.nb01; + dh += args.nb01/2; + } } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[row] = tot; + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } } } @@ -5038,12 +5045,12 @@ kernel void kernel_mul_mv_q6_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, device const char * src0, @@ -5059,7 +5066,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5071,7 +5078,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5092,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5104,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const uint16_t * q2 = xr->qs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; device const uint8_t * aux8 = (device const uint8_t *)q2; const uint32_t aux32 = q2[2] | (q2[3] << 16); const float d = db * (0.5f + (aux32 >> 28)); float sum = 0; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5130,10 +5135,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5148,10 +5153,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, device const char * src0, @@ -5167,7 +5172,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5179,7 +5184,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5200,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5213,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const uint8_t * sc = xr->scales + ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint8_t ls1 = sc[0] & 0xf; const uint8_t ls2 = sc[0] >> 4; @@ -5222,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl( const float d2 = db * (0.5f + ls2); float sum1 = 0, sum2 = 0; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } - for (int l = 2; l < 4; ++l) { + for (short l = 2; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5248,10 +5251,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5267,10 +5270,10 @@ kernel void kernel_mul_mv_iq2_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, device const char * src0, @@ -5286,7 +5289,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5298,7 +5301,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5319,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5331,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } @@ -5358,10 +5361,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.5f; + dst_f32[first_row + row] = sum_all * 0.5f; } } } @@ -5377,10 +5380,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( args_t args, device const char * src0, @@ -5396,7 +5399,7 @@ void kernel_mul_mv_iq3_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5408,7 +5411,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5425,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5440,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl( device const uint8_t * signs = xr->signs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); } @@ -5470,10 +5471,10 @@ void kernel_mul_mv_iq3_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -5489,10 +5490,10 @@ kernel void kernel_mul_mv_iq3_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( args_t args, device const char * src0, @@ -5508,7 +5509,7 @@ void kernel_mul_mv_iq2_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5520,7 +5521,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5532,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl( // threadgroup_barrier(mem_flags::mem_threadgroup); //} - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5552,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl( device const uint8_t * signs = qs + QK_K/8; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d1 = db * (0.5f + (sc[0] & 0xf)); const float d2 = db * (0.5f + (sc[0] >> 4)); float2 sum = {0}; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); } @@ -5583,10 +5582,10 @@ void kernel_mul_mv_iq2_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5602,10 +5601,10 @@ kernel void kernel_mul_mv_iq2_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( args_t args, device const char * src0, @@ -5621,7 +5620,7 @@ void kernel_mul_mv_iq1_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5633,18 +5632,17 @@ void kernel_mul_mv_iq1_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float sumy = 0; - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; sumy += yl[i]; } @@ -5657,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl( device const uint16_t * qh = xr->qh + ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); float sum = 0; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5683,15 +5680,28 @@ void kernel_mul_mv_iq1_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq1_m_f32_impl( args_t args, device const char * src0, @@ -5703,11 +5713,12 @@ void kernel_mul_mv_iq1_m_f32_impl( ushort sgitg) { const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5719,20 +5730,19 @@ void kernel_mul_mv_iq1_m_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; iq1m_scale_t scale; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float4 sumy = {0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; @@ -5747,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const uint8_t * qh = xr->qh + 2 * ib; device const uint16_t * sc = (device const uint16_t *)xr->scales; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); @@ -5756,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl( constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); float2 sum = {0.f}; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5778,15 +5788,28 @@ void kernel_mul_mv_iq1_m_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -5799,10 +5822,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5813,14 +5838,14 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float sumf[nr0]={0.f}; device const float * yb = y + ix * QK4_NL + it * 8; @@ -5830,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl( float4 qf1, qf2; for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; + for (short row = 0; row < nr0; row++) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5860,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl( acc1 += acc2; sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 16 * QK4_NL; @@ -5868,15 +5893,29 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -5892,7 +5931,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5903,16 +5942,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; + const short ix = tiisg/16; // 0 or 1 + const short it = tiisg%16; // 0...15 + const short ib = it/2; + const short il = it%2; shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float sumf[nr0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -5923,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); @@ -5949,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 2 * QK_K; @@ -5957,54 +5998,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -[[host_name("kernel_mul_mv_iq1_s_f32")]] -kernel void kernel_mul_mv_iq1_s_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq1_m_f32")]] -kernel void kernel_mul_mv_iq1_m_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_nl_f32")]] -kernel void kernel_mul_mv_iq4_nl_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); -} - [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( constant ggml_metal_kargs_mul_mv & args, @@ -6016,7 +6017,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6660,25 +6661,27 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; kernel void kernel_pool_2d_max_f32( device const float * src0, From bd40678df768ab189b26b8b03e3de4062e3b71a3 Mon Sep 17 00:00:00 2001 From: Slobodan Josic <127323561+slojosic-amd@users.noreply.github.com> Date: Wed, 26 Mar 2025 23:46:30 +0100 Subject: [PATCH 8/8] HIP: Add support for RDNA4 targets (#12372) --- docs/build.md | 2 +- ggml/src/ggml-cuda/common.cuh | 18 ++++++++++-------- ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++++--- ggml/src/ggml-cuda/mmq.cu | 2 +- ggml/src/ggml-cuda/mmq.cuh | 4 ++-- ggml/src/ggml-cuda/mmvq.cu | 4 ++-- ggml/src/ggml-cuda/vendors/hip.h | 4 ++++ 7 files changed, 25 insertions(+), 17 deletions(-) diff --git a/docs/build.md b/docs/build.md index aa1db9a047130..9c1314a29431b 100644 --- a/docs/build.md +++ b/docs/build.md @@ -191,7 +191,7 @@ The following compilation options are also available to tweak performance: | Option | Legal values | Default | Description | |-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | +| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | | GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 954ff5f16924b..f8c55a2b869fc 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -52,7 +52,7 @@ #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS) // AMD -// GCN/CNDA, wave size is 64 +// GCN/CDNA, wave size is 64 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a @@ -60,16 +60,18 @@ #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 -// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 +// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 #define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) -#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) @@ -209,9 +211,9 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) #define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE @@ -244,14 +246,14 @@ static bool fp16_mma_available(const int cc) { return false; #else return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc); + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc); + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. @@ -409,7 +411,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__) c = __builtin_amdgcn_sdot4(a, b, c, false); -#elif defined(RDNA3) +#elif defined(RDNA3) || defined(RDNA4) c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); #elif defined(RDNA1) || defined(__gfx900__) int tmp1; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6dd5dcb85e15c..3bb472ffbfdf2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1216,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - if (GGML_CUDA_CC_IS_CDNA(cc)) { + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { const float alpha = 1.0f; const float beta = 0.0f; CUBLAS_CHECK( @@ -1759,7 +1759,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } - if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) { + int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { cu_compute_type = CUBLAS_COMPUTE_32F; alpha = &alpha_f32; beta = &beta_f32; @@ -1836,7 +1838,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } #endif - if (dst->op_params[0] == GGML_PREC_DEFAULT) { + if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream); } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 2c19485d51a92..b36b43d5417ba 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index ee01154254e3d..f136c41955b19 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile( template #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) __launch_bounds__(WARP_SIZE*nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #else #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA __launch_bounds__(WARP_SIZE*nwarps, 1) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index a7d518a574ddc..45ea30f62df08 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -54,7 +54,7 @@ enum mmvq_parameter_table_id { }; static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) +#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; @@ -64,7 +64,7 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { } static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { return MMVQ_PARAMETERS_RDNA2; } if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index a4c717a321cfb..3983ce5b423c0 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -151,6 +151,10 @@ #define CDNA #endif +#if defined(__GFX12__) +#define RDNA4 +#endif + #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ defined(__gfx1150__) || defined(__gfx1151__) #define RDNA3