diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e83b0964c855..41df275c9605a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -311,6 +311,29 @@ llama_tokens common_speculative_gen_draft( common_sampler_reset(smpl); + // read in from an environment variable for now + const std::vector batch_costs = []() { + std::vector costs; + if (const char* env = std::getenv("GGML_BATCH_COSTS")) { + for (const char* p = env; *p; ) { + char* end; + costs.push_back(std::strtof(p, &end)); + p = *end == ',' ? end + 1 : end; + } + } + return costs; + }(); + GGML_ASSERT(batch_costs.size() >= 2 && "GGML_BATCH_COSTS must have at least 2 values"); + + // read in from an environment variable for now (default = 0) + const size_t max_look_ahead = std::getenv("GGML_MAX_LOOK_AHEAD") ? atoi(getenv("GGML_MAX_LOOK_AHEAD")) : 0; + + // the current sequence probability, as predicted by the draft + float sequence_p = 1.0; + + // the longest draft size we have seen that is +EV + size_t best_size = 0; + // sample n_draft tokens from the draft model for (int i = 0; i < params.n_draft; ++i) { common_batch_clear(batch); @@ -335,8 +358,11 @@ llama_tokens common_speculative_gen_draft( break; } - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { + // only collect +EV draft tokens + sequence_p *= cur_p->data[0].p; + if (sequence_p > batch_costs[std::min(result.size(), batch_costs.size() - 1)]) { + best_size = result.size(); + } else if (sequence_p <= batch_costs[std::min(result.size() + max_look_ahead, batch_costs.size() - 1)]) { break; } @@ -348,6 +374,9 @@ llama_tokens common_speculative_gen_draft( prompt_dft.push_back(id); } + // truncate to the best we saw that was +EV + result.resize(best_size); + if (!spec->vocab_dft_compatible) { std::string detokenized = common_detokenize(ctx_dft, result, true); detokenized = replace_to_tgt(spec, detokenized);