diff --git a/common/sampling.cpp b/common/sampling.cpp index 7fc2e2158d5c4..3cd28d46863a4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -189,12 +189,16 @@ static llama_token llama_sampling_sample_impl( std::vector original_logits; auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); + if (cur_p.data == NULL) { + return -1; + } if (ctx_sampling->grammar != NULL && !is_resampling) { GGML_ASSERT(!original_logits.empty()); } llama_token id = 0; // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + GGML_ASSERT(logits); // already checked in llama_sampling_prepare if (temp < 0.0) { // greedy sampling, with probs @@ -284,6 +288,9 @@ static llama_token_data_array llama_sampling_prepare_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + if (!logits) { + return {NULL, 0, false}; + } if (ctx_sampling->grammar != NULL && !apply_grammar) { GGML_ASSERT(original_logits != NULL); @@ -298,6 +305,9 @@ static llama_token_data_array llama_sampling_prepare_impl( if (ctx_cfg) { float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); + if (!logits_guidance) { + return {NULL, 0, false}; + } llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); } diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index be30d20bf8194..9812dcb3dfd62 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -169,6 +169,9 @@ int main(int argc, char ** argv) { auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, i_batch[i]); + if (!logits) { + return 1; + } std::vector candidates; candidates.reserve(n_vocab); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 52fd719b38ee5..5a60ca9cf8259 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -58,6 +58,9 @@ static std::vector> encode(llama_context * ctx, const std::ve // sum up all token embeddings for (int32_t k = n_inst; k < n_toks; k++) { float * emb = llama_get_embeddings_ith(ctx, k); + if (!emb) { + throw std::runtime_error("llama_get_embeddings_ith failed"); + } for (uint64_t j = 0; j < n_embd; j++) { emb_unorm[j] += emb[j]; } @@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_decode(ctx, bat); auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); + if (!logits) { + throw std::runtime_error("llama_get_logits_ith failed"); + } auto candidates = std::vector(llama_n_vocab(mdl)); auto n_candidates = (int32_t)candidates.size(); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index afac145f63934..f1a86346224f2 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -530,6 +530,9 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + if (id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx, id, true); diff --git a/examples/llama.android/app/src/main/cpp/llama-android.cpp b/examples/llama.android/app/src/main/cpp/llama-android.cpp index 4af9de3038359..e20ba7a8edca3 100644 --- a/examples/llama.android/app/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/app/src/main/cpp/llama-android.cpp @@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop( auto n_vocab = llama_n_vocab(model); auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); + if (!logits) { + throw std::runtime_error("llama_get_logits_ith failed"); + } std::vector candidates; candidates.reserve(n_vocab); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index a6d67e5d72cd2..e293edda018e6 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -44,6 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_llama, int * n_past) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); + GGML_ASSERT(id != -1); llama_sampling_accept(ctx_sampling, ctx_llama, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 9c3540b2008c2..dcf3c21751743 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -159,6 +159,9 @@ int main(int argc, char ** argv) { // sample first token { id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + if (id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx, id, true); @@ -284,6 +287,9 @@ int main(int argc, char ** argv) { // sample the next token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); + if (id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx, id, true); @@ -361,6 +367,9 @@ int main(int argc, char ** argv) { // sample from the last level for (int i = 0; i < W; i++) { tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + if (tokens_j[N - 2][i] == -1) { + return 1; + } } } else { for (int i = 0; i < W; i++) { diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index eebbd00a58e66..01e02f182649d 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -131,6 +131,7 @@ int main(int argc, char ** argv){ while (true) { // sample from the target model llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); + GGML_ASSERT(id != -1); llama_sampling_accept(ctx_sampling, ctx, id, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 832b51ee086be..3c2ba844b86f6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -706,6 +706,9 @@ int main(int argc, char ** argv) { } const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + if (id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7c5595d6edb2d..91fdd73fbeeeb 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -341,6 +341,7 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); + GGML_ASSERT(id != -1); llama_sampling_accept(client.ctx_sampling, ctx, id, true); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index f2ef9ca10d4a2..36757601693dd 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -239,6 +239,9 @@ int main(int argc, char ** argv) { { auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + if (!logits) { + return 1; + } std::vector candidates; candidates.reserve(n_vocab); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index bae014e6f4c16..d196af6e67792 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par for (int seq = 0; seq < n_seq_batch; seq++) { const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); + if (!all_logits) { + return {std::move(tokens), -1, {}, {}}; + } llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; if (!params.logits_file.empty()) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6af5cb96e6d13..908802aa82c52 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2257,6 +2257,14 @@ struct server_context { completion_token_output result; const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + if (id == -1) { + send_error(slot, "can't get completions out of an embeddings model"); + slot.cache_tokens.clear(); + slot.reset(); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } llama_sampling_accept(slot.ctx_sampling, ctx, id, true); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index b0f8e0fdc4987..63ba48c1bc965 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -120,6 +120,9 @@ int main(int argc, char ** argv) { { auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + if (!logits) { + return 1; + } std::vector candidates; candidates.reserve(n_vocab); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 12e46fbc91a24..78932cd0c40ad 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -229,6 +229,9 @@ int main(int argc, char ** argv) { // stochastic verification llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); + if (dist_tgt.data == NULL) { + return 1; + } llama_sample_softmax(ctx_tgt, &dist_tgt); float p_tgt = 0, p_dft = 0; @@ -337,6 +340,9 @@ int main(int argc, char ** argv) { // sample from the target model LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + if (token_id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); @@ -457,7 +463,9 @@ int main(int argc, char ** argv) { continue; } - llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); + if (llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft) == -1) { + return -1; + } const auto & cur_p = drafts[s].ctx_sampling->cur; diff --git a/llama.cpp b/llama.cpp index abff8c1c03e7a..16cbc65dc4749 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17301,42 +17301,39 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits; } +static float * llama_get_logits_ith_fail(int i, const std::string & reason) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str()); +#ifndef NDEBUG + GGML_ASSERT(false); +#endif + return nullptr; +} + float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { int32_t j = -1; llama_synchronize(ctx); - - try { - if (ctx->logits == nullptr) { - throw std::runtime_error("no logits"); - } - - if (i < 0) { - j = ctx->n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); - } - } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); - } else { - j = ctx->output_ids[i]; - } - + if (ctx->logits == nullptr) { + // this can happen for embeddings models like bert + return llama_get_logits_ith_fail(i, "no logits"); + } + if (i < 0) { + j = ctx->n_outputs + i; if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= ctx->n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); + return llama_get_logits_ith_fail(i, format("negative index out of range [%d, 0)", -ctx->n_outputs)); } - - return ctx->logits + j*ctx->model.hparams.n_vocab; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); -#ifndef NDEBUG - GGML_ASSERT(false); -#endif - return nullptr; + } else if ((size_t) i >= ctx->output_ids.size()) { + return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; } + if (j < 0) { + return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i)); + } + if (j >= ctx->n_outputs) { + // This should not happen + return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); + } + return ctx->logits + j*ctx->model.hparams.n_vocab; } float * llama_get_embeddings(struct llama_context * ctx) { @@ -17345,43 +17342,43 @@ float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embd; } +static float * llama_get_embeddings_ith_fail(int i, const std::string & reason) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str()); +#ifndef NDEBUG + GGML_ASSERT(false); +#endif + return nullptr; +} + float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { int32_t j = -1; - llama_synchronize(ctx); - - try { - if (ctx->embd == nullptr) { - throw std::runtime_error("no embeddings"); - } - - if (i < 0) { - j = ctx->n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); - } - } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); - } else { - j = ctx->output_ids[i]; - } - + if (ctx->embd == nullptr) { + return llama_get_embeddings_ith_fail(i, "no embeddings"); + } + if (i < 0) { + j = ctx->n_outputs + i; if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= ctx->n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); + return llama_get_embeddings_ith_fail( + i, format("negative index out of range [%d, 0)", -ctx->n_outputs)); } - - return ctx->embd + j*ctx->model.hparams.n_embd; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); -#ifndef NDEBUG - GGML_ASSERT(false); -#endif - return nullptr; + } else if ((size_t) i >= ctx->output_ids.size()) { + return llama_get_embeddings_ith_fail( + i, format("out of range [0, %lu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; + } + if (j < 0) { + return llama_get_embeddings_ith_fail( + i, format("batch.logits[%d] != true", i)); + } + if (j >= ctx->n_outputs) { + // This should not happen + return llama_get_embeddings_ith_fail( + i, format("corrupt output buffer (j=%d, n_outputs=%d)", + j, ctx->n_outputs)); } + return ctx->embd + j*ctx->model.hparams.n_embd; } float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {