diff --git a/README.md b/README.md index 45c5d06f3e10e..f754022de894d 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ### Recent API changes +- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796 - [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849 ### Hot topics diff --git a/common/common.cpp b/common/common.cpp index 036a981349a69..c244db6443eaa 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1292,7 +1292,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.seed = params.seed; cparams.logits_all = params.logits_all; - cparams.embedding = params.embedding; + cparams.embeddings = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; cparams.rope_freq_base = params.rope_freq_base; cparams.rope_freq_scale = params.rope_freq_scale; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index acff715e99d05..ff5883da6ba27 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -19,11 +19,11 @@ static std::vector split_lines(const std::string & s) { static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, false); + llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); } } -static void normalize(float * vec, float * out, int n) { +static void normalize(const float * vec, float * out, int n) { float norm = 0; for (int i = 0; i < n; i++) { norm += vec[i] * vec[i]; @@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } // normalize on copy - for (int k = 0; k < n_seq; k++) { - float * emb = llama_get_embeddings_ith(ctx, k); - float * out = output + k * n_embd; - normalize(emb, out, n_embd); + for (int i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + + // try to get sequence embeddings - supported only when pooling_type is not NONE + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + if (embd == NULL) { + fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); + continue; + } + } + + float * out = output + batch.seq_id[i][0] * n_embd; + normalize(embd, out, n_embd); } } @@ -132,7 +145,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts); + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); // allocate output const int n_embd = llama_n_embd(model); @@ -145,6 +158,7 @@ int main(int argc, char ** argv) { for (int k = 0; k < n_prompts; k++) { // clamp to n_batch tokens auto & inp = inputs[k]; + const uint64_t n_toks = inp.size(); // encode if at capacity diff --git a/examples/server-embd.py b/examples/server-embd.py new file mode 100644 index 0000000000000..c5c4ea87b09fc --- /dev/null +++ b/examples/server-embd.py @@ -0,0 +1,34 @@ +import asyncio +import requests +import numpy as np + +n = 8 + +result = [] + +async def requests_post_async(*args, **kwargs): + return await asyncio.to_thread(requests.post, *args, **kwargs) + +async def main(): + model_url = "http://127.0.0.1:6900" + responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( + url= f"{model_url}/embedding", + json= {"content": str(i)*1024} + ) for i in range(n)]) + + for response in responses: + embedding = response.json()["embedding"] + print(embedding[-8:]) + result.append(embedding) + +asyncio.run(main()) + +# compute cosine similarity + +for i in range(n-1): + for j in range(i+1, n): + embedding1 = np.array(result[i]) + embedding2 = np.array(result[j]) + similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) + print(f"Similarity between {i} and {j}: {similarity:.2f}") + diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 208edd571cb0e..8fe5e0b19668f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1210,7 +1210,7 @@ struct llama_server_context queue_results.send(res); } - void send_embedding(server_slot &slot) + void send_embedding(server_slot & slot, const llama_batch & batch) { task_result res; res.id = slot.task_id; @@ -1219,6 +1219,7 @@ struct llama_server_context res.stop = true; const int n_embd = llama_n_embd(model); + if (!params.embedding) { LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}}); @@ -1229,12 +1230,29 @@ struct llama_server_context } else { - const float *data = llama_get_embeddings(ctx); - std::vector embedding(data, data + n_embd); - res.result_json = json - { - {"embedding", embedding}, - }; + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + if (embd == NULL) { + LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}}); + res.result_json = json + { + {"embedding", std::vector(n_embd, 0.0f)}, + }; + continue; + } + } + + res.result_json = json + { + {"embedding", std::vector(embd, embd + n_embd)}, + }; + } } queue_results.send(res); } @@ -1845,7 +1863,7 @@ struct llama_server_context ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); slot_npast++; } @@ -1881,7 +1899,7 @@ struct llama_server_context for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); for (auto & slot : slots) { @@ -1954,7 +1972,7 @@ struct llama_server_context // prompt evaluated for embedding if (slot.embedding) { - send_embedding(slot); + send_embedding(slot, batch_view); slot.release(); slot.i_batch = -1; continue; @@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); + printf(" --pooling {none,mean,cls}\n"); + printf(" pooling type for embeddings, use model default if unspecified\n"); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); @@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.yarn_beta_slow = std::stof(argv[i]); } + else if (arg == "--pooling") + { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } + else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else { invalid_param = true; break; } + } else if (arg == "--threads" || arg == "-t") { if (++i >= argc) @@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_batch = std::stoi(argv[i]); - params.n_batch = std::min(512, params.n_batch); } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { diff --git a/llama.cpp b/llama.cpp index de579d9e372b4..76afcbc135f4c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1665,7 +1665,7 @@ struct llama_hparams { }; struct llama_cparams { - uint32_t n_ctx; // context size used during inference + uint32_t n_ctx; // context size used during inference uint32_t n_batch; uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing @@ -1682,7 +1682,9 @@ struct llama_cparams { float yarn_beta_slow; float defrag_thold; + bool embeddings; bool offload_kqv; + enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; @@ -1972,7 +1974,7 @@ struct llama_context { int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) int32_t n_eval = 0; // number of eval calls - // decode output (2-dimensional array: [n_tokens][n_vocab]) + // logits output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; #ifndef NDEBUG // guard against access to unset logits @@ -1980,8 +1982,13 @@ struct llama_context { #endif bool logits_all = false; - // input embedding (1-dimensional array: [n_embd]) - std::vector embedding; + // embeddings output (2-dimensional array: [n_tokens][n_embd]) + // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE + std::vector embd; + + // sequence embeddings output (map of [n_embd] vectors) + // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE + std::map> embd_seq; // memory buffers used to evaluate the model std::vector buf_compute_meta; @@ -5092,6 +5099,7 @@ static struct ggml_tensor * llm_build_kv( llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; + cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); @@ -6085,6 +6093,7 @@ struct llm_build_context { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); struct ggml_tensor * cur; @@ -6092,9 +6101,10 @@ struct llm_build_context { // get input vectors with right size const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type); - struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + + struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0); - struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0); + struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); @@ -6112,39 +6122,38 @@ struct llm_build_context { cb(inpL, "inp_norm", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); - cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens] + struct ggml_tensor * KQ_mask = ggml_cont(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_tokens, n_tokens, n_tokens*ggml_type_size(lctx.inp_KQ_mask->type), 0)); + cb(KQ_mask, "KQ_mask", -1); // [n_tokens, n_tokens] // iterate layers for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * cur = inpL; + struct ggml_tensor * Qcur; + struct ggml_tensor * Kcur; + struct ggml_tensor * Vcur; + // self-attention if (model.arch == LLM_ARCH_BERT) { - struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); + Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); + Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); + Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); cb(Vcur, "Vcur", il); - // seems like we just need to do this for Q? - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - cb(cur, "kqv_out", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { // compute Q and K and RoPE them cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -6163,12 +6172,40 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); + } - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - cb(cur, "kqv_out", il); + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + if (model.layers[il].bo) { + cb(cur, "kqv_wo", il); + } + + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); } + cb(cur, "kqv_out", il); // re-add the layer input cur = ggml_add(ctx0, cur, inpL); @@ -6209,16 +6246,29 @@ struct llm_build_context { // final output cur = inpL; + cb(cur, "result_embd", -1); // pooling layer - if (pooling_type == LLAMA_POOLING_TYPE_MEAN) { - cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); - } else if (pooling_type == LLAMA_POOLING_TYPE_CLS) { - cur = ggml_get_rows(ctx0, cur, inp_cls); - } else { - GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type"); + switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // nop + } break; + case LLAMA_POOLING_TYPE_MEAN: + { + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); + cb(cur, "result_embd_pooled", -1); + } break; + case LLAMA_POOLING_TYPE_CLS: + { + cur = ggml_get_rows(ctx0, cur, inp_cls); + cb(cur, "result_embd_pooled", -1); + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ASSERT(false && "Invalid pooling type"); + } break; } - cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); @@ -7980,7 +8030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - { + if (hparams.causal_attn) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -7995,16 +8045,40 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || - (hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) { + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { f = -INFINITY; } else { - f = 0; + f = 0.0f; } data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } } + } else { + // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used) + const int64_t n_tokens = batch.n_tokens; + + assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + + float * data = (float *) lctx.inp_KQ_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_tokens; ++i) { + float f = -INFINITY; + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + if (batch.seq_id[i][s] == seq_id) { + f = 0.0f; + break; + } + } + + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f; + } + } + } } if (hparams.need_kq_pos) { @@ -8023,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); - float * data = (float *) lctx.inp_mean->data; + float * data = (float *) lctx.inp_mean->data; memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); std::vector sum(n_tokens, 0); for (int i = 0; i < n_tokens; ++i) { const llama_seq_id seq_id = batch.seq_id[i][0]; + + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); + sum[seq_id] += 1; } @@ -8051,11 +8128,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); + uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); for (int i = 0; i < n_tokens; ++i) { const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + const llama_pos pos = batch.pos[i]; + + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); + if (pos == 0) { data[seq_id] = i; } @@ -8169,23 +8251,26 @@ static int llama_decode_internal( batch.seq_id = seq_id_arr.data(); } - llama_kv_cache_update(&lctx); + // non-causal masks do not use the KV cache + if (hparams.causal_attn) { + llama_kv_cache_update(&lctx); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*n_tokens) { + kv_self.head = 0; + } - if (!llama_kv_cache_find_slot(kv_self, batch)) { - return 1; - } + if (!llama_kv_cache_find_slot(kv_self, batch)) { + return 1; + } - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); + } //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); @@ -8195,20 +8280,26 @@ static int llama_decode_internal( ggml_cgraph * gf = llama_build_graph(lctx, batch, false); // the output is always the last tensor in the graph - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - - if (strcmp(res->name, "result_output") == 0) { - // the embeddings could be the second to last tensor, or the third to last tensor - if (strcmp(embeddings->name, "result_norm") != 0) { - embeddings = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - } - } else if (strcmp(res->name, "result_embd") == 0) { - embeddings = res; - res = nullptr; + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; + + if (!hparams.causal_attn) { + res = nullptr; // do not extract logits for embedding models such as BERT + + // token or sequence embeddings + embd = gf->nodes[gf->n_nodes - 1]; + + GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); } else { - GGML_ASSERT(false); + if (strcmp(res->name, "result_output") == 0) { + // the token embeddings could be the second to last tensor, or the third to last tensor + if (strcmp(embd->name, "result_norm") != 0) { + embd = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); + } + } else { + GGML_ASSERT(false && "missing result_output tensor"); + } } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -8275,46 +8366,82 @@ static int llama_decode_internal( logits_out.clear(); #endif - ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res); - GGML_ASSERT(res_backend != nullptr); + ggml_backend_t backend_res = ggml_backend_sched_get_node_backend(lctx.sched, res); + GGML_ASSERT(backend_res != nullptr); + if (batch.logits) { logits_out.resize(n_vocab * n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { if (batch.logits[i] == 0) { continue; } - ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); #ifndef NDEBUG logits_valid[i] = true; #endif } } else if (lctx.logits_all) { logits_out.resize(n_vocab * n_tokens); - ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); #ifndef NDEBUG std::fill(logits_valid.begin(), logits_valid.end(), true); #endif } else { logits_out.resize(n_vocab); - ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float)); #ifndef NDEBUG logits_valid[0] = true; #endif } - ggml_backend_synchronize(res_backend); + ggml_backend_synchronize(backend_res); } // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; + if (cparams.embeddings && embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd); + GGML_ASSERT(backend_embd != nullptr); - const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0; - const int64_t embd_size = res ? n_embd : n_embd * n_tokens; + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + auto & embd_out = lctx.embd; + + if (batch.logits) { + embd_out.resize(n_embd * n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } - embedding_out.resize(embd_size); - ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float)); - ggml_backend_synchronize(embeddings_backend); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); + } + } + } break; + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_MEAN: + { + GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0); + + // extract sequence embeddings + auto & embd_seq_out = lctx.embd_seq; + embd_seq_out.clear(); + + for (uint32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ASSERT(false && "unknown pooling type"); + } break; + } + ggml_backend_synchronize(backend_embd); } // measure the performance only for the single-token evals @@ -8608,19 +8735,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto& token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { - auto buf = token_data.text.substr(3, 2); - return strtol(buf.c_str(), NULL, 16); - } - case LLAMA_VOCAB_TYPE_BPE: { - GGML_ASSERT(false); - return unicode_to_bytes_bpe(token_data.text); - } - case LLAMA_VOCAB_TYPE_WPM: { - GGML_ASSERT(false); - } - default: - GGML_ASSERT(false); + case LLAMA_VOCAB_TYPE_SPM: { + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); + } + case LLAMA_VOCAB_TYPE_BPE: { + GGML_ASSERT(false); + return unicode_to_bytes_bpe(token_data.text); + } + case LLAMA_VOCAB_TYPE_WPM: { + GGML_ASSERT(false); + } + default: + GGML_ASSERT(false); } } @@ -11864,7 +11991,7 @@ struct llama_context_params llama_context_default_params() { /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, /*.logits_all =*/ false, - /*.embedding =*/ false, + /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -12015,6 +12142,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.defrag_thold = params.defrag_thold; + cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.pooling_type = params.pooling_type; @@ -12192,8 +12320,8 @@ struct llama_context * llama_new_context_with_model( // resized during inference, reserve maximum ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); - if (params.embedding) { - ctx->embedding.resize(hparams.n_embd); + if (params.embeddings) { + ctx->embd.reserve(hparams.n_embd*cparams.n_batch); } // graph inputs @@ -12628,7 +12756,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { // assume worst case for logits although only currently set ones are serialized const size_t s_logits = ctx->logits.capacity() * sizeof(float); const size_t s_embedding_size = sizeof(size_t); - const size_t s_embedding = ctx->embedding.size() * sizeof(float); + const size_t s_embedding = ctx->embd.capacity() * sizeof(float); const size_t s_kv_buf_size = sizeof(size_t); const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); @@ -12737,12 +12865,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // copy embeddings { - const size_t embedding_size = ctx->embedding.size(); + const size_t embeddings_size = ctx->embd.size(); - data_ctx->write(&embedding_size, sizeof(embedding_size)); + data_ctx->write(&embeddings_size, sizeof(embeddings_size)); - if (embedding_size) { - data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); + if (embeddings_size) { + data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float)); } } @@ -12846,15 +12974,17 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { // set embeddings { - size_t embedding_size; + size_t embeddings_size; + + memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size); - memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); + GGML_ASSERT(ctx->embd.capacity() == embeddings_size); - GGML_ASSERT(ctx->embedding.capacity() == embedding_size); + if (embeddings_size) { + ctx->embd.resize(embeddings_size); - if (embedding_size) { - memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); - inp += embedding_size * sizeof(float); + memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float)); + inp += embeddings_size * sizeof(float); } } @@ -13104,11 +13234,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { } float * llama_get_embeddings(struct llama_context * ctx) { - return ctx->embedding.data(); + return ctx->embd.data(); } float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { - return ctx->embedding.data() + i*ctx->model.hparams.n_embd; + return ctx->embd.data() + i*ctx->model.hparams.n_embd; +} + +float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { + auto it = ctx->embd_seq.find(seq_id); + if (it == ctx->embd_seq.end()) { + return nullptr; + } + + return it->second.data(); } const char * llama_token_get_text(const struct llama_model * model, llama_token token) { diff --git a/llama.h b/llama.h index 70da4cb3f0ff6..3dc162b078d30 100644 --- a/llama.h +++ b/llama.h @@ -163,7 +163,7 @@ extern "C" { // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence // - seq_id : the sequence to which the respective token belongs - // - logits : if zero, the logits for the respective token will not be output + // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output // typedef struct llama_batch { int32_t n_tokens; @@ -173,7 +173,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * logits; + int8_t * logits; // TODO: rename this to "output" // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below @@ -260,7 +260,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) - bool embedding; // embedding mode only + bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU // Abort callback @@ -655,14 +655,20 @@ extern "C" { // llama_get_logits(ctx) + i*n_vocab LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - // Get the embeddings for the input - // shape: [n_embd] (1-dimensional) + // Get all output token embeddings + // shape: [n_tokens*n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - // Get the embeddings for the ith sequence + // Get the embeddings for the ith token // llama_get_embeddings(ctx) + i*n_embd + // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); + // Get the embeddings for a sequence id + // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // // Vocab //