diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe26e..ab3386b1dfa67 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1501,6 +1501,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); } ).set_sparam()); + add_opt(common_arg( + {"--backend-sampling"}, + "enable backend sampling (default: disabled)", + [](common_params & params) { + params.sampling.backend_sampling = true; + } + ).set_sparam()); + add_opt(common_arg( + {"--backend-dist"}, + "perform final (distribution) sampling on backend (default: disabled)", + [](common_params & params) { + params.sampling.backend_dist = true; + params.sampling.backend_sampling = true; + } + ).set_sparam()); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", diff --git a/common/build-info.cpp b/common/build-info.cpp new file mode 100644 index 0000000000000..6e8240fbb16a1 --- /dev/null +++ b/common/build-info.cpp @@ -0,0 +1,4 @@ +int LLAMA_BUILD_NUMBER = 5590; +char const *LLAMA_COMMIT = "0d398442"; +char const *LLAMA_COMPILER = "cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0"; +char const *LLAMA_BUILD_TARGET = "x86_64-linux-gnu"; diff --git a/common/common.cpp b/common/common.cpp index f3cc55247e718..6a6f5fec3d540 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -949,20 +950,34 @@ std::vector fs_list_files(const std::string & path) { // Model utils // -struct common_init_result common_init_from_params(common_params & params) { - common_init_result iparams; +llama_model * common_load_model_from_params(common_params & params) { auto mparams = common_model_params_to_llama(params); llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", __func__, params.model.path.c_str()); + return nullptr; + } + + return model; +} + +struct common_init_result common_init_context_from_model( + llama_model * model, + common_params & params) { + common_init_result iparams; + + if (model == NULL) { + LOG_ERR("%s: model is NULL\n", __func__); return iparams; } const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); + cparams.samplers = params.backend_samplers; + cparams.n_samplers = params.n_backend_samplers; llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { @@ -1129,6 +1144,14 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } +struct common_init_result common_init_from_params(common_params & params) { + llama_model * model = common_load_model_from_params(params); + if (model == NULL) { + return common_init_result(); + } + return common_init_context_from_model(model, params); +} + std::string get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. diff --git a/common/common.h b/common/common.h index de5b404dd8895..01e6dfe59b49d 100644 --- a/common/common.h +++ b/common/common.h @@ -195,6 +195,10 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + // Backend sampling flags + bool backend_sampling = false; // enable backend sampling + bool backend_dist = false; // backend performs final sampling (dist) + // print the parameters into a string std::string print() const; }; @@ -519,6 +523,9 @@ struct common_params { bool has_speculative() const { return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); } + + llama_sampler_seq_config * backend_samplers = NULL; + size_t n_backend_samplers = 0; }; // call once at the start of a program if it uses libcommon @@ -640,6 +647,14 @@ struct common_init_result { struct common_init_result common_init_from_params(common_params & params); +// Load model only (allows creating backend samplers before context initialization) +llama_model * common_load_model_from_params(common_params & params); + +// Initialize context from an already-loaded model (allows pre-configuring backend samplers) +struct common_init_result common_init_context_from_model( + llama_model * model, + common_params & params); + struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); diff --git a/common/llguidance.cpp b/common/llguidance.cpp index adce620e4d62f..27d15516e9438 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) { } static llama_sampler_i llama_sampler_llg_i = { - /* .name = */ llama_sampler_llg_name, - /* .accept = */ llama_sampler_llg_accept_impl, - /* .apply = */ llama_sampler_llg_apply, - /* .reset = */ llama_sampler_llg_reset, - /* .clone = */ llama_sampler_llg_clone, - /* .free = */ llama_sampler_llg_free, + /* .name = */ llama_sampler_llg_name, + /* .accept = */ llama_sampler_llg_accept_impl, + /* .apply = */ llama_sampler_llg_apply, + /* .reset = */ llama_sampler_llg_reset, + /* .clone = */ llama_sampler_llg_clone, + /* .free = */ llama_sampler_llg_free, + /* .apply_ggml = */ NULL, + /* .accept_ggml = */ NULL, + /* .set_input_ggml = */ NULL, + /* .set_backend_context = */ NULL, }; static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, diff --git a/common/sampling.cpp b/common/sampling.cpp index 7a6b7be1e0ee6..0a2be6bf7df56 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -121,17 +121,34 @@ struct common_sampler { } void set_logits(struct llama_context * ctx, int idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab); - cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (uint32_t i = 0; i < sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } cur_p = { cur.data(), cur.size(), -1, false }; @@ -144,6 +161,10 @@ struct common_sampler { mutable int64_t t_total_us = 0; }; +static bool sampler_enabled(const struct common_params_sampling & params, enum common_sampler_type type) { + return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end(); +} + std::string common_params_sampling::print() const { char result[1024]; @@ -301,6 +322,43 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co return result; } +struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { + if (!params.backend_sampling) { + return nullptr; + } + const llama_vocab * vocab = llama_model_get_vocab(model); + + llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + chain_params.no_perf = params.no_perf; + + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + + const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE); + const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K); + const bool enable_dist = params.backend_dist; + + if (!params.logit_bias.empty()) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias( + llama_vocab_n_tokens(vocab), + params.logit_bias.size(), + params.logit_bias.data())); + } + + if (enable_temp) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); + } + + if (enable_top_k) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); + } + + if (enable_dist) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed)); + } + + return chain; +} + void common_sampler_free(struct common_sampler * gsmpl) { if (gsmpl) { llama_sampler_free(gsmpl->grmr); @@ -384,6 +442,15 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { + // Check if a backend sampler has already sampled a token in which case we + // return that token id directly. + { + const llama_token id = llama_get_backend_sampled_token_ith(ctx, idx); + if (id != LLAMA_TOKEN_NULL) { + LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); + return id; + } + } llama_synchronize(ctx); // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations diff --git a/common/sampling.h b/common/sampling.h index e198eecda3810..0ec164de05343 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -38,6 +38,13 @@ struct common_sampler; struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); +// Create a backend sampler chain from common sampling parameters +// Returns a llama_sampler chain configured with backend samplers based on the parameters +// This chain can be used per-sequence for backend-based sampling +// Note: Only samplers that have backend equivalents will be added to the chain +// The returned sampler should be freed with llama_sampler_free() +struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params); + void common_sampler_free(struct common_sampler * gsmpl); // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 3722cf3ab26ee..b8003c48c51fc 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, size_t temp_storage_bytes = 0; if (order == GGML_SORT_ORDER_ASC) { - DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits - stream); + if (nrows == 1) { + DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + d_offsets, d_offsets + 1, stream); + } } else { - DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, - sizeof(float) * 8, stream); + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + } } ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); void * d_temp_storage = temp_storage_alloc.get(); if (order == GGML_SORT_ORDER_ASC) { - DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8, - stream); + if (nrows == 1) { + DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + } } else { - DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, - 0, sizeof(float) * 8, stream); + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, + stream); + } } } #endif // GGML_CUDA_USE_CUB diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 0000000000000..041dc7cdb5194 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,69 @@ +#include "cumsum.cuh" + +#ifdef GGML_CUDA_USE_CUB +#include +using namespace cub; +#endif // GGML_CUDA_USE_CUB + +#include + +__global__ void cumsum_f32_kernel(const float * x, float * dst, int64_t n) { + // TODO: this is a naive implementation just for getting something working. + if (threadIdx.x == 0 && blockIdx.x == 0) { + dst[0] = x[0]; + for (int64_t i = 1; i < n; i++) { + dst[i] = dst[i-1] + x[i]; + } + } +} + +void cumsum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) { +#ifdef GGML_CUDA_USE_CUB + size_t tmp_size = 0; + + // Query how much temp storage CUDA UnBound (CUB) needs + cub::DeviceScan::InclusiveSum( + nullptr, // d_temp_storage (null = just query size) + tmp_size, // reference to size (will be set by CUB) + x, // input pointer + dst, // output pointer + ne, // number of elements + stream // CUDA stream to use + ); + + ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); + + // Perform the inclusive scan + cub::DeviceScan::InclusiveSum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream); + +#else + GGML_UNUSED(pool); + cumsum_f32_kernel<<<1, 1, 0, stream>>>(x, dst, ne); +#endif // GGML_CUDA_USE_CUB +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + + const int64_t ne0 = src0->ne[0]; // row length (cumsum computed along this dimension) + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; + const int64_t ne3 = src0->ne[3]; + const int64_t nrows = ne1 * ne2 * ne3; // total number of rows + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t stream = ctx.stream(); + + for (int64_t i = 0; i < nrows; i++) { + const float * src_row = src0_d + i * ne0; + float * dst_row = dst_d + i * ne0; + cumsum_f32_cuda(pool, src_row, dst_row, ne0, stream); + } +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 0000000000000..7fca7e1456437 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void cumsum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream); + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8fe0899bb5aac..a4ff461120293 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -19,6 +19,7 @@ #include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" @@ -2678,6 +2679,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM: ggml_cuda_op_sum(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; @@ -4223,6 +4227,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_POOL_2D: case GGML_OP_ACC: return true; + case GGML_OP_CUMSUM: case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGSORT: diff --git a/include/llama.h b/include/llama.h index 8547226ff210c..9fbce771d74cf 100644 --- a/include/llama.h +++ b/include/llama.h @@ -210,6 +210,13 @@ extern "C" { bool sorted; // note: do not assume the data is sorted - always check this flag } llama_token_data_array; + struct llama_sampler_ggml_data { + struct ggml_tensor * logits; + struct ggml_tensor * probs; + struct ggml_tensor * sampled; + struct ggml_tensor * candidates; + }; + typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_encode/llama_decode @@ -300,6 +307,11 @@ extern "C" { bool no_host; // bypass host buffer allowing extra buffers to be used }; + struct llama_sampler_seq_config { + llama_seq_id seq_id; + struct llama_sampler * sampler; + }; + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { @@ -348,6 +360,10 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + + // backend sampler chain configuration + struct llama_sampler_seq_config * samplers; + size_t n_samplers; }; // model quantization parameters @@ -950,6 +966,32 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // Get the backend sampled token for the ith token. + // Returns LLAMA_TOKEN_NULL if no token was sampled. + LLAMA_API llama_token llama_get_backend_sampled_token_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled probabilites for the ith token + // The index matches llama_get_backend_sampled_token_ith(). + // Returns NULL if no probabilites were generated. + LLAMA_API float * llama_get_backend_sampled_probs_ith(struct llama_context * ctx, int32_t i); + // + // Get the number of backend sampled probabilites for the ith token. + LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled logits for the ith token + // Returns NULL if no logits were sampled. + LLAMA_API float * llama_get_backend_sampled_logits_ith(struct llama_context * ctx, int32_t i); + // + // Get the number of backend sampled logits for the ith token. + LLAMA_API uint32_t llama_get_backend_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled candidates (token ids) for the ith token + // Returns NULL if no candidates were sampled. + LLAMA_API llama_token * llama_get_backend_sampled_candidates_ith(struct llama_context * ctx, int32_t i); + // + // Get the number of backend sampled candidates for the ith token. + LLAMA_API uint32_t llama_get_backend_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); + // // Vocab // @@ -1135,6 +1177,22 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL + void (*apply_ggml)( struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data); + + void (*accept_ggml)( struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); + + void (*set_input_ggml)(struct llama_sampler * smpl); + + void (*init_ggml)(struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft); + + // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_sampler * smpl, ...); }; @@ -1144,6 +1202,8 @@ extern "C" { llama_sampler_context_t ctx; }; + LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); + // mirror of llama_sampler_i: LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); @@ -1153,6 +1213,18 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + LLAMA_API void llama_sampler_init_ggml(struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft); + LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl); + LLAMA_API void llama_sampler_apply_ggml( struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data); + + LLAMA_API void llama_sampler_accept_ggml( struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); // llama_sampler_chain // a type of llama_sampler that can chain multiple samplers one after another @@ -1299,9 +1371,29 @@ extern "C" { // LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); + // + // Backend samplers + // + + /// @details Greedy sampling on backend - always selects the token with the highest probability + LLAMA_API struct llama_sampler * llama_sampler_backend_init_greedy(void); + + /// @details Temperature scaling on backend - scales logits by 1/temperature + LLAMA_API struct llama_sampler * llama_sampler_backend_init_temp(float temp); + + /// @details Top-K filtering on backend - keeps only the k tokens with highest probabilities + LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k); + + /// @details Distribution sampling on backend - final sampling step that selects a token + LLAMA_API struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed); + // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); + /// @details Sample and accept a token from the idx-th output of the last evaluation // // Shorthand for: diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8ec95ee176240..c17b89008948b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,6 +31,7 @@ add_library(llama llama-model.cpp llama-quant.cpp llama-sampling.cpp + llama-backend-sampler.cpp llama-vocab.cpp unicode-data.cpp unicode.cpp diff --git a/src/llama-backend-sampler.cpp b/src/llama-backend-sampler.cpp new file mode 100644 index 0000000000000..cd6b8bb7526a6 --- /dev/null +++ b/src/llama-backend-sampler.cpp @@ -0,0 +1,489 @@ +#include "llama.h" +#include "ggml.h" +#include +#include +#include +#include +#include + +static void llama_sampler_backend_greedy_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + GGML_UNUSED(smpl); + struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits); + ggml_set_name(argmax_result, "argmax_result"); + ggml_data->sampled = argmax_result; +} + +static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) { + return "test-ggml"; +} + +static struct llama_sampler * llama_sampler_backend_greedy_clone(const struct llama_sampler * smpl) { + (void) smpl; + return llama_sampler_backend_init_greedy(); +} + +struct llama_sampler * llama_sampler_backend_init_greedy() { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_greedy_sampler_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_greedy_clone, + /*.free =*/ nullptr, + /*.apply_ggml =*/ llama_sampler_backend_greedy_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ nullptr, + /*.init_ggml =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ nullptr, + }; + + return sampler; +} + +struct llama_sampler_backend_temp_ctx { + float temp; +}; + + +static void llama_sampler_backend_temp_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + + auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx; + + if (ctx_data->temp <= 0.0f) { + return; + } + + struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp); + ggml_set_name(scaled, "temp_scaled"); + + // Make sure the scaled tensor is contiguous for subsequent operations + ggml_data->logits = ggml_cont(ctx, scaled); + ggml_set_name(ggml_data->logits, "temp_scaled_logits"); + + ggml_build_forward_expand(gf, ggml_data->logits); +} + +static const char * llama_sampler_backend_temp_name(const struct llama_sampler *) { + return "backend-temp"; +} + +static void llama_sampler_backend_temp_free(struct llama_sampler * smpl) { + auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx; + delete ctx_data; +} + +static struct llama_sampler * llama_sampler_backend_temp_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_backend_temp_ctx *) smpl->ctx; + return llama_sampler_backend_init_temp(ctx->temp); +} + +struct llama_sampler * llama_sampler_backend_init_temp(float temp) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_temp_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_temp_clone, + /*.free =*/ llama_sampler_backend_temp_free, + /*.apply_ggml =*/ llama_sampler_backend_temp_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ nullptr, + /*.set_backend_context =*/ nullptr, + }; + + auto * ctx_data = new llama_sampler_backend_temp_ctx { + /*.temp =*/ temp, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ ctx_data, + }; + + return sampler; +} + + +struct llama_sampler_backend_top_k_ctx { + int32_t k; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; +}; + +static void llama_sampler_backend_top_k_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; + ctx_data->device = ggml_backend_buft_get_device(buft); +} + +static void llama_sampler_backend_top_k_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + + auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; + + struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k); + ggml_set_name(top_k, "top_k"); + + // top_k is a view of argsort - check if backend supports the underlying argsort operation + // by checking the source tensor (which is the argsort result) + if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) { + fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n"); + fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); + } + + ggml_data->candidates = top_k; + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); + struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); + ggml_set_name(top_k_rows, "top_k_rows"); + + ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k); + ggml_build_forward_expand(gf, ggml_data->logits); +} + +static const char * llama_sampler_backend_top_k_name(const struct llama_sampler *) { + return "backend-top-k"; +} + +static void llama_sampler_backend_top_k_free(struct llama_sampler * smpl) { + auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; + delete ctx_data; +} + +static struct llama_sampler * llama_sampler_backend_top_k_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_backend_top_k_ctx *) smpl->ctx; + return llama_sampler_backend_init_top_k(ctx->k); +} + +struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_top_k_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_top_k_clone, + /*.free =*/ llama_sampler_backend_top_k_free, + /*.apply_ggml =*/ llama_sampler_backend_top_k_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ nullptr, + /*.init_ggml =*/ llama_sampler_backend_top_k_init_ggml, + }; + + auto * ctx_data = new llama_sampler_backend_top_k_ctx { + /*.k =*/ k, + /*.device =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ ctx_data, + }; + + return sampler; +} + + +static uint32_t get_rng_seed(uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + // use system clock if std::random_device is not a true RNG + static bool is_rd_prng = std::random_device().entropy() == 0; + if (is_rd_prng) { + return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); + } + std::random_device rd; + return rd(); + } + return seed; +} + +struct llama_sampler_backend_dist_ctx { + const uint32_t seed; + uint32_t seed_cur; + std::mt19937 rng; + + struct ggml_tensor * uniform; + struct ggml_context * ctx; + ggml_backend_buffer_t buffer; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; +}; + +static void llama_sampler_backend_dist_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + sctx->device = ggml_backend_buft_get_device(buft); + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + sctx->ctx = ggml_init(params); + + // Create the uniform random scalar input tensor. This will be set by + // llama_sampler_backend_dist_set_input_ggml after this graph is built. + sctx->uniform = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, 1); + ggml_set_name(sctx->uniform, "uniform"); + ggml_set_input(sctx->uniform); + ggml_set_output(sctx->uniform); + + // Allocate all tensors from our context to the backend + sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft); +} + +static void llama_sampler_backend_dist_set_input_ggml(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + GGML_ASSERT(sctx->uniform != nullptr); + + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + ggml_backend_tensor_set(sctx->uniform, &rnd, 0, sizeof(float)); +} + +static void llama_sampler_backend_dist_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + + struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits); + ggml_set_name(probs, "dist_probs"); + + struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); + if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) { + fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n"); + fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); + } + ggml_set_name(cumsum, "cumsum"); + + // The uniform tensor has a random value and we subtract this tensor with + // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). + // Recall that each entry in cumsum is the cumulative probability up to that + // index so values stay negative while the cumulative total is below the + // random value, and become zero/positive once the threshold is crossed. + struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->uniform); + ggml_set_name(diff, "dist_cumsum"); + + // The ggml_step function produces a tensor where entries are 1 if the + // corresponding entry in diff is > 0, and 0 otherwise. So all values up to + // the index where the cumulative probability exceeds the random value are 0, + // and all entries after that are 1. + struct ggml_tensor * mask = ggml_step(ctx, diff); + ggml_set_name(mask, "dist_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + // Use ggml_scale_bias to scale the index value by -1 and then add the size + // of the mask to that value so we get the correct index ((-1 * idxf) + n). + struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); + ggml_set_name(idx, "dist_index_i32"); + + // Map back to original vocab ids if a candidates tensor is available. + struct ggml_tensor * sampled_token = idx; + if (ggml_data->candidates != nullptr) { + struct ggml_tensor * candidates = ggml_data->candidates; + struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates), + ggml_type_size(candidates->type), 0); + + sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx); + ggml_set_name(sampled_token, "dist_sampled_token"); + } + + ggml_set_output(sampled_token); + ggml_data->sampled = sampled_token; +} + +static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) { + return "backend-dist"; +} + +static void llama_sampler_backend_dist_free(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + ggml_backend_buffer_free(sctx->buffer); + ggml_free(sctx->ctx); + delete sctx; +} + +static struct llama_sampler * llama_sampler_backend_dist_clone(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + return llama_sampler_backend_init_dist(sctx->seed); +} + + +struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_dist_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_dist_clone, + /*.free =*/ llama_sampler_backend_dist_free, + /*.apply_ggml =*/ llama_sampler_backend_dist_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ llama_sampler_backend_dist_set_input_ggml, + /*.init_ggml =*/ llama_sampler_backend_dist_init_ggml, + }; + + auto seed_cur = get_rng_seed(seed); + auto * ctx_data = new llama_sampler_backend_dist_ctx { + /*.seed =*/ seed, + /*.seed_cur =*/ seed_cur, + /*.rng =*/ std::mt19937(seed_cur), + /*.uniform =*/ nullptr, + /*.ctx =*/ nullptr, + /*.buffer =*/ nullptr, + /*.device =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ ctx_data, + }; + + return sampler; +} + +struct llama_sampler_backend_logit_bias_ctx { + const int32_t n_vocab; + + const std::vector logit_bias; + + struct ggml_tensor * logit_bias_t; + struct ggml_context * ctx; + ggml_backend_buffer_t buffer; +}; + +static void llama_sampler_backend_logit_bias_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead() * sctx->n_vocab * sizeof(float), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + sctx->ctx = ggml_init(params); + + struct ggml_tensor * logit_bias = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, sctx->n_vocab); + sctx->logit_bias_t = logit_bias; + ggml_set_name(sctx->logit_bias_t, "logit_bias"); + ggml_set_input(sctx->logit_bias_t); + ggml_set_output(sctx->logit_bias_t); + + // Allocate all tensors from our context to the backend + sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft); +} + +static void llama_sampler_backend_logit_bias_set_input_ggml(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + GGML_ASSERT(sctx->logit_bias_t != nullptr); + + // Create a sparse logit_bias vector from the logit_bias entries. + std::vector logit_bias_sparse(sctx->n_vocab, 0.0f); + for (const auto & lb : sctx->logit_bias) { + GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); + logit_bias_sparse[lb.token] = lb.bias; + } + + ggml_backend_tensor_set(sctx->logit_bias_t, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->logit_bias_t)); +} + +static void llama_sampler_backend_logit_bias_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + GGML_UNUSED(ctx); + + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + if (sctx->logit_bias_t == nullptr) { + return; + } + + // Add the sparse logit logit_bias to the logits + struct ggml_tensor * logit_biased = ggml_add_inplace(sctx->ctx, ggml_data->logits, sctx->logit_bias_t); + ggml_build_forward_expand(gf, logit_biased); +} + +static const char * llama_sampler_backend_logit_bias_name(const struct llama_sampler *) { + return "backend-logit_bias"; +} + +static void llama_sampler_backend_logit_bias_free(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + ggml_backend_buffer_free(sctx->buffer); + ggml_free(sctx->ctx); + delete sctx; +} + +static struct llama_sampler * llama_sampler_backend_logit_bias_clone(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + return llama_sampler_backend_init_logit_bias(sctx->n_vocab, + sctx->logit_bias.size(), + sctx->logit_bias.data()); +} + + +struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_logit_bias_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_logit_bias_clone, + /*.free =*/ llama_sampler_backend_logit_bias_free, + /*.apply_ggml =*/ llama_sampler_backend_logit_bias_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ llama_sampler_backend_logit_bias_set_input_ggml, + /*.init_ggml =*/ llama_sampler_backend_logit_bias_init_ggml, + }; + + auto * ctx_data = new llama_sampler_backend_logit_bias_ctx { + /*.n_vocab =*/ n_vocab, + /*.logit_bias =*/ std::vector(logit_bias, logit_bias + n_logit_bias), + /*.logit_bias_t =*/ nullptr, + /*.ctx =*/ nullptr, + /*.buffer =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ ctx_data, + }; + + return sampler; +} diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 86a1a4ba187ee..f0866a9ca1962 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -28,7 +28,8 @@ bool llama_batch_allocr::init( const llama_memory_i * memory, uint32_t n_embd, uint32_t n_seq_max, - bool output_all) { + bool output_all, + bool backend_sampling) { clear(); batch = batch_inp; @@ -145,6 +146,24 @@ bool llama_batch_allocr::init( } } + if (backend_sampling) { + std::vector seq_output_count(n_seq_max, 0); + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.logits[i] == 0) { + continue; + } + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + const llama_seq_id seq_id = batch.seq_id[i][s]; + seq_output_count[seq_id]++; + if (seq_output_count[seq_id] > 1) { + LLAMA_LOG_ERROR("%s: backend sampling allows at most one output token per sequence (%d)\n", __func__, seq_id); + return false; + } + } + } + } + // // compute stats // diff --git a/src/llama-batch.h b/src/llama-batch.h index 209cf3699de23..d8751274f376d 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -79,7 +79,8 @@ class llama_batch_allocr { const llama_memory_i * memory, uint32_t n_embd, uint32_t n_seq_max, - bool output_all); + bool output_all, + bool backend_sampling = false); const llama_batch & get_batch() const; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 70a3ec62dfc63..1694e44720062 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -58,6 +58,16 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + // backend samplers + if (params.samplers != nullptr && params.n_samplers > 0) { + sampling.samplers.reserve(params.n_samplers); + + for (size_t i = 0; i < params.n_samplers; ++i) { + const auto & config = params.samplers[i]; + sampling.samplers[config.seq_id] = config.sampler; + } + } + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; @@ -420,10 +430,24 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } } + + // Initialize the full vocabulary token ids for backend samplers. + { + const llama_vocab * vocab = llama_model_get_vocab(&model); + const int n_vocab = llama_vocab_n_tokens(vocab); + sampling.token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampling.token_ids_full_vocab[i] = i; + } + } } llama_context::~llama_context() { ggml_opt_free(opt_ctx); + // TODO: perhaps use a smart pointer for samplers + for (auto const& [seq_id, sampler] : sampling.samplers) { + llama_sampler_free(sampler); + } } void llama_context::synchronize() { @@ -564,6 +588,35 @@ float * llama_context::get_logits() { return logits; } +int64_t llama_context::resolve_output_row(int32_t i) const { + int64_t j = -1; + + // support negative indices (last output row) + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + // use output_ids to translate the batch token index into a row number + // that holds this token's data. + j = output_ids[i]; + } + + if (j < 0) { + // the batch token was not configured to output anything + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + + if (j >= n_outputs) { + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); + } + + return j; +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -610,6 +663,10 @@ float * llama_context::get_embeddings() { return embd; } +llama_token * llama_context::get_backend_sampled_tokens() { + return sampling.sampled; +} + float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; @@ -659,6 +716,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) { + output_reorder(); + + if (sampling.sampled == nullptr) { + return LLAMA_TOKEN_NULL; + } + + try { + const int64_t row = resolve_output_row(idx); + GGML_ASSERT(row < (int64_t) sampling.sampled_size); + return sampling.sampled[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); + return LLAMA_TOKEN_NULL; + } +} + +float * llama_context::get_backend_sampled_probs_ith(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return nullptr; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { + return nullptr; + } + return sampling.probs + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +float * llama_context::get_backend_sampled_logits_ith(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return nullptr; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { + return nullptr; + } + return sampling.logits + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +const llama_token * llama_context::get_backend_sampled_candidates_ith(int32_t idx) { + output_reorder(); + + try { + const int64_t row = resolve_output_row(idx); + if (sampling.candidates != nullptr && + (size_t) row < sampling.candidates_count.size() && + sampling.candidates_count[row] > 0) { + return sampling.candidates + row*model.vocab.n_tokens(); + } + } catch (const std::exception & err) { + // fallback to full vocab list + } + + return sampling.token_ids_full_vocab.data(); +} + +size_t llama_context::get_backend_sampled_candidates_count(int32_t idx) { + output_reorder(); + + if (sampling.candidates == nullptr) { + return 0; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.candidates_count.size()) { + return 0; + } + return sampling.candidates_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_backend_sampled_logits_count(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return 0; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.logits_count.size()) { + return 0; + } + return sampling.logits_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_backend_sampled_probs_count(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return 0; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.probs_count.size()) { + return 0; + } + return sampling.probs_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -715,6 +902,37 @@ void llama_context::set_warmup(bool value) { cparams.warmup = value; } +void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + + auto it = sampling.samplers.find(seq_id); + if (it != sampling.samplers.end()) { + // If the sampler to be set is the same that is already set, do nothing. + if (it->second == sampler) { + return; + } + + llama_sampler_free(it->second); + + // If sampler is nullptr, we remove the samppler chain for this seq_id. + // chain for this seq_id. + if (sampler == nullptr) { + sampling.samplers.erase(it); + return; + } + + // Otherwise, we replace the existing sampler with the new one. + it->second = sampler; + return; + } + + // If there is no sampler for this seq_id and the caller provides a non-null + // sampler, we set it. + if (sampler != nullptr) { + sampling.samplers[seq_id] = sampler; + } +} + void llama_context::set_adapter_lora( llama_adapter_lora * adapter, float scale) { @@ -979,6 +1197,97 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } +static std::unordered_map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { + std::unordered_map seq_to_row; + // how many output tokens we have seen so far for this ubatch. + uint32_t local = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + // skip tokens that are not output. + if (!ubatch.output[i]) { + continue; + } + + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + // row_offset is the number of output tokens before this ubatch. + seq_to_row[seq_id] = row_offset + local; + ++local; + } + return seq_to_row; +} + +static void copy_tensor_async_ints( + const std::unordered_map & tensor_map, + llama_token * sampled, + size_t sampled_size, + const std::unordered_map & seq_to_row, + ggml_backend_sched_t sched) { + if (sampled == nullptr || sampled_size == 0) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + GGML_ASSERT(it != seq_to_row.end()); + const uint32_t row = it->second; + GGML_ASSERT(row < sampled_size); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + } +} + +static void copy_tensor_async_floats( + const std::unordered_map & tensor_map, + float * dst, + size_t stride, + std::vector & counts, + const std::unordered_map & seq_to_row, + ggml_backend_sched_t sched) { + if (dst == nullptr || stride == 0) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + GGML_ASSERT(it != seq_to_row.end()); + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + float * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of logits/probabilities that were written for this row. + counts[row] = ggml_nelements(tensor); + } +} + +static void copy_tensor_async_candidates( + const std::unordered_map & tensor_map, + llama_token * dst, + size_t stride, + std::vector & counts, + const std::unordered_map & seq_to_row, + ggml_backend_sched_t sched) { + if (dst == nullptr || stride == 0) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + GGML_ASSERT(it != seq_to_row.end()); + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + llama_token * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of candidates that were written. + counts[row] = ggml_nelements(tensor); + } +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1000,8 +1309,12 @@ int llama_context::decode(const llama_batch & batch_inp) { // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; + const bool has_backend_samplers = !sampling.samplers.empty(); - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, + cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, + output_all, + has_backend_samplers)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -1089,6 +1402,10 @@ int llama_context::decode(const llama_batch & batch_inp) { int64_t n_outputs_prev = 0; + // This flag indicates whether a backend sampler has actually sampled a specific + // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. + bool backend_has_sampled = false; + do { const auto & ubatch = mctx->get_ubatch(); @@ -1147,80 +1464,106 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + backend_has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); - if (t_embd && res->get_embd_pooled()) { - t_embd = res->get_embd_pooled(); - } + if (has_backend_samplers && backend_has_sampled) { + const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); - // extract logits - if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + // If a backend sampler has sampled a token we only want to copy the + // sampled tokens and avoid copying logits and probabilites. + if (!res->t_sampled.empty()) { + // async copy the sampled tokens from the backend to the host. + copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + } else { + // async copy the sampled logits/probs from the backend to the host. + copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, n_vocab, sampling.logits_count, seq_to_output_row, sched.get()); + copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, n_vocab, sampling.probs_count, seq_to_output_row, sched.get()); + } - float * logits_out = logits + n_outputs_prev*n_vocab; + // async copy the candidate token ids from the backend to the host. + // These are needed for: + // 1) Backend dist sampler to map indices to vocab token ids. + // 2) CPU samplers to associate candidate logits with their token ids. + copy_tensor_async_candidates(res->t_candidates, sampling.candidates, n_vocab, sampling.candidates_count, seq_to_output_row, sched.get()); - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); - } } - // extract embeddings - if (t_embd && n_outputs > 0) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); + if (!backend_has_sampled) { + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); + } - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings (cleared before processing each batch) - auto & embd_seq_out = embd_seq; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; + // extract logits + if (t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); - const uint32_t n_cls_out = hparams.n_cls_out; + float * logits_out = logits + n_outputs_prev*n_vocab; - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } + } - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + // extract embeddings + if (t_embd && n_outputs > 0) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } + } } } @@ -1306,8 +1649,31 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd*n_outputs_max : 0; + const bool backend_sampling = !sampling.samplers.empty(); + size_t backend_float_count = 0; + size_t backend_token_count = 0; + + if (!backend_sampling) { + logits_size = has_logits ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd*n_outputs_max : 0; + + // reset backend sampling values. + sampling.logits_size = 0; + sampling.probs_size = 0; + sampling.sampled_size = 0; + sampling.candidates_size = 0; + } else { + logits_size = 0; + embd_size = 0; + + sampling.logits_size = n_vocab*n_outputs_max; + sampling.probs_size = n_vocab*n_outputs_max; + sampling.sampled_size = n_outputs_max; + sampling.candidates_size = n_vocab*n_outputs_max; + + backend_float_count = sampling.logits_size + sampling.probs_size; + backend_token_count = sampling.sampled_size + sampling.candidates_size; + } if (output_ids.empty()) { // init, never resized afterwards @@ -1315,7 +1681,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; - const size_t new_size = (logits_size + embd_size) * sizeof(float); + const size_t new_size = (logits_size + embd_size + backend_float_count) * sizeof(float) + + backend_token_count * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1346,8 +1713,56 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; + logits = nullptr; + embd = nullptr; + + // reset sampling pointers. + sampling.logits = nullptr; + sampling.probs = nullptr; + sampling.sampled = nullptr; + sampling.candidates = nullptr; + + if (!backend_sampling) { + logits = has_logits ? output_base : nullptr; + embd = has_embd ? output_base + logits_size : nullptr; + } else { + size_t offset = 0; + uint8_t * base = (uint8_t *) output_base; + + if (sampling.logits_size > 0) { + sampling.logits = (float *) (base + offset); + offset += sampling.logits_size * sizeof(float); + } + if (sampling.probs_size > 0) { + sampling.probs = (float *) (base + offset); + offset += sampling.probs_size * sizeof(float); + } + if (sampling.sampled_size > 0) { + sampling.sampled = (llama_token *) (base + offset); + offset += sampling.sampled_size * sizeof(llama_token); + } + if (sampling.candidates_size > 0) { + sampling.candidates = (llama_token *) (base + offset); + offset += sampling.candidates_size * sizeof(llama_token); + } + + const size_t n_rows = (size_t) n_outputs_max; + if (sampling.outputs_capacity < n_rows) { + sampling.outputs_capacity = n_rows; + + sampling.logits_count.assign(n_rows, 0); + sampling.probs_count.assign(n_rows, 0); + sampling.candidates_count.assign(n_rows, 0); + } else { + std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); + std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); + std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); + } + + if (sampling.sampled && sampling.sampled_size > 0) { + std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + } + } // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1376,6 +1791,38 @@ void llama_context::output_reorder() { std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); } } + + if (sampling.logits && sampling.logits_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + } + } + + if (sampling.probs && sampling.probs_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + } + } + + if (sampling.candidates && sampling.candidates_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + } + } + + if (sampling.sampled && sampling.sampled_size > 0) { + std::swap(sampling.sampled[i0], sampling.sampled[i1]); + } + + if (!sampling.logits_count.empty()) { + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + } + if (!sampling.probs_count.empty()) { + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); + } + if (!sampling.candidates_count.empty()) { + std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); + } } output_swaps.clear(); @@ -1452,10 +1899,12 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, + /*.dev_out =*/ model.dev_output(), /*.cvec =*/ &cvec, /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -2319,6 +2768,8 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.sampler =*/ nullptr, + /*.n_sampler =*/ 0, }; return result; @@ -2478,6 +2929,13 @@ float * llama_get_logits(llama_context * ctx) { float * llama_get_logits_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); + if (ctx->get_backend_sampled_token_ith(i) != LLAMA_TOKEN_NULL) { + return nullptr; + } + if (ctx->get_backend_sampled_probs_ith(i) != nullptr) { + return nullptr; + } + return ctx->get_logits_ith(i); } @@ -2499,6 +2957,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { + ctx->set_backend_sampler(seq_id, smpl); +} + +llama_token llama_get_backend_sampled_token_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_backend_sampled_token_ith(i); +} + +float * llama_get_backend_sampled_probs_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_backend_sampled_probs_ith(i); +} + +float * llama_get_backend_sampled_logits_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_backend_sampled_logits_ith(i); +} + +llama_token * llama_get_backend_sampled_candidates_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return const_cast(ctx->get_backend_sampled_candidates_ith(i)); +} + +uint32_t llama_get_backend_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_backend_sampled_candidates_count(i)); +} + +uint32_t llama_get_backend_sampled_logits_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_backend_sampled_logits_count(i)); +} + +uint32_t llama_get_backend_sampled_probs_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_backend_sampled_probs_count(i)); +} + // llama adapter API int32_t llama_set_adapter_lora( diff --git a/src/llama-context.h b/src/llama-context.h index 20cbd78955412..2bdbf8a55326b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -66,6 +66,18 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + llama_token * get_backend_sampled_tokens(); + llama_token get_backend_sampled_token_ith(int32_t idx); + + float * get_backend_sampled_logits_ith(int32_t idx); + size_t get_backend_sampled_logits_count(int32_t idx); + + float * get_backend_sampled_probs_ith(int32_t idx); + size_t get_backend_sampled_probs_count(int32_t idx); + + const llama_token * get_backend_sampled_candidates_ith(int32_t idx); + size_t get_backend_sampled_candidates_count(int32_t idx); + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -191,6 +203,7 @@ struct llama_context { uint32_t output_reserve(int32_t n_outputs); void output_reorder(); + int64_t resolve_output_row(int32_t i) const; // // graph @@ -208,6 +221,8 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); + void set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -242,6 +257,31 @@ struct llama_context { size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; + struct sampling_info { + std::unordered_map samplers; + + float * logits = nullptr; + size_t logits_size = 0; + + llama_token * sampled = nullptr; + size_t sampled_size = 0; + + float * probs = nullptr; + size_t probs_size = 0; + + llama_token * candidates = nullptr; + size_t candidates_size = 0; + + size_t outputs_capacity = 0; + std::vector logits_count; + std::vector probs_count; + std::vector candidates_count; + + std::vector token_ids_full_vocab; + }; + + sampling_info sampling; + // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 650e40ec6ffce..8af9188d05da7 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -462,6 +462,28 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { inp_rs->set_input(ubatch); } +void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + for (const auto & [seq_id, sampler] : samplers) { + if (sampler->iface->set_input_ggml) { + sampler->iface->set_input_ggml(sampler); + } + } +} + +bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { + if (params.samplers.empty()) { + return true; + } + + for (const auto & [seq_id, sampler] : params.samplers) { + if (samplers[seq_id] != sampler) { + return false; + } + } + return true; +} + // // llm_graph_result // @@ -482,6 +504,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_sampled.clear(); + t_sampled_probs.clear(); + t_sampled_logits.clear(); + t_candidates.clear(); params = {}; @@ -583,10 +609,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : rope_type (hparams.rope_type), sched (params.sched), backend_cpu (params.backend_cpu), + dev_out (params.dev_out), cvec (params.cvec), loras (params.loras), mctx (params.mctx), cross (params.cross), + samplers (params.samplers), cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), @@ -2021,6 +2049,100 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } +void llm_graph_context::build_sampling() const { + if (samplers.empty()) { + return; + } + + std::unordered_map seq_to_logit_row; + int32_t logit_row_idx = 0; + + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (ubatch.output[i]) { + llama_seq_id seq_id = ubatch.seq_id[i][0]; + seq_to_logit_row[seq_id] = logit_row_idx; + logit_row_idx++; + } + } + if (seq_to_logit_row.empty()) { + return; + } + + // res->t_logits will contain logits for all tokens that specied that want + // logits calculated (logits=1 or output=1) + ggml_tensor * logits_t = res->t_logits; + GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); + + const int64_t n_vocab = logits_t->ne[0]; + + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(dev_out); + + std::unordered_map active_samplers; + + for (const auto & [seq_id, sampler] : samplers) { + // Only process samplers for sequences that are in the current batch + auto it = seq_to_logit_row.find(seq_id); + if (it == seq_to_logit_row.end()) { + continue; + } + const int32_t row_idx = it->second; + + // Allow GPU sampler to create input tensors by implementing init_ggml. + if (sampler->iface->init_ggml != nullptr) { + sampler->iface->init_ggml(sampler, buft); + } + + active_samplers[seq_id] = sampler; + + ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]); + ggml_format_name(logits_seq, "logits_seq_%d", seq_id); + + struct llama_sampler_ggml_data ggml_data = { + /*.logits =*/ logits_seq, + /*.probs =*/ nullptr, + /*.sampled =*/ nullptr, + /*.candidates =*/ nullptr, + }; + + llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data); + + if (ggml_data.sampled != nullptr) { + res->t_sampled[seq_id] = ggml_data.sampled; + ggml_build_forward_expand(gf, ggml_data.sampled); + } + + if (ggml_data.probs != nullptr) { + res->t_sampled_probs[seq_id] = ggml_data.probs; + ggml_build_forward_expand(gf, ggml_data.probs); + } + + if (ggml_data.logits != logits_seq) { + res->t_sampled_logits[seq_id] = ggml_data.logits; + ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]); + } + + if (ggml_data.candidates != nullptr) { + res->t_candidates[seq_id] = ggml_data.candidates; + ggml_build_forward_expand(gf, ggml_data.candidates); + } + } + + // TODO: Call llama_sampler_accept_ggml after all samplers have been applied. + /* + for (const auto & [seq_id, sampler] : samplers) { + if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) { + ggml_tensor * selected_token = it->second; + if (selected_token != nullptr) { + llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token); + } + } + } + */ + + auto inp_sampling = std::make_unique(n_vocab, false, active_samplers); + res->add_input(std::move(inp_sampling)); +} + int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f67927..6797d78a20e38 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -383,6 +383,24 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_sampling : public llm_graph_input_i { +public: + llm_graph_input_sampling(int32_t n_vocab, bool sorted, + std::unordered_map samplers) : + n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) { } + virtual ~llm_graph_input_sampling() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + int32_t n_vocab; + bool sorted_value; + ggml_tensor * size = nullptr; // I32 [1] + ggml_tensor * sorted = nullptr; // I32 [1] + + std::unordered_map samplers; +}; + // // llm_graph_result // @@ -410,12 +428,30 @@ struct llm_graph_params { ggml_backend_sched_t sched; ggml_backend_t backend_cpu; + ggml_backend_dev_t dev_out; const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + std::unordered_map samplers; + + static bool samplers_equal( + const std::unordered_map & lhs, + const std::unordered_map & rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto & [seq_id, sampler] : lhs) { + auto it = rhs.find(seq_id); + if (it == rhs.end() || it->second != sampler) { + return false; + } + } + return true; + } + uint32_t n_outputs; llm_graph_cb cb; @@ -463,7 +499,9 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && - n_outputs == other.n_outputs; + n_outputs == other.n_outputs && + samplers_equal(samplers, other.samplers); + } }; @@ -504,6 +542,11 @@ class llm_graph_result { ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + std::unordered_map t_sampled_logits; + std::unordered_map t_candidates; + std::unordered_map t_sampled; + std::unordered_map t_sampled_probs; + std::vector inputs; ggml_context_ptr ctx_compute; @@ -574,11 +617,15 @@ struct llm_graph_context { ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? + ggml_backend_dev_t dev_out; + const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + std::unordered_map samplers; + const llm_graph_cb & cb_func; llm_graph_result * res; @@ -819,6 +866,12 @@ struct llm_graph_context { ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + // + // sampling (backend sampling) + // + + void build_sampling() const; + // // dense (out) // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 175549a9e30f1..29fc6bbc63025 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7414,6 +7414,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // add backend sampling layers (if any) + llm->build_sampling(); + // if the gguf model was converted with --sentence-transformers-dense-modules // there will be two additional dense projection layers // dense linear projections are applied after pooling diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 3f4a729bc36c7..621438a9cf4d4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -372,6 +372,39 @@ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_ar smpl->iface->apply(smpl, cur_p); } +void llama_sampler_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_ASSERT(smpl->iface->apply_ggml); + smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data); +} + +void llama_sampler_accept_ggml( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + if (smpl->iface->accept_ggml) { + smpl->iface->accept_ggml(smpl, ctx, gf, selected_token); + } +} + +void llama_sampler_set_input_ggml(struct llama_sampler * smpl) { + if (smpl->iface->set_input_ggml) { + smpl->iface->set_input_ggml(smpl); + } +} + +void llama_sampler_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + if (smpl->iface->init_ggml) { + smpl->iface->init_ggml(smpl, buft); + } +} + void llama_sampler_reset(struct llama_sampler * smpl) { if (smpl->iface->reset) { smpl->iface->reset(smpl); @@ -406,7 +439,15 @@ void llama_sampler_free(struct llama_sampler * smpl) { } llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const llama_token sampled_token = llama_get_backend_sampled_token_ith(ctx, idx); + const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx); + const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx); + const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx); + + // If a backend sampler has already sampled a token, return it. + if (sampled_token != LLAMA_TOKEN_NULL) { + return sampled_token; + } const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -415,9 +456,26 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte // TODO: do not allocate each time std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } llama_token_data_array cur_p = { @@ -462,6 +520,10 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d time_meas tm(chain->t_sample_us, chain->params.no_perf); for (auto * smpl : chain->samplers) { + // Skip GPU samplers - they have apply_ggml but no apply + if (smpl->iface->apply == nullptr) { + continue; + } llama_sampler_apply(smpl, cur_p); } } @@ -496,13 +558,67 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) { delete chain; } +static void llama_sampler_chain_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + if (smpl->iface->apply_ggml) { + smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data); + } + } +} + +static void llama_sampler_chain_accept_ggml( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + if (smpl->iface->accept_ggml) { + smpl->iface->accept_ggml(smpl, ctx, gf, selected_token); + } + } +} + +static void llama_sampler_chain_set_input_ggml(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + if (smpl->iface->set_input_ggml) { + smpl->iface->set_input_ggml(smpl); + } + } +} + +static void llama_sampler_chain_set_backend_context( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + if (smpl->iface->init_ggml) { + smpl->iface->init_ggml(smpl,buft); + } + } +} + static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ llama_sampler_chain_name, - /* .accept = */ llama_sampler_chain_accept, - /* .apply = */ llama_sampler_chain_apply, - /* .reset = */ llama_sampler_chain_reset, - /* .clone = */ llama_sampler_chain_clone, - /* .free = */ llama_sampler_chain_free, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, + /* .apply_ggml = */ llama_sampler_chain_apply_ggml, + /* .accept_ggml = */ llama_sampler_chain_accept_ggml, + /* .set_input_ggml = */ llama_sampler_chain_set_input_ggml, + /* .set_backend_context = */ llama_sampler_chain_set_backend_context, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -571,12 +687,16 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ llama_sampler_greedy_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_greedy() { @@ -696,12 +816,16 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ llama_sampler_dist_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_apply, - /* .reset = */ llama_sampler_dist_reset, - /* .clone = */ llama_sampler_dist_clone, - /* .free = */ llama_sampler_dist_free, + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -741,12 +865,16 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ llama_sampler_top_k_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_k_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_k_clone, - /* .free = */ llama_sampler_top_k_free, + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { @@ -836,12 +964,16 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ llama_sampler_top_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_p_clone, - /* .free = */ llama_sampler_top_p_free, + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { @@ -930,12 +1062,16 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ llama_sampler_min_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_min_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_min_p_clone, - /* .free = */ llama_sampler_min_p_free, + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { @@ -1029,12 +1165,16 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ llama_sampler_typical_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_typical_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_typical_clone, - /* .free = */ llama_sampler_typical_free, + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { @@ -1073,12 +1213,16 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ llama_sampler_temp_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_clone, - /* .free = */ llama_sampler_temp_free, + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp(float temp) { @@ -1183,12 +1327,16 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ llama_sampler_temp_ext_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_ext_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_ext_clone, - /* .free = */ llama_sampler_temp_ext_free, + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { @@ -1277,12 +1425,16 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_xtc_i = { - /* .name = */ llama_sampler_xtc_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ llama_sampler_xtc_reset, - /* .clone = */ llama_sampler_xtc_clone, - /* .free = */ llama_sampler_xtc_free, + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { @@ -1385,12 +1537,16 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ llama_sampler_mirostat_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_apply, - /* .reset = */ llama_sampler_mirostat_reset, - /* .clone = */ llama_sampler_mirostat_clone, - /* .free = */ llama_sampler_mirostat_free, + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { @@ -1484,12 +1640,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ llama_sampler_mirostat_v2_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_v2_apply, - /* .reset = */ llama_sampler_mirostat_v2_reset, - /* .clone = */ llama_sampler_mirostat_v2_clone, - /* .free = */ llama_sampler_mirostat_v2_free, + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -1601,12 +1761,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ llama_sampler_grammar_name, - /* .accept = */ llama_sampler_grammar_accept_impl, - /* .apply = */ llama_sampler_grammar_apply, - /* .reset = */ llama_sampler_grammar_reset, - /* .clone = */ llama_sampler_grammar_clone, - /* .free = */ llama_sampler_grammar_free, + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -1808,12 +1972,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ llama_sampler_penalties_name, - /* .accept = */ llama_sampler_penalties_accept, - /* .apply = */ llama_sampler_penalties_apply, - /* .reset = */ llama_sampler_penalties_reset, - /* .clone = */ llama_sampler_penalties_clone, - /* .free = */ llama_sampler_penalties_free, + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -1899,12 +2067,16 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_n_sigma_i = { - /* .name = */ llama_sampler_top_n_sigma_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_n_sigma_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_n_sigma_clone, - /* .free = */ llama_sampler_top_n_sigma_free, + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { @@ -2229,12 +2401,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dry_i = { - /* .name = */ llama_sampler_dry_name, - /* .accept = */ llama_sampler_dry_accept, - /* .apply = */ llama_sampler_dry_apply, - /* .reset = */ llama_sampler_dry_reset, - /* .clone = */ llama_sampler_dry_clone, - /* .free = */ llama_sampler_dry_free, + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -2370,12 +2546,16 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ llama_sampler_logit_bias_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_logit_bias_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_logit_bias_clone, - /* .free = */ llama_sampler_logit_bias_free, + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_logit_bias( @@ -2600,12 +2780,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_infill_i = { - /* .name = */ llama_sampler_infill_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_infill_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_infill_clone, - /* .free = */ llama_sampler_infill_free, + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f4ce..0db8b4bd88845 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -206,6 +206,18 @@ llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-backend-sampler.cpp LABEL "model") +target_include_directories(test-backend-sampler PRIVATE ${PROJECT_SOURCE_DIR}/src) +llama_test(test-backend-sampler NAME test-backend-sampler-greedy ARGS --test greedy) +llama_test(test-backend-sampler NAME test-backend-sampler-temp ARGS --test temp) +llama_test(test-backend-sampler NAME test-backend-sampler-top_k ARGS --test top_k) +llama_test(test-backend-sampler NAME test-backend-sampler-dist ARGS --test dist) +llama_test(test-backend-sampler NAME test-backend-sampler-dist-and-cpu ARGS --test dist_and_cpu) +llama_test(test-backend-sampler NAME test-backend-sampler-logit-bias ARGS --test logit_bias) +llama_test(test-backend-sampler NAME test-backend-sampler-mul_seq ARGS --test multi_sequence) +llama_test(test-backend-sampler NAME test-backend-sampler-set-sampler ARGS --test set_sampler) + + if (NOT GGML_BACKEND_DL) # these tests use the backends directly and cannot be built with dynamic loading llama_build_and_test(test-barrier.cpp) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2bb4b12224798..8f96250ddc570 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7906,6 +7906,8 @@ static std::vector> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1})); return test_cases; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp new file mode 100644 index 0000000000000..2ed13688c98ac --- /dev/null +++ b/tests/test-backend-sampler.cpp @@ -0,0 +1,811 @@ +#include "ggml.h" +#include "llama.h" +#include "get-model.h" +#include "common.h" + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include +#include +#include +#include +#include + +struct test_model_context { + llama_model * model = nullptr; + llama_context * ctx = nullptr; + const llama_vocab * vocab = nullptr; + int n_vocab = 0; + std::unordered_map seq_positions; + std::unordered_map last_batch_info; + + bool setup_model(const char * model_path) { + if (model != nullptr) { + return true; + } + + llama_backend_init(); + + llama_model_params mparams = llama_model_default_params(); + model = llama_model_load_from_file(model_path, mparams); + if (model == nullptr) { + fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); + cleanup(); + return false; + } + vocab = llama_model_get_vocab(model); + + return true; + } + + bool setup(const char * model_path, std::vector & configs) { + if (model == nullptr) { + setup_model(model_path); + } + + if (model != nullptr && ctx != nullptr) { + return true; + } + + llama_context_params cparams = llama_context_default_params(); + cparams.n_ctx = 512; + cparams.n_batch = 512; + cparams.samplers = configs.data(); + cparams.n_samplers = configs.size(); + + int32_t max_seq_id = 0; + for (const auto & config : configs) { + if (config.seq_id > max_seq_id) { + max_seq_id = config.seq_id; + } + } + cparams.n_seq_max = max_seq_id + 1; + + ctx = llama_init_from_model(model, cparams); + if (ctx == nullptr) { + fprintf(stderr, "Warning: failed to create context, skipping test\n"); + cleanup(); + return false; + } + llama_set_warmup(ctx, false); + + vocab = llama_model_get_vocab(model); + n_vocab = llama_vocab_n_tokens(vocab); + fprintf(stderr, "Vocabulary size: %d\n", n_vocab); + + return true; + } + + bool decode(const std::map & prompts) { + if (ctx == nullptr || vocab == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + last_batch_info.clear(); + llama_batch batch = llama_batch_init(512, 0, prompts.size()); + + int n_tokens_per_prompt = 0; + + for (const auto & [seq_id, prompt] : prompts) { + std::vector tokens; + tokens.push_back(llama_vocab_bos(vocab)); + + std::vector prompt_tokens(32); + int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + //TODO: refactor this function to just handle a single prompt at a time + // to avoid this check and complexity. + if (n_tokens_per_prompt == 0) { + n_tokens_per_prompt = n_tokens; + } else { + if (n_tokens != n_tokens_per_prompt) { + fprintf(stderr, "Error: prompts must have the same number of tokens\n"); + llama_batch_free(batch); + return false; + } + n_tokens_per_prompt = n_tokens; + } + if (n_tokens < 0) { + fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id); + llama_batch_free(batch); + return false; + } + + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); + } + + seq_positions[seq_id] = tokens.size(); + } + + + printf("Batch contents:\n"); + printf(" n_tokens: %d\n", batch.n_tokens); + for (int i = 0; i < batch.n_tokens; i++) { + printf(" token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]); + + for (int j = 0; j < batch.n_seq_id[i]; j++) { + printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : ""); + } + printf("], logits=%d\n", batch.logits[i]); +} + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed\n"); + llama_batch_free(batch); + return false; + } + + // Build mapping from seq id to batch token idx + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id seq_id = batch.seq_id[i][0]; + last_batch_info[seq_id] = i; + printf("seq %d : batch idx %d\n", seq_id, i); + } + } + + llama_batch_free(batch); + return true; + } + + int32_t idx_for_seq(llama_seq_id seq_id) { + auto it = last_batch_info.find(seq_id); + if (it == last_batch_info.end()) { + fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id); + return -1; + } + return it->second; + } + + bool decode_token(llama_token token, llama_seq_id seq_id = 0) { + if (ctx == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + llama_batch batch = llama_batch_init(1, 0, 1); + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id); + llama_batch_free(batch); + return false; + } + + last_batch_info.clear(); + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id cur_seq = batch.seq_id[i][0]; + last_batch_info[cur_seq] = i; + } + } + + seq_positions[seq_id]++; + llama_batch_free(batch); + return true; + } + + bool decode_tokens(const std::map & seq_tokens) { + if (ctx == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size()); + + for (const auto & [seq_id, token] : seq_tokens) { + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + } + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for batch tokens\n"); + llama_batch_free(batch); + return false; + } + + for (const auto & [seq_id, _] : seq_tokens) { + seq_positions[seq_id]++; + } + + last_batch_info.clear(); + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id cur_seq = batch.seq_id[i][0]; + last_batch_info[cur_seq] = i; + } + } + + llama_batch_free(batch); + return true; + } + + std::string token_to_piece(llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); + } + + return piece; + } + + void cleanup() { + if (ctx) llama_free(ctx); + if (model) llama_model_free(model); + llama_backend_free(); + ctx = nullptr; + model = nullptr; + vocab = nullptr; + } + + ~test_model_context() { + cleanup(); + } +}; + +static void test_backend_greedy_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + + struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params); + + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_greedy()); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Some"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, loop_idx); + printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + test_ctx.decode_token(token, 0); + } +} + +static void test_backend_top_k_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t k = 8; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_k(k)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx); + uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + for (size_t i = 0; i < n_logits; ++i) { + printf("top_k logit[%zu] = %.6f\n", i, logits[i]); + } + + llama_token * candidates = llama_get_backend_sampled_candidates_ith(test_ctx.ctx, batch_idx); + uint32_t n_candidates = llama_get_backend_sampled_candidates_count_ith(test_ctx.ctx, batch_idx); + for (size_t i = 0; i < n_candidates; ++i) { + printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i], + test_ctx.token_to_piece(candidates[i], false).c_str()); + } + + // Sample using CPU sampler for verification that it is possible to do hybrid + // sampling, first top_k on the backend and then dist on the CPU. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + GGML_ASSERT(chain->iface->apply_ggml != nullptr); + + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + printf("backend top-k hybrid sampling test PASSED\n"); + + llama_sampler_free(chain); +} + +static void test_backend_temp_sampling(const char * model_path) { + test_model_context test_ctx; + + const float temp_0 = 0.8f; + struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0); + llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_backend_init_temp(temp_0)); + + const float temp_1 = 0.1f; + struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1); + llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_backend_init_temp(temp_1)); + + std::vector backend_sampler_configs = { + { 0, backend_sampler_chain_0 }, + { 1, backend_sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) { + return; + } + + int32_t batch_idx_0 = test_ctx.idx_for_seq(0); + int32_t batch_idx_1 = test_ctx.idx_for_seq(1); + + int n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx_0); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + + // Sample from sequence 0 using CPU sampler + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(chain_0, llama_sampler_init_dist(18)); + + llama_token token_0 = llama_sampler_sample(chain_0, test_ctx.ctx, batch_idx_0); + const std::string token_0_str = test_ctx.token_to_piece(token_0, false); + printf("Sequence 0 sampled token id:%d, string: '%s'\n", token_0, token_0_str.c_str()); + GGML_ASSERT(token_0 >= 0 && token_0 < test_ctx.n_vocab); + + // Sample from sequence 1 using CPU sampler + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * chain_1 = llama_sampler_chain_init(chain_params_1); + llama_sampler_chain_add(chain_1, llama_sampler_init_dist(18)); + + llama_token token_1 = llama_sampler_sample(chain_1, test_ctx.ctx, batch_idx_1); + const std::string token_1_str = test_ctx.token_to_piece(token_1, false); + printf("Sequence 1 sampled token id:%d, string: '%s'\n", token_1, token_1_str.c_str()); + GGML_ASSERT(token_1 >= 0 && token_1 < test_ctx.n_vocab); + + printf("backend temp sampling test PASSED\n"); + + llama_sampler_free(chain_0); + llama_sampler_free(chain_1); +} + +static void test_backend_multi_sequence_sampling(const char * model_path) { + test_model_context test_ctx; + + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_greedy()); + + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_temp(0.8f)); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_greedy()); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0 }, + { 1, sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + std::map prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + return; + } + + int32_t batch_idx_0 = test_ctx.idx_for_seq(0); + llama_token seq0_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_0); + const std::string seq0_token_str = test_ctx.token_to_piece(seq0_token, false); + printf("Seq 0 sampled token id=%d, string='%s'\n", seq0_token, seq0_token_str.c_str()); + GGML_ASSERT(seq0_token >= 0 && seq0_token < test_ctx.n_vocab); + + int32_t batch_idx_1 = test_ctx.idx_for_seq(1); + llama_token seq1_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_1); + const std::string seq1_token_str = test_ctx.token_to_piece(seq1_token, false); + printf("Seq 1 sampled token id=%d, string='%s'\n", seq1_token, seq1_token_str.c_str()); + GGML_ASSERT(seq1_token >= 0 && seq1_token < test_ctx.n_vocab); + + // Generate tokens for each sequence + printf("\nMulti-sequence generation:\n"); + for (int step = 0; step < 4; step++) { + std::map tokens; + + for (llama_seq_id seq_id : {0, 1}) { + int32_t idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str()); + tokens[seq_id] = token; + } + + // Decode all tokens in a single batch + if (!test_ctx.decode_tokens(tokens)) { + break; + } + } + + printf("backend multi-sequence sampling test PASSED\n"); +} + +static void test_backend_dist_sampling(const char * model_path) { + test_model_context test_ctx; + + const int32_t seed = 88; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + std::vector backend_sampler_configs = {{ 0, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{0, "Hello"}})) { + return; + } + + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(0)); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); +} + +static void test_backend_dist_sampling_and_cpu(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t seed = 88; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + GGML_ASSERT(backend_token == cpu_token); +} + +static void test_backend_logit_bias_sampling(const char * model_path) { + test_model_context test_ctx; + + // Calling setup_model to ensure vocab is loaded and can be accessed + if (!test_ctx.setup_model(model_path)) { + return; + } + + const int seq_id = 0; + + // Create the logit biases vector. + std::vector logit_bias; + + // Get the token for the piece "World". + const std::string piece = "World"; + std::vector tokens(16); + llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); + llama_token bias_token = tokens[0]; + logit_bias.push_back({ bias_token, +100.0f }); + printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); + + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_logit_bias( + llama_vocab_n_tokens(test_ctx.vocab), + logit_bias.size(), + logit_bias.data())); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(88)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + GGML_ASSERT(backend_token == bias_token); +} + +static void test_backend_set_sampler(const char * model_path) { + test_model_context test_ctx; + + const int32_t seed = 88; + const int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using backend sampler configured above + llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + + // Now clear the backend sampler for this sequence. + llama_set_backend_sampler(test_ctx.ctx, seq_id, nullptr); + printf("Cleared backend sampler for seq_id %d\n", seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + std::map tokens = { { seq_id, backend_token}, }; + if (!test_ctx.decode_tokens(tokens)) { + return; + } + + // Should not have any sampled token or probs after clearing the backend sampler. + const int32_t idx = test_ctx.idx_for_seq(seq_id); + GGML_ASSERT(llama_get_backend_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL); + GGML_ASSERT(llama_get_backend_sampled_probs_ith(test_ctx.ctx, idx) == nullptr); + + // Sample the token using the CPU sampler chain. + llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id); + const std::string token2_str = test_ctx.token_to_piece(token2, false); + printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str()); + std::map tokens2 = { { seq_id, token2}, }; + + // Set a new backend sampler for the sequence. + struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params); + llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_top_k(20)); + llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + llama_set_backend_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain); + + if (!test_ctx.decode_tokens(tokens2)) { + return; + } + + llama_token new_backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); + const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); + printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); +} + +static void test_backend_max_outputs(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t seed = 88; + llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + llama_batch batch = llama_batch_init(512, 0, 1); + std::string prompt = "Hello"; + + std::vector tokens; + tokens.push_back(llama_vocab_bos(test_ctx.vocab)); + + std::vector prompt_tokens(32); + int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + for (size_t i = 0; i < tokens.size(); i++) { + // set all tokens as output to trigger error + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } + + printf(">>> test_max_outputs expected error start:\n"); + const int ret = llama_decode(test_ctx.ctx, batch); + GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence"); + printf("<<< test_max_outputs expected error end.\n"); + llama_batch_free(batch); +} + +struct backend_test_case { + const char * name; + void (*fn)(const char *); + bool enabled_by_default; +}; + +static const backend_test_case BACKEND_TESTS[] = { + { "greedy", test_backend_greedy_sampling, true }, + { "logit_bias", test_backend_logit_bias_sampling, true }, + { "temp", test_backend_temp_sampling, true }, + { "top_k", test_backend_top_k_sampling, true }, + { "multi_sequence", test_backend_multi_sequence_sampling, true }, + { "dist", test_backend_dist_sampling, true }, + { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, + { "set_sampler", test_backend_set_sampler, true }, + { "max_outputs", test_backend_max_outputs, true }, +}; + +struct backend_cli_args { + const char * model = nullptr; + const char * test = nullptr; +}; + +static backend_cli_args parse_backend_cli(int argc, char ** argv) { + backend_cli_args out; + + for (int i = 1; i < argc; ++i) { + const char * arg = argv[i]; + + if (std::strcmp(arg, "--test") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--test expects a value\n"); + exit(EXIT_FAILURE); + } + out.test = argv[++i]; + continue; + } + if (std::strncmp(arg, "--test=", 7) == 0) { + out.test = arg + 7; + continue; + } + if (std::strcmp(arg, "--model") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--model expects a value\n"); + exit(EXIT_FAILURE); + } + out.model = argv[++i]; + continue; + } + if (std::strncmp(arg, "--model=", 8) == 0) { + out.model = arg + 8; + continue; + } + if (!out.model) { + out.model = arg; + continue; + } + + fprintf(stderr, "Unexpected argument: %s\n", arg); + exit(EXIT_FAILURE); + } + + return out; +} + +static std::vector collect_tests_to_run(const char * requested) { + std::vector selected; + + if (requested != nullptr) { + for (const auto & test : BACKEND_TESTS) { + if (std::strcmp(test.name, requested) == 0) { + selected.push_back(&test); + break; + } + } + if (selected.empty()) { + fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested); + for (const auto & test : BACKEND_TESTS) { + fprintf(stderr, " %s\n", test.name); + } + exit(EXIT_FAILURE); + } + } else { + for (const auto & test : BACKEND_TESTS) { + if (test.enabled_by_default) { + selected.push_back(&test); + } + } + } + + if (selected.empty()) { + fprintf(stderr, "No backend sampling tests selected. Use --test= to pick one.\n"); + } + + return selected; +} + +static void run_tests(const std::vector & tests, const char * model_path) { + for (const auto * test : tests) { + fprintf(stderr, "\n=== %s ===\n", test->name); + test->fn(model_path); + } +} + + +int main(int argc, char *argv[] ) { + const backend_cli_args args = parse_backend_cli(argc, argv); + + std::array model_argv { argv[0], const_cast(args.model) }; + const int model_argc = args.model ? 2 : 1; + char * model_path = get_model_or_exit(model_argc, model_argv.data()); + + auto * file = fopen(model_path, "r"); + if (file == nullptr) { + fprintf(stderr, "no model at '%s' found\n", model_path); + return EXIT_FAILURE; + } + + fprintf(stderr, "using '%s'\n", model_path); + fclose(file); + + ggml_time_init(); + + const std::vector tests = collect_tests_to_run(args.test); + if (!tests.empty()) { + run_tests(tests, model_path); + } + + return 0; +} diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 78b42267b59c3..64855b646fe76 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -137,16 +137,30 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); - common_init_result llama_init = common_init_from_params(params); - - model = llama_init.model.get(); - ctx = llama_init.context.get(); + model = common_load_model_from_params(params); if (model == NULL) { LOG_ERR("%s: error: unable to load model\n", __func__); return 1; } + // Configure backend sampler if configured + llama_sampler * backend_sampler = common_sampler_backend_init(model, sparams); + if (backend_sampler) { + llama_sampler_seq_config sampler_config = { 0, backend_sampler }; + params.backend_samplers = &sampler_config; + params.n_backend_samplers = 1; + } + + common_init_result llama_init = common_init_context_from_model(model, params); + ctx = llama_init.context.get(); + model = llama_init.model.get(); // Update pointer (now managed by llama_init) + + if (ctx == NULL) { + LOG_ERR("%s: error: unable to create context\n", __func__); + return 1; + } + llama_memory_t mem = llama_get_memory(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 097c9440be2d9..47bf04750aa50 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3750c8fdb6065..18e30a9589540 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -200,6 +200,8 @@ struct slot_params { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, + {"backend_dist", sampling.backend_dist}, {"lora", lora}, }; } @@ -258,6 +260,8 @@ struct slot_params { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, + {"backend_dist", sampling.backend_dist}, {"lora", lora}, }; } @@ -360,6 +364,11 @@ struct server_task { params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); + const bool request_backend_dist = json_value(data, "backend_dist", defaults.sampling.backend_dist); + params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling; + params.sampling.backend_dist = params.sampling.backend_sampling && request_backend_dist; + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); @@ -1705,6 +1714,7 @@ struct server_slot { json json_schema; struct common_sampler * smpl = nullptr; + llama_sampler * backend_sampler = nullptr; llama_token sampled; @@ -1750,6 +1760,13 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + if (backend_sampler != nullptr) { + if (ctx != nullptr) { + llama_set_backend_sampler(ctx, id, nullptr); + } + backend_sampler = nullptr; + } + task.reset(); task_prev.reset(); @@ -2371,6 +2388,13 @@ struct server_context { common_sampler_free(slot.smpl); slot.smpl = nullptr; + if (slot.backend_sampler != nullptr) { + if (ctx != nullptr) { + llama_set_backend_sampler(ctx, slot.id, nullptr); + } + slot.backend_sampler = nullptr; + } + llama_free(slot.ctx_dft); slot.ctx_dft = nullptr; @@ -2850,6 +2874,11 @@ struct server_context { SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); } + if (!configure_slot_backend_sampler(slot, task.params.sampling)) { + send_error(task, "Failed to configure backend samplers", ERROR_TYPE_SERVER); + return false; + } + // initialize draft batch // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] if (slot.ctx_dft) { @@ -2867,6 +2896,31 @@ struct server_context { return true; } + bool configure_slot_backend_sampler(server_slot & slot, const common_params_sampling & sampling) { + if (!sampling.backend_sampling) { + if (slot.backend_sampler != nullptr) { + llama_set_backend_sampler(ctx, slot.id, nullptr); + slot.backend_sampler = nullptr; + } + return true; + } + + llama_sampler * backend_chain = common_sampler_backend_init(model, sampling); + if (backend_chain == nullptr) { + SLT_ERR(slot, "%s", "failed to initialize backend sampler\n"); + return false; + } + + if (slot.backend_sampler != nullptr) { + llama_set_backend_sampler(ctx, slot.id, nullptr); + slot.backend_sampler = nullptr; + } + + slot.backend_sampler = backend_chain; + llama_set_backend_sampler(ctx, slot.id, backend_chain); + return true; + } + bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = result.text_to_send; diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index d00ae128538b4..eaed6cd75d6cc 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -159,6 +159,16 @@ key: 'samplers', label: 'Samplers', type: 'input' + }, + { + key: 'backend_sampling', + label: 'Backend sampling', + type: 'checkbox' + }, + { + key: 'backend_dist', + label: 'Backend dist sampling', + type: 'checkbox' } ] }, @@ -283,6 +293,10 @@ function handleConfigChange(key: string, value: string | boolean) { localConfig[key] = value; + + if (key === 'backend_sampling' && value === false) { + localConfig.backend_dist = false; + } } function handleReset() { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index 8834e3e3e1cc1..1bafaf137a4b0 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -210,7 +210,9 @@

