Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
}

common_init();
#if 0
if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
Expand Down Expand Up @@ -169,9 +168,9 @@ int main(int argc, char ** argv) {
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
llama_decode_ext(ctx_tgt, batch0);
llama_decode_ext(ctx_tgt, batch1);
llama_decode_ext(ctx_dft, batch2);
llama_decode_ext(ctx_tgt, batch0.get());
llama_decode_ext(ctx_tgt, batch1.get());
llama_decode_ext(ctx_dft, batch2.get());

const auto t_enc_end = ggml_time_us();

Expand Down Expand Up @@ -338,7 +337,7 @@ int main(int argc, char ** argv) {
if (i == s) {
continue;
}
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
// synchronize active status for sequences with the same drafted token
drafts[i].active = drafts[i].active && accept;
if (!drafts[i].active) {
Expand Down Expand Up @@ -446,7 +445,7 @@ int main(int argc, char ** argv) {

llama_batch_ext_clear(batch_dft);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true);

llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
Expand Down Expand Up @@ -475,13 +474,19 @@ int main(int argc, char ** argv) {
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;

llama_batch_ext_clear(batch_tgt);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
struct batch_info {
llama_token id;
llama_pos pos;
std::vector<llama_seq_id> seq_id;
};

std::vector<batch_info> batch_tgt_data;

batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} });

// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
batch_dft.n_tokens = 0;
llama_batch_ext_clear(batch_dft);

for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].skip = false;
Expand Down Expand Up @@ -512,11 +517,10 @@ int main(int argc, char ** argv) {
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);

// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
if (batch_tgt.seq_id[t][p] == s) {
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
batch_tgt.n_seq_id[t]++;
for (int t = 0; t < (int) batch_tgt_data.size(); ++t) {
for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) {
if (batch_tgt_data[t].seq_id[p] == s) {
batch_tgt_data[t].seq_id.push_back(n_seq_cur);
break;
}
}
Expand Down Expand Up @@ -558,32 +562,30 @@ int main(int argc, char ** argv) {
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});

// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
drafts[s].i_batch_tgt.push_back(batch_tgt_data.size());

common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }});

// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;

common_batch_add(batch_dft, id, n_past_cur, { s }, true);
drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true);

if (batch_tgt.n_tokens > n_draft) {
if (batch_tgt_data.size() > (size_t) n_draft) {
drafts[s].drafting = false;
}
}
}

// no sequence is drafting anymore
if (batch_dft.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch_dft) == 0) {
break;
}

// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch_dft);
llama_decode_ext(ctx_dft, batch_dft);
++n_past_cur;
++n_drafted;

if (batch_tgt.n_tokens > n_draft) {
if (batch_tgt_data.size() > (size_t) n_draft) {
break;
}
}
Expand All @@ -595,8 +597,15 @@ int main(int argc, char ** argv) {
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
}

llama_batch_ext_clear(batch_tgt);
for (int i = 0; i < (int) batch_tgt_data.size(); ++i) {
const auto & data = batch_tgt_data[i];

llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true);
}

// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
llama_decode_ext(ctx_tgt, batch_tgt);
++n_past_tgt;
}

Expand Down Expand Up @@ -639,12 +648,12 @@ int main(int argc, char ** argv) {
common_sampler_free(drafts[s].smpl);
}

llama_batch_free(batch_dft);
llama_batch_ext_free(batch_dft);
llama_batch_ext_free(batch_tgt);

llama_backend_free();

LOG("\n\n");

#endif
return 0;
}