Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7884b0e
sampling : add support for backend sampling
danbev Nov 17, 2025
9fe9a00
llama-cli : add backend sampler configuration
danbev Nov 17, 2025
f1f3e68
server : add backend sampling options/configuration
danbev Nov 17, 2025
a3eb847
webui : add backend sampling options
danbev Nov 17, 2025
67d3b8e
ggml : add initial cumsum implementation for CUDA
danbev Nov 17, 2025
71574f9
sampling : enable all backend sampler tests
danbev Nov 18, 2025
4b52e59
graph : do not include llama-model.h
ggerganov Nov 18, 2025
82957a9
sampling : always expose sampled_ids
danbev Nov 18, 2025
311c1a3
sampling : ensure at most one output token per seq
danbev Nov 18, 2025
26be108
CUDA: Optimize argsort for gpu-based token sampling
ORippler Nov 18, 2025
0da7e7d
sampling : remove version from sampler chain
danbev Nov 19, 2025
51fee29
sampling : always populate logits for sampled probs
danbev Nov 19, 2025
7e98ebc
sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
d74eb61
squash! sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
38f408c
common : fix regression caused by extra memory allocations during sam…
ggerganov Nov 19, 2025
18ed4d8
squash! sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
0c660e7
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 20, 2025
ed4345b
squash! common : fix regression caused by extra memory allocations du…
danbev Nov 20, 2025
0d28b16
sampling : introduce sampling_info struct
danbev Nov 20, 2025
c162562
sampling : return early if backend sampling is disabled
danbev Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
}
).set_sparam());
add_opt(common_arg(
{"--backend-sampling"},
"enable backend sampling (default: disabled)",
[](common_params & params) {
params.sampling.backend_sampling = true;
}
).set_sparam());
add_opt(common_arg(
{"--backend-dist"},
"perform final (distribution) sampling on backend (default: disabled)",
[](common_params & params) {
params.sampling.backend_dist = true;
params.sampling.backend_sampling = true;
}
).set_sparam());
add_opt(common_arg(
{"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified",
Expand Down
3 changes: 3 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "sampling.h"

#include <algorithm>
#include <cinttypes>
Expand Down Expand Up @@ -956,6 +957,8 @@ struct common_init_result common_init_from_params(common_params & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);

auto cparams = common_context_params_to_llama(params);
cparams.samplers = params.backend_samplers;
cparams.n_samplers = params.n_backend_samplers;

llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
Expand Down
7 changes: 7 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens

// Backend sampling flags
bool backend_sampling = false; // enable backend sampling
bool backend_dist = false; // backend performs final sampling (dist)

// print the parameters into a string
std::string print() const;
};
Expand Down Expand Up @@ -512,6 +516,9 @@ struct common_params {
bool has_speculative() const {
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
}

llama_sampler_seq_config * backend_samplers = NULL;
size_t n_backend_samplers = 0;
};

// call once at the start of a program if it uses libcommon
Expand Down
16 changes: 10 additions & 6 deletions common/llguidance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
}

static llama_sampler_i llama_sampler_llg_i = {
/* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free,
/* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free,
/* .apply_ggml = */ NULL,
/* .accept_ggml = */ NULL,
/* .set_input_ggml = */ NULL,
/* .set_backend_context = */ NULL,
};

static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
Expand Down
78 changes: 73 additions & 5 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,44 @@ struct common_sampler {
llama_token_data_array cur_p;

void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx);
const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx);
const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx);

const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);

const int n_vocab = llama_vocab_n_tokens(vocab);

cur.resize(n_vocab);

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
if (sampled_probs) {
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
cur.resize(sampled_probs_count);
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
}
} else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
cur.resize(sampled_logits_count);
for (uint32_t i = 0; i < sampled_logits_count; i++) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
}
} else {
const auto * logits = llama_get_logits_ith(ctx, idx);
GGML_ASSERT(logits != nullptr);
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
}

cur_p = { cur.data(), cur.size(), -1, false };
}
};

static bool sampler_enabled(const struct common_params_sampling & params, enum common_sampler_type type) {
return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end();
}

std::string common_params_sampling::print() const {
char result[1024];

Expand Down Expand Up @@ -287,6 +308,43 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
return result;
}

struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
if (!params.backend_sampling) {
return nullptr;
}
const llama_vocab * vocab = llama_model_get_vocab(model);

llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
chain_params.no_perf = params.no_perf;

struct llama_sampler * chain = llama_sampler_chain_init(chain_params);

const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE);
const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K);
const bool enable_dist = params.backend_dist;

if (!params.logit_bias.empty()) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
}

if (enable_temp) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
}

if (enable_top_k) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
}

if (enable_dist) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
}

return chain;
}

void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
Expand Down Expand Up @@ -337,6 +395,16 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}

llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
// Check if a backend sampler has already sampled a token in which case we
// return that token id directly.
{
const llama_token id = llama_get_backend_sampled_token_ith(ctx, idx);
if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
return id;
}
}

gsmpl->set_logits(ctx, idx);

auto & grmr = gsmpl->grmr;
Expand Down
7 changes: 7 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ struct common_sampler;

struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);

// Create a backend sampler chain from common sampling parameters
// Returns a llama_sampler chain configured with backend samplers based on the parameters
// This chain can be used per-sequence for backend-based sampling
// Note: Only samplers that have backend equivalents will be added to the chain
// The returned sampler should be freed with llama_sampler_free()
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params);

void common_sampler_free(struct common_sampler * gsmpl);

// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
Expand Down
49 changes: 35 additions & 14 deletions ggml/src/ggml-cuda/argsort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
size_t temp_storage_bytes = 0;

if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
stream);
if (nrows == 1) {
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, stream);
}
} else {
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
sizeof(float) * 8, stream);
if (nrows == 1) {
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
}

ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();

if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
stream);
if (nrows == 1) {
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
} else {
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
0, sizeof(float) * 8, stream);
if (nrows == 1) {
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
stream);
}
}
}
#endif // GGML_CUDA_USE_CUB
Expand Down
69 changes: 69 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "cumsum.cuh"

#ifdef GGML_CUDA_USE_CUB
#include <cub/cub.cuh>
using namespace cub;
#endif // GGML_CUDA_USE_CUB

#include <cstdint>

__global__ void cumsum_f32_kernel(const float * x, float * dst, int64_t n) {
// TODO: this is a naive implementation just for getting something working.
if (threadIdx.x == 0 && blockIdx.x == 0) {
dst[0] = x[0];
for (int64_t i = 1; i < n; i++) {
dst[i] = dst[i-1] + x[i];
}
}
}

void cumsum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
#ifdef GGML_CUDA_USE_CUB
size_t tmp_size = 0;

// Query how much temp storage CUDA UnBound (CUB) needs
cub::DeviceScan::InclusiveSum(
nullptr, // d_temp_storage (null = just query size)
tmp_size, // reference to size (will be set by CUB)
x, // input pointer
dst, // output pointer
ne, // number of elements
stream // CUDA stream to use
);

ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);

// Perform the inclusive scan
cub::DeviceScan::InclusiveSum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);

#else
GGML_UNUSED(pool);
cumsum_f32_kernel<<<1, 1, 0, stream>>>(x, dst, ne);
#endif // GGML_CUDA_USE_CUB
}

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguously_allocated(src0));

const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;

const int64_t ne0 = src0->ne[0]; // row length (cumsum computed along this dimension)
const int64_t ne1 = src0->ne[1];
const int64_t ne2 = src0->ne[2];
const int64_t ne3 = src0->ne[3];
const int64_t nrows = ne1 * ne2 * ne3; // total number of rows

ggml_cuda_pool & pool = ctx.pool();
cudaStream_t stream = ctx.stream();

for (int64_t i = 0; i < nrows; i++) {
const float * src_row = src0_d + i * ne0;
float * dst_row = dst_d + i * ne0;
cumsum_f32_cuda(pool, src_row, dst_row, ne0, stream);
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

void cumsum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
Expand Down Expand Up @@ -2678,6 +2679,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM:
ggml_cuda_op_sum(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
Expand Down Expand Up @@ -4123,6 +4127,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
case GGML_OP_CUMSUM:
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGSORT:
Expand Down
Loading