From 010571490fd3d7ff52715f452d6a0d88a9b3599a Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Wed, 22 May 2024 12:14:24 -0500 Subject: [PATCH 1/6] create append_pooling operation; allow to specify attention_type; add last token pooling; update examples --- common/common.cpp | 14 +++ common/common.h | 1 + examples/embedding/embedding.cpp | 38 ++++++--- examples/gritlm/gritlm.cpp | 20 +++-- examples/retrieval/retrieval.cpp | 25 +++++- llama.cpp | 142 ++++++++++++++++++++----------- llama.h | 9 +- 7 files changed, 175 insertions(+), 74 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 1591790e6df4c..d3afbeded90e5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -542,6 +542,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else { invalid_param = true; } + return true; + } + if (arg == "--attention") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::string value(argv[i]); + /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } + else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; } else { invalid_param = true; } return true; } @@ -1820,6 +1832,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "backend" }); options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); + if (llama_supports_mlock()) { options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); } @@ -2447,6 +2460,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; + cparams.attention_type = params.attention_type; cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; diff --git a/common/common.h b/common/common.h index 2345d855eed3c..2640c65f3d1f5 100644 --- a/common/common.h +++ b/common/common.h @@ -94,6 +94,7 @@ struct gpt_params { enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type // // sampling parameters struct llama_sampling_params sparams; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 244751e003d9e..9d66b5477293b 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -17,9 +17,25 @@ static std::vector split_lines(const std::string & s) { return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { - for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); +static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { + switch (pooling_type) { + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_NONE: + return true; + case LLAMA_POOLING_TYPE_CLS: + return pos == 0; + case LLAMA_POOLING_TYPE_LAST: + return pos == n_tokens - 1; + default: + GGML_ASSERT(false && "unsupported pooling type"); + } +} + +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id, enum llama_pooling_type pooling_type) { + int n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + bool logit = needs_logit(pooling_type, i, n_tokens); + llama_batch_add(batch, tokens[i], i, { seq_id }, logit); } } @@ -40,13 +56,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu // try to get sequence embeddings - supported only when pooling_type is not NONE const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - if (embd == NULL) { - fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); - continue; - } - } + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); float * out = output + batch.seq_id[i][0] * n_embd; //TODO: I would also add a parameter here to enable normalization or not. @@ -97,6 +107,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } + if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); @@ -176,7 +192,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s); + batch_add_seq(batch, inp, s, pooling_type); s += 1; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2135157916c97..8a3bdef83774a 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -44,7 +44,6 @@ static std::vector> encode(llama_context * ctx, const std::ve // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); - llama_set_causal_attn(ctx, false); // run model llama_decode(ctx, batch); @@ -98,7 +97,6 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_token eos_token = llama_token_eos(mdl); llama_kv_cache_clear(ctx); - llama_set_causal_attn(ctx, true); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector inputs = llama_tokenize(mdl, prompt, false, true); @@ -166,9 +164,14 @@ int main(int argc, char * argv[]) { llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); - // create new context - set to embedding mode + // create generation context + llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams); + + // create embedding context cparams.embeddings = true; - llama_context * ctx = llama_new_context_with_model(mdl, cparams); + cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; + cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; + llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -186,8 +189,8 @@ int main(int argc, char * argv[]) { }; // No need to add instruction for retrieval documents - const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); - const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); + const std::vector> d_rep = encode(ctx_emb, documents, gritlm_instruction("")); + const std::vector> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction)); const int n_embd = llama_n_embd(mdl); @@ -206,10 +209,11 @@ int main(int argc, char * argv[]) { // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; - std::string response = generate(ctx, prompt, true); + std::string response = generate(ctx_gen, prompt, true); } - llama_free(ctx); + llama_free(ctx_gen); + llama_free(ctx_emb); llama_free_model(mdl); llama_backend_free(); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 55b7b2f70ae2a..ee43430f0d3f5 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -73,9 +73,25 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { +static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { + switch (pooling_type) { + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_NONE: + return true; + case LLAMA_POOLING_TYPE_CLS: + return pos == 0; + case LLAMA_POOLING_TYPE_LAST: + return pos == n_tokens - 1; + default: + GGML_ASSERT(false && "unsupported pooling type"); + } +} + +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id, enum llama_pooling_type pooling_type) { + int n_tokens = tokens.size(); for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); + bool logit = needs_logit(pooling_type, i, n_tokens); + llama_batch_add(batch, tokens[i], i, { seq_id }, logit); } } @@ -159,6 +175,7 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", @@ -230,7 +247,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s); + batch_add_seq(batch, inp, s, pooling_type); s += 1; } @@ -253,7 +270,7 @@ int main(int argc, char ** argv) { std::vector query_tokens = llama_tokenize(ctx, query, true); struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); - batch_add_seq(query_batch, query_tokens, 0); + batch_add_seq(query_batch, query_tokens, 0, pooling_type); std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); diff --git a/llama.cpp b/llama.cpp index 05591aa4389a7..2f906af7da815 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7435,6 +7435,44 @@ struct llm_build_context { return lctx.inp_s_seq; } + struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { + struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1]; + if (strcmp(inp->name, "result_embd") != 0) { + inp = gf->nodes[gf->n_nodes - 2]; + GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found"); + } + + struct ggml_tensor * cur; + + switch (pooling_type) { + case LLAMA_POOLING_TYPE_MEAN: + { + struct ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } break; + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + struct ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } break; + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; + default: + { + GGML_ASSERT(false && "unknown pooling type"); + } break; + } + + cb(cur, "result_embd_pooled", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -8415,8 +8453,6 @@ struct llm_build_context { if (model.arch != LLM_ARCH_JINA_BERT_V2) { inp_pos = build_inp_pos(); } - struct ggml_tensor * inp_mean = build_inp_mean(); - struct ggml_tensor * inp_cls = build_inp_cls(); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); @@ -8591,28 +8627,6 @@ struct llm_build_context { cur = inpL; cb(cur, "result_embd", -1); - // pooling layer - switch (pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // nop - } break; - case LLAMA_POOLING_TYPE_MEAN: - { - cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); - cb(cur, "result_embd_pooled", -1); - } break; - case LLAMA_POOLING_TYPE_CLS: - { - cur = ggml_get_rows(ctx0, cur, inp_cls); - cb(cur, "result_embd_pooled", -1); - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ASSERT(false && "Invalid pooling type"); - } break; - } - ggml_build_forward_expand(gf, cur); return gf; @@ -11697,6 +11711,11 @@ static struct ggml_cgraph * llama_build_graph( GGML_ASSERT(false); } + // add on pooling layer + if (lctx.cparams.embeddings) { + result = llm.append_pooling(result); + } + llm.free(); return result; @@ -11918,6 +11937,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } + if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(lctx.inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); + + uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + for (int i = 0; i < n_tokens; ++i) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_pos pos = batch.pos[i]; + + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = i; + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } + if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; @@ -12245,30 +12295,13 @@ static int llama_decode_internal( // no output res = nullptr; embd = nullptr; - } else if (!hparams.causal_attn) { - res = nullptr; // do not extract logits for embedding models such as BERT - - // token or sequence embeddings - embd = gf->nodes[gf->n_nodes - 1]; - - GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); } else if (cparams.embeddings) { - // the embeddings could be in the second to last tensor, or any of the previous tensors - int i_embd = gf->n_nodes - 2; - for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { - i_embd = gf->n_nodes - i; - if (i_embd < 0) { break; } - embd = gf->nodes[i_embd]; - } - GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor"); - - // TODO: use a per-batch flag to know when to skip logits while keeping embeddings - if (!cparams.causal_attn) { - res = nullptr; // do not extract logits when not needed - // skip computing logits - // TODO: is this safe? - gf->n_nodes = i_embd + 1; + res = nullptr; // do not extract logits for embedding case + embd = gf->nodes[gf->n_nodes - 1]; + if (strcmp(embd->name, "result_embd_pooled") != 0) { + embd = gf->nodes[gf->n_nodes - 2]; } + GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); } else { embd = nullptr; // do not extract embeddings when not needed GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); @@ -12337,11 +12370,10 @@ static int llama_decode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); } } break; - case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: { - GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0); - // extract sequence embeddings auto & embd_seq_out = lctx.embd_seq; embd_seq_out.clear(); @@ -15893,6 +15925,7 @@ struct llama_context_params llama_context_default_params() { /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, + /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, @@ -16134,7 +16167,12 @@ struct llama_context * llama_new_context_with_model( } cparams.yarn_attn_factor *= hparams.rope_attn_factor; - cparams.causal_attn = hparams.causal_attn; + + if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { + cparams.causal_attn = hparams.causal_attn; + } else { + cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; + } if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { @@ -16494,6 +16532,10 @@ enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { return ctx->cparams.pooling_type; } +bool llama_causal_attn(const struct llama_context * ctx) { + return ctx->cparams.causal_attn; +} + int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } diff --git a/llama.h b/llama.h index da310ffaf9ad9..66635a57363e1 100644 --- a/llama.h +++ b/llama.h @@ -174,6 +174,13 @@ extern "C" { LLAMA_POOLING_TYPE_NONE = 0, LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, + LLAMA_POOLING_TYPE_LAST = 3, + }; + + enum llama_attention_type { + LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1, + LLAMA_ATTENTION_TYPE_CAUSAL = 0, + LLAMA_ATTENTION_TYPE_NONCAUSAL = 1, }; enum llama_split_mode { @@ -293,7 +300,7 @@ extern "C" { enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id - // (ignored if no pooling layer) + enum llama_attention_type attention_type; // causal, non-causal, or unspecified // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model From 1756c4b5b69d1bd2c36d55f235be2425ec3a138e Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Wed, 22 May 2024 22:42:08 -0500 Subject: [PATCH 2/6] find result_norm/result_embd tensors properly; update output allocation logic --- examples/embedding/embedding.cpp | 4 ++-- examples/retrieval/retrieval.cpp | 6 +++--- llama.cpp | 18 ++++++++++++------ 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9d66b5477293b..659c90245f318 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -31,8 +31,8 @@ static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tok } } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id, enum llama_pooling_type pooling_type) { - int n_tokens = tokens.size(); +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) { + size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { bool logit = needs_logit(pooling_type, i, n_tokens); llama_batch_add(batch, tokens[i], i, { seq_id }, logit); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index ee43430f0d3f5..3501a0eb34ba9 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -87,9 +87,9 @@ static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tok } } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id, enum llama_pooling_type pooling_type) { - int n_tokens = tokens.size(); - for (size_t i = 0; i < tokens.size(); i++) { +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { bool logit = needs_logit(pooling_type, i, n_tokens); llama_batch_add(batch, tokens[i], i, { seq_id }, logit); } diff --git a/llama.cpp b/llama.cpp index 2f906af7da815..60d562b0053b5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7436,11 +7436,17 @@ struct llm_build_context { } struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { - struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1]; - if (strcmp(inp->name, "result_embd") != 0) { - inp = gf->nodes[gf->n_nodes - 2]; - GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found"); + // find result_norm tensor for input + struct ggml_tensor * inp = nullptr; + for (int i = gf->n_nodes - 1; i >= 0; --i) { + inp = gf->nodes[i]; + if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) { + break; + } else { + inp = nullptr; + } } + GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor"); struct ggml_tensor * cur; @@ -12029,8 +12035,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead - const bool has_logits = cparams.causal_attn; - const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + const bool has_logits = !cparams.embeddings; + const bool has_embd = cparams.embeddings; const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; From 7c37ae9d29761e8fc7d4c704fa66f2e0b7a0e728 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Tue, 4 Jun 2024 01:20:13 -0500 Subject: [PATCH 3/6] only use embd output for pooling_type NONE --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 60d562b0053b5..98aa0ab3c4a7b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11779,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -11811,7 +11811,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // (!a || b) is a logical implication (a -> b) // !hparams.causal_attn -> !cparams.causal_attn (hparams.causal_attn || !cparams.causal_attn) && - "causal attention with embedding models is not supported" + "causal attention is not supported by this model" ); if (lctx.inp_KQ_mask) { @@ -12036,7 +12036,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { // TODO: use a per-batch flag for logits presence instead const bool has_logits = !cparams.embeddings; - const bool has_embd = cparams.embeddings; + const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; From d4e6972f60587cde173974adabe7b11a438fdbcb Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Tue, 4 Jun 2024 11:16:37 -0500 Subject: [PATCH 4/6] get rid of old causal_attn accessor --- llama.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 98aa0ab3c4a7b..56e4a956c69ca 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16538,10 +16538,6 @@ enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { return ctx->cparams.pooling_type; } -bool llama_causal_attn(const struct llama_context * ctx) { - return ctx->cparams.causal_attn; -} - int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } From 8093253b41dcc475ba160c97a7435f41b746d04a Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 6 Jun 2024 15:11:25 -0500 Subject: [PATCH 5/6] take out attention_type; add in llama_set_embeddings --- common/common.cpp | 12 ------------ common/common.h | 1 - examples/gritlm/gritlm.cpp | 22 ++++++++++------------ llama.cpp | 12 +++++------- llama.h | 11 ++++------- 5 files changed, 19 insertions(+), 39 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d3afbeded90e5..7233953cb2e65 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -546,17 +546,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa else { invalid_param = true; } return true; } - if (arg == "--attention") { - if (++i >= argc) { - invalid_param = true; - return true; - } - std::string value(argv[i]); - /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } - else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; } - else { invalid_param = true; } - return true; - } if (arg == "--defrag-thold" || arg == "-dt") { if (++i >= argc) { invalid_param = true; @@ -2460,7 +2449,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; - cparams.attention_type = params.attention_type; cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; diff --git a/common/common.h b/common/common.h index 2640c65f3d1f5..2345d855eed3c 100644 --- a/common/common.h +++ b/common/common.h @@ -94,7 +94,6 @@ struct gpt_params { enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings - enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type // // sampling parameters struct llama_sampling_params sparams; diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 8a3bdef83774a..2c61c2e1eb3bc 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -44,6 +44,8 @@ static std::vector> encode(llama_context * ctx, const std::ve // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, true); + llama_set_causal_attn(ctx, false); // run model llama_decode(ctx, batch); @@ -97,6 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_token eos_token = llama_token_eos(mdl); llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, false); + llama_set_causal_attn(ctx, true); + llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector inputs = llama_tokenize(mdl, prompt, false, true); @@ -165,13 +170,7 @@ int main(int argc, char * argv[]) { llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams); - - // create embedding context - cparams.embeddings = true; - cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; - cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; - llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams); + llama_context * ctx = llama_new_context_with_model(mdl, cparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -189,8 +188,8 @@ int main(int argc, char * argv[]) { }; // No need to add instruction for retrieval documents - const std::vector> d_rep = encode(ctx_emb, documents, gritlm_instruction("")); - const std::vector> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction)); + const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); + const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); const int n_embd = llama_n_embd(mdl); @@ -209,11 +208,10 @@ int main(int argc, char * argv[]) { // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; - std::string response = generate(ctx_gen, prompt, true); + std::string response = generate(ctx, prompt, true); } - llama_free(ctx_gen); - llama_free(ctx_emb); + llama_free(ctx); llama_free_model(mdl); llama_backend_free(); diff --git a/llama.cpp b/llama.cpp index 56e4a956c69ca..42be6bed95e5f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15931,7 +15931,6 @@ struct llama_context_params llama_context_default_params() { /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, - /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, @@ -16173,12 +16172,7 @@ struct llama_context * llama_new_context_with_model( } cparams.yarn_attn_factor *= hparams.rope_attn_factor; - - if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { - cparams.causal_attn = hparams.causal_attn; - } else { - cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; - } + cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { @@ -17914,6 +17908,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback) ctx->abort_callback_data = abort_callback_data; } +void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { + ctx->cparams.embeddings = embeddings; +} + void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { ctx->cparams.causal_attn = causal_attn; } diff --git a/llama.h b/llama.h index 66635a57363e1..05d8b092b42a4 100644 --- a/llama.h +++ b/llama.h @@ -177,12 +177,6 @@ extern "C" { LLAMA_POOLING_TYPE_LAST = 3, }; - enum llama_attention_type { - LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1, - LLAMA_ATTENTION_TYPE_CAUSAL = 0, - LLAMA_ATTENTION_TYPE_NONCAUSAL = 1, - }; - enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -300,7 +294,6 @@ extern "C" { enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id - enum llama_attention_type attention_type; // causal, non-causal, or unspecified // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model @@ -793,6 +786,10 @@ extern "C" { // Get the number of threads used for prompt and batch processing (multiple token). LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx); + // Set whether the model is in embeddings model or not + // If true, embeddings will be returned but logits will not + LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); + // Set whether to use causal attention or not // If set to true, the model will only attend to the past tokens LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); From 5cc7b453a49fb2dc15e6207e9714be894480a6f2 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Mon, 17 Jun 2024 00:24:33 -0600 Subject: [PATCH 6/6] bypass logits when doing non-NONE pooling --- examples/embedding/embedding.cpp | 21 +++------------------ examples/retrieval/retrieval.cpp | 28 +++++++++------------------- llama.cpp | 8 +++++--- 3 files changed, 17 insertions(+), 40 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 659c90245f318..b4b73c0175cda 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -17,25 +17,10 @@ static std::vector split_lines(const std::string & s) { return lines; } -static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { - switch (pooling_type) { - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_NONE: - return true; - case LLAMA_POOLING_TYPE_CLS: - return pos == 0; - case LLAMA_POOLING_TYPE_LAST: - return pos == n_tokens - 1; - default: - GGML_ASSERT(false && "unsupported pooling type"); - } -} - -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) { +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - bool logit = needs_logit(pooling_type, i, n_tokens); - llama_batch_add(batch, tokens[i], i, { seq_id }, logit); + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -192,7 +177,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s, pooling_type); + batch_add_seq(batch, inp, s); s += 1; } diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 3501a0eb34ba9..eb89d16daf18d 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -73,25 +73,10 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { - switch (pooling_type) { - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_NONE: - return true; - case LLAMA_POOLING_TYPE_CLS: - return pos == 0; - case LLAMA_POOLING_TYPE_LAST: - return pos == n_tokens - 1; - default: - GGML_ASSERT(false && "unsupported pooling type"); - } -} - -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) { +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - bool logit = needs_logit(pooling_type, i, n_tokens); - llama_batch_add(batch, tokens[i], i, { seq_id }, logit); + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -175,7 +160,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", @@ -247,7 +237,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s, pooling_type); + batch_add_seq(batch, inp, s); s += 1; } @@ -270,7 +260,7 @@ int main(int argc, char ** argv) { std::vector query_tokens = llama_tokenize(ctx, query, true); struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); - batch_add_seq(query_batch, query_tokens, 0, pooling_type); + batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); diff --git a/llama.cpp b/llama.cpp index 42be6bed95e5f..ef48e3c244d9b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11779,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -12166,11 +12166,13 @@ static int llama_decode_internal( std::vector> seq_id; // count outputs - if (batch_all.logits) { + if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) { + n_outputs = n_tokens_all; + } else if (batch_all.logits) { for (uint32_t i = 0; i < n_tokens_all; ++i) { n_outputs += batch_all.logits[i] != 0; } - } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { + } else if (lctx.logits_all) { n_outputs = n_tokens_all; } else { // keep last output only