From 70353752f3377fdf9db4607abcdca0b040924ea2 Mon Sep 17 00:00:00 2001 From: Andrei Ungureanu <15995496+Andrei997@users.noreply.github.com> Date: Wed, 23 Jul 2025 12:19:10 +0800 Subject: [PATCH] feat: add multi-modal embedding support to server.cpp --- tools/server/server.cpp | 479 ++++++++++++++++++++++++++++++---------- 1 file changed, 357 insertions(+), 122 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 022b5d0b31034..e840e3d20dd59 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -491,6 +491,7 @@ struct server_task { } } } + // set reverse prompt from cli args if not set in the request if (params.antiprompt.empty()) { params.antiprompt = defaults.antiprompt; @@ -1044,6 +1045,10 @@ struct server_task_result_embd : server_task_result { std::vector> embedding; int32_t n_tokens; + int32_t start_image_token_idx = -1; // -1 means no image + int32_t end_image_token_idx = -1; + + bool has_stored_embeddings = false; // OAI-compat fields oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; @@ -1059,18 +1064,34 @@ struct server_task_result_embd : server_task_result { } json to_json_non_oaicompat() { - return json { + json result = json { {"index", index}, {"embedding", embedding}, }; + + // Add image indices if this was a multimodal request + if (start_image_token_idx != -1) { + result["start_image_token_idx"] = start_image_token_idx; + result["end_image_token_idx"] = end_image_token_idx; + } + + return result; } json to_json_oaicompat() { - return json { + json result = json { {"index", index}, {"embedding", embedding[0]}, {"tokens_evaluated", n_tokens}, }; + + // Add image indices for OAI-compat too + if (start_image_token_idx != -1) { + result["start_image_token_idx"] = start_image_token_idx; + result["end_image_token_idx"] = end_image_token_idx; + } + + return result; } }; @@ -1302,6 +1323,11 @@ struct server_slot { std::vector generated_token_probs; + // Fields for storing embeddings when processing multi-modal inputs + std::vector> stored_pre_image_embeddings; + std::vector> stored_image_embeddings; + bool has_stored_embeddings = false; + bool has_next_token = true; bool has_new_line = false; bool truncated = false; @@ -1355,6 +1381,11 @@ struct server_slot { json_schema = json(); generated_tool_call_ids.clear(); + // *** NEW: Clear multimodal embedding storage *** + stored_pre_image_embeddings.clear(); + stored_image_embeddings.clear(); + has_stored_embeddings = false; + // clear speculative decoding stats n_draft_total = 0; n_draft_accepted = 0; @@ -2570,6 +2601,10 @@ struct server_context { } void send_embedding(const server_slot & slot, const llama_batch & batch) { + printf("=== send_embedding DEBUG ===\n"); + printf("batch.n_tokens = %d\n", batch.n_tokens); + printf("has_stored_embeddings = %s\n", slot.has_stored_embeddings ? "true" : "false"); + auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2577,39 +2612,119 @@ struct server_context { res->oaicompat = slot.params.oaicompat; const int n_embd = llama_model_n_embd(model); - - std::vector embd_res(n_embd, 0.0f); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; + printf("n_embd = %d\n", n_embd); + + int pooling_type = llama_pooling_type(slot.ctx); + printf("pooling_type = %d (NONE=%d)\n", pooling_type, LLAMA_POOLING_TYPE_NONE); + + if (slot.has_stored_embeddings) { + // *** MULTIMODAL EMBEDDING ASSEMBLY *** + printf("ASSEMBLY: Multimodal embedding - combining all parts\n"); + + // Part 1: Pre-image text embeddings + for (const auto& pre_embd : slot.stored_pre_image_embeddings) { + res->embedding.push_back(pre_embd); } - const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); - } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + printf("ASSEMBLY: Added %zu pre-image embeddings\n", slot.stored_pre_image_embeddings.size()); + + // *** SET IMAGE START INDEX *** + res->start_image_token_idx = slot.stored_pre_image_embeddings.size(); + + // Part 2: Image embeddings + for (const auto& img_embd : slot.stored_image_embeddings) { + res->embedding.push_back(img_embd); } - if (embd == nullptr) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; + printf("ASSEMBLY: Added %zu image embeddings\n", slot.stored_image_embeddings.size()); + + // *** SET IMAGE END INDEX *** + res->end_image_token_idx = slot.stored_pre_image_embeddings.size() + slot.stored_image_embeddings.size() - 1; + + // Part 3: Post-image text embeddings (current batch) + if (pooling_type != LLAMA_POOLING_TYPE_NONE) { + printf("ASSEMBLY: Using sequence-level pooling for post-image text\n"); + const float * embd = llama_get_embeddings_seq(ctx, slot.id); + if (embd != nullptr) { + std::vector embd_res(n_embd); + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + printf("ASSEMBLY: Added pooled post-image embedding\n"); + } else { + printf("ASSEMBLY: ERROR - llama_get_embeddings_seq returned NULL for post-image!\n"); + } + } else { + printf("ASSEMBLY: Using token-level embeddings for post-image text\n"); + int post_image_embeddings = 0; + + // For multimodal, we need to get embeddings from the current batch + for (int pos = 0; pos < batch.n_tokens; ++pos) { + const float * embd = llama_get_embeddings_ith(ctx, pos); + if (embd != nullptr) { + res->embedding.emplace_back(embd, embd + n_embd); + post_image_embeddings++; + if (pos < 3) { + printf("ASSEMBLY: Found post-image embedding at batch pos %d\n", pos); + } + } else { + if (pos < 3) { + printf("ASSEMBLY: No post-image embedding at batch pos %d\n", pos); + } + } + } + printf("ASSEMBLY: Added %d post-image embeddings from current batch\n", post_image_embeddings); } + + printf("ASSEMBLY: Total multimodal embeddings: %zu (pre:%zu + img:%zu + post:current)\n", + res->embedding.size(), + slot.stored_pre_image_embeddings.size(), + slot.stored_image_embeddings.size()); + + printf("ASSEMBLY: Image token indices: start=%d, end=%d\n", + res->start_image_token_idx, res->end_image_token_idx); - // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, 2); - res->embedding.push_back(embd_res); - break; + } else { + // *** REGULAR TEXT-ONLY EMBEDDING (UNCHANGED LOGIC) *** + printf("ASSEMBLY: Text-only embedding - using existing logic\n"); + + if (pooling_type != LLAMA_POOLING_TYPE_NONE) { + printf("Using sequence-level pooling\n"); + // Sequence-level pooling - get the pooled embedding for the entire sequence + const float * embd = llama_get_embeddings_seq(ctx, slot.id); + printf("llama_get_embeddings_seq returned: %p\n", (void*)embd); + + if (embd != nullptr) { + std::vector embd_res(n_embd); + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + printf("Added pooled embedding, size = %zu\n", embd_res.size()); + } else { + printf("ERROR: llama_get_embeddings_seq returned NULL!\n"); + } } else { - res->embedding.emplace_back(embd, embd + n_embd); + printf("Using token-level embeddings\n"); + // Token-level embeddings - get embeddings for each position in the sequence + int embeddings_found = 0; + for (int pos = 0; pos < slot.n_past; ++pos) { + const float * embd = llama_get_embeddings_ith(ctx, pos); + if (embd != nullptr) { + res->embedding.emplace_back(embd, embd + n_embd); + embeddings_found++; + if (pos < 5 || pos >= slot.n_past - 5) { + printf("Found embedding at pos %d\n", pos); + } + } else { + if (pos < 5 || pos >= slot.n_past - 5) { + printf("No embedding at pos %d\n", pos); + } + } + } + printf("Total embeddings found: %d out of %d positions\n", embeddings_found, slot.n_past); } } - SLT_DBG(slot, "%s", "sending embeddings\n"); + printf("Final embedding count: %zu\n", res->embedding.size()); + printf("=== send_embedding END ===\n"); queue_results.send(std::move(res)); } @@ -3301,10 +3416,36 @@ struct server_context { // check if we should process the image if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + printf("=== IMAGE PROCESSING DEBUG ===\n"); + printf("Before process_chunk: slot.n_past = %d, slot.n_prompt_tokens = %d\n", + slot.n_past, slot.n_prompt_tokens); + // process the image int32_t new_n_past; int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + + // CAPTURE IMAGE EMBEDDINGS immediately after process_chunk() + if (res == 0 && slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + const int n_embd = llama_model_n_embd(model); + int image_embeddings_found = 0; + + for (int batch_pos = 0; batch_pos < 512; batch_pos++) { // reasonable upper bound + const float * embd = llama_get_embeddings_ith(ctx, batch_pos); + if (embd != nullptr) { + slot.stored_image_embeddings.emplace_back(embd, embd + n_embd); + image_embeddings_found++; + } else { + break; // Stop at first nullptr + } + } + + printf("STORAGE: Captured %d image embeddings dynamically\n", image_embeddings_found); + slot.has_stored_embeddings = true; + } + int32_t n_pos = new_n_past - slot.n_past; + printf("process_chunk result: res = %d, old_n_past = %d, new_n_past = %d, n_pos = %d\n", + res, slot.n_past, new_n_past, n_pos); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); @@ -3321,19 +3462,18 @@ struct server_context { slot.n_past += n_pos; slot.n_prompt_tokens_processed += n_pos; + + printf("=== IMAGE PROCESSING END ===\n"); } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - // get next token to process llama_token cur_tok = slot.prompt_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { break; // end of text chunk } - // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); slot.cache_tokens.push_back(cur_tok); @@ -3410,6 +3550,31 @@ struct server_context { const int ret = llama_decode(ctx, batch_view); + // NOTE: Added this embedding capture to store emebeddings retrieved before image + if (ret == 0) { + for (auto & slot : slots) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + // Capture pre-image embeddings (only before image processing) + if (!slot.has_stored_embeddings) { + const int n_embd = llama_model_n_embd(model); + int pre_image_captured = 0; + + for (int batch_pos = 0; batch_pos < batch_view.n_tokens; batch_pos++) { + const float * embd = llama_get_embeddings_ith(ctx, batch_pos); + if (embd != nullptr) { + slot.stored_pre_image_embeddings.emplace_back(embd, embd + n_embd); + pre_image_captured++; + } + } + + if (pre_image_captured > 0) { + printf("PRE-IMAGE CAPTURE: Stored %d pre-image embeddings\n", pre_image_captured); + } + } + } + } + } + metrics.on_decoded(slots); if (ret != 0) { @@ -4521,10 +4686,9 @@ int main(int argc, char ** argv) { json tokens_response = json::array(); if (body.count("content") != 0) { const bool add_special = json_value(body, "add_special", false); - const bool parse_special = json_value(body, "parse_special", true); const bool with_pieces = json_value(body, "with_pieces", false); - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true); if (with_pieces) { for (const auto& token : tokens) { @@ -4569,105 +4733,176 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { - if (!ctx_server.params_base.embedding) { - res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return; - } + // new embeddings implementation + const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok]( + const httplib::Request & req, + httplib::Response & res, + const std::vector & files, + oaicompat_type oaicompat) -> void { + + const json data = json::parse(req.body); + const auto & prompt = oaicompat ? data.at("prompt") : data.at("content"); + + printf("EMBEDDINGS: Processing prompt with %zu files\n", files.size()); + + // Process files + mtmd::bitmaps bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; + + printf("EMBEDDINGS: Multimodal context available: %s\n", has_mtmd ? "YES" : "NO"); + + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } + + for (size_t i = 0; i < files.size(); i++) { + printf("EMBEDDINGS: Processing file %zu, size: %zu bytes\n", i, files[i].size()); + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, files[i].data(), files[i].size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image or audio file"); + } + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); + bmp.set_id(hash.c_str()); + printf("EMBEDDINGS: File %zu processed, hash: %s\n", + i, hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + + // Process prompt + std::vector inputs; + + if (has_mtmd && !files.empty()) { + // multimodal tokenization + std::string prompt_str; + if (prompt.is_string()) { + prompt_str = prompt.get(); + } else { + prompt_str = prompt.dump(); + } + + printf("EMBEDDINGS: Tokenizing multimodal prompt: \"%.100s%s\"\n", + prompt_str.c_str(), prompt_str.length() > 100 ? "..." : ""); + + mtmd_input_text inp_txt = { + prompt_str.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + + printf("EMBEDDINGS: Calling mtmd_tokenize with %zu bitmaps\n", bitmaps_c_ptr.size()); + + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + + if (tokenized != 0) { + printf("EMBEDDINGS: mtmd_tokenize failed with error: %d\n", tokenized); + throw std::runtime_error("Failed to tokenize prompt"); + } + + printf("EMBEDDINGS: mtmd_tokenize succeeded\n"); + + server_tokens tmp(chunks, true); + printf("EMBEDDINGS: Created server_tokens\n"); + inputs.push_back(std::move(tmp)); + + } else { + // non-multimodal version + printf("EMBEDDINGS: Using non-multimodal tokenization\n"); + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + printf("EMBEDDINGS: Tokenized %zu prompts\n", tokenized_prompts.size()); + + for (auto & p : tokenized_prompts) { + printf("EMBEDDINGS: Prompt tokens: %zu\n", p.size()); + auto tmp = server_tokens(p, ctx_server.mctx != nullptr); + inputs.push_back(std::move(tmp)); + } + } + + printf("EMBEDDINGS: Total inputs created: %zu\n", inputs.size()); + + // Create embedding tasks + std::vector tasks; + std::unordered_set task_ids; + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + + printf("EMBEDDINGS: Posted %zu tasks to queue\n", tasks.size()); + + // Wait for results + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { + // single result + printf("EMBEDDINGS: Returning single result\n"); + res_ok(res, results[0]->to_json()); + } else { + // multiple results (multitask) + printf("EMBEDDINGS: Returning %zu results\n", results.size()); + json arr = json::array(); + for (auto & res : results) { + arr.push_back(res->to_json()); + } + res_ok(res, arr); + } + }, [&](const json & error_data) { + printf("EMBEDDINGS: Error occurred during processing\n"); + res_error(res, error_data); + }, [&req]() { + return !req.has_header("Connection") || req.get_header_value("Connection") != "keep-alive"; + }); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + printf("EMBEDDINGS: Processing completed\n"); + }; + const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + std::vector files; + + // Parse the request body here to extract images const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } else { - res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - - auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); - for (const auto & tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - - // create and queue the task - json responses = json::array(); - bool error = false; - std::unordered_set task_ids; - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr); - - // OAI-compat - task.params.oaicompat = oaicompat; - - tasks.push_back(std::move(task)); + + // Handle simple image field for non-OAI endpoint + if (body.contains("image")) { + std::string image_data = body.at("image"); + if (string_starts_with(image_data, "data:image/")) { + auto parts = string_split(image_data, ','); + auto decoded_data = base64_decode(parts[1]); + files.push_back(decoded_data); } - - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - // get the result - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - for (auto & res : results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - }, [&](const json & error_data) { - res_error(res, error_data); - error = true; - }, req.is_connection_closed); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - - if (error) { - return; } - - // write JSON response - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res_ok(res, root); - }; - - const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); + + handle_embeddings_impl(req, res, files, OAICOMPAT_TYPE_NONE); }; const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); + std::vector files; // dummy files + handle_embeddings_impl(req, res, files, OAICOMPAT_TYPE_EMBEDDING); }; const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {