From 9218cd9fff419842b6ccef17ab9613876f59ae54 Mon Sep 17 00:00:00 2001 From: hiro Date: Tue, 23 Jan 2024 23:59:04 +0700 Subject: [PATCH 1/2] feat: Add input as vector --- controllers/llamaCPP.cc | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 41ce65e2f..ee69f35b0 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -408,10 +408,22 @@ void llamaCPP::embedding( json prompt; if (jsonBody->isMember("input") != 0) { - prompt = (*jsonBody)["input"].asString(); + if ((*jsonBody)["input"].isString()) { + prompt = (*jsonBody)["input"].asString(); + } else if ((*jsonBody)["input"].isArray()) { + const auto &inputArray = (*jsonBody)["input"]; + std::vector inputStrings; + for (const auto &input : inputArray) { + if (input.isString()) { + inputStrings.push_back(input.asString()); + } + } + prompt = inputStrings; + } } else { prompt = ""; } + const int task_id = llama.request_completion( {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1); task_result result = llama.next_result(task_id); From 5f9a0a48a11c3930f9f33cc7bc46421f0a17ccd5 Mon Sep 17 00:00:00 2001 From: hiro Date: Wed, 24 Jan 2024 08:20:46 +0700 Subject: [PATCH 2/2] fix: Final update for embedding to support both single and vector of input string --- controllers/llamaCPP.cc | 71 +++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 39 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index ee69f35b0..63044114e 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -21,13 +21,8 @@ std::shared_ptr create_inference_state(llamaCPP *instance) { // -------------------------------------------- -std::string create_embedding_payload(const std::vector &embedding, +Json::Value create_embedding_payload(const std::vector &embedding, int prompt_tokens) { - Json::Value root; - - root["object"] = "list"; - - Json::Value dataArray(Json::arrayValue); Json::Value dataItem; dataItem["object"] = "embedding"; @@ -39,20 +34,7 @@ std::string create_embedding_payload(const std::vector &embedding, dataItem["embedding"] = embeddingArray; dataItem["index"] = 0; - dataArray.append(dataItem); - root["data"] = dataArray; - - root["model"] = "_"; - - Json::Value usage; - usage["prompt_tokens"] = prompt_tokens; - usage["total_tokens"] = prompt_tokens; // Assuming total tokens equals prompt - // tokens in this context - root["usage"] = usage; - - Json::StreamWriterBuilder writer; - writer["indentation"] = ""; // Compact output - return Json::writeString(writer, root); + return dataItem; } std::string create_full_return_json(const std::string &id, @@ -406,31 +388,42 @@ void llamaCPP::embedding( std::function &&callback) { const auto &jsonBody = req->getJsonObject(); - json prompt; - if (jsonBody->isMember("input") != 0) { - if ((*jsonBody)["input"].isString()) { - prompt = (*jsonBody)["input"].asString(); - } else if ((*jsonBody)["input"].isArray()) { - const auto &inputArray = (*jsonBody)["input"]; - std::vector inputStrings; - for (const auto &input : inputArray) { - if (input.isString()) { - inputStrings.push_back(input.asString()); + Json::Value responseData(Json::arrayValue); + + if (jsonBody->isMember("input")) { + const Json::Value &input = (*jsonBody)["input"]; + if (input.isString()) { + // Process the single string input + const int task_id = llama.request_completion( + {{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1); + task_result result = llama.next_result(task_id); + std::vector embedding_result = result.result_json["embedding"]; + responseData.append(create_embedding_payload(embedding_result, 0)); + } else if (input.isArray()) { + // Process each element in the array input + for (const auto &elem : input) { + if (elem.isString()) { + const int task_id = llama.request_completion( + {{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1); + task_result result = llama.next_result(task_id); + std::vector embedding_result = result.result_json["embedding"]; + responseData.append(create_embedding_payload(embedding_result, 0)); } } - prompt = inputStrings; } - } else { - prompt = ""; } - const int task_id = llama.request_completion( - {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1); - task_result result = llama.next_result(task_id); - std::vector embedding_result = result.result_json["embedding"]; auto resp = nitro_utils::nitroHttpResponse(); - std::string embedding_resp = create_embedding_payload(embedding_result, 0); - resp->setBody(embedding_resp); + Json::Value root; + root["data"] = responseData; + root["model"] = "_"; + root["object"] = "list"; + Json::Value usage; + usage["prompt_tokens"] = 0; + usage["total_tokens"] = 0; + root["usage"] = usage; + + resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root)); resp->setContentTypeString("application/json"); callback(resp); return;