{/if} {:else if field.type === 'checkbox'} - {@const isDisabled = field.key === 'pdfAsImage' && !supportsVision()} + {@const pdfDisabled = field.key === 'pdfAsImage' && !supportsVision()} + {@const backendDistDisabled = field.key === 'backend_dist' && !localConfig.backend_sampling} + {@const isDisabled = pdfDisabled || backendDistDisabled}
{field.help || SETTING_CONFIG_INFO[field.key]}

- {:else if field.key === 'pdfAsImage' && !supportsVision()} + {:else if pdfDisabled}

PDF-to-image processing requires a vision-capable model. PDFs will be processed as text.

+ {:else if backendDistDisabled} +

+ Enable GPU sampling to allow GPU dist sampling. +

{/if}
diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index c25ea23f37be3..672b8e9847b45 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -18,6 +18,8 @@ export const SETTING_CONFIG_DEFAULT: Record = modelSelectorEnabled: false, // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', + backend_sampling: false, + backend_dist: false, temperature: 0.8, dynatemp_range: 0.0, dynatemp_exponent: 1.0, @@ -51,6 +53,10 @@ export const SETTING_CONFIG_INFO: Record = { 'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.', samplers: 'The order at which samplers are applied, in simplified way. Default is "top_k;typ_p;top_p;min_p;temperature": top_k->typ_p->top_p->min_p->temperature', + backend_sampling: + 'Enable backend-based samplers. When enabled, supported samplers run on the accelerator backend for faster sampling.', + backend_dist: + 'Perform the final distribution sampling step on the backend. Requires backend sampling to be enabled.', temperature: 'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.', dynatemp_range: diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index aa83910b27f53..5cda8c886865f 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -98,6 +98,8 @@ export class ChatService { dry_penalty_last_n, // Other parameters samplers, + backend_sampling, + backend_dist, custom, timings_per_token } = options; @@ -182,6 +184,9 @@ export class ChatService { : samplers; } + if (backend_sampling !== undefined) requestBody.backend_sampling = backend_sampling; + if (backend_dist !== undefined) requestBody.backend_dist = backend_dist; + if (timings_per_token !== undefined) requestBody.timings_per_token = timings_per_token; if (custom) { diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index c70b9580cb75b..e00994bc469be 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -298,6 +298,12 @@ class ChatStore { if (currentConfig.samplers) { apiOptions.samplers = currentConfig.samplers; } + if (currentConfig.backend_sampling !== undefined) { + apiOptions.backend_sampling = Boolean(currentConfig.backend_sampling); + } + if (currentConfig.backend_dist !== undefined) { + apiOptions.backend_dist = Boolean(currentConfig.backend_dist); + } if (currentConfig.custom) { apiOptions.custom = currentConfig.custom; } diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index 1a8bc64989957..149d4fb118f54 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -181,6 +181,8 @@ export interface ApiChatCompletionRequest { dry_penalty_last_n?: number; // Sampler configuration samplers?: string[]; + backend_sampling?: boolean; + backend_dist?: boolean; // Custom parameters (JSON string) custom?: Record; timings_per_token?: boolean; diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index b47842b66e619..e68d107faa3bf 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -37,6 +37,8 @@ export interface SettingsChatServiceOptions { dry_penalty_last_n?: number; // Sampler configuration samplers?: string | string[]; + backend_sampling?: boolean; + backend_dist?: boolean; // Custom parameters custom?: string; timings_per_token?: boolean;