Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for llama_get_logits_ith() errors #7448

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,16 @@ static llama_token llama_sampling_sample_impl(

std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
if (cur_p.data == NULL) {
return -1;
}
if (ctx_sampling->grammar != NULL && !is_resampling) {
GGML_ASSERT(!original_logits.empty());
}
llama_token id = 0;
// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);
GGML_ASSERT(logits); // already checked in llama_sampling_prepare

if (temp < 0.0) {
// greedy sampling, with probs
Expand Down Expand Up @@ -284,6 +288,9 @@ static llama_token_data_array llama_sampling_prepare_impl(

// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);
if (!logits) {
return {NULL, 0, false};
}

if (ctx_sampling->grammar != NULL && !apply_grammar) {
GGML_ASSERT(original_logits != NULL);
Expand All @@ -298,6 +305,9 @@ static llama_token_data_array llama_sampling_prepare_impl(

if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
if (!logits_guidance) {
return {NULL, 0, false};
}
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}

Expand Down
3 changes: 3 additions & 0 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ int main(int argc, char ** argv) {

auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
if (!logits) {
return 1;
}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
Expand Down
6 changes: 6 additions & 0 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// sum up all token embeddings
for (int32_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
if (!emb) {
throw std::runtime_error("llama_get_embeddings_ith failed");
}
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j];
}
Expand Down Expand Up @@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo

llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
}

auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
auto n_candidates = (int32_t)candidates.size();
Expand Down
3 changes: 3 additions & 0 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
if (id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down
3 changes: 3 additions & 0 deletions examples/llama.android/app/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop(

auto n_vocab = llama_n_vocab(model);
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
Expand Down
1 change: 1 addition & 0 deletions examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
GGML_ASSERT(id != -1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes it's a return 1, other times it's an assertion, or an exception.
Which to use when? Should a single way be chosen?

llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
Expand Down
9 changes: 9 additions & 0 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ int main(int argc, char ** argv) {
// sample first token
{
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
if (id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down Expand Up @@ -284,6 +287,9 @@ int main(int argc, char ** argv) {

// sample the next token
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
if (id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down Expand Up @@ -361,6 +367,9 @@ int main(int argc, char ** argv) {
// sample from the last level
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
if (tokens_j[N - 2][i] == -1) {
return 1;
}
}
} else {
for (int i = 0; i < W; i++) {
Expand Down
1 change: 1 addition & 0 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ int main(int argc, char ** argv){
while (true) {
// sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
GGML_ASSERT(id != -1);

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down
3 changes: 3 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,9 @@ int main(int argc, char ** argv) {
}

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
if (id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);

Expand Down
1 change: 1 addition & 0 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);

const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
GGML_ASSERT(id != -1);

llama_sampling_accept(client.ctx_sampling, ctx, id, true);

Expand Down
3 changes: 3 additions & 0 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ int main(int argc, char ** argv) {
{
auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
if (!logits) {
return 1;
}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
Expand Down
3 changes: 3 additions & 0 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par

for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
if (!all_logits) {
return {std::move(tokens), -1, {}, {}};
}

llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
if (!params.logits_file.empty()) {
Expand Down
3 changes: 3 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2257,6 +2257,9 @@ struct server_context {

completion_token_output result;
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
if (id == -1) {
continue; // keep going, don't crash, already logged
jart marked this conversation as resolved.
Show resolved Hide resolved
}

llama_sampling_accept(slot.ctx_sampling, ctx, id, true);

Expand Down
3 changes: 3 additions & 0 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ int main(int argc, char ** argv) {
{
auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
if (!logits) {
return 1;
}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
Expand Down
10 changes: 9 additions & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ int main(int argc, char ** argv) {
// stochastic verification

llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
if (dist_tgt.data == NULL) {
return 1;
}
llama_sample_softmax(ctx_tgt, &dist_tgt);
float p_tgt = 0, p_dft = 0;

Expand Down Expand Up @@ -337,6 +340,9 @@ int main(int argc, char ** argv) {
// sample from the target model
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
if (token_id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);

Expand Down Expand Up @@ -457,7 +463,9 @@ int main(int argc, char ** argv) {
continue;
}

llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
if (llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft) == -1) {
return -1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's returning -1 here while it's returning 1 in other places for the same kind of check in the same file. Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It returns 1 from main() functions to exit(1).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. But it's -1 (negative one) here, so the exit code will be 255.

}

const auto & cur_p = drafts[s].ctx_sampling->cur;

Expand Down
119 changes: 58 additions & 61 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17301,42 +17301,39 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits;
}

static float * llama_get_logits_ith_fail(int i, const std::string & reason) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
}

float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1;
llama_synchronize(ctx);

try {
if (ctx->logits == nullptr) {
throw std::runtime_error("no logits");
}

if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}

if (ctx->logits == nullptr) {
// this can happen for embeddings models like bert
return llama_get_logits_ith_fail(i, "no logits");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
return llama_get_logits_ith_fail(i, format("negative index out of range [%d, 0)", -ctx->n_outputs));
}

return ctx->logits + j*ctx->model.hparams.n_vocab;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
} else if ((size_t) i >= ctx->output_ids.size()) {
return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->logits + j*ctx->model.hparams.n_vocab;
}

float * llama_get_embeddings(struct llama_context * ctx) {
Expand All @@ -17345,43 +17342,43 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embd;
}

static float * llama_get_embeddings_ith_fail(int i, const std::string & reason) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
}

float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1;

llama_synchronize(ctx);

try {
if (ctx->embd == nullptr) {
throw std::runtime_error("no embeddings");
}

if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}

if (ctx->embd == nullptr) {
return llama_get_embeddings_ith_fail(i, "no embeddings");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
return llama_get_embeddings_ith_fail(
i, format("negative index out of range [%d, 0)", -ctx->n_outputs));
}

return ctx->embd + j*ctx->model.hparams.n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
} else if ((size_t) i >= ctx->output_ids.size()) {
return llama_get_embeddings_ith_fail(
i, format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
return llama_get_embeddings_ith_fail(
i, format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
return llama_get_embeddings_ith_fail(
i, format("corrupt output buffer (j=%d, n_outputs=%d)",
j, ctx->n_outputs));
}
return ctx->embd + j*ctx->model.hparams.n_embd;
}

float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
Expand Down
Loading