Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ static llama_token llama_sampling_sample_impl(
logits[it->first] += it->second;
}

if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}

cur.clear();

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
Expand All @@ -198,10 +203,6 @@ static llama_token llama_sampling_sample_impl(

llama_token_data_array cur_p = { cur.data(), cur.size(), false };

if (ctx_cfg) {
llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale);
}

// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
Expand Down
56 changes: 38 additions & 18 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7898,39 +7898,59 @@ static void llama_log_softmax(float * array, size_t size) {
}
}

void llama_sample_apply_guidance(
struct llama_context * ctx,
float * logits,
float * logits_guidance,
float scale) {
GGML_ASSERT(ctx);

const auto t_start_sample_us = ggml_time_us();
const auto n_vocab = llama_n_vocab(llama_get_model(ctx));

llama_log_softmax(logits, n_vocab);
llama_log_softmax(logits_guidance, n_vocab);

for (int i = 0; i < n_vocab; ++i) {
auto & l = logits[i];
const auto & g = logits_guidance[i];

l = scale * (l - g) + g;
}

ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}

void llama_sample_classifier_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
float scale) {
int64_t t_start_sample_us = ggml_time_us();

GGML_ASSERT(ctx);
int64_t t_start_sample_us;

auto n_vocab = llama_n_vocab(llama_get_model(ctx));
t_start_sample_us = ggml_time_us();
const size_t n_vocab = llama_n_vocab(llama_get_model(ctx));

GGML_ASSERT(n_vocab == (int)candidates->size);
GGML_ASSERT(n_vocab == candidates->size);
GGML_ASSERT(!candidates->sorted);

std::vector<float> logits_base;
logits_base.reserve(candidates->size);
for (size_t i = 0; i < candidates->size; ++i) {
logits_base.push_back(candidates->data[i].logit);
std::vector<float> logits_base(n_vocab);
for (size_t i = 0; i < n_vocab; ++i) {
logits_base[i] = candidates->data[i].logit;
}
llama_log_softmax(logits_base.data(), candidates->size);

float* logits_guidance = llama_get_logits(guidance_ctx);
llama_log_softmax(logits_guidance, n_vocab);
float * logits_guidance = llama_get_logits(guidance_ctx);

for (int i = 0; i < n_vocab; ++i) {
float logit_guidance = logits_guidance[i];
float logit_base = logits_base[i];
candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance;
}
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
llama_sample_apply_guidance(ctx, logits_base.data(), logits_guidance, scale);
t_start_sample_us = ggml_time_us();

if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
for (size_t i = 0; i < n_vocab; ++i) {
candidates->data[i].logit = logits_base[i];
}

ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}

llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
Expand Down
17 changes: 12 additions & 5 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,14 +714,21 @@ extern "C" {
float penalty_present);

/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
LLAMA_API void llama_sample_classifier_free_guidance(
/// @param logits Logits extracted from the original generation context.
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
LLAMA_API void llama_sample_apply_guidance(
struct llama_context * ctx,
float * logits,
float * logits_guidance,
float scale);

LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
float scale);
float scale),
"use llama_sample_apply_guidance() instead");

/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(
Expand Down