diff --git a/CODEOWNERS b/CODEOWNERS index 908d13a35b9..6ef6c0489f1 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,10 +2,8 @@ # multiplie collaborators per item can be specified /.devops/*.Dockerfile @ngxson -/.github/actions/ @slaren @CISC +/.github/actions/ @CISC /.github/workflows/ @CISC -/.github/workflows/release.yml @slaren -/.github/workflows/winget.yml @slaren /ci/ @ggerganov /cmake/ @ggerganov /common/CMakeLists.txt @ggerganov @@ -40,21 +38,14 @@ /examples/passkey/ @ggerganov /examples/retrieval/ @ggerganov /examples/save-load-state/ @ggerganov -/examples/simple-chat/ @slaren -/examples/simple/ @slaren /examples/speculative-simple/ @ggerganov /examples/speculative/ @ggerganov /ggml/cmake/ @ggerganov -/ggml/include/ @ggerganov @slaren -/ggml/src/ggml-alloc.c @slaren -/ggml/src/ggml-backend* @slaren -/ggml/src/ggml-blas/ @slaren -/ggml/src/ggml-common.h @ggerganov @slaren -/ggml/src/ggml-cpu/ @ggerganov @slaren +/ggml/include/ @ggerganov +/ggml/src/ggml-common.h @ggerganov +/ggml/src/ggml-cpu/ @ggerganov /ggml/src/ggml-cpu/spacemit/ @alex-spacemit -/ggml/src/ggml-cuda/common.cuh @slaren /ggml/src/ggml-cuda/fattn* @JohannesGaessler -/ggml/src/ggml-cuda/ggml-cuda.cu @slaren /ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an /ggml/src/ggml-cuda/mmq.* @JohannesGaessler /ggml/src/ggml-cuda/mmvf.* @JohannesGaessler @@ -62,19 +53,19 @@ /ggml/src/ggml-cuda/fattn-wmma* @IMbackK /ggml/src/ggml-hip/ @IMbackK /ggml/src/ggml-cuda/vendors/hip.h @IMbackK -/ggml/src/ggml-impl.h @ggerganov @slaren +/ggml/src/ggml-impl.h @ggerganov /ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-opencl/ @lhez @max-krasnyansky /ggml/src/ggml-hexagon/ @max-krasnyansky @lhez /ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-rpc/ @rgerganov -/ggml/src/ggml-threading.* @ggerganov @slaren +/ggml/src/ggml-threading.* @ggerganov /ggml/src/ggml-vulkan/ @0cc4m /ggml/src/ggml-webgpu/ @reeselevine /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM -/ggml/src/ggml.c @ggerganov @slaren -/ggml/src/ggml.cpp @ggerganov @slaren +/ggml/src/ggml.c @ggerganov +/ggml/src/ggml.cpp @ggerganov /ggml/src/gguf.cpp @JohannesGaessler @Green-Sky /gguf-py/ @CISC /media/ @ggerganov @@ -86,15 +77,11 @@ /src/llama-arch.* @CISC /src/llama-chat.* @ngxson /src/llama-graph.* @CISC -/src/llama-model-loader.* @slaren /src/llama-model.* @CISC /src/llama-vocab.* @CISC /src/models/ @CISC /tests/ @ggerganov -/tests/test-backend-ops.cpp @slaren -/tests/test-thread-safety.cpp @slaren /tools/batched-bench/ @ggerganov -/tools/llama-bench/ @slaren /tools/main/ @ggerganov /tools/mtmd/ @ngxson /tools/perplexity/ @ggerganov @@ -106,8 +93,6 @@ /tools/tokenize/ @ggerganov /tools/tts/ @ggerganov /vendor/ @ggerganov -/.clang-format @slaren -/.clang-tidy @slaren /AUTHORS @ggerganov /CMakeLists.txt @ggerganov /CONTRIBUTING.md @ggerganov diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe2..dd787290d25 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1232,6 +1232,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { const auto sampler_names = string_split(value, ';'); params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS; } ).set_sparam()); add_opt(common_arg( @@ -1261,6 +1262,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.sampling.temp = std::stof(value); params.sampling.temp = std::max(params.sampling.temp, 0.0f); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP; } ).set_sparam()); add_opt(common_arg( @@ -1268,6 +1270,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), [](common_params & params, int value) { params.sampling.top_k = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K; } ).set_sparam()); add_opt(common_arg( @@ -1275,6 +1278,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), [](common_params & params, const std::string & value) { params.sampling.top_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P; } ).set_sparam()); add_opt(common_arg( @@ -1282,6 +1286,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), [](common_params & params, const std::string & value) { params.sampling.min_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P; } ).set_sparam()); add_opt(common_arg( @@ -1296,6 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), [](common_params & params, const std::string & value) { params.sampling.xtc_probability = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY; } ).set_sparam()); add_opt(common_arg( @@ -1303,6 +1309,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), [](common_params & params, const std::string & value) { params.sampling.xtc_threshold = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD; } ).set_sparam()); add_opt(common_arg( @@ -1321,6 +1328,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } params.sampling.penalty_last_n = value; params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N; } ).set_sparam()); add_opt(common_arg( @@ -1328,6 +1336,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), [](common_params & params, const std::string & value) { params.sampling.penalty_repeat = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT; } ).set_sparam()); add_opt(common_arg( @@ -1425,6 +1434,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), [](common_params & params, int value) { params.sampling.mirostat = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT; } ).set_sparam()); add_opt(common_arg( @@ -1432,6 +1442,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), [](common_params & params, const std::string & value) { params.sampling.mirostat_eta = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA; } ).set_sparam()); add_opt(common_arg( @@ -1439,6 +1450,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), [](common_params & params, const std::string & value) { params.sampling.mirostat_tau = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU; } ).set_sparam()); add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index f3cc55247e7..0d7fd9a9371 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -949,6 +950,58 @@ std::vector fs_list_files(const std::string & path) { // Model utils // +static inline void common_init_sampler_from_model( + const llama_model * model, + common_params_sampling & sparams) { + + const uint64_t config = sparams.user_sampling_config; + + auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[64] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + int32_t v = strtol(buf, &end, 10); + if (end && end != buf) dst = v; + } + }; + + auto get_float = [&](const char * key, float & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[128] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + float v = strtof(buf, &end); + if (end && end != buf) dst = v; + } + }; + + // Sampling sequence + if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) { + char buf[512] = {0}; + if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) { + const std::vector sampler_names = string_split(std::string(buf), ';'); + if (!sampler_names.empty()) { + sparams.samplers = common_sampler_types_from_names(sampler_names, true); + } + } + } + + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); +} + struct common_init_result common_init_from_params(common_params & params) { common_init_result iparams; auto mparams = common_model_params_to_llama(params); @@ -960,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + common_init_sampler_from_model(model, params.sampling); + const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); diff --git a/common/common.h b/common/common.h index de5b404dd88..2f23d0baa83 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,22 @@ struct common_grammar_trigger { llama_token token = LLAMA_TOKEN_NULL; }; +enum common_params_sampling_config : uint64_t { + COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2, + COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5, + COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11, +}; + + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -172,6 +188,8 @@ struct common_params_sampling { bool no_perf = false; // disable performance metrics bool timing_per_token = false; + uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d24a4682f3d..daaf0bf4974 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10061,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase): torch.uint8: np.uint8, } + # only used when byteswapping data. Only correct size is needed + _dtype_byteswap_map: dict[torch.dtype, type] = { + torch.float64: np.float64, + torch.float32: np.float32, + torch.bfloat16: np.float16, + torch.float16: np.float16, + torch.int64: np.int64, + torch.uint64: np.uint64, + torch.int32: np.int32, + torch.uint32: np.uint32, + torch.int16: np.int16, + torch.uint16: np.uint16, + torch.int8: np.int8, + torch.uint8: np.uint8, + torch.bool: np.uint8, + torch.float8_e4m3fn: np.uint8, + torch.float8_e5m2: np.uint8, + } + # used for safetensors slices # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 @@ -10104,8 +10123,14 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor: @classmethod def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor: def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[tensor.dtype] - return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) + numpy_dtype = cls._dtype_byteswap_map[dtype] + return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape) dtype = cls._dtype_str_map[t.dtype] shape = t.shape lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) @@ -10113,10 +10138,16 @@ def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: @classmethod def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[remote_tensor.dtype] + numpy_dtype = cls._dtype_byteswap_map[dtype] shape = remote_tensor.shape meta = cls.meta_with_dtype_and_shape(dtype, shape) - lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape)) return cast(torch.Tensor, lazy) @classmethod diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 605fcfcb9c2..4dbca868bc7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -530,6 +530,7 @@ extern "C" { GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, + GGML_OP_TOP_K, GGML_OP_LEAKY_RELU, GGML_OP_TRI, GGML_OP_FILL, @@ -2258,18 +2259,25 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); - GGML_API struct ggml_tensor * ggml_arange( + // similar to ggml_top_k but implemented as `argsort` + `view` + GGML_API struct ggml_tensor * ggml_argsort_top_k( struct ggml_context * ctx, - float start, - float stop, - float step); + struct ggml_tensor * a, + int k); // top k elements per row + // note: the resulting top k indices are in no particular order GGML_API struct ggml_tensor * ggml_top_k( struct ggml_context * ctx, struct ggml_tensor * a, int k); + GGML_API struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step); + #define GGML_KQ_MASK_PAD 64 // q: [n_embd_k, n_batch, n_head, ne3 ] diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 606c6d1783d..08cf2b118b2 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -3236,3 +3237,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ABORT("Function is not implemented."); } } + +static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // weight + ggml_tensor * src1 = dst->src[1]; // input + GGML_TENSOR_BINARY_OP_LOCALS + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; + + const int64_t i12 = i2; + const int64_t i13 = i3; + acl_tensor_ptr accumulator = + ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst->ne, dst->nb, 2); + + // The outer product needs to be accumulated in this dimension. + for (int64_t i1 = 0; i1 < ne11; i1++) { + acl_tensor_ptr acl_input = ggml_cann_create_tensor( + (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src1->ne, src1->nb, 1); + + acl_tensor_ptr acl_weight = ggml_cann_create_tensor( + (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, 1); + + ggml_cann_pool_alloc output_allocator(ctx.pool()); + void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); + acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst->ne, dst->nb, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); + float alpha_value = 1.0f; + aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); + } + } + } +} + +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + const enum ggml_type type = src0->type; + + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + ggml_cann_out_prod_fp(ctx, dst); + break; + default: + GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD"); + break; + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index a6c2eb1226b..1ebbc769c71 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::functionsrc[0]` and `dst->src[1]`. + * + * @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation + */ +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3c67c48ffa1..8995a5c121f 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_FLASH_ATTN_EXT: ggml_cann_flash_attn_ext(ctx, dst); break; + case GGML_OP_OUT_PROD: + ggml_cann_out_prod(ctx, dst); + break; default: return false; } @@ -2563,6 +2566,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_OUT_PROD: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + } case GGML_OP_CONV_TRANSPOSE_1D: // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255. return (op->src[0]->ne[0] - 1) <= 255; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c7348cc26c1..3247af8bb03 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_TOP_K: + { + ggml_compute_forward_top_k(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: @@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan( cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; + case GGML_OP_TOP_K: + { + cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks; + } break; case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne10 = node->src[1]->ne[0]; // DK diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 41e89d83c2e..d405696539e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort template -struct argsort_cmp { +struct cmp_argsort { const float * data; bool operator()(int32_t a, int32_t b) const { if constexpr (order == GGML_SORT_ORDER_ASC) { @@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32( switch (order) { case GGML_SORT_ORDER_ASC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; case GGML_SORT_ORDER_DESC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; default: @@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort( } } +// ggml_compute_forward_top_k + +struct cmp_top_k { + const float * data; + bool operator()(int32_t a, int32_t b) const { + return data[a] > data[b]; + } +}; + +static void ggml_compute_forward_top_k_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + const int top_k = ne0; + + int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int64_t i = ith; i < nr; i += nth) { + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne00; j++) { + tmp[j] = j; + } + + std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data}); + + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + + std::copy(tmp, tmp + top_k, dst_data); + + // emphasize that the order is not important + if (top_k > 1) { + std::swap(dst_data[0], dst_data[1]); + } + } +} + +void ggml_compute_forward_top_k( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_top_k_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 98a0eec16d9..0fdfee79766 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0eefc0b137b..329500a03e0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l return res; } +// note: reuse the argsort kernel for top_k +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + // note: the top_k kernel is always descending order + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 39ee6e3427e..3976e622b9b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index acf9dfd5fc1..09b1b503118 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 0fae97029f1..342dc4f8c37 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -832,14 +832,19 @@ typedef struct { } ggml_metal_kargs_leaky_relu; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; } ggml_metal_kargs_argsort; typedef struct { @@ -851,6 +856,11 @@ typedef struct { uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; int32_t len; } ggml_metal_kargs_argsort_merge; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e46970dad47..9871e976f23 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argsort(ctx, idx); } break; + case GGML_OP_TOP_K: + { + n_fuse = ggml_metal_op_top_k(ctx, idx); + } break; case GGML_OP_LEAKY_RELU: { n_fuse = ggml_metal_op_leaky_relu(ctx, idx); @@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_argsort args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nth, }; ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); ggml_metal_kargs_argsort_merge args_merge = { - .ne00 = ne00, - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb00 = nb00, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .len = len, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ ne00, + /*.len =*/ len, }; // merges per row @@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + + // bitonic sort requires the number of elements to be power of 2 + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + // blocks per row + const int npr = (ne00 + nth - 1)/nth; + + const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]); + + if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { + std::swap(bid_dst, bid_tmp); + } + + const int top_k = ne0; + + ggml_metal_kargs_argsort args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices + }; + + if (npr > 1) { + args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + + int len = args.top_k; + + while (len < args.ne0) { + ggml_metal_op_concurrency_reset(ctx); + + // merges per row + const int nm = (args.ne0 + 2*len - 1) / (2*len); + + const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge))); + + ggml_metal_kargs_argsort_merge args_merge = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements + /*.len =*/ len, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_merge); + ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); + + std::swap(bid_dst, bid_tmp); + + len <<= 1; + } + + return 1; +} + int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 332e550ee70..b5546146e13 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index f6033ddc97b..70bf6f3d981 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ { res *= 2; } break; + case GGML_OP_TOP_K: + { + res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]); + } break; default: break; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 59e5761704e..73b45c762d9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32( ushort3 ntg[[threads_per_threadgroup]]) { // bitonic sort const int col = tpitg[0]; + const int ib = tgpig[0] / args.ne01; - const int i00 = (tgpig[0]/args.ne01)*ntg.x; - const int i01 = tgpig[0]%args.ne01; - const int i02 = tgpig[1]; - const int i03 = tgpig[2]; + const int i00 = ib*ntg.x; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); @@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32( } } + const int64_t i0 = ib*args.top_k; + // copy the result to dst without the padding - if (i00 + col < args.ne00) { - dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; + if (i0 + col < args.ne0 && col < args.top_k) { + dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03; dst[col] = shmem_i32[col]; } @@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32( const int start = im * (2 * args.len); - const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); - const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); const int total = len0 + len1; device const int32_t * tmp0 = tmp + start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; device const int32_t * tmp1 = tmp0 + args.len; dst += start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; device const float * src0_row = (device const float *)(src0 + args.nb01*i01 @@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32( const int chunk = (total + ntg.x - 1) / ntg.x; const int k0 = tpitg.x * chunk; - const int k1 = min(k0 + chunk, total); + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } if (k0 >= total) { return; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d78c727e53b..6cf15b43bb3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2501,9 +2501,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { if (hsv >= 192) { return 2; + } else if ((hsv | hsk) & 8) { + return 4; } else { return 8; } @@ -2535,9 +2537,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; } else { - return {get_fa_scalar_num_large_rows(hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; } } } @@ -7740,7 +7742,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -7871,7 +7873,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a5846a23937..b99345a2e93 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "TOP_K", "LEAKY_RELU", "TRI", "FILL", @@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "top_k(x)", "leaky_relu(x)", "tri(x)", "fill(x, c)", @@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll( return result; } -// ggml_arange - -struct ggml_tensor * ggml_arange( - struct ggml_context * ctx, - float start, - float stop, - float step) { - GGML_ASSERT(stop > start); - - const int64_t steps = (int64_t) ceilf((stop - start) / step); - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); - - ggml_set_op_params_f32(result, 0, start); - ggml_set_op_params_f32(result, 1, stop); - ggml_set_op_params_f32(result, 2, step); - - result->op = GGML_OP_ARANGE; - - return result; -} - // ggml_timestep_embedding struct ggml_tensor * ggml_timestep_embedding( @@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort( struct ggml_tensor * a, enum ggml_sort_order order) { GGML_ASSERT(a->ne[0] <= INT32_MAX); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -5149,9 +5130,9 @@ struct ggml_tensor * ggml_argsort( return result; } -// ggml_top_k +// ggml_argsort_top_k -struct ggml_tensor * ggml_top_k( +struct ggml_tensor * ggml_argsort_top_k( struct ggml_context * ctx, struct ggml_tensor * a, int k) { @@ -5167,6 +5148,44 @@ struct ggml_tensor * ggml_top_k( return result; } +// ggml_top_k + +struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k) { + GGML_ASSERT(a->ne[0] >= k); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]); + + result->op = GGML_OP_TOP_K; + result->src[0] = a; + + return result; +} + +// ggml_arange + +struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step) { + GGML_ASSERT(stop > start); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); + + ggml_set_op_params_f32(result, 0, start); + ggml_set_op_params_f32(result, 1, stop); + ggml_set_op_params_f32(result, 2, step); + + result->op = GGML_OP_ARANGE; + + return result; +} + // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8bc558fe4b5..6f5a742e04a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -25,6 +25,20 @@ class General: ALIGNMENT = "general.alignment" FILE_TYPE = "general.file_type" + # Recommended Sampler Parameters + SAMPLING_SEQUENCE = "general.sampling.sequence" + SAMPLING_TOP_K = "general.sampling.top_k" + SAMPLING_TOP_P = "general.sampling.top_p" + SAMPLING_MIN_P = "general.sampling.min_p" + SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability" + SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold" + SAMPLING_TEMP = "general.sampling.temp" + SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n" + SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat" + SAMPLING_MIROSTAT = "general.sampling.mirostat" + SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau" + SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta" + # Authorship Metadata NAME = "general.name" AUTHOR = "general.author" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a051daeeb13..57ca2035fe2 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -4,6 +4,7 @@ import os import shutil import struct +import sys import tempfile from dataclasses import dataclass from enum import Enum, auto @@ -372,8 +373,10 @@ def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ + (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) if self.use_temp_file and self.temp_file is None: fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp.seek(0) @@ -399,8 +402,10 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ + (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) file_id = -1 for i, tensors in enumerate(self.tensors): @@ -496,6 +501,42 @@ def add_custom_alignment(self, alignment: int) -> None: def add_file_type(self, ftype: int) -> None: self.add_uint32(Keys.General.FILE_TYPE, ftype) + def add_sampling_sequence(self, sequence: str) -> None: + self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence) + + def add_sampling_top_k(self, top_k: int) -> None: + self.add_int32(Keys.General.SAMPLING_TOP_K, top_k) + + def add_sampling_top_p(self, top_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_TOP_P, top_p) + + def add_sampling_min_p(self, min_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIN_P, min_p) + + def add_sampling_xtc_probability(self, xtc_probability: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability) + + def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold) + + def add_sampling_temp(self, temp: float) -> None: + self.add_float32(Keys.General.SAMPLING_TEMP, temp) + + def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None: + self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n) + + def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None: + self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat) + + def add_sampling_mirostat(self, mirostat: int) -> None: + self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat) + + def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau) + + def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta) + def add_name(self, name: str) -> None: self.add_string(Keys.General.NAME, name) diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index 67efedbdbc5..e0d478ce95d 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -17,6 +17,20 @@ @dataclass class Metadata: + # Recommended Sampler Parameters to be written to GGUF KV Store + sampling_sequence: Optional[str] = None + sampling_top_k: Optional[int] = None + sampling_top_p: Optional[float] = None + sampling_min_p: Optional[float] = None + sampling_xtc_probability: Optional[float] = None + sampling_xtc_threshold: Optional[float] = None + sampling_temp: Optional[float] = None + sampling_penalty_last_n: Optional[int] = None + sampling_penalty_repeat: Optional[float] = None + sampling_mirostat: Optional[int] = None + sampling_mirostat_tau: Optional[float] = None + sampling_mirostat_eta: Optional[float] = None + # Authorship Metadata to be written to GGUF KV Store name: Optional[str] = None author: Optional[str] = None @@ -54,15 +68,43 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat model_card = Metadata.load_model_card(model_path) hf_params = Metadata.load_hf_parameters(model_path) + gen_config = Metadata.load_generation_config(model_path) # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + if gen_config: + metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence) + metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k) + metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p) + metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p) + metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold) + metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp) + metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta) + # Metadata Override File Provided # This is based on LLM_KV_NAMES mapping in llama.cpp metadata_override = Metadata.load_metadata_override(metadata_override_path) + metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence) + metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k) + metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p) + metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p) + metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold) + metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp) + metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta) + metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) @@ -172,6 +214,23 @@ def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: with open(config_path, "r", encoding="utf-8") as f: return json.load(f) + @staticmethod + def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + generation_config_path = model_path / "generation_config.json" + + if not generation_config_path.is_file(): + return {} + + try: + with open(generation_config_path, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + # not all models have valid generation_config.json + return {} + @staticmethod def id_to_title(string): # Convert capitalization into title form unless acronym or version number @@ -546,6 +605,32 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): assert self.name is not None + + if self.sampling_sequence is not None: + gguf_writer.add_sampling_sequence(self.sampling_sequence) + if self.sampling_top_k is not None: + gguf_writer.add_sampling_top_k(self.sampling_top_k) + if self.sampling_top_p is not None: + gguf_writer.add_sampling_top_p(self.sampling_top_p) + if self.sampling_min_p is not None: + gguf_writer.add_sampling_min_p(self.sampling_min_p) + if self.sampling_xtc_probability is not None: + gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability) + if self.sampling_xtc_threshold is not None: + gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold) + if self.sampling_temp is not None: + gguf_writer.add_sampling_temp(self.sampling_temp) + if self.sampling_penalty_last_n is not None: + gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n) + if self.sampling_penalty_repeat is not None: + gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat) + if self.sampling_mirostat is not None: + gguf_writer.add_sampling_mirostat(self.sampling_mirostat) + if self.sampling_mirostat_tau is not None: + gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau) + if self.sampling_mirostat_eta is not None: + gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta) + gguf_writer.add_name(self.name) if self.author is not None: diff --git a/include/llama.h b/include/llama.h index 8547226ff21..b52eaacfa7e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -246,6 +246,21 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_STR, }; + enum llama_model_meta_key { + LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_K, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_P, + LLAMA_MODEL_META_KEY_SAMPLING_MIN_P, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD, + LLAMA_MODEL_META_KEY_SAMPLING_TEMP, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA, + }; + struct llama_model_kv_override { enum llama_model_kv_override_type tag; @@ -518,6 +533,9 @@ extern "C" { // Get the number of metadata key/value pairs LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); + // Get sampling metadata key name. Returns nullptr if the key is invalid + LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key); + // Get metadata key name by index LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fc6cddc92f5..7ef87acf1b3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -114,19 +114,31 @@ static const std::map LLM_ARCH_NAMES = { }; static const std::map LLM_KV_NAMES = { - { LLM_KV_GENERAL_TYPE, "general.type" }, - { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, - { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, - { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, - { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, - { LLM_KV_GENERAL_NAME, "general.name" }, - { LLM_KV_GENERAL_AUTHOR, "general.author" }, - { LLM_KV_GENERAL_VERSION, "general.version" }, - { LLM_KV_GENERAL_URL, "general.url" }, - { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, - { LLM_KV_GENERAL_LICENSE, "general.license" }, - { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, - { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + { LLM_KV_GENERAL_TYPE, "general.type" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, + { LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" }, + { LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" }, + { LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" }, + { LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" }, + { LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" }, + { LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" }, + { LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_VERSION, "general.version" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 02a1c2dc258..9ad3157bf67 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -123,6 +123,18 @@ enum llm_kv { LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_FILE_TYPE, + LLM_KV_GENERAL_SAMPLING_SEQUENCE, + LLM_KV_GENERAL_SAMPLING_TOP_K, + LLM_KV_GENERAL_SAMPLING_TOP_P, + LLM_KV_GENERAL_SAMPLING_MIN_P, + LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, + LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, + LLM_KV_GENERAL_SAMPLING_TEMP, + LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, + LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_VERSION, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 650e40ec6ff..1d012e09aba 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // organize experts into n_expert_groups ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] - ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] + ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens] // get top n_group_used expert groups group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens] group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens] - ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] + ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] cb(expert_groups, "ffn_moe_group_topk", il); // mask out the other groups @@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 35179a98e0c..a042ea9632c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7687,6 +7687,24 @@ int32_t llama_model_meta_count(const llama_model * model) { return (int)model->gguf_kv.size(); } +const char * llama_model_meta_key_str(llama_model_meta_key key) { + switch (key) { + case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE: return "general.sampling.sequence"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K: return "general.sampling.top_k"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P: return "general.sampling.top_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P: return "general.sampling.min_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD: return "general.sampling.xtc_threshold"; + case LLAMA_MODEL_META_KEY_SAMPLING_TEMP: return "general.sampling.temp"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N: return "general.sampling.penalty_last_n"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT: return "general.sampling.penalty_repeat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT: return "general.sampling.mirostat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU: return "general.sampling.mirostat_tau"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA: return "general.sampling.mirostat_eta"; + default: return nullptr; + } +} + int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) { if (i < 0 || i >= (int)model->gguf_kv.size()) { if (buf_size > 0) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ce8c068d7aa..2fc7d4c3c72 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -39,6 +39,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) { return mse_a_b / mse_a_0; } +// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap) +static double jdst(const int32_t * a, const int32_t * b, size_t n) { + std::unordered_map set_a; + std::unordered_map set_b; + + for (size_t i = 0; i < n; ++i) { + set_a[a[i]]++; + set_b[b[i]]++; + } + + size_t diff = 0; + + for (const auto & p : set_a) { + const int64_t na = p.second; + const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0; + + diff += std::abs(na - nb); + } + + for (const auto & p : set_b) { + if (set_a.find(p.first) == set_a.end()) { + diff += p.second; + } + } + + return (double) diff / (2*n); +} + // maximum absolute asymmetry between a and b // asymmetry: (a - b) / (a + b) // This is more stable than relative error if one of the values fluctuates towards zero. @@ -1051,6 +1080,14 @@ struct test_case { return 1e-4; } + virtual double max_err() { + return max_nmse_err(); + } + + virtual double err(const float * a, const float * b, size_t n) { + return nmse(a, b, n); + } + virtual float grad_eps() { return 1e-1f; } @@ -1257,16 +1294,16 @@ struct test_case { // compare struct callback_userdata { bool ok; - double max_err; + test_case * tc; ggml_backend_t backend1; ggml_backend_t backend2; }; callback_userdata ud { true, - max_nmse_err(), + this, backend1, - backend2 + backend2, }; auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { @@ -1314,9 +1351,9 @@ struct test_case { } } - double err = nmse(f1.data(), f2.data(), f1.size()); - if (err > ud->max_err) { - printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + double err = ud->tc->err(f1.data(), f2.data(), f1.size()); + if (err > ud->tc->max_err()) { + printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err()); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -4943,7 +4980,71 @@ struct test_argsort : public test_case { } }; -struct test_topk_moe: public test_case { +// GGML_OP_TOP_K +struct test_top_k : public test_case { + const ggml_type type; + const std::array ne; + const int k; + + std::string vars() override { + return VARS_TO_STR3(type, ne, k); + } + + test_top_k(ggml_type type = GGML_TYPE_F32, + std::array ne = {16, 10, 10, 10}, + int k = 4) + : type(type), ne(ne), k(k) {} + + double max_err() override { + return 0.0; + } + + double err(const float * a, const float * b, size_t n) override { + std::vector ia(n); + std::vector ib(n); + + double diff = 0.0f; + + for (size_t i = 0; i < n; i++) { + ia[i] = (int32_t) a[i]; + ib[i] = (int32_t) b[i]; + + // penalize the result if the data is not integer valued + diff += std::fabs(a[i] - ia[i]); + diff += std::fabs(b[i] - ib[i]); + } + + return diff + jdst(ia.data(), ib.data(), n); + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_top_k(ctx, a, k); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // initialize with unique values to avoid ties + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } + } +}; + +struct test_topk_moe : public test_case { const std::array ne; const int n_expert_used; const bool with_norm; @@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case { ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] @@ -7534,6 +7635,23 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } + for (int k : {1, 2, 3, 7, 15}) { + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k)); + } + + // exhaustive top_k tests + //for (int i = 1; i < 9999; ++i) { + // test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1)); + //} + for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); @@ -7859,6 +7977,9 @@ static std::vector> make_test_cases_perf() { } } + // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 + test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { @@ -7911,6 +8032,7 @@ static std::vector> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40)); return test_cases; } diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 484956100bc..ae25b6ddf72 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index 7e83d30f132..176a98b212f 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -8,6 +8,7 @@ import rehypeKatex from 'rehype-katex'; import rehypeStringify from 'rehype-stringify'; import { copyCodeToClipboard } from '$lib/utils/copy'; + import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; import { preprocessLaTeX } from '$lib/utils/latex-protection'; import { browser } from '$app/environment'; import '$styles/katex-custom.scss'; @@ -60,6 +61,7 @@ .use(remarkRehype) // Convert Markdown AST to rehype .use(rehypeKatex) // Render math using KaTeX .use(rehypeHighlight) // Add syntax highlighting + .use(rehypeRestoreTableHtml) // Restore limited HTML (e.g.,
,
    ) inside Markdown tables .use(rehypeStringify); // Convert to HTML string }); diff --git a/tools/server/webui/src/lib/constants/table-html-restorer.ts b/tools/server/webui/src/lib/constants/table-html-restorer.ts new file mode 100644 index 00000000000..e5d5b120114 --- /dev/null +++ b/tools/server/webui/src/lib/constants/table-html-restorer.ts @@ -0,0 +1,20 @@ +/** + * Matches
    ,
    ,
    tags (case-insensitive). + * Used to detect line breaks in table cell text content. + */ +export const BR_PATTERN = //gi; + +/** + * Matches a complete
      ...
    block. + * Captures the inner content (group 1) for further
  • extraction. + * Case-insensitive, allows multiline content. + */ +export const LIST_PATTERN = /^
      ([\s\S]*)<\/ul>$/i; + +/** + * Matches individual
    • ...
    • elements within a list. + * Captures the inner content (group 1) of each list item. + * Non-greedy to handle multiple consecutive items. + * Case-insensitive, allows multiline content. + */ +export const LI_PATTERN = /
    • ([\s\S]*?)<\/li>/gi; diff --git a/tools/server/webui/src/lib/markdown/table-html-restorer.ts b/tools/server/webui/src/lib/markdown/table-html-restorer.ts new file mode 100644 index 00000000000..918aa468116 --- /dev/null +++ b/tools/server/webui/src/lib/markdown/table-html-restorer.ts @@ -0,0 +1,181 @@ +/** + * Rehype plugin to restore limited HTML elements inside Markdown table cells. + * + * ## Problem + * The remark/rehype pipeline neutralizes inline HTML as literal text + * (remarkLiteralHtml) so that XML/HTML snippets in LLM responses display + * as-is instead of being rendered. This causes
      and
        markup in + * table cells to show as plain text. + * + * ## Solution + * This plugin traverses the HAST post-conversion, parses whitelisted HTML + * patterns from text nodes, and replaces them with actual HAST element nodes + * that will be rendered as real HTML. + * + * ## Supported HTML + * - `
        ` / `
        ` / `
        ` - Line breaks (inline) + * - `
        • ...
        ` - Unordered lists (block) + * + * ## Key Implementation Details + * + * ### 1. Sibling Combination (Critical) + * The Markdown pipeline may fragment content across multiple text nodes and `
        ` + * elements. For example, `
        • a
        ` might arrive as: + * - Text: `"
          "` + * - Element: `
          ` + * - Text: `"
        • a
        "` + * + * We must combine consecutive text nodes and `
        ` elements into a single string + * before attempting to parse list markup. Without this, list detection fails. + * + * ### 2. visitParents for Deep Traversal + * Table cell content may be wrapped in intermediate elements (e.g., `

        ` tags). + * Using `visitParents` instead of direct child iteration ensures we find text + * nodes at any depth within the cell. + * + * ### 3. Reference Comparison for No-Op Detection + * When checking if `
        ` expansion changed anything, we compare: + * `expanded.length !== 1 || expanded[0] !== textNode` + * + * This catches both cases: + * - Multiple nodes created (text was split) + * - Single NEW node created (original had only `
        `, now it's an element) + * + * A simple `length > 1` check would miss the single `
        ` case. + * + * ### 4. Strict List Validation + * `parseList()` rejects malformed markup by checking for garbage text between + * `

      • ` elements. This prevents creating broken DOM from partial matches like + * `
          garbage
        • a
        `. + * + * ### 5. Newline Substitution for `
        ` in Combined String + * When combining siblings, existing `
        ` elements become `\n` in the combined + * string. This allows list content to span visual lines while still being parsed + * as a single unit. + * + * @example + * // Input Markdown: + * // | Feature | Notes | + * // |---------|-------| + * // | Multi-line | First
        Second | + * // | List |
        • A
        • B
        | + * // + * // Without this plugin:
        and
          render as literal text + * // With this plugin:
          becomes line break,
            becomes actual list + */ + +import type { Plugin } from 'unified'; +import type { Element, ElementContent, Root, Text } from 'hast'; +import { visit } from 'unist-util-visit'; +import { visitParents } from 'unist-util-visit-parents'; +import { BR_PATTERN, LIST_PATTERN, LI_PATTERN } from '$lib/constants/table-html-restorer'; + +/** + * Expands text containing `
            ` tags into an array of text nodes and br elements. + */ +function expandBrTags(value: string): ElementContent[] { + const matches = [...value.matchAll(BR_PATTERN)]; + if (!matches.length) return [{ type: 'text', value } as Text]; + + const result: ElementContent[] = []; + let cursor = 0; + + for (const m of matches) { + if (m.index! > cursor) { + result.push({ type: 'text', value: value.slice(cursor, m.index) } as Text); + } + result.push({ type: 'element', tagName: 'br', properties: {}, children: [] } as Element); + cursor = m.index! + m[0].length; + } + + if (cursor < value.length) { + result.push({ type: 'text', value: value.slice(cursor) } as Text); + } + + return result; +} + +/** + * Parses a `
            • ...
            ` string into a HAST element. + * Returns null if the markup is malformed or contains unexpected content. + */ +function parseList(value: string): Element | null { + const match = value.trim().match(LIST_PATTERN); + if (!match) return null; + + const body = match[1]; + const items: ElementContent[] = []; + let cursor = 0; + + for (const liMatch of body.matchAll(LI_PATTERN)) { + // Reject if there's non-whitespace between list items + if (body.slice(cursor, liMatch.index!).trim()) return null; + + items.push({ + type: 'element', + tagName: 'li', + properties: {}, + children: expandBrTags(liMatch[1] ?? '') + } as Element); + + cursor = liMatch.index! + liMatch[0].length; + } + + // Reject if no items found or trailing garbage exists + if (!items.length || body.slice(cursor).trim()) return null; + + return { type: 'element', tagName: 'ul', properties: {}, children: items } as Element; +} + +/** + * Processes a single table cell, restoring HTML elements from text content. + */ +function processCell(cell: Element) { + visitParents(cell, 'text', (textNode: Text, ancestors) => { + const parent = ancestors[ancestors.length - 1]; + if (!parent || parent.type !== 'element') return; + + const parentEl = parent as Element; + const siblings = parentEl.children as ElementContent[]; + const startIndex = siblings.indexOf(textNode as ElementContent); + if (startIndex === -1) return; + + // Combine consecutive text nodes and
            elements into one string + let combined = ''; + let endIndex = startIndex; + + for (let i = startIndex; i < siblings.length; i++) { + const sib = siblings[i]; + if (sib.type === 'text') { + combined += (sib as Text).value; + endIndex = i; + } else if (sib.type === 'element' && (sib as Element).tagName === 'br') { + combined += '\n'; + endIndex = i; + } else { + break; + } + } + + // Try parsing as list first (replaces entire combined range) + const list = parseList(combined); + if (list) { + siblings.splice(startIndex, endIndex - startIndex + 1, list); + return; + } + + // Otherwise, just expand
            tags in this text node + const expanded = expandBrTags(textNode.value); + if (expanded.length !== 1 || expanded[0] !== textNode) { + siblings.splice(startIndex, 1, ...expanded); + } + }); +} + +export const rehypeRestoreTableHtml: Plugin<[], Root> = () => (tree) => { + visit(tree, 'element', (node: Element) => { + if (node.tagName === 'td' || node.tagName === 'th') { + processCell(node); + } + }); +};