diff --git a/common/arg.cpp b/common/arg.cpp
index 430ab45dfe26e..ab3386b1dfa67 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1501,6 +1501,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
}
).set_sparam());
+ add_opt(common_arg(
+ {"--backend-sampling"},
+ "enable backend sampling (default: disabled)",
+ [](common_params & params) {
+ params.sampling.backend_sampling = true;
+ }
+ ).set_sparam());
+ add_opt(common_arg(
+ {"--backend-dist"},
+ "perform final (distribution) sampling on backend (default: disabled)",
+ [](common_params & params) {
+ params.sampling.backend_dist = true;
+ params.sampling.backend_sampling = true;
+ }
+ ).set_sparam());
add_opt(common_arg(
{"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified",
diff --git a/common/build-info.cpp b/common/build-info.cpp
new file mode 100644
index 0000000000000..6e8240fbb16a1
--- /dev/null
+++ b/common/build-info.cpp
@@ -0,0 +1,4 @@
+int LLAMA_BUILD_NUMBER = 5590;
+char const *LLAMA_COMMIT = "0d398442";
+char const *LLAMA_COMPILER = "cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0";
+char const *LLAMA_BUILD_TARGET = "x86_64-linux-gnu";
diff --git a/common/common.cpp b/common/common.cpp
index f3cc55247e718..6a6f5fec3d540 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -8,6 +8,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
+#include "sampling.h"
#include
#include
@@ -949,20 +950,34 @@ std::vector fs_list_files(const std::string & path) {
// Model utils
//
-struct common_init_result common_init_from_params(common_params & params) {
- common_init_result iparams;
+llama_model * common_load_model_from_params(common_params & params) {
auto mparams = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
+ return nullptr;
+ }
+
+ return model;
+}
+
+struct common_init_result common_init_context_from_model(
+ llama_model * model,
+ common_params & params) {
+ common_init_result iparams;
+
+ if (model == NULL) {
+ LOG_ERR("%s: model is NULL\n", __func__);
return iparams;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
auto cparams = common_context_params_to_llama(params);
+ cparams.samplers = params.backend_samplers;
+ cparams.n_samplers = params.n_backend_samplers;
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
@@ -1129,6 +1144,14 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
+struct common_init_result common_init_from_params(common_params & params) {
+ llama_model * model = common_load_model_from_params(params);
+ if (model == NULL) {
+ return common_init_result();
+ }
+ return common_init_context_from_model(model, params);
+}
+
std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
diff --git a/common/common.h b/common/common.h
index de5b404dd8895..01e6dfe59b49d 100644
--- a/common/common.h
+++ b/common/common.h
@@ -195,6 +195,10 @@ struct common_params_sampling {
std::vector logit_bias; // logit biases to apply
std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens
+ // Backend sampling flags
+ bool backend_sampling = false; // enable backend sampling
+ bool backend_dist = false; // backend performs final sampling (dist)
+
// print the parameters into a string
std::string print() const;
};
@@ -519,6 +523,9 @@ struct common_params {
bool has_speculative() const {
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
}
+
+ llama_sampler_seq_config * backend_samplers = NULL;
+ size_t n_backend_samplers = 0;
};
// call once at the start of a program if it uses libcommon
@@ -640,6 +647,14 @@ struct common_init_result {
struct common_init_result common_init_from_params(common_params & params);
+// Load model only (allows creating backend samplers before context initialization)
+llama_model * common_load_model_from_params(common_params & params);
+
+// Initialize context from an already-loaded model (allows pre-configuring backend samplers)
+struct common_init_result common_init_context_from_model(
+ llama_model * model,
+ common_params & params);
+
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
diff --git a/common/llguidance.cpp b/common/llguidance.cpp
index adce620e4d62f..27d15516e9438 100644
--- a/common/llguidance.cpp
+++ b/common/llguidance.cpp
@@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
}
static llama_sampler_i llama_sampler_llg_i = {
- /* .name = */ llama_sampler_llg_name,
- /* .accept = */ llama_sampler_llg_accept_impl,
- /* .apply = */ llama_sampler_llg_apply,
- /* .reset = */ llama_sampler_llg_reset,
- /* .clone = */ llama_sampler_llg_clone,
- /* .free = */ llama_sampler_llg_free,
+ /* .name = */ llama_sampler_llg_name,
+ /* .accept = */ llama_sampler_llg_accept_impl,
+ /* .apply = */ llama_sampler_llg_apply,
+ /* .reset = */ llama_sampler_llg_reset,
+ /* .clone = */ llama_sampler_llg_clone,
+ /* .free = */ llama_sampler_llg_free,
+ /* .apply_ggml = */ NULL,
+ /* .accept_ggml = */ NULL,
+ /* .set_input_ggml = */ NULL,
+ /* .set_backend_context = */ NULL,
};
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 7a6b7be1e0ee6..0a2be6bf7df56 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -121,17 +121,34 @@ struct common_sampler {
}
void set_logits(struct llama_context * ctx, int idx) {
- const auto * logits = llama_get_logits_ith(ctx, idx);
+ const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx);
+ const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx);
+ const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx);
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
- cur.resize(n_vocab);
-
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+ if (sampled_probs) {
+ const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
+ cur.resize(sampled_probs_count);
+ for (uint32_t i = 0; i < sampled_probs_count; ++i) {
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
+ }
+ } else if (sampled_logits) {
+ const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
+ cur.resize(sampled_logits_count);
+ for (uint32_t i = 0; i < sampled_logits_count; i++) {
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
+ }
+ } else {
+ const auto * logits = llama_get_logits_ith(ctx, idx);
+ GGML_ASSERT(logits != nullptr);
+ cur.resize(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+ }
}
cur_p = { cur.data(), cur.size(), -1, false };
@@ -144,6 +161,10 @@ struct common_sampler {
mutable int64_t t_total_us = 0;
};
+static bool sampler_enabled(const struct common_params_sampling & params, enum common_sampler_type type) {
+ return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end();
+}
+
std::string common_params_sampling::print() const {
char result[1024];
@@ -301,6 +322,43 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
return result;
}
+struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
+ if (!params.backend_sampling) {
+ return nullptr;
+ }
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+
+ llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
+ chain_params.no_perf = params.no_perf;
+
+ struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
+
+ const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE);
+ const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K);
+ const bool enable_dist = params.backend_dist;
+
+ if (!params.logit_bias.empty()) {
+ llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias(
+ llama_vocab_n_tokens(vocab),
+ params.logit_bias.size(),
+ params.logit_bias.data()));
+ }
+
+ if (enable_temp) {
+ llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
+ }
+
+ if (enable_top_k) {
+ llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
+ }
+
+ if (enable_dist) {
+ llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
+ }
+
+ return chain;
+}
+
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
@@ -384,6 +442,15 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
+ // Check if a backend sampler has already sampled a token in which case we
+ // return that token id directly.
+ {
+ const llama_token id = llama_get_backend_sampled_token_ith(ctx, idx);
+ if (id != LLAMA_TOKEN_NULL) {
+ LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
+ return id;
+ }
+ }
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
diff --git a/common/sampling.h b/common/sampling.h
index e198eecda3810..0ec164de05343 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -38,6 +38,13 @@ struct common_sampler;
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
+// Create a backend sampler chain from common sampling parameters
+// Returns a llama_sampler chain configured with backend samplers based on the parameters
+// This chain can be used per-sequence for backend-based sampling
+// Note: Only samplers that have backend equivalents will be added to the chain
+// The returned sampler should be freed with llama_sampler_free()
+struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params);
+
void common_sampler_free(struct common_sampler * gsmpl);
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
index 3722cf3ab26ee..b8003c48c51fc 100644
--- a/ggml/src/ggml-cuda/argsort.cu
+++ b/ggml/src/ggml-cuda/argsort.cu
@@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
size_t temp_storage_bytes = 0;
if (order == GGML_SORT_ORDER_ASC) {
- DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
- temp_indices, dst, // values (indices)
- ncols * nrows, nrows, // num items, num segments
- d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
- stream);
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols * nrows, nrows, // num items, num segments
+ d_offsets, d_offsets + 1, stream);
+ }
} else {
- DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
- dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
- sizeof(float) * 8, stream);
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
+ dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+ }
}
ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
if (order == GGML_SORT_ORDER_ASC) {
- DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
- ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
- stream);
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
+ ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+ }
} else {
- DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
- temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
- 0, sizeof(float) * 8, stream);
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
+ temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
+ stream);
+ }
}
}
#endif // GGML_CUDA_USE_CUB
diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu
new file mode 100644
index 0000000000000..041dc7cdb5194
--- /dev/null
+++ b/ggml/src/ggml-cuda/cumsum.cu
@@ -0,0 +1,69 @@
+#include "cumsum.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+#include
+using namespace cub;
+#endif // GGML_CUDA_USE_CUB
+
+#include
+
+__global__ void cumsum_f32_kernel(const float * x, float * dst, int64_t n) {
+ // TODO: this is a naive implementation just for getting something working.
+ if (threadIdx.x == 0 && blockIdx.x == 0) {
+ dst[0] = x[0];
+ for (int64_t i = 1; i < n; i++) {
+ dst[i] = dst[i-1] + x[i];
+ }
+ }
+}
+
+void cumsum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
+#ifdef GGML_CUDA_USE_CUB
+ size_t tmp_size = 0;
+
+ // Query how much temp storage CUDA UnBound (CUB) needs
+ cub::DeviceScan::InclusiveSum(
+ nullptr, // d_temp_storage (null = just query size)
+ tmp_size, // reference to size (will be set by CUB)
+ x, // input pointer
+ dst, // output pointer
+ ne, // number of elements
+ stream // CUDA stream to use
+ );
+
+ ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size);
+
+ // Perform the inclusive scan
+ cub::DeviceScan::InclusiveSum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
+
+#else
+ GGML_UNUSED(pool);
+ cumsum_f32_kernel<<<1, 1, 0, stream>>>(x, dst, ne);
+#endif // GGML_CUDA_USE_CUB
+}
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguously_allocated(src0));
+
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+
+ const int64_t ne0 = src0->ne[0]; // row length (cumsum computed along this dimension)
+ const int64_t ne1 = src0->ne[1];
+ const int64_t ne2 = src0->ne[2];
+ const int64_t ne3 = src0->ne[3];
+ const int64_t nrows = ne1 * ne2 * ne3; // total number of rows
+
+ ggml_cuda_pool & pool = ctx.pool();
+ cudaStream_t stream = ctx.stream();
+
+ for (int64_t i = 0; i < nrows; i++) {
+ const float * src_row = src0_d + i * ne0;
+ float * dst_row = dst_d + i * ne0;
+ cumsum_f32_cuda(pool, src_row, dst_row, ne0, stream);
+ }
+}
diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh
new file mode 100644
index 0000000000000..7fca7e1456437
--- /dev/null
+++ b/ggml/src/ggml-cuda/cumsum.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+void cumsum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 8fe0899bb5aac..a4ff461120293 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -19,6 +19,7 @@
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
+#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
@@ -2678,6 +2679,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM:
ggml_cuda_op_sum(ctx, dst);
break;
+ case GGML_OP_CUMSUM:
+ ggml_cuda_op_cumsum(ctx, dst);
+ break;
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
@@ -4223,6 +4227,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
+ case GGML_OP_CUMSUM:
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGSORT:
diff --git a/include/llama.h b/include/llama.h
index 8547226ff210c..9fbce771d74cf 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -210,6 +210,13 @@ extern "C" {
bool sorted; // note: do not assume the data is sorted - always check this flag
} llama_token_data_array;
+ struct llama_sampler_ggml_data {
+ struct ggml_tensor * logits;
+ struct ggml_tensor * probs;
+ struct ggml_tensor * sampled;
+ struct ggml_tensor * candidates;
+ };
+
typedef bool (*llama_progress_callback)(float progress, void * user_data);
// Input data for llama_encode/llama_decode
@@ -300,6 +307,11 @@ extern "C" {
bool no_host; // bypass host buffer allowing extra buffers to be used
};
+ struct llama_sampler_seq_config {
+ llama_seq_id seq_id;
+ struct llama_sampler * sampler;
+ };
+
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
// https://github.com/ggml-org/llama.cpp/pull/7544
struct llama_context_params {
@@ -348,6 +360,10 @@ extern "C" {
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
+
+ // backend sampler chain configuration
+ struct llama_sampler_seq_config * samplers;
+ size_t n_samplers;
};
// model quantization parameters
@@ -950,6 +966,32 @@ extern "C" {
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
+ // Get the backend sampled token for the ith token.
+ // Returns LLAMA_TOKEN_NULL if no token was sampled.
+ LLAMA_API llama_token llama_get_backend_sampled_token_ith(struct llama_context * ctx, int32_t i);
+
+ // Get the backend sampled probabilites for the ith token
+ // The index matches llama_get_backend_sampled_token_ith().
+ // Returns NULL if no probabilites were generated.
+ LLAMA_API float * llama_get_backend_sampled_probs_ith(struct llama_context * ctx, int32_t i);
+ //
+ // Get the number of backend sampled probabilites for the ith token.
+ LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
+
+ // Get the backend sampled logits for the ith token
+ // Returns NULL if no logits were sampled.
+ LLAMA_API float * llama_get_backend_sampled_logits_ith(struct llama_context * ctx, int32_t i);
+ //
+ // Get the number of backend sampled logits for the ith token.
+ LLAMA_API uint32_t llama_get_backend_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
+
+ // Get the backend sampled candidates (token ids) for the ith token
+ // Returns NULL if no candidates were sampled.
+ LLAMA_API llama_token * llama_get_backend_sampled_candidates_ith(struct llama_context * ctx, int32_t i);
+ //
+ // Get the number of backend sampled candidates for the ith token.
+ LLAMA_API uint32_t llama_get_backend_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i);
+
//
// Vocab
//
@@ -1135,6 +1177,22 @@ extern "C" {
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
+ void (*apply_ggml)( struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data);
+
+ void (*accept_ggml)( struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_tensor * selected_token);
+
+ void (*set_input_ggml)(struct llama_sampler * smpl);
+
+ void (*init_ggml)(struct llama_sampler * smpl,
+ ggml_backend_buffer_type_t buft);
+
+
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
};
@@ -1144,6 +1202,8 @@ extern "C" {
llama_sampler_context_t ctx;
};
+ LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
+
// mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
@@ -1153,6 +1213,18 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
+ LLAMA_API void llama_sampler_init_ggml(struct llama_sampler * smpl,
+ ggml_backend_buffer_type_t buft);
+ LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl);
+ LLAMA_API void llama_sampler_apply_ggml( struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data);
+
+ LLAMA_API void llama_sampler_accept_ggml( struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_tensor * selected_token);
// llama_sampler_chain
// a type of llama_sampler that can chain multiple samplers one after another
@@ -1299,9 +1371,29 @@ extern "C" {
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
+ //
+ // Backend samplers
+ //
+
+ /// @details Greedy sampling on backend - always selects the token with the highest probability
+ LLAMA_API struct llama_sampler * llama_sampler_backend_init_greedy(void);
+
+ /// @details Temperature scaling on backend - scales logits by 1/temperature
+ LLAMA_API struct llama_sampler * llama_sampler_backend_init_temp(float temp);
+
+ /// @details Top-K filtering on backend - keeps only the k tokens with highest probabilities
+ LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k);
+
+ /// @details Distribution sampling on backend - final sampling step that selects a token
+ LLAMA_API struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed);
+
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
+ LLAMA_API struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab,
+ int32_t n_logit_bias,
+ const llama_logit_bias * logit_bias);
+
/// @details Sample and accept a token from the idx-th output of the last evaluation
//
// Shorthand for:
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 8ec95ee176240..c17b89008948b 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -31,6 +31,7 @@ add_library(llama
llama-model.cpp
llama-quant.cpp
llama-sampling.cpp
+ llama-backend-sampler.cpp
llama-vocab.cpp
unicode-data.cpp
unicode.cpp
diff --git a/src/llama-backend-sampler.cpp b/src/llama-backend-sampler.cpp
new file mode 100644
index 0000000000000..cd6b8bb7526a6
--- /dev/null
+++ b/src/llama-backend-sampler.cpp
@@ -0,0 +1,489 @@
+#include "llama.h"
+#include "ggml.h"
+#include
+#include
+#include
+#include
+#include
+
+static void llama_sampler_backend_greedy_apply_ggml(
+ struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data) {
+ GGML_UNUSED(gf);
+ GGML_UNUSED(smpl);
+ struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits);
+ ggml_set_name(argmax_result, "argmax_result");
+ ggml_data->sampled = argmax_result;
+}
+
+static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) {
+ return "test-ggml";
+}
+
+static struct llama_sampler * llama_sampler_backend_greedy_clone(const struct llama_sampler * smpl) {
+ (void) smpl;
+ return llama_sampler_backend_init_greedy();
+}
+
+struct llama_sampler * llama_sampler_backend_init_greedy() {
+ static const llama_sampler_i iface = {
+ /*.name =*/ llama_sampler_backend_greedy_sampler_name,
+ /*.accept =*/ nullptr,
+ /*.apply =*/ nullptr,
+ /*.reset =*/ nullptr,
+ /*.clone =*/ llama_sampler_backend_greedy_clone,
+ /*.free =*/ nullptr,
+ /*.apply_ggml =*/ llama_sampler_backend_greedy_apply_ggml,
+ /*.accept_ggml =*/ nullptr,
+ /*.set_input_ggml =*/ nullptr,
+ /*.init_ggml =*/ nullptr,
+ };
+
+ auto * sampler = new llama_sampler {
+ /*.iface =*/ &iface,
+ /*.ctx =*/ nullptr,
+ };
+
+ return sampler;
+}
+
+struct llama_sampler_backend_temp_ctx {
+ float temp;
+};
+
+
+static void llama_sampler_backend_temp_apply_ggml(
+ struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data) {
+
+ auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx;
+
+ if (ctx_data->temp <= 0.0f) {
+ return;
+ }
+
+ struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp);
+ ggml_set_name(scaled, "temp_scaled");
+
+ // Make sure the scaled tensor is contiguous for subsequent operations
+ ggml_data->logits = ggml_cont(ctx, scaled);
+ ggml_set_name(ggml_data->logits, "temp_scaled_logits");
+
+ ggml_build_forward_expand(gf, ggml_data->logits);
+}
+
+static const char * llama_sampler_backend_temp_name(const struct llama_sampler *) {
+ return "backend-temp";
+}
+
+static void llama_sampler_backend_temp_free(struct llama_sampler * smpl) {
+ auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx;
+ delete ctx_data;
+}
+
+static struct llama_sampler * llama_sampler_backend_temp_clone(const struct llama_sampler * smpl) {
+ auto * ctx = (llama_sampler_backend_temp_ctx *) smpl->ctx;
+ return llama_sampler_backend_init_temp(ctx->temp);
+}
+
+struct llama_sampler * llama_sampler_backend_init_temp(float temp) {
+ static const llama_sampler_i iface = {
+ /*.name =*/ llama_sampler_backend_temp_name,
+ /*.accept =*/ nullptr,
+ /*.apply =*/ nullptr,
+ /*.reset =*/ nullptr,
+ /*.clone =*/ llama_sampler_backend_temp_clone,
+ /*.free =*/ llama_sampler_backend_temp_free,
+ /*.apply_ggml =*/ llama_sampler_backend_temp_apply_ggml,
+ /*.accept_ggml =*/ nullptr,
+ /*.set_input_ggml =*/ nullptr,
+ /*.set_backend_context =*/ nullptr,
+ };
+
+ auto * ctx_data = new llama_sampler_backend_temp_ctx {
+ /*.temp =*/ temp,
+ };
+
+ auto * sampler = new llama_sampler {
+ /*.iface =*/ &iface,
+ /*.ctx =*/ ctx_data,
+ };
+
+ return sampler;
+}
+
+
+struct llama_sampler_backend_top_k_ctx {
+ int32_t k;
+
+ // Only required for checking operation support and can be removed later.
+ ggml_backend_dev_t device;
+};
+
+static void llama_sampler_backend_top_k_init_ggml(
+ struct llama_sampler * smpl,
+ ggml_backend_buffer_type_t buft) {
+ auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
+ ctx_data->device = ggml_backend_buft_get_device(buft);
+}
+
+static void llama_sampler_backend_top_k_apply_ggml(
+ struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data) {
+
+ auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
+
+ struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k);
+ ggml_set_name(top_k, "top_k");
+
+ // top_k is a view of argsort - check if backend supports the underlying argsort operation
+ // by checking the source tensor (which is the argsort result)
+ if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) {
+ fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n");
+ fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
+ }
+
+ ggml_data->candidates = top_k;
+
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
+ struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
+ ggml_set_name(top_k_rows, "top_k_rows");
+
+ ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
+ ggml_build_forward_expand(gf, ggml_data->logits);
+}
+
+static const char * llama_sampler_backend_top_k_name(const struct llama_sampler *) {
+ return "backend-top-k";
+}
+
+static void llama_sampler_backend_top_k_free(struct llama_sampler * smpl) {
+ auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
+ delete ctx_data;
+}
+
+static struct llama_sampler * llama_sampler_backend_top_k_clone(const struct llama_sampler * smpl) {
+ auto * ctx = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
+ return llama_sampler_backend_init_top_k(ctx->k);
+}
+
+struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k) {
+ static const llama_sampler_i iface = {
+ /*.name =*/ llama_sampler_backend_top_k_name,
+ /*.accept =*/ nullptr,
+ /*.apply =*/ nullptr,
+ /*.reset =*/ nullptr,
+ /*.clone =*/ llama_sampler_backend_top_k_clone,
+ /*.free =*/ llama_sampler_backend_top_k_free,
+ /*.apply_ggml =*/ llama_sampler_backend_top_k_apply_ggml,
+ /*.accept_ggml =*/ nullptr,
+ /*.set_input_ggml =*/ nullptr,
+ /*.init_ggml =*/ llama_sampler_backend_top_k_init_ggml,
+ };
+
+ auto * ctx_data = new llama_sampler_backend_top_k_ctx {
+ /*.k =*/ k,
+ /*.device =*/ nullptr,
+ };
+
+ auto * sampler = new llama_sampler {
+ /*.iface =*/ &iface,
+ /*.ctx =*/ ctx_data,
+ };
+
+ return sampler;
+}
+
+
+static uint32_t get_rng_seed(uint32_t seed) {
+ if (seed == LLAMA_DEFAULT_SEED) {
+ // use system clock if std::random_device is not a true RNG
+ static bool is_rd_prng = std::random_device().entropy() == 0;
+ if (is_rd_prng) {
+ return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
+ }
+ std::random_device rd;
+ return rd();
+ }
+ return seed;
+}
+
+struct llama_sampler_backend_dist_ctx {
+ const uint32_t seed;
+ uint32_t seed_cur;
+ std::mt19937 rng;
+
+ struct ggml_tensor * uniform;
+ struct ggml_context * ctx;
+ ggml_backend_buffer_t buffer;
+
+ // Only required for checking operation support and can be removed later.
+ ggml_backend_dev_t device;
+};
+
+static void llama_sampler_backend_dist_init_ggml(
+ struct llama_sampler * smpl,
+ ggml_backend_buffer_type_t buft) {
+
+ auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
+ sctx->device = ggml_backend_buft_get_device(buft);
+ ggml_init_params params = {
+ /*.mem_size =*/ ggml_tensor_overhead(),
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
+ sctx->ctx = ggml_init(params);
+
+ // Create the uniform random scalar input tensor. This will be set by
+ // llama_sampler_backend_dist_set_input_ggml after this graph is built.
+ sctx->uniform = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, 1);
+ ggml_set_name(sctx->uniform, "uniform");
+ ggml_set_input(sctx->uniform);
+ ggml_set_output(sctx->uniform);
+
+ // Allocate all tensors from our context to the backend
+ sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft);
+}
+
+static void llama_sampler_backend_dist_set_input_ggml(struct llama_sampler * smpl) {
+ auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
+ GGML_ASSERT(sctx->uniform != nullptr);
+
+ std::uniform_real_distribution dist(0.0f, 1.0f);
+ const float rnd = dist(sctx->rng);
+ ggml_backend_tensor_set(sctx->uniform, &rnd, 0, sizeof(float));
+}
+
+static void llama_sampler_backend_dist_apply_ggml(
+ struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data) {
+ GGML_UNUSED(gf);
+ auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
+
+ struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits);
+ ggml_set_name(probs, "dist_probs");
+
+ struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
+ if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) {
+ fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n");
+ fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
+ }
+ ggml_set_name(cumsum, "cumsum");
+
+ // The uniform tensor has a random value and we subtract this tensor with
+ // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
+ // Recall that each entry in cumsum is the cumulative probability up to that
+ // index so values stay negative while the cumulative total is below the
+ // random value, and become zero/positive once the threshold is crossed.
+ struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->uniform);
+ ggml_set_name(diff, "dist_cumsum");
+
+ // The ggml_step function produces a tensor where entries are 1 if the
+ // corresponding entry in diff is > 0, and 0 otherwise. So all values up to
+ // the index where the cumulative probability exceeds the random value are 0,
+ // and all entries after that are 1.
+ struct ggml_tensor * mask = ggml_step(ctx, diff);
+ ggml_set_name(mask, "dist_mask");
+
+ // Taking the sum of the mask gives us the sum of elements after the threshold
+ // we are interested in.
+ struct ggml_tensor * idxf = ggml_sum(ctx, mask);
+ ggml_set_name(idxf, "dist_index_f32");
+
+ // Use ggml_scale_bias to scale the index value by -1 and then add the size
+ // of the mask to that value so we get the correct index ((-1 * idxf) + n).
+ struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
+ ggml_set_name(idx, "dist_index_i32");
+
+ // Map back to original vocab ids if a candidates tensor is available.
+ struct ggml_tensor * sampled_token = idx;
+ if (ggml_data->candidates != nullptr) {
+ struct ggml_tensor * candidates = ggml_data->candidates;
+ struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates),
+ ggml_type_size(candidates->type), 0);
+
+ sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx);
+ ggml_set_name(sampled_token, "dist_sampled_token");
+ }
+
+ ggml_set_output(sampled_token);
+ ggml_data->sampled = sampled_token;
+}
+
+static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) {
+ return "backend-dist";
+}
+
+static void llama_sampler_backend_dist_free(struct llama_sampler * smpl) {
+ auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
+ ggml_backend_buffer_free(sctx->buffer);
+ ggml_free(sctx->ctx);
+ delete sctx;
+}
+
+static struct llama_sampler * llama_sampler_backend_dist_clone(const struct llama_sampler * smpl) {
+ auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
+ return llama_sampler_backend_init_dist(sctx->seed);
+}
+
+
+struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed) {
+ static const llama_sampler_i iface = {
+ /*.name =*/ llama_sampler_backend_dist_name,
+ /*.accept =*/ nullptr,
+ /*.apply =*/ nullptr,
+ /*.reset =*/ nullptr,
+ /*.clone =*/ llama_sampler_backend_dist_clone,
+ /*.free =*/ llama_sampler_backend_dist_free,
+ /*.apply_ggml =*/ llama_sampler_backend_dist_apply_ggml,
+ /*.accept_ggml =*/ nullptr,
+ /*.set_input_ggml =*/ llama_sampler_backend_dist_set_input_ggml,
+ /*.init_ggml =*/ llama_sampler_backend_dist_init_ggml,
+ };
+
+ auto seed_cur = get_rng_seed(seed);
+ auto * ctx_data = new llama_sampler_backend_dist_ctx {
+ /*.seed =*/ seed,
+ /*.seed_cur =*/ seed_cur,
+ /*.rng =*/ std::mt19937(seed_cur),
+ /*.uniform =*/ nullptr,
+ /*.ctx =*/ nullptr,
+ /*.buffer =*/ nullptr,
+ /*.device =*/ nullptr,
+ };
+
+ auto * sampler = new llama_sampler {
+ /*.iface =*/ &iface,
+ /*.ctx =*/ ctx_data,
+ };
+
+ return sampler;
+}
+
+struct llama_sampler_backend_logit_bias_ctx {
+ const int32_t n_vocab;
+
+ const std::vector logit_bias;
+
+ struct ggml_tensor * logit_bias_t;
+ struct ggml_context * ctx;
+ ggml_backend_buffer_t buffer;
+};
+
+static void llama_sampler_backend_logit_bias_init_ggml(
+ struct llama_sampler * smpl,
+ ggml_backend_buffer_type_t buft) {
+ auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
+ if (sctx->logit_bias.empty()) {
+ return;
+ }
+ ggml_init_params params = {
+ /*.mem_size =*/ ggml_tensor_overhead() * sctx->n_vocab * sizeof(float),
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
+ sctx->ctx = ggml_init(params);
+
+ struct ggml_tensor * logit_bias = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, sctx->n_vocab);
+ sctx->logit_bias_t = logit_bias;
+ ggml_set_name(sctx->logit_bias_t, "logit_bias");
+ ggml_set_input(sctx->logit_bias_t);
+ ggml_set_output(sctx->logit_bias_t);
+
+ // Allocate all tensors from our context to the backend
+ sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft);
+}
+
+static void llama_sampler_backend_logit_bias_set_input_ggml(struct llama_sampler * smpl) {
+ auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
+ if (sctx->logit_bias.empty()) {
+ return;
+ }
+ GGML_ASSERT(sctx->logit_bias_t != nullptr);
+
+ // Create a sparse logit_bias vector from the logit_bias entries.
+ std::vector logit_bias_sparse(sctx->n_vocab, 0.0f);
+ for (const auto & lb : sctx->logit_bias) {
+ GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
+ logit_bias_sparse[lb.token] = lb.bias;
+ }
+
+ ggml_backend_tensor_set(sctx->logit_bias_t, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->logit_bias_t));
+}
+
+static void llama_sampler_backend_logit_bias_apply_ggml(
+ struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data) {
+ GGML_UNUSED(gf);
+ GGML_UNUSED(ctx);
+
+ auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
+ if (sctx->logit_bias_t == nullptr) {
+ return;
+ }
+
+ // Add the sparse logit logit_bias to the logits
+ struct ggml_tensor * logit_biased = ggml_add_inplace(sctx->ctx, ggml_data->logits, sctx->logit_bias_t);
+ ggml_build_forward_expand(gf, logit_biased);
+}
+
+static const char * llama_sampler_backend_logit_bias_name(const struct llama_sampler *) {
+ return "backend-logit_bias";
+}
+
+static void llama_sampler_backend_logit_bias_free(struct llama_sampler * smpl) {
+ auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
+ ggml_backend_buffer_free(sctx->buffer);
+ ggml_free(sctx->ctx);
+ delete sctx;
+}
+
+static struct llama_sampler * llama_sampler_backend_logit_bias_clone(const struct llama_sampler * smpl) {
+ auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
+ return llama_sampler_backend_init_logit_bias(sctx->n_vocab,
+ sctx->logit_bias.size(),
+ sctx->logit_bias.data());
+}
+
+
+struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab,
+ int32_t n_logit_bias,
+ const llama_logit_bias * logit_bias) {
+ static const llama_sampler_i iface = {
+ /*.name =*/ llama_sampler_backend_logit_bias_name,
+ /*.accept =*/ nullptr,
+ /*.apply =*/ nullptr,
+ /*.reset =*/ nullptr,
+ /*.clone =*/ llama_sampler_backend_logit_bias_clone,
+ /*.free =*/ llama_sampler_backend_logit_bias_free,
+ /*.apply_ggml =*/ llama_sampler_backend_logit_bias_apply_ggml,
+ /*.accept_ggml =*/ nullptr,
+ /*.set_input_ggml =*/ llama_sampler_backend_logit_bias_set_input_ggml,
+ /*.init_ggml =*/ llama_sampler_backend_logit_bias_init_ggml,
+ };
+
+ auto * ctx_data = new llama_sampler_backend_logit_bias_ctx {
+ /*.n_vocab =*/ n_vocab,
+ /*.logit_bias =*/ std::vector(logit_bias, logit_bias + n_logit_bias),
+ /*.logit_bias_t =*/ nullptr,
+ /*.ctx =*/ nullptr,
+ /*.buffer =*/ nullptr,
+ };
+
+ auto * sampler = new llama_sampler {
+ /*.iface =*/ &iface,
+ /*.ctx =*/ ctx_data,
+ };
+
+ return sampler;
+}
diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp
index 86a1a4ba187ee..f0866a9ca1962 100644
--- a/src/llama-batch.cpp
+++ b/src/llama-batch.cpp
@@ -28,7 +28,8 @@ bool llama_batch_allocr::init(
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
- bool output_all) {
+ bool output_all,
+ bool backend_sampling) {
clear();
batch = batch_inp;
@@ -145,6 +146,24 @@ bool llama_batch_allocr::init(
}
}
+ if (backend_sampling) {
+ std::vector seq_output_count(n_seq_max, 0);
+
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
+ if (batch.logits[i] == 0) {
+ continue;
+ }
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
+ const llama_seq_id seq_id = batch.seq_id[i][s];
+ seq_output_count[seq_id]++;
+ if (seq_output_count[seq_id] > 1) {
+ LLAMA_LOG_ERROR("%s: backend sampling allows at most one output token per sequence (%d)\n", __func__, seq_id);
+ return false;
+ }
+ }
+ }
+ }
+
//
// compute stats
//
diff --git a/src/llama-batch.h b/src/llama-batch.h
index 209cf3699de23..d8751274f376d 100644
--- a/src/llama-batch.h
+++ b/src/llama-batch.h
@@ -79,7 +79,8 @@ class llama_batch_allocr {
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
- bool output_all);
+ bool output_all,
+ bool backend_sampling = false);
const llama_batch & get_batch() const;
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 70a3ec62dfc63..1694e44720062 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -58,6 +58,16 @@ llama_context::llama_context(
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
+ // backend samplers
+ if (params.samplers != nullptr && params.n_samplers > 0) {
+ sampling.samplers.reserve(params.n_samplers);
+
+ for (size_t i = 0; i < params.n_samplers; ++i) {
+ const auto & config = params.samplers[i];
+ sampling.samplers[config.seq_id] = config.sampler;
+ }
+ }
+
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
rope_scaling_type = hparams.rope_scaling_type_train;
@@ -420,10 +430,24 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
}
}
+
+ // Initialize the full vocabulary token ids for backend samplers.
+ {
+ const llama_vocab * vocab = llama_model_get_vocab(&model);
+ const int n_vocab = llama_vocab_n_tokens(vocab);
+ sampling.token_ids_full_vocab.resize(n_vocab);
+ for (int i = 0; i < n_vocab; ++i) {
+ sampling.token_ids_full_vocab[i] = i;
+ }
+ }
}
llama_context::~llama_context() {
ggml_opt_free(opt_ctx);
+ // TODO: perhaps use a smart pointer for samplers
+ for (auto const& [seq_id, sampler] : sampling.samplers) {
+ llama_sampler_free(sampler);
+ }
}
void llama_context::synchronize() {
@@ -564,6 +588,35 @@ float * llama_context::get_logits() {
return logits;
}
+int64_t llama_context::resolve_output_row(int32_t i) const {
+ int64_t j = -1;
+
+ // support negative indices (last output row)
+ if (i < 0) {
+ j = n_outputs + i;
+ if (j < 0) {
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
+ }
+ } else if ((size_t) i >= output_ids.size()) {
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
+ } else {
+ // use output_ids to translate the batch token index into a row number
+ // that holds this token's data.
+ j = output_ids[i];
+ }
+
+ if (j < 0) {
+ // the batch token was not configured to output anything
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
+ }
+
+ if (j >= n_outputs) {
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
+ }
+
+ return j;
+}
+
float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1;
@@ -610,6 +663,10 @@ float * llama_context::get_embeddings() {
return embd;
}
+llama_token * llama_context::get_backend_sampled_tokens() {
+ return sampling.sampled;
+}
+
float * llama_context::get_embeddings_ith(int32_t i) {
int64_t j = -1;
@@ -659,6 +716,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}
+llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) {
+ output_reorder();
+
+ if (sampling.sampled == nullptr) {
+ return LLAMA_TOKEN_NULL;
+ }
+
+ try {
+ const int64_t row = resolve_output_row(idx);
+ GGML_ASSERT(row < (int64_t) sampling.sampled_size);
+ return sampling.sampled[row];
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
+ return LLAMA_TOKEN_NULL;
+ }
+}
+
+float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
+ output_reorder();
+
+ if (sampling.probs == nullptr) {
+ return nullptr;
+ }
+
+ try {
+ const int64_t row = resolve_output_row(idx);
+ if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
+ return nullptr;
+ }
+ return sampling.probs + row*model.vocab.n_tokens();
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
+ return nullptr;
+ }
+}
+
+float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
+ output_reorder();
+
+ if (sampling.logits == nullptr) {
+ return nullptr;
+ }
+
+ try {
+ const int64_t row = resolve_output_row(idx);
+ if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
+ return nullptr;
+ }
+ return sampling.logits + row*model.vocab.n_tokens();
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
+ return nullptr;
+ }
+}
+
+const llama_token * llama_context::get_backend_sampled_candidates_ith(int32_t idx) {
+ output_reorder();
+
+ try {
+ const int64_t row = resolve_output_row(idx);
+ if (sampling.candidates != nullptr &&
+ (size_t) row < sampling.candidates_count.size() &&
+ sampling.candidates_count[row] > 0) {
+ return sampling.candidates + row*model.vocab.n_tokens();
+ }
+ } catch (const std::exception & err) {
+ // fallback to full vocab list
+ }
+
+ return sampling.token_ids_full_vocab.data();
+}
+
+size_t llama_context::get_backend_sampled_candidates_count(int32_t idx) {
+ output_reorder();
+
+ if (sampling.candidates == nullptr) {
+ return 0;
+ }
+
+ try {
+ const int64_t row = resolve_output_row(idx);
+ if ((size_t) row >= sampling.candidates_count.size()) {
+ return 0;
+ }
+ return sampling.candidates_count[row];
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
+ return 0;
+ }
+}
+
+size_t llama_context::get_backend_sampled_logits_count(int32_t idx) {
+ output_reorder();
+
+ if (sampling.logits == nullptr) {
+ return 0;
+ }
+
+ try {
+ const int64_t row = resolve_output_row(idx);
+ if ((size_t) row >= sampling.logits_count.size()) {
+ return 0;
+ }
+ return sampling.logits_count[row];
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
+ return 0;
+ }
+}
+
+size_t llama_context::get_backend_sampled_probs_count(int32_t idx) {
+ output_reorder();
+
+ if (sampling.probs == nullptr) {
+ return 0;
+ }
+
+ try {
+ const int64_t row = resolve_output_row(idx);
+ if ((size_t) row >= sampling.probs_count.size()) {
+ return 0;
+ }
+ return sampling.probs_count[row];
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
+ return 0;
+ }
+}
+
+
void llama_context::attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) {
@@ -715,6 +902,37 @@ void llama_context::set_warmup(bool value) {
cparams.warmup = value;
}
+void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
+ LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
+
+ auto it = sampling.samplers.find(seq_id);
+ if (it != sampling.samplers.end()) {
+ // If the sampler to be set is the same that is already set, do nothing.
+ if (it->second == sampler) {
+ return;
+ }
+
+ llama_sampler_free(it->second);
+
+ // If sampler is nullptr, we remove the samppler chain for this seq_id.
+ // chain for this seq_id.
+ if (sampler == nullptr) {
+ sampling.samplers.erase(it);
+ return;
+ }
+
+ // Otherwise, we replace the existing sampler with the new one.
+ it->second = sampler;
+ return;
+ }
+
+ // If there is no sampler for this seq_id and the caller provides a non-null
+ // sampler, we set it.
+ if (sampler != nullptr) {
+ sampling.samplers[seq_id] = sampler;
+ }
+}
+
void llama_context::set_adapter_lora(
llama_adapter_lora * adapter,
float scale) {
@@ -979,6 +1197,97 @@ int llama_context::encode(const llama_batch & batch_inp) {
return 0;
}
+static std::unordered_map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
+ std::unordered_map seq_to_row;
+ // how many output tokens we have seen so far for this ubatch.
+ uint32_t local = 0;
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
+ // skip tokens that are not output.
+ if (!ubatch.output[i]) {
+ continue;
+ }
+
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
+ // row_offset is the number of output tokens before this ubatch.
+ seq_to_row[seq_id] = row_offset + local;
+ ++local;
+ }
+ return seq_to_row;
+}
+
+static void copy_tensor_async_ints(
+ const std::unordered_map & tensor_map,
+ llama_token * sampled,
+ size_t sampled_size,
+ const std::unordered_map & seq_to_row,
+ ggml_backend_sched_t sched) {
+ if (sampled == nullptr || sampled_size == 0) {
+ return;
+ }
+
+ for (const auto & [seq_id, tensor] : tensor_map) {
+ auto it = seq_to_row.find(seq_id);
+ GGML_ASSERT(it != seq_to_row.end());
+ const uint32_t row = it->second;
+ GGML_ASSERT(row < sampled_size);
+
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
+ ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
+ }
+}
+
+static void copy_tensor_async_floats(
+ const std::unordered_map & tensor_map,
+ float * dst,
+ size_t stride,
+ std::vector & counts,
+ const std::unordered_map & seq_to_row,
+ ggml_backend_sched_t sched) {
+ if (dst == nullptr || stride == 0) {
+ return;
+ }
+
+ for (const auto & [seq_id, tensor] : tensor_map) {
+ auto it = seq_to_row.find(seq_id);
+ GGML_ASSERT(it != seq_to_row.end());
+ const uint32_t row = it->second;
+ GGML_ASSERT(row < counts.size());
+
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
+ float * row_ptr = dst + (size_t) row * stride;
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
+
+ // Update the actual number of logits/probabilities that were written for this row.
+ counts[row] = ggml_nelements(tensor);
+ }
+}
+
+static void copy_tensor_async_candidates(
+ const std::unordered_map & tensor_map,
+ llama_token * dst,
+ size_t stride,
+ std::vector & counts,
+ const std::unordered_map & seq_to_row,
+ ggml_backend_sched_t sched) {
+ if (dst == nullptr || stride == 0) {
+ return;
+ }
+
+ for (const auto & [seq_id, tensor] : tensor_map) {
+ auto it = seq_to_row.find(seq_id);
+ GGML_ASSERT(it != seq_to_row.end());
+ const uint32_t row = it->second;
+ GGML_ASSERT(row < counts.size());
+
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
+ llama_token * row_ptr = dst + (size_t) row * stride;
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
+
+ // Update the actual number of candidates that were written.
+ counts[row] = ggml_nelements(tensor);
+ }
+}
+
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
@@ -1000,8 +1309,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
+ const bool has_backend_samplers = !sampling.samplers.empty();
- if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
+ cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
+ output_all,
+ has_backend_samplers)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
@@ -1089,6 +1402,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
int64_t n_outputs_prev = 0;
+ // This flag indicates whether a backend sampler has actually sampled a specific
+ // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
+ bool backend_has_sampled = false;
+
do {
const auto & ubatch = mctx->get_ubatch();
@@ -1147,80 +1464,106 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
- auto * t_logits = res->get_logits();
- auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
+ backend_has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
- if (t_embd && res->get_embd_pooled()) {
- t_embd = res->get_embd_pooled();
- }
+ if (has_backend_samplers && backend_has_sampled) {
+ const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
- // extract logits
- if (t_logits && n_outputs > 0) {
- ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
- GGML_ASSERT(backend_res != nullptr);
- GGML_ASSERT(logits != nullptr);
+ // If a backend sampler has sampled a token we only want to copy the
+ // sampled tokens and avoid copying logits and probabilites.
+ if (!res->t_sampled.empty()) {
+ // async copy the sampled tokens from the backend to the host.
+ copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
+ } else {
+ // async copy the sampled logits/probs from the backend to the host.
+ copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, n_vocab, sampling.logits_count, seq_to_output_row, sched.get());
+ copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, n_vocab, sampling.probs_count, seq_to_output_row, sched.get());
+ }
- float * logits_out = logits + n_outputs_prev*n_vocab;
+ // async copy the candidate token ids from the backend to the host.
+ // These are needed for:
+ // 1) Backend dist sampler to map indices to vocab token ids.
+ // 2) CPU samplers to associate candidate logits with their token ids.
+ copy_tensor_async_candidates(res->t_candidates, sampling.candidates, n_vocab, sampling.candidates_count, seq_to_output_row, sched.get());
- if (n_outputs) {
- GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
- ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
- }
}
- // extract embeddings
- if (t_embd && n_outputs > 0) {
- ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
- GGML_ASSERT(backend_embd != nullptr);
+ if (!backend_has_sampled) {
+ auto * t_logits = res->get_logits();
+ auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
- switch (cparams.pooling_type) {
- case LLAMA_POOLING_TYPE_NONE:
- {
- // extract token embeddings
- GGML_ASSERT(embd != nullptr);
- float * embd_out = embd + n_outputs_prev*n_embd;
+ if (t_embd && res->get_embd_pooled()) {
+ t_embd = res->get_embd_pooled();
+ }
- if (n_outputs) {
- GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
- }
- } break;
- case LLAMA_POOLING_TYPE_MEAN:
- case LLAMA_POOLING_TYPE_CLS:
- case LLAMA_POOLING_TYPE_LAST:
- {
- // extract sequence embeddings (cleared before processing each batch)
- auto & embd_seq_out = embd_seq;
-
- for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
- const llama_seq_id seq_id = ubatch.seq_id_unq[s];
- const int32_t seq_idx = ubatch.seq_idx[seq_id];
-
- embd_seq_out[seq_id].resize(n_embd);
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
- }
- } break;
- case LLAMA_POOLING_TYPE_RANK:
- {
- // extract the rerank score - n_cls_out floats per sequence
- auto & embd_seq_out = embd_seq;
+ // extract logits
+ if (t_logits && n_outputs > 0) {
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
+ GGML_ASSERT(backend_res != nullptr);
+ GGML_ASSERT(logits != nullptr);
- const uint32_t n_cls_out = hparams.n_cls_out;
+ float * logits_out = logits + n_outputs_prev*n_vocab;
- for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
- const llama_seq_id seq_id = ubatch.seq_id_unq[s];
- const int32_t seq_idx = ubatch.seq_idx[seq_id];
+ if (n_outputs) {
+ GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
+ }
+ }
- embd_seq_out[seq_id].resize(n_cls_out);
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
+ // extract embeddings
+ if (t_embd && n_outputs > 0) {
+ ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
+ GGML_ASSERT(backend_embd != nullptr);
+
+ switch (cparams.pooling_type) {
+ case LLAMA_POOLING_TYPE_NONE:
+ {
+ // extract token embeddings
+ GGML_ASSERT(embd != nullptr);
+ float * embd_out = embd + n_outputs_prev*n_embd;
+
+ if (n_outputs) {
+ GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
+ }
+ } break;
+ case LLAMA_POOLING_TYPE_MEAN:
+ case LLAMA_POOLING_TYPE_CLS:
+ case LLAMA_POOLING_TYPE_LAST:
+ {
+ // extract sequence embeddings (cleared before processing each batch)
+ auto & embd_seq_out = embd_seq;
+
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
+
+ embd_seq_out[seq_id].resize(n_embd);
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
+ }
+ } break;
+ case LLAMA_POOLING_TYPE_RANK:
+ {
+ // extract the rerank score - n_cls_out floats per sequence
+ auto & embd_seq_out = embd_seq;
+
+ const uint32_t n_cls_out = hparams.n_cls_out;
+
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
+
+ embd_seq_out[seq_id].resize(n_cls_out);
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
+ }
+ } break;
+ case LLAMA_POOLING_TYPE_UNSPECIFIED:
+ {
+ GGML_ABORT("unknown pooling type");
}
- } break;
- case LLAMA_POOLING_TYPE_UNSPECIFIED:
- {
- GGML_ABORT("unknown pooling type");
- }
+ }
}
}
@@ -1306,8 +1649,31 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
has_embd = true;
}
- logits_size = has_logits ? n_vocab*n_outputs_max : 0;
- embd_size = has_embd ? n_embd*n_outputs_max : 0;
+ const bool backend_sampling = !sampling.samplers.empty();
+ size_t backend_float_count = 0;
+ size_t backend_token_count = 0;
+
+ if (!backend_sampling) {
+ logits_size = has_logits ? n_vocab*n_outputs_max : 0;
+ embd_size = has_embd ? n_embd*n_outputs_max : 0;
+
+ // reset backend sampling values.
+ sampling.logits_size = 0;
+ sampling.probs_size = 0;
+ sampling.sampled_size = 0;
+ sampling.candidates_size = 0;
+ } else {
+ logits_size = 0;
+ embd_size = 0;
+
+ sampling.logits_size = n_vocab*n_outputs_max;
+ sampling.probs_size = n_vocab*n_outputs_max;
+ sampling.sampled_size = n_outputs_max;
+ sampling.candidates_size = n_vocab*n_outputs_max;
+
+ backend_float_count = sampling.logits_size + sampling.probs_size;
+ backend_token_count = sampling.sampled_size + sampling.candidates_size;
+ }
if (output_ids.empty()) {
// init, never resized afterwards
@@ -1315,7 +1681,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
}
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
- const size_t new_size = (logits_size + embd_size) * sizeof(float);
+ const size_t new_size = (logits_size + embd_size + backend_float_count) * sizeof(float)
+ + backend_token_count * sizeof(llama_token);
// alloc only when more than the current capacity is required
// TODO: also consider shrinking the buffer
@@ -1346,8 +1713,56 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
- logits = has_logits ? output_base : nullptr;
- embd = has_embd ? output_base + logits_size : nullptr;
+ logits = nullptr;
+ embd = nullptr;
+
+ // reset sampling pointers.
+ sampling.logits = nullptr;
+ sampling.probs = nullptr;
+ sampling.sampled = nullptr;
+ sampling.candidates = nullptr;
+
+ if (!backend_sampling) {
+ logits = has_logits ? output_base : nullptr;
+ embd = has_embd ? output_base + logits_size : nullptr;
+ } else {
+ size_t offset = 0;
+ uint8_t * base = (uint8_t *) output_base;
+
+ if (sampling.logits_size > 0) {
+ sampling.logits = (float *) (base + offset);
+ offset += sampling.logits_size * sizeof(float);
+ }
+ if (sampling.probs_size > 0) {
+ sampling.probs = (float *) (base + offset);
+ offset += sampling.probs_size * sizeof(float);
+ }
+ if (sampling.sampled_size > 0) {
+ sampling.sampled = (llama_token *) (base + offset);
+ offset += sampling.sampled_size * sizeof(llama_token);
+ }
+ if (sampling.candidates_size > 0) {
+ sampling.candidates = (llama_token *) (base + offset);
+ offset += sampling.candidates_size * sizeof(llama_token);
+ }
+
+ const size_t n_rows = (size_t) n_outputs_max;
+ if (sampling.outputs_capacity < n_rows) {
+ sampling.outputs_capacity = n_rows;
+
+ sampling.logits_count.assign(n_rows, 0);
+ sampling.probs_count.assign(n_rows, 0);
+ sampling.candidates_count.assign(n_rows, 0);
+ } else {
+ std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
+ std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
+ std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
+ }
+
+ if (sampling.sampled && sampling.sampled_size > 0) {
+ std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
+ }
+ }
// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1376,6 +1791,38 @@ void llama_context::output_reorder() {
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
}
}
+
+ if (sampling.logits && sampling.logits_size > 0) {
+ for (uint64_t k = 0; k < n_vocab; ++k) {
+ std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
+ }
+ }
+
+ if (sampling.probs && sampling.probs_size > 0) {
+ for (uint64_t k = 0; k < n_vocab; ++k) {
+ std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
+ }
+ }
+
+ if (sampling.candidates && sampling.candidates_size > 0) {
+ for (uint64_t k = 0; k < n_vocab; ++k) {
+ std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
+ }
+ }
+
+ if (sampling.sampled && sampling.sampled_size > 0) {
+ std::swap(sampling.sampled[i0], sampling.sampled[i1]);
+ }
+
+ if (!sampling.logits_count.empty()) {
+ std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
+ }
+ if (!sampling.probs_count.empty()) {
+ std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
+ }
+ if (!sampling.candidates_count.empty()) {
+ std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
+ }
}
output_swaps.clear();
@@ -1452,10 +1899,12 @@ llm_graph_params llama_context::graph_params(
/*.gtype =*/ gtype,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
+ /*.dev_out =*/ model.dev_output(),
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
+ /*.samplers =*/ sampling.samplers,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
/*.res =*/ res,
@@ -2319,6 +2768,8 @@ llama_context_params llama_context_default_params() {
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
+ /*.sampler =*/ nullptr,
+ /*.n_sampler =*/ 0,
};
return result;
@@ -2478,6 +2929,13 @@ float * llama_get_logits(llama_context * ctx) {
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
ctx->synchronize();
+ if (ctx->get_backend_sampled_token_ith(i) != LLAMA_TOKEN_NULL) {
+ return nullptr;
+ }
+ if (ctx->get_backend_sampled_probs_ith(i) != nullptr) {
+ return nullptr;
+ }
+
return ctx->get_logits_ith(i);
}
@@ -2499,6 +2957,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
+void llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
+ ctx->set_backend_sampler(seq_id, smpl);
+}
+
+llama_token llama_get_backend_sampled_token_ith(llama_context * ctx, int32_t i) {
+ ctx->synchronize();
+
+ return ctx->get_backend_sampled_token_ith(i);
+}
+
+float * llama_get_backend_sampled_probs_ith(llama_context * ctx, int32_t i) {
+ ctx->synchronize();
+
+ return ctx->get_backend_sampled_probs_ith(i);
+}
+
+float * llama_get_backend_sampled_logits_ith(llama_context * ctx, int32_t i) {
+ ctx->synchronize();
+
+ return ctx->get_backend_sampled_logits_ith(i);
+}
+
+llama_token * llama_get_backend_sampled_candidates_ith(llama_context * ctx, int32_t i) {
+ ctx->synchronize();
+
+ return const_cast(ctx->get_backend_sampled_candidates_ith(i));
+}
+
+uint32_t llama_get_backend_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
+ ctx->synchronize();
+
+ return static_cast(ctx->get_backend_sampled_candidates_count(i));
+}
+
+uint32_t llama_get_backend_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
+ ctx->synchronize();
+
+ return static_cast(ctx->get_backend_sampled_logits_count(i));
+}
+
+uint32_t llama_get_backend_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
+ ctx->synchronize();
+
+ return static_cast(ctx->get_backend_sampled_probs_count(i));
+}
+
// llama adapter API
int32_t llama_set_adapter_lora(
diff --git a/src/llama-context.h b/src/llama-context.h
index 20cbd78955412..2bdbf8a55326b 100644
--- a/src/llama-context.h
+++ b/src/llama-context.h
@@ -66,6 +66,18 @@ struct llama_context {
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
+ llama_token * get_backend_sampled_tokens();
+ llama_token get_backend_sampled_token_ith(int32_t idx);
+
+ float * get_backend_sampled_logits_ith(int32_t idx);
+ size_t get_backend_sampled_logits_count(int32_t idx);
+
+ float * get_backend_sampled_probs_ith(int32_t idx);
+ size_t get_backend_sampled_probs_count(int32_t idx);
+
+ const llama_token * get_backend_sampled_candidates_ith(int32_t idx);
+ size_t get_backend_sampled_candidates_count(int32_t idx);
+
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch);
@@ -191,6 +203,7 @@ struct llama_context {
uint32_t output_reserve(int32_t n_outputs);
void output_reorder();
+ int64_t resolve_output_row(int32_t i) const;
//
// graph
@@ -208,6 +221,8 @@ struct llama_context {
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
+ void set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler);
+
private:
llm_graph_params graph_params(
llm_graph_result * res,
@@ -242,6 +257,31 @@ struct llama_context {
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
+ struct sampling_info {
+ std::unordered_map samplers;
+
+ float * logits = nullptr;
+ size_t logits_size = 0;
+
+ llama_token * sampled = nullptr;
+ size_t sampled_size = 0;
+
+ float * probs = nullptr;
+ size_t probs_size = 0;
+
+ llama_token * candidates = nullptr;
+ size_t candidates_size = 0;
+
+ size_t outputs_capacity = 0;
+ std::vector logits_count;
+ std::vector probs_count;
+ std::vector candidates_count;
+
+ std::vector token_ids_full_vocab;
+ };
+
+ sampling_info sampling;
+
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 650e40ec6ffce..8af9188d05da7 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -462,6 +462,28 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
inp_rs->set_input(ubatch);
}
+void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
+ GGML_UNUSED(ubatch);
+ for (const auto & [seq_id, sampler] : samplers) {
+ if (sampler->iface->set_input_ggml) {
+ sampler->iface->set_input_ggml(sampler);
+ }
+ }
+}
+
+bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
+ if (params.samplers.empty()) {
+ return true;
+ }
+
+ for (const auto & [seq_id, sampler] : params.samplers) {
+ if (samplers[seq_id] != sampler) {
+ return false;
+ }
+ }
+ return true;
+}
+
//
// llm_graph_result
//
@@ -482,6 +504,10 @@ void llm_graph_result::reset() {
t_logits = nullptr;
t_embd = nullptr;
t_embd_pooled = nullptr;
+ t_sampled.clear();
+ t_sampled_probs.clear();
+ t_sampled_logits.clear();
+ t_candidates.clear();
params = {};
@@ -583,10 +609,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
rope_type (hparams.rope_type),
sched (params.sched),
backend_cpu (params.backend_cpu),
+ dev_out (params.dev_out),
cvec (params.cvec),
loras (params.loras),
mctx (params.mctx),
cross (params.cross),
+ samplers (params.samplers),
cb_func (params.cb),
res (params.res),
ctx0 (res->get_ctx()),
@@ -2021,6 +2049,100 @@ void llm_graph_context::build_pooling(
ggml_build_forward_expand(gf, cur);
}
+void llm_graph_context::build_sampling() const {
+ if (samplers.empty()) {
+ return;
+ }
+
+ std::unordered_map seq_to_logit_row;
+ int32_t logit_row_idx = 0;
+
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
+ if (ubatch.output[i]) {
+ llama_seq_id seq_id = ubatch.seq_id[i][0];
+ seq_to_logit_row[seq_id] = logit_row_idx;
+ logit_row_idx++;
+ }
+ }
+ if (seq_to_logit_row.empty()) {
+ return;
+ }
+
+ // res->t_logits will contain logits for all tokens that specied that want
+ // logits calculated (logits=1 or output=1)
+ ggml_tensor * logits_t = res->t_logits;
+ GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
+
+ const int64_t n_vocab = logits_t->ne[0];
+
+ ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(dev_out);
+
+ std::unordered_map active_samplers;
+
+ for (const auto & [seq_id, sampler] : samplers) {
+ // Only process samplers for sequences that are in the current batch
+ auto it = seq_to_logit_row.find(seq_id);
+ if (it == seq_to_logit_row.end()) {
+ continue;
+ }
+ const int32_t row_idx = it->second;
+
+ // Allow GPU sampler to create input tensors by implementing init_ggml.
+ if (sampler->iface->init_ggml != nullptr) {
+ sampler->iface->init_ggml(sampler, buft);
+ }
+
+ active_samplers[seq_id] = sampler;
+
+ ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]);
+ ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
+
+ struct llama_sampler_ggml_data ggml_data = {
+ /*.logits =*/ logits_seq,
+ /*.probs =*/ nullptr,
+ /*.sampled =*/ nullptr,
+ /*.candidates =*/ nullptr,
+ };
+
+ llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data);
+
+ if (ggml_data.sampled != nullptr) {
+ res->t_sampled[seq_id] = ggml_data.sampled;
+ ggml_build_forward_expand(gf, ggml_data.sampled);
+ }
+
+ if (ggml_data.probs != nullptr) {
+ res->t_sampled_probs[seq_id] = ggml_data.probs;
+ ggml_build_forward_expand(gf, ggml_data.probs);
+ }
+
+ if (ggml_data.logits != logits_seq) {
+ res->t_sampled_logits[seq_id] = ggml_data.logits;
+ ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
+ }
+
+ if (ggml_data.candidates != nullptr) {
+ res->t_candidates[seq_id] = ggml_data.candidates;
+ ggml_build_forward_expand(gf, ggml_data.candidates);
+ }
+ }
+
+ // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
+ /*
+ for (const auto & [seq_id, sampler] : samplers) {
+ if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
+ ggml_tensor * selected_token = it->second;
+ if (selected_token != nullptr) {
+ llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
+ }
+ }
+ }
+ */
+
+ auto inp_sampling = std::make_unique(n_vocab, false, active_samplers);
+ res->add_input(std::move(inp_sampling));
+}
+
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;
diff --git a/src/llama-graph.h b/src/llama-graph.h
index d0c3934f67927..6797d78a20e38 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -383,6 +383,24 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
const llama_memory_hybrid_context * mctx;
};
+class llm_graph_input_sampling : public llm_graph_input_i {
+public:
+ llm_graph_input_sampling(int32_t n_vocab, bool sorted,
+ std::unordered_map samplers) :
+ n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) { }
+ virtual ~llm_graph_input_sampling() = default;
+
+ void set_input(const llama_ubatch * ubatch) override;
+ bool can_reuse(const llm_graph_params & params) override;
+
+ int32_t n_vocab;
+ bool sorted_value;
+ ggml_tensor * size = nullptr; // I32 [1]
+ ggml_tensor * sorted = nullptr; // I32 [1]
+
+ std::unordered_map samplers;
+};
+
//
// llm_graph_result
//
@@ -410,12 +428,30 @@ struct llm_graph_params {
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
+ ggml_backend_dev_t dev_out;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
+ std::unordered_map samplers;
+
+ static bool samplers_equal(
+ const std::unordered_map & lhs,
+ const std::unordered_map & rhs) {
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+ for (const auto & [seq_id, sampler] : lhs) {
+ auto it = rhs.find(seq_id);
+ if (it == rhs.end() || it->second != sampler) {
+ return false;
+ }
+ }
+ return true;
+ }
+
uint32_t n_outputs;
llm_graph_cb cb;
@@ -463,7 +499,9 @@ struct llm_graph_params {
cvec == other.cvec &&
loras == other.loras &&
cross == other.cross &&
- n_outputs == other.n_outputs;
+ n_outputs == other.n_outputs &&
+ samplers_equal(samplers, other.samplers);
+
}
};
@@ -504,6 +542,11 @@ class llm_graph_result {
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
+ std::unordered_map t_sampled_logits;
+ std::unordered_map t_candidates;
+ std::unordered_map t_sampled;
+ std::unordered_map t_sampled_probs;
+
std::vector inputs;
ggml_context_ptr ctx_compute;
@@ -574,11 +617,15 @@ struct llm_graph_context {
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
+ ggml_backend_dev_t dev_out;
+
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
+ std::unordered_map samplers;
+
const llm_graph_cb & cb_func;
llm_graph_result * res;
@@ -819,6 +866,12 @@ struct llm_graph_context {
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;
+ //
+ // sampling (backend sampling)
+ //
+
+ void build_sampling() const;
+
//
// dense (out)
//
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 175549a9e30f1..29fc6bbc63025 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -7414,6 +7414,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
+ // add backend sampling layers (if any)
+ llm->build_sampling();
+
// if the gguf model was converted with --sentence-transformers-dense-modules
// there will be two additional dense projection layers
// dense linear projections are applied after pooling
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index 3f4a729bc36c7..621438a9cf4d4 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -372,6 +372,39 @@ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_ar
smpl->iface->apply(smpl, cur_p);
}
+void llama_sampler_apply_ggml(
+ struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data) {
+ GGML_ASSERT(smpl->iface->apply_ggml);
+ smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data);
+}
+
+void llama_sampler_accept_ggml(
+ struct llama_sampler * smpl,
+ ggml_context * ctx,
+ ggml_cgraph * gf,
+ struct ggml_tensor * selected_token) {
+ if (smpl->iface->accept_ggml) {
+ smpl->iface->accept_ggml(smpl, ctx, gf, selected_token);
+ }
+}
+
+void llama_sampler_set_input_ggml(struct llama_sampler * smpl) {
+ if (smpl->iface->set_input_ggml) {
+ smpl->iface->set_input_ggml(smpl);
+ }
+}
+
+void llama_sampler_init_ggml(
+ struct llama_sampler * smpl,
+ ggml_backend_buffer_type_t buft) {
+ if (smpl->iface->init_ggml) {
+ smpl->iface->init_ggml(smpl, buft);
+ }
+}
+
void llama_sampler_reset(struct llama_sampler * smpl) {
if (smpl->iface->reset) {
smpl->iface->reset(smpl);
@@ -406,7 +439,15 @@ void llama_sampler_free(struct llama_sampler * smpl) {
}
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
- const auto * logits = llama_get_logits_ith(ctx, idx);
+ const llama_token sampled_token = llama_get_backend_sampled_token_ith(ctx, idx);
+ const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx);
+ const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx);
+ const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx);
+
+ // If a backend sampler has already sampled a token, return it.
+ if (sampled_token != LLAMA_TOKEN_NULL) {
+ return sampled_token;
+ }
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -415,9 +456,26 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
// TODO: do not allocate each time
std::vector cur;
- cur.reserve(n_vocab);
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+
+ if (sampled_probs) {
+ const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
+ cur.resize(sampled_probs_count);
+ for (uint32_t i = 0; i < sampled_probs_count; ++i) {
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
+ }
+ } else if (sampled_logits) {
+ const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
+ cur.resize(sampled_logits_count);
+ for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
+ }
+ } else {
+ const auto * logits = llama_get_logits_ith(ctx, idx);
+ GGML_ASSERT(logits != nullptr);
+ cur.resize(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+ }
}
llama_token_data_array cur_p = {
@@ -462,6 +520,10 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
time_meas tm(chain->t_sample_us, chain->params.no_perf);
for (auto * smpl : chain->samplers) {
+ // Skip GPU samplers - they have apply_ggml but no apply
+ if (smpl->iface->apply == nullptr) {
+ continue;
+ }
llama_sampler_apply(smpl, cur_p);
}
}
@@ -496,13 +558,67 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
delete chain;
}
+static void llama_sampler_chain_apply_ggml(
+ struct llama_sampler * smpl,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct llama_sampler_ggml_data * ggml_data) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+ for (auto * smpl : chain->samplers) {
+ if (smpl->iface->apply_ggml) {
+ smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data);
+ }
+ }
+}
+
+static void llama_sampler_chain_accept_ggml(
+ struct llama_sampler * smpl,
+ ggml_context * ctx,
+ ggml_cgraph * gf,
+ struct ggml_tensor * selected_token) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+ for (auto * smpl : chain->samplers) {
+ if (smpl->iface->accept_ggml) {
+ smpl->iface->accept_ggml(smpl, ctx, gf, selected_token);
+ }
+ }
+}
+
+static void llama_sampler_chain_set_input_ggml(struct llama_sampler * smpl) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+ for (auto * smpl : chain->samplers) {
+ if (smpl->iface->set_input_ggml) {
+ smpl->iface->set_input_ggml(smpl);
+ }
+ }
+}
+
+static void llama_sampler_chain_set_backend_context(
+ struct llama_sampler * smpl,
+ ggml_backend_buffer_type_t buft) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+ for (auto * smpl : chain->samplers) {
+ if (smpl->iface->init_ggml) {
+ smpl->iface->init_ggml(smpl,buft);
+ }
+ }
+}
+
static struct llama_sampler_i llama_sampler_chain_i = {
- /* .name = */ llama_sampler_chain_name,
- /* .accept = */ llama_sampler_chain_accept,
- /* .apply = */ llama_sampler_chain_apply,
- /* .reset = */ llama_sampler_chain_reset,
- /* .clone = */ llama_sampler_chain_clone,
- /* .free = */ llama_sampler_chain_free,
+ /* .name = */ llama_sampler_chain_name,
+ /* .accept = */ llama_sampler_chain_accept,
+ /* .apply = */ llama_sampler_chain_apply,
+ /* .reset = */ llama_sampler_chain_reset,
+ /* .clone = */ llama_sampler_chain_clone,
+ /* .free = */ llama_sampler_chain_free,
+ /* .apply_ggml = */ llama_sampler_chain_apply_ggml,
+ /* .accept_ggml = */ llama_sampler_chain_accept_ggml,
+ /* .set_input_ggml = */ llama_sampler_chain_set_input_ggml,
+ /* .set_backend_context = */ llama_sampler_chain_set_backend_context,
};
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@@ -571,12 +687,16 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
}
static struct llama_sampler_i llama_sampler_greedy_i = {
- /* .name = */ llama_sampler_greedy_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_greedy_apply,
- /* .reset = */ nullptr,
- /* .clone = */ nullptr,
- /* .free = */ nullptr,
+ /* .name = */ llama_sampler_greedy_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_greedy_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ nullptr,
+ /* .free = */ nullptr,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_greedy() {
@@ -696,12 +816,16 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_dist_i = {
- /* .name = */ llama_sampler_dist_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_dist_apply,
- /* .reset = */ llama_sampler_dist_reset,
- /* .clone = */ llama_sampler_dist_clone,
- /* .free = */ llama_sampler_dist_free,
+ /* .name = */ llama_sampler_dist_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_dist_apply,
+ /* .reset = */ llama_sampler_dist_reset,
+ /* .clone = */ llama_sampler_dist_clone,
+ /* .free = */ llama_sampler_dist_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
@@ -741,12 +865,16 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_top_k_i = {
- /* .name = */ llama_sampler_top_k_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_top_k_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_top_k_clone,
- /* .free = */ llama_sampler_top_k_free,
+ /* .name = */ llama_sampler_top_k_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_top_k_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_top_k_clone,
+ /* .free = */ llama_sampler_top_k_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
@@ -836,12 +964,16 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_top_p_i = {
- /* .name = */ llama_sampler_top_p_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_top_p_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_top_p_clone,
- /* .free = */ llama_sampler_top_p_free,
+ /* .name = */ llama_sampler_top_p_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_top_p_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_top_p_clone,
+ /* .free = */ llama_sampler_top_p_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
@@ -930,12 +1062,16 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_min_p_i = {
- /* .name = */ llama_sampler_min_p_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_min_p_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_min_p_clone,
- /* .free = */ llama_sampler_min_p_free,
+ /* .name = */ llama_sampler_min_p_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_min_p_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_min_p_clone,
+ /* .free = */ llama_sampler_min_p_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
@@ -1029,12 +1165,16 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_typical_i = {
- /* .name = */ llama_sampler_typical_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_typical_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_typical_clone,
- /* .free = */ llama_sampler_typical_free,
+ /* .name = */ llama_sampler_typical_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_typical_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_typical_clone,
+ /* .free = */ llama_sampler_typical_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
@@ -1073,12 +1213,16 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_temp_i = {
- /* .name = */ llama_sampler_temp_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_temp_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_temp_clone,
- /* .free = */ llama_sampler_temp_free,
+ /* .name = */ llama_sampler_temp_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_temp_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_temp_clone,
+ /* .free = */ llama_sampler_temp_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_temp(float temp) {
@@ -1183,12 +1327,16 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_temp_ext_i = {
- /* .name = */ llama_sampler_temp_ext_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_temp_ext_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_temp_ext_clone,
- /* .free = */ llama_sampler_temp_ext_free,
+ /* .name = */ llama_sampler_temp_ext_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_temp_ext_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_temp_ext_clone,
+ /* .free = */ llama_sampler_temp_ext_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
@@ -1277,12 +1425,16 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_xtc_i = {
- /* .name = */ llama_sampler_xtc_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sample_xtc_apply,
- /* .reset = */ llama_sampler_xtc_reset,
- /* .clone = */ llama_sampler_xtc_clone,
- /* .free = */ llama_sampler_xtc_free,
+ /* .name = */ llama_sampler_xtc_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sample_xtc_apply,
+ /* .reset = */ llama_sampler_xtc_reset,
+ /* .clone = */ llama_sampler_xtc_clone,
+ /* .free = */ llama_sampler_xtc_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
@@ -1385,12 +1537,16 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_mirostat_i = {
- /* .name = */ llama_sampler_mirostat_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_mirostat_apply,
- /* .reset = */ llama_sampler_mirostat_reset,
- /* .clone = */ llama_sampler_mirostat_clone,
- /* .free = */ llama_sampler_mirostat_free,
+ /* .name = */ llama_sampler_mirostat_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_mirostat_apply,
+ /* .reset = */ llama_sampler_mirostat_reset,
+ /* .clone = */ llama_sampler_mirostat_clone,
+ /* .free = */ llama_sampler_mirostat_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
@@ -1484,12 +1640,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
- /* .name = */ llama_sampler_mirostat_v2_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_mirostat_v2_apply,
- /* .reset = */ llama_sampler_mirostat_v2_reset,
- /* .clone = */ llama_sampler_mirostat_v2_clone,
- /* .free = */ llama_sampler_mirostat_v2_free,
+ /* .name = */ llama_sampler_mirostat_v2_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
+ /* .free = */ llama_sampler_mirostat_v2_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
@@ -1601,12 +1761,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_grammar_i = {
- /* .name = */ llama_sampler_grammar_name,
- /* .accept = */ llama_sampler_grammar_accept_impl,
- /* .apply = */ llama_sampler_grammar_apply,
- /* .reset = */ llama_sampler_grammar_reset,
- /* .clone = */ llama_sampler_grammar_clone,
- /* .free = */ llama_sampler_grammar_free,
+ /* .name = */ llama_sampler_grammar_name,
+ /* .accept = */ llama_sampler_grammar_accept_impl,
+ /* .apply = */ llama_sampler_grammar_apply,
+ /* .reset = */ llama_sampler_grammar_reset,
+ /* .clone = */ llama_sampler_grammar_clone,
+ /* .free = */ llama_sampler_grammar_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
static struct llama_sampler * llama_sampler_init_grammar_impl(
@@ -1808,12 +1972,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_penalties_i = {
- /* .name = */ llama_sampler_penalties_name,
- /* .accept = */ llama_sampler_penalties_accept,
- /* .apply = */ llama_sampler_penalties_apply,
- /* .reset = */ llama_sampler_penalties_reset,
- /* .clone = */ llama_sampler_penalties_clone,
- /* .free = */ llama_sampler_penalties_free,
+ /* .name = */ llama_sampler_penalties_name,
+ /* .accept = */ llama_sampler_penalties_accept,
+ /* .apply = */ llama_sampler_penalties_apply,
+ /* .reset = */ llama_sampler_penalties_reset,
+ /* .clone = */ llama_sampler_penalties_clone,
+ /* .free = */ llama_sampler_penalties_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_penalties(
@@ -1899,12 +2067,16 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
- /* .name = */ llama_sampler_top_n_sigma_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_top_n_sigma_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_top_n_sigma_clone,
- /* .free = */ llama_sampler_top_n_sigma_free,
+ /* .name = */ llama_sampler_top_n_sigma_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_top_n_sigma_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_top_n_sigma_clone,
+ /* .free = */ llama_sampler_top_n_sigma_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
@@ -2229,12 +2401,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_dry_i = {
- /* .name = */ llama_sampler_dry_name,
- /* .accept = */ llama_sampler_dry_accept,
- /* .apply = */ llama_sampler_dry_apply,
- /* .reset = */ llama_sampler_dry_reset,
- /* .clone = */ llama_sampler_dry_clone,
- /* .free = */ llama_sampler_dry_free,
+ /* .name = */ llama_sampler_dry_name,
+ /* .accept = */ llama_sampler_dry_accept,
+ /* .apply = */ llama_sampler_dry_apply,
+ /* .reset = */ llama_sampler_dry_reset,
+ /* .clone = */ llama_sampler_dry_clone,
+ /* .free = */ llama_sampler_dry_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
@@ -2370,12 +2546,16 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_logit_bias_i = {
- /* .name = */ llama_sampler_logit_bias_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_logit_bias_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_logit_bias_clone,
- /* .free = */ llama_sampler_logit_bias_free,
+ /* .name = */ llama_sampler_logit_bias_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_logit_bias_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_logit_bias_clone,
+ /* .free = */ llama_sampler_logit_bias_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_logit_bias(
@@ -2600,12 +2780,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_infill_i = {
- /* .name = */ llama_sampler_infill_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_infill_apply,
- /* .reset = */ nullptr,
- /* .clone = */ llama_sampler_infill_clone,
- /* .free = */ llama_sampler_infill_free,
+ /* .name = */ llama_sampler_infill_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_infill_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_infill_clone,
+ /* .free = */ llama_sampler_infill_free,
+ /* .apply_ggml = */ nullptr,
+ /* .accept_ggml = */ nullptr,
+ /* .set_input_ggml = */ nullptr,
+ /* .set_backend_context = */ nullptr,
};
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index d9cc5e933f4ce..0db8b4bd88845 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -206,6 +206,18 @@ llama_build_and_test(test-backend-ops.cpp)
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
llama_build_and_test(test-autorelease.cpp LABEL "model")
+llama_build_and_test(test-backend-sampler.cpp LABEL "model")
+target_include_directories(test-backend-sampler PRIVATE ${PROJECT_SOURCE_DIR}/src)
+llama_test(test-backend-sampler NAME test-backend-sampler-greedy ARGS --test greedy)
+llama_test(test-backend-sampler NAME test-backend-sampler-temp ARGS --test temp)
+llama_test(test-backend-sampler NAME test-backend-sampler-top_k ARGS --test top_k)
+llama_test(test-backend-sampler NAME test-backend-sampler-dist ARGS --test dist)
+llama_test(test-backend-sampler NAME test-backend-sampler-dist-and-cpu ARGS --test dist_and_cpu)
+llama_test(test-backend-sampler NAME test-backend-sampler-logit-bias ARGS --test logit_bias)
+llama_test(test-backend-sampler NAME test-backend-sampler-mul_seq ARGS --test multi_sequence)
+llama_test(test-backend-sampler NAME test-backend-sampler-set-sampler ARGS --test set_sampler)
+
+
if (NOT GGML_BACKEND_DL)
# these tests use the backends directly and cannot be built with dynamic loading
llama_build_and_test(test-barrier.cpp)
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 2bb4b12224798..8f96250ddc570 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -7906,6 +7906,8 @@ static std::vector> make_test_cases_perf() {
}
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
+ test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1}));
+ test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));
return test_cases;
}
diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp
new file mode 100644
index 0000000000000..2ed13688c98ac
--- /dev/null
+++ b/tests/test-backend-sampler.cpp
@@ -0,0 +1,811 @@
+#include "ggml.h"
+#include "llama.h"
+#include "get-model.h"
+#include "common.h"
+
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include
+#include
+#include
+#include
{/if}
{:else if field.type === 'checkbox'}
- {@const isDisabled = field.key === 'pdfAsImage' && !supportsVision()}
+ {@const pdfDisabled = field.key === 'pdfAsImage' && !supportsVision()}
+ {@const backendDistDisabled = field.key === 'backend_dist' && !localConfig.backend_sampling}
+ {@const isDisabled = pdfDisabled || backendDistDisabled}
{field.help || SETTING_CONFIG_INFO[field.key]}
- {:else if field.key === 'pdfAsImage' && !supportsVision()}
+ {:else if pdfDisabled}
PDF-to-image processing requires a vision-capable model. PDFs will be processed as
text.
+ {:else if backendDistDisabled}
+
+ Enable GPU sampling to allow GPU dist sampling.
+
{/if}
diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts
index c25ea23f37be3..672b8e9847b45 100644
--- a/tools/server/webui/src/lib/constants/settings-config.ts
+++ b/tools/server/webui/src/lib/constants/settings-config.ts
@@ -18,6 +18,8 @@ export const SETTING_CONFIG_DEFAULT: Record =
modelSelectorEnabled: false,
// make sure these default values are in sync with `common.h`
samplers: 'top_k;typ_p;top_p;min_p;temperature',
+ backend_sampling: false,
+ backend_dist: false,
temperature: 0.8,
dynatemp_range: 0.0,
dynatemp_exponent: 1.0,
@@ -51,6 +53,10 @@ export const SETTING_CONFIG_INFO: Record = {
'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.',
samplers:
'The order at which samplers are applied, in simplified way. Default is "top_k;typ_p;top_p;min_p;temperature": top_k->typ_p->top_p->min_p->temperature',
+ backend_sampling:
+ 'Enable backend-based samplers. When enabled, supported samplers run on the accelerator backend for faster sampling.',
+ backend_dist:
+ 'Perform the final distribution sampling step on the backend. Requires backend sampling to be enabled.',
temperature:
'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.',
dynatemp_range:
diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts
index aa83910b27f53..5cda8c886865f 100644
--- a/tools/server/webui/src/lib/services/chat.ts
+++ b/tools/server/webui/src/lib/services/chat.ts
@@ -98,6 +98,8 @@ export class ChatService {
dry_penalty_last_n,
// Other parameters
samplers,
+ backend_sampling,
+ backend_dist,
custom,
timings_per_token
} = options;
@@ -182,6 +184,9 @@ export class ChatService {
: samplers;
}
+ if (backend_sampling !== undefined) requestBody.backend_sampling = backend_sampling;
+ if (backend_dist !== undefined) requestBody.backend_dist = backend_dist;
+
if (timings_per_token !== undefined) requestBody.timings_per_token = timings_per_token;
if (custom) {
diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts
index c70b9580cb75b..e00994bc469be 100644
--- a/tools/server/webui/src/lib/stores/chat.svelte.ts
+++ b/tools/server/webui/src/lib/stores/chat.svelte.ts
@@ -298,6 +298,12 @@ class ChatStore {
if (currentConfig.samplers) {
apiOptions.samplers = currentConfig.samplers;
}
+ if (currentConfig.backend_sampling !== undefined) {
+ apiOptions.backend_sampling = Boolean(currentConfig.backend_sampling);
+ }
+ if (currentConfig.backend_dist !== undefined) {
+ apiOptions.backend_dist = Boolean(currentConfig.backend_dist);
+ }
if (currentConfig.custom) {
apiOptions.custom = currentConfig.custom;
}
diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts
index 1a8bc64989957..149d4fb118f54 100644
--- a/tools/server/webui/src/lib/types/api.d.ts
+++ b/tools/server/webui/src/lib/types/api.d.ts
@@ -181,6 +181,8 @@ export interface ApiChatCompletionRequest {
dry_penalty_last_n?: number;
// Sampler configuration
samplers?: string[];
+ backend_sampling?: boolean;
+ backend_dist?: boolean;
// Custom parameters (JSON string)
custom?: Record;
timings_per_token?: boolean;
diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts
index b47842b66e619..e68d107faa3bf 100644
--- a/tools/server/webui/src/lib/types/settings.d.ts
+++ b/tools/server/webui/src/lib/types/settings.d.ts
@@ -37,6 +37,8 @@ export interface SettingsChatServiceOptions {
dry_penalty_last_n?: number;
// Sampler configuration
samplers?: string | string[];
+ backend_sampling?: boolean;
+ backend_dist?: boolean;
// Custom parameters
custom?: string;
timings_per_token?: boolean;