diff --git a/common/arg.cpp b/common/arg.cpp index a25743c899862..a465eb36234e7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", - "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix", + "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", [](common_params & params, const std::string & value) { params.embd_out = value; } diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index dd9b51a9e50fd..478aa1be7b5b8 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -601,7 +601,10 @@ class SchemaConverter { } std::string _resolve_ref(const std::string & ref) { - std::string ref_name = ref.substr(ref.find_last_of('/') + 1); + auto it = ref.find('#'); + std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref; + static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)"); + std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-"); if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { _refs_being_resolved.insert(ref); json resolved = _refs[ref]; @@ -774,11 +777,24 @@ class SchemaConverter { std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; - if (target.is_null() || !target.contains(sel)) { + if (target.is_object() && target.contains(sel)) { + target = target[sel]; + } else if (target.is_array()) { + size_t sel_index; + try { + sel_index = std::stoul(sel); + } catch (const std::invalid_argument & e) { + sel_index = target.size(); + } + if (sel_index >= target.size()) { + _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); + return; + } + target = target[sel_index]; + } else { _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); return; } - target = target[sel]; } _refs[ref] = target; } diff --git a/examples/embedding/README.md b/examples/embedding/README.md index 3dd279d9fc41a..1684f36480d82 100644 --- a/examples/embedding/README.md +++ b/examples/embedding/README.md @@ -38,6 +38,7 @@ The above command will output space-separated float values. | | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$ | 'json' | openai style | | 'json+' | add cosine similarity matrix | +| 'raw' | plain text output | ### --embd-separator $"string"$ | $"string"$ | | diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 388908bc4d70a..9e3ab5905bb37 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } } +// plain, pipe-friendly output: one embedding per line +static void print_raw_embeddings(const float * emb, + int n_embd_count, + int n_embd, + const llama_model * model, + enum llama_pooling_type pooling_type, + int embd_normalize) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK); + const int cols = is_rank ? std::min(n_embd, (int) n_cls_out) : n_embd; + + for (int j = 0; j < n_embd_count; ++j) { + for (int i = 0; i < cols; ++i) { + if (embd_normalize == 0) { + LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } else { + LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } + } + LOG("\n"); + } +} + int main(int argc, char ** argv) { common_params params; @@ -372,6 +395,8 @@ int main(int argc, char ** argv) { } if (notArray) LOG("\n}\n"); + } else if (params.embd_out == "raw") { + print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize); } LOG("\n"); diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 2d57549046b88..26989157fe6b6 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -371,8 +371,17 @@ def visit(n: dict): raise ValueError(f'Unsupported ref {ref}') for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] + assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}' + if isinstance(target, list): + try: + sel_index = int(sel) + except ValueError: + raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}') + assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel_index] + else: + assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] self._refs[ref] = target else: @@ -547,7 +556,8 @@ def join_seq(): def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] + ref_fragment = ref.split('#')[-1] + ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment) if ref_name not in self._rules and ref not in self._refs_being_resolved: self._refs_being_resolved.add(ref) resolved = self._refs[ref] diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index f030ea0136a95..5df6dc96a3b2e 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, ACL_MEM_MALLOC_HUGE_FIRST)); acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + theta_scale_ne, theta_scale_nb, 1); float start = 0; float step = 1; @@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); void * yarn_ramp_buffer = yarn_ramp_allocator.get(); acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, - theta_scale_nb, GGML_MAX_DIMS); + theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8bd5449f1f75f..51345742ee59e 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -67,19 +67,30 @@ GGML_ABORT("CANN error"); } +// Thread-local variable to record the current device of this thread. +thread_local int g_current_cann_device = -1; + /** - * @brief Sets the device to be used by CANN. + * @brief Set the CANN device to be used. * - * @param device The device ID to set. + * @param device The target device ID to set. */ void ggml_cann_set_device(const int32_t device) { - int current_device = -1; - aclrtGetDevice(¤t_device); + // int current_device = -1; + // Note: In some CANN versions, if no device has been set yet, + // aclrtGetDevice(¤t_device) may return 0 by default. + // aclrtGetDevice(¤t_device); - if (device == current_device) { + // If the current device is already the target one, no need to switch. + if (device == g_current_cann_device) { return; } + + // Switch to the new device. ACL_CHECK(aclrtSetDevice(device)); + + // Update the global device record. + g_current_cann_device = device; } /** diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 94ab1ec0f5a90..be505748af5a4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -50,6 +50,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml.h" @@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SET_ROWS: ggml_cuda_op_set_rows(ctx, dst); break; + case GGML_OP_SET: + ggml_cuda_op_set(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -3842,6 +3846,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; + case GGML_OP_SET: + { + const ggml_type t = op->type; + return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) && + t == op->src[0]->type && + t == op->src[1]->type; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index c2c31cdaf231b..4e31783436d80 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -343,6 +343,10 @@ static __global__ void mul_mat_vec_f( } dst[tid*stride_col_dst + row] = value; + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate); + } } template diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 7a783e4fcf9b4..be04a85cc5515 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -310,6 +310,10 @@ static __global__ void mul_mat_vec_q( dst[j*stride_col_dst + threadIdx.x] = result; } } + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate); + } } static std::pair calc_launch_params( diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu new file mode 100644 index 0000000000000..04bfe07ba0336 --- /dev/null +++ b/ggml/src/ggml-cuda/set.cu @@ -0,0 +1,39 @@ +#include "set.cuh" +#include "cpy.cuh" + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)); + GGML_ASSERT(src1->type == src0->type); + GGML_ASSERT(dst ->type == src0->type); + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t nb1 = ((int32_t *) dst->op_params)[0]; + const size_t nb2 = ((int32_t *) dst->op_params)[1]; + const size_t nb3 = ((int32_t *) dst->op_params)[2]; + const size_t offset = ((int32_t *) dst->op_params)[3]; + const bool inplace= (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + ggml_cuda_cpy(ctx, src0, dst); + } + + ggml_tensor dst_view = *dst; + dst_view.data = (void *)((char *)dst->data + offset); + dst_view.ne[0] = src1->ne[0]; + dst_view.ne[1] = src1->ne[1]; + dst_view.ne[2] = src1->ne[2]; + dst_view.ne[3] = src1->ne[3]; + + dst_view.nb[0] = ggml_element_size(dst); + dst_view.nb[1] = nb1; + dst_view.nb[2] = nb2; + dst_view.nb[3] = nb3; + + ggml_cuda_cpy(ctx, src1, &dst_view); +} diff --git a/ggml/src/ggml-cuda/set.cuh b/ggml/src/ggml-cuda/set.cuh new file mode 100644 index 0000000000000..dd09529f3e42b --- /dev/null +++ b/ggml/src/ggml-cuda/set.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_BLOCK_SIZE 256 + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ecfc1c856cb59..5e3dc0a3d0cc1 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -211,7 +211,7 @@ static inline void hex_format_op_names(char * str, const struct ggml_tensor * t) // ** backend sessions struct ggml_hexagon_session { - ggml_hexagon_session(int dev_id) noexcept(false); + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); ~ggml_hexagon_session() noexcept(true); void allocate(int dev_id) noexcept(false); @@ -1631,10 +1631,13 @@ void ggml_hexagon_session::release() noexcept(true) { } } -ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) { +ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) { buffer_type.context = nullptr; repack_buffer_type.context = nullptr; + buffer_type.device = dev; + repack_buffer_type.device = dev; + try { allocate(dev_id); @@ -3628,7 +3631,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { devices[i].iface = ggml_backend_hexagon_device_i; devices[i].reg = reg; try { - devices[i].context = new ggml_hexagon_session(i); + devices[i].context = new ggml_hexagon_session(i, &devices[i]); } catch (std::exception const &exc) { GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i); devices[i].context = nullptr; diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index ca53f3e90068c..75657f3fca2e7 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -35,6 +35,7 @@ #include "roll.hpp" #include "rope.hpp" #include "set_rows.hpp" +#include "ssm_conv.hpp" #include "softmax.hpp" #include "tsembd.hpp" #include "wkv.hpp" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 62d0ecd94ee0a..328d1a71b7580 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -50,6 +50,7 @@ #include "ggml-sycl/getrows.hpp" #include "ggml-sycl/repeat_back.hpp" #include "ggml-sycl/quantize.hpp" +#include "ggml-sycl/ssm_conv.hpp" #include "ggml.h" static bool g_sycl_loaded = false; @@ -3921,6 +3922,8 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_SSM_CONV: + ggml_sycl_ssm_conv(ctx, dst); case GGML_OP_ROLL: ggml_sycl_roll(ctx, dst); break; @@ -4602,6 +4605,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: return true; + case GGML_OP_SSM_CONV: + return op->type == GGML_TYPE_F32 && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32; case GGML_OP_ROLL: return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: diff --git a/ggml/src/ggml-sycl/ssm_conv.cpp b/ggml/src/ggml-sycl/ssm_conv.cpp new file mode 100644 index 0000000000000..0dc0f71c9a157 --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_conv.cpp @@ -0,0 +1,127 @@ +#include "ssm_conv.hpp" +#include "common.hpp" + +#include + +using namespace sycl; + +static void kernel_ssm_conv( + queue &q, + const float *src_data, + const float *weights, + float *dst_data, + int d_conv, + int d_inner, + int n_t, + int n_s, + int ncs __attribute__((unused)), + int src_stride_inner, + int src_stride_seq, + int dst_stride_token, + int dst_stride_seq +) { + const size_t total_work = static_cast(d_inner) * static_cast(n_t) * static_cast(n_s); + const size_t work_group_size = 256; + const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size; + + const range<1> global_range(num_work_groups * work_group_size); + const range<1> local_range(work_group_size); + + q.submit([&](handler &h) { + h.parallel_for( + nd_range<1>(global_range, local_range), + [=](nd_item<1> item) { + const size_t idx = item.get_global_id(0); + if (idx >= total_work) { + return; + } + + const int channel = static_cast(idx % d_inner); + const int token = static_cast((idx / d_inner) % n_t); + const int seq = static_cast(idx / (static_cast(d_inner) * static_cast(n_t))); + + const float *s = src_data + + static_cast(seq) * static_cast(src_stride_seq) + + static_cast(channel) * static_cast(src_stride_inner) + + static_cast(token); + + const float *c = weights + static_cast(channel) * static_cast(d_conv); + + float sumf = 0.0f; + for (int i0 = 0; i0 < d_conv; ++i0) { + sumf += s[i0] * c[i0]; + } + + const size_t dst_idx = + static_cast(seq) * static_cast(dst_stride_seq) + + static_cast(token) * static_cast(dst_stride_token) + + static_cast(channel); + + dst_data[dst_idx] = sumf; + } + ); + }); +} + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int d_conv = src1->ne[0]; + const int ncs = src0->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; + + GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(src0->ne[1] == d_inner); + GGML_ASSERT(src1->ne[1] == d_inner); + + GGML_ASSERT(dst->ne[0] == d_inner); + GGML_ASSERT(dst->ne[1] == n_t); + GGML_ASSERT(dst->ne[2] == n_s); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast(sizeof(float))); + + const int src_stride_inner = ncs; + const int src_stride_seq = ncs * d_inner; + const int dst_stride_token = d_inner; + const int dst_stride_seq = d_inner * n_t; + + try { + queue *q = ctx.stream(); + + const float *src_data = static_cast(src0->data); + const float *weights = static_cast(src1->data); + float *dst_data = static_cast(dst->data); + + GGML_ASSERT(src_data && weights && dst_data); + + kernel_ssm_conv( + *q, + src_data, + weights, + dst_data, + d_conv, + d_inner, + n_t, + n_s, + ncs, + src_stride_inner, + src_stride_seq, + dst_stride_token, + dst_stride_seq + ); + + } catch (const std::exception &e) { + std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what()); + throw; + } +} diff --git a/ggml/src/ggml-sycl/ssm_conv.hpp b/ggml/src/ggml-sycl/ssm_conv.hpp new file mode 100644 index 0000000000000..1a8ad05f0c7f0 --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_conv.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 736693e174527..6d5dd6051e782 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache( const uint32_t n_layer_kv = hparams.n_layer_kv(); + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + // create a context for each buffer type - std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache( return nullptr; } - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); @@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache( } // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for kv cache"); } @@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache( LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); ggml_backend_buffer_clear(buf, 0); - bufs.emplace_back(buf); + ctxs_bufs.emplace_back(std::move(ctx), buf); } { @@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) { } if (data) { - for (auto & buf : bufs) { + for (auto & [_, buf] : ctxs_bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } @@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { std::map llama_kv_cache::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -957,10 +961,14 @@ bool llama_kv_cache::get_has_shift() const { uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; + // pad the n_kv value so that the graph remains constant across batches and can be reused + // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220) + const uint32_t n_pad_cur = std::max(n_pad, 256u); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { const auto & cells = v_cells[sinfo.strm[s]]; - result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); + result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result); } return result; @@ -1298,7 +1306,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch size_t llama_kv_cache::total_size() const { size_t size = 0; - for (const auto & buf : bufs) { + for (const auto & [_, buf] : ctxs_bufs) { size += ggml_backend_buffer_get_size(buf.get()); } @@ -2010,8 +2018,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } - -uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 85f0663d8c1d4..bf7821c07ca8f 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -19,8 +19,6 @@ struct llama_context; class llama_kv_cache : public llama_memory_i { public: - static uint32_t get_padding(const llama_cparams & cparams); - struct stream_copy_info { bool empty() const { assert(ssrc.size() == sdst.size()); @@ -217,8 +215,8 @@ class llama_kv_cache : public llama_memory_i { // this is the SWA type of the cache - not to be confused with the model SWA type const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - std::vector ctxs; - std::vector bufs; + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector> ctxs_bufs; // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index d67f5a5f47b87..276e1697d466c 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent( cells.clear(); cells.resize(mem_size); + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + // create a context for each buffer type - std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent( return nullptr; } - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; r_l.resize(n_layer); @@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent( } // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for rs cache"); } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - bufs.emplace_back(buf); + ctxs_bufs.emplace_back(std::move(ctx), buf); } { @@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) { used = 0; if (data) { - for (auto & buf : bufs) { + for (auto & [_, buf] : ctxs_bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } @@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const { size_t llama_memory_recurrent::total_size() const { size_t size = 0; - for (const auto & buf : bufs) { + for (const auto & [_, buf] : ctxs_bufs) { size += ggml_backend_buffer_get_size(buf.get()); } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 077c6e3ce938d..47f01d7391248 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -109,8 +109,8 @@ class llama_memory_recurrent : public llama_memory_i { const uint32_t n_seq_max = 1; - std::vector ctxs; - std::vector bufs; + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector> ctxs_bufs; size_t total_size() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 05e467180089e..ea6f59ed482bb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2231,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // define a comparator for the buft -> ctx map to ensure that the order is well-defined: struct ggml_backend_buft_comparator { bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { - return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs); + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; } }; std::map ctx_map; @@ -19641,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { llama_memory_i * res; switch (arch) { @@ -19692,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } - const auto padding = llama_kv_cache::get_padding(cparams); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - res = new llama_memory_hybrid( /* model */ *this, /* attn_type_k */ params.type_k, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ padding, + /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, /* recurrent_type_k */ GGML_TYPE_F32, @@ -19714,23 +19710,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_attn */ std::move(filter_attn), /* filter_recr */ std::move(filter_recr)); } else { - const auto padding = llama_kv_cache::get_padding(cparams); - uint32_t n_ctx_per_stream = cparams.n_ctx; if (!cparams.kv_unified) { n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; - n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); - - cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max; - } else { - n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); - - cparams.n_ctx = n_ctx_per_stream; } - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - llama_memory_i::layer_reuse_cb reuse = nullptr; if (arch == LLM_ARCH_GEMMA3N) { @@ -19757,7 +19742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, - padding, + 1, nullptr, reuse); } else { @@ -19772,7 +19757,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.kv_unified, n_ctx_per_stream, cparams.n_seq_max, - padding, + 1, hparams.n_swa, hparams.swa_type, nullptr, diff --git a/src/llama-model.h b/src/llama-model.h index 248f854101cd7..1ab1cf7f8e94d 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -500,9 +500,8 @@ struct llama_model { ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; - // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; + llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 67df240c6fef3..8a55bc54ae466 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1124,9 +1124,9 @@ static void test_all(const std::string & lang, std::function` option, each test can be run at a specified context depth, pr For a description of the other options, see the [main example](../main/README.md). +> [!NOTE] +> The measurements with `llama-bench` do not include the times for tokenization and for sampling. + ## Examples ### Text generation with different models @@ -131,7 +134,7 @@ $ ./llama-bench -n 0 -n 16 -p 64 -t 1,2,4,8,16,32 | llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 16 | pp 64 | 33.52 ± 0.03 | | llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 16 | tg 16 | 15.32 ± 0.05 | | llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | pp 64 | 59.00 ± 1.11 | -| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | tg 16 | 16.41 ± 0.79 || +| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | tg 16 | 16.41 ± 0.79 | ### Different numbers of layers offloaded to the GPU diff --git a/tools/server/public_legacy/json-schema-to-grammar.mjs b/tools/server/public_legacy/json-schema-to-grammar.mjs index 6f0952974496a..1d9dc5105eee9 100644 --- a/tools/server/public_legacy/json-schema-to-grammar.mjs +++ b/tools/server/public_legacy/json-schema-to-grammar.mjs @@ -345,10 +345,14 @@ export class SchemaConverter { const selectors = ref.split('#')[1].split('/').slice(1); for (const sel of selectors) { - if (!target || !(sel in target)) { + const selIndex = parseInt(sel, 10); + if (target && sel in target) { + target = target[sel]; + } else if (target && selIndex in target) { + target = target[selIndex]; + } else { throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`); } - target = target[sel]; } this._refs[ref] = target; @@ -594,7 +598,8 @@ export class SchemaConverter { } _resolveRef(ref) { - let refName = ref.split('/').pop(); + let refFragment = ref.split('#').pop(); + let refName = 'ref' + refFragment.replace(/[^a-zA-Z0-9-]+/g, '-'); if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) { this._refsBeingResolved.add(ref); const resolved = this._refs[ref]; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4124bffa40f85..cb794ab647eba 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2866,10 +2866,12 @@ struct server_context { // if context shifting is disabled, make sure that we don't run out of context if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); + SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); } // check the limits @@ -2929,16 +2931,6 @@ struct server_context { } } - // if context shift is disabled, we stop when it reaches the context limit - if (slot.n_past >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); - } - if (llama_vocab_is_eog(vocab, result.tok)) { slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; @@ -2946,19 +2938,6 @@ struct server_context { SLT_DBG(slot, "%s", "stopped by EOS\n"); } - const auto n_ctx_train = llama_model_n_ctx_train(model); - - if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; // stop prediction - - SLT_WRN(slot, - "n_predict (%d) is set for infinite generation. " - "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", - slot.task->params.n_predict, n_ctx_train); - } - SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); return slot.has_next_token; // continue diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 4adbbde64f594..7b047b7b3b74d 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -45,7 +45,7 @@ def test_ctx_shift_enabled(): @pytest.mark.parametrize("n_predict,n_token_output,truncated", [ (64, 64, False), - (-1, 120, True), + (-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total ]) def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): global server