From 73b3a7512159eb3f3e3b018fd4f129f85d9bde2a Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 29 Feb 2024 09:34:44 +0700 Subject: [PATCH] refactor: parameter mapping - model entity --- CMakeLists.txt | 3 +- common/base.h | 8 +- controllers/llamaCPP.cc | 241 +++++++++++++++---------------- controllers/llamaCPP.h | 30 ++-- models/chat_completion_request.h | 36 +++++ 5 files changed, 173 insertions(+), 145 deletions(-) create mode 100644 models/chat_completion_request.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a1fe5c876..1de2f6291 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,8 +78,9 @@ endif() aux_source_directory(controllers CTL_SRC) aux_source_directory(common COMMON_SRC) aux_source_directory(context CONTEXT_SRC) +aux_source_directory(models MODEL_SRC) # aux_source_directory(filters FILTER_SRC) aux_source_directory(plugins -# PLUGIN_SRC) aux_source_directory(models MODEL_SRC) +# PLUGIN_SRC) # drogon_create_views(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/views # ${CMAKE_CURRENT_BINARY_DIR}) use the following line to create views with diff --git a/common/base.h b/common/base.h index 9332e58bf..e87d07488 100644 --- a/common/base.h +++ b/common/base.h @@ -1,5 +1,6 @@ #pragma once #include +#include using namespace drogon; @@ -8,9 +9,8 @@ class BaseModel { virtual ~BaseModel() {} // Model management - virtual void LoadModel( - const HttpRequestPtr& req, - std::function&& callback) = 0; + virtual void LoadModel(const HttpRequestPtr& req, + std::function&& callback) = 0; virtual void UnloadModel( const HttpRequestPtr& req, std::function&& callback) = 0; @@ -25,7 +25,7 @@ class BaseChatCompletion { // General chat method virtual void ChatCompletion( - const HttpRequestPtr& req, + inferences::ChatCompletionRequest &&completion, std::function&& callback) = 0; }; diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 9fec77a51..021c79faa 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -1,10 +1,8 @@ #include "llamaCPP.h" - -#include #include +#include #include "log.h" -#include "utils/nitro_utils.h" // External #include "common.h" @@ -175,19 +173,18 @@ void llamaCPP::WarmupModel() { } void llamaCPP::ChatCompletion( - const HttpRequestPtr& req, + inferences::ChatCompletionRequest&& completion, std::function&& callback) { - const auto& jsonBody = req->getJsonObject(); // Check if model is loaded if (CheckModelLoaded(callback)) { // Model is loaded // Do Inference - InferenceImpl(jsonBody, callback); + InferenceImpl(std::move(completion), callback); } } void llamaCPP::InferenceImpl( - std::shared_ptr jsonBody, + inferences::ChatCompletionRequest&& completion, std::function& callback) { std::string formatted_output = pre_prompt; @@ -196,131 +193,131 @@ void llamaCPP::InferenceImpl( int no_images = 0; // To set default value - if (jsonBody) { - // Increase number of chats received and clean the prompt - no_of_chats++; - if (no_of_chats % clean_cache_threshold == 0) { - LOG_INFO << "Clean cache threshold reached!"; - llama.kv_cache_clear(); - LOG_INFO << "Cache cleaned"; - } - - // Default values to enable auto caching - data["cache_prompt"] = caching_enabled; - data["n_keep"] = -1; - - // Passing load value - data["repeat_last_n"] = this->repeat_last_n; - - data["stream"] = (*jsonBody).get("stream", false).asBool(); - data["n_predict"] = (*jsonBody).get("max_tokens", 500).asInt(); - data["top_p"] = (*jsonBody).get("top_p", 0.95).asFloat(); - data["temperature"] = (*jsonBody).get("temperature", 0.8).asFloat(); - data["frequency_penalty"] = - (*jsonBody).get("frequency_penalty", 0).asFloat(); - data["presence_penalty"] = (*jsonBody).get("presence_penalty", 0).asFloat(); - const Json::Value& messages = (*jsonBody)["messages"]; + // Increase number of chats received and clean the prompt + no_of_chats++; + if (no_of_chats % clean_cache_threshold == 0) { + LOG_INFO << "Clean cache threshold reached!"; + llama.kv_cache_clear(); + LOG_INFO << "Cache cleaned"; + } - if (!grammar_file_content.empty()) { - data["grammar"] = grammar_file_content; - }; + // Default values to enable auto caching + data["cache_prompt"] = caching_enabled; + data["n_keep"] = -1; + + // Passing load value + data["repeat_last_n"] = this->repeat_last_n; + + LOG_INFO << "Messages:" << completion.messages.toStyledString(); + LOG_INFO << "Stop:" << completion.stop.toStyledString(); + + data["stream"] = completion.stream; + data["n_predict"] = completion.max_tokens; + data["top_p"] = completion.top_p; + data["temperature"] = completion.temperature; + data["frequency_penalty"] = completion.frequency_penalty; + data["presence_penalty"] = completion.presence_penalty; + const Json::Value& messages = completion.messages; + + if (!grammar_file_content.empty()) { + data["grammar"] = grammar_file_content; + }; + + if (!llama.multimodal) { + for (const auto& message : messages) { + std::string input_role = message["role"].asString(); + std::string role; + if (input_role == "user") { + role = user_prompt; + std::string content = message["content"].asString(); + formatted_output += role + content; + } else if (input_role == "assistant") { + role = ai_prompt; + std::string content = message["content"].asString(); + formatted_output += role + content; + } else if (input_role == "system") { + role = system_prompt; + std::string content = message["content"].asString(); + formatted_output = role + content + formatted_output; - if (!llama.multimodal) { - for (const auto& message : messages) { - std::string input_role = message["role"].asString(); - std::string role; - if (input_role == "user") { - role = user_prompt; - std::string content = message["content"].asString(); - formatted_output += role + content; - } else if (input_role == "assistant") { - role = ai_prompt; - std::string content = message["content"].asString(); - formatted_output += role + content; - } else if (input_role == "system") { - role = system_prompt; - std::string content = message["content"].asString(); - formatted_output = role + content + formatted_output; - - } else { - role = input_role; - std::string content = message["content"].asString(); - formatted_output += role + content; - } + } else { + role = input_role; + std::string content = message["content"].asString(); + formatted_output += role + content; } - formatted_output += ai_prompt; - } else { - data["image_data"] = json::array(); - for (const auto& message : messages) { - std::string input_role = message["role"].asString(); - std::string role; - if (input_role == "user") { - formatted_output += role; - for (auto content_piece : message["content"]) { - role = user_prompt; - - json content_piece_image_data; - content_piece_image_data["data"] = ""; - - auto content_piece_type = content_piece["type"].asString(); - if (content_piece_type == "text") { - auto text = content_piece["text"].asString(); - formatted_output += text; - } else if (content_piece_type == "image_url") { - auto image_url = content_piece["image_url"]["url"].asString(); - std::string base64_image_data; - if (image_url.find("http") != std::string::npos) { - LOG_INFO << "Remote image detected but not supported yet"; - } else if (image_url.find("data:image") != std::string::npos) { - LOG_INFO << "Base64 image detected"; - base64_image_data = nitro_utils::extractBase64(image_url); - LOG_INFO << base64_image_data; - } else { - LOG_INFO << "Local image detected"; - nitro_utils::processLocalImage( - image_url, [&](const std::string& base64Image) { - base64_image_data = base64Image; - }); - LOG_INFO << base64_image_data; - } - content_piece_image_data["data"] = base64_image_data; - - formatted_output += "[img-" + std::to_string(no_images) + "]"; - content_piece_image_data["id"] = no_images; - data["image_data"].push_back(content_piece_image_data); - no_images++; + } + formatted_output += ai_prompt; + } else { + data["image_data"] = json::array(); + for (const auto& message : messages) { + std::string input_role = message["role"].asString(); + std::string role; + if (input_role == "user") { + formatted_output += role; + for (auto content_piece : message["content"]) { + role = user_prompt; + + json content_piece_image_data; + content_piece_image_data["data"] = ""; + + auto content_piece_type = content_piece["type"].asString(); + if (content_piece_type == "text") { + auto text = content_piece["text"].asString(); + formatted_output += text; + } else if (content_piece_type == "image_url") { + auto image_url = content_piece["image_url"]["url"].asString(); + std::string base64_image_data; + if (image_url.find("http") != std::string::npos) { + LOG_INFO << "Remote image detected but not supported yet"; + } else if (image_url.find("data:image") != std::string::npos) { + LOG_INFO << "Base64 image detected"; + base64_image_data = nitro_utils::extractBase64(image_url); + LOG_INFO << base64_image_data; + } else { + LOG_INFO << "Local image detected"; + nitro_utils::processLocalImage( + image_url, [&](const std::string& base64Image) { + base64_image_data = base64Image; + }); + LOG_INFO << base64_image_data; } - } + content_piece_image_data["data"] = base64_image_data; - } else if (input_role == "assistant") { - role = ai_prompt; - std::string content = message["content"].asString(); - formatted_output += role + content; - } else if (input_role == "system") { - role = system_prompt; - std::string content = message["content"].asString(); - formatted_output = role + content + formatted_output; - - } else { - role = input_role; - std::string content = message["content"].asString(); - formatted_output += role + content; + formatted_output += "[img-" + std::to_string(no_images) + "]"; + content_piece_image_data["id"] = no_images; + data["image_data"].push_back(content_piece_image_data); + no_images++; + } } + + } else if (input_role == "assistant") { + role = ai_prompt; + std::string content = message["content"].asString(); + formatted_output += role + content; + } else if (input_role == "system") { + role = system_prompt; + std::string content = message["content"].asString(); + formatted_output = role + content + formatted_output; + + } else { + role = input_role; + std::string content = message["content"].asString(); + formatted_output += role + content; } - formatted_output += ai_prompt; - LOG_INFO << formatted_output; } + formatted_output += ai_prompt; + LOG_INFO << formatted_output; + } - data["prompt"] = formatted_output; - for (const auto& stop_word : (*jsonBody)["stop"]) { - stopWords.push_back(stop_word.asString()); - } - // specify default stop words - // Ensure success case for chatML - stopWords.push_back("<|im_end|>"); - stopWords.push_back(nitro_utils::rtrim(user_prompt)); - data["stop"] = stopWords; + data["prompt"] = formatted_output; + for (const auto& stop_word : completion.stop) { + stopWords.push_back(stop_word.asString()); } + // specify default stop words + // Ensure success case for chatML + stopWords.push_back("<|im_end|>"); + stopWords.push_back(nitro_utils::rtrim(user_prompt)); + data["stop"] = stopWords; bool is_streamed = data["stream"]; // Enable full message debugging diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index eebbb5d91..75e597658 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -2,42 +2,37 @@ #if defined(_WIN32) #define NOMINMAX #endif - #pragma once -#define LOG_TARGET stdout #include -#include "stb_image.h" -#include "context/llama_server_context.h" - #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error #define CPPHTTPLIB_NO_EXCEPTIONS 1 #endif #include -#include "common/base.h" -#include "utils/json.hpp" - -// auto generated files (update with ./deps.sh) - #include +#include #include -#include -#include +#include "common/base.h" +#include "context/llama_server_context.h" +#include "stb_image.h" +#include "utils/json.hpp" + +#include "models/chat_completion_request.h" #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 #endif - using json = nlohmann::json; using namespace drogon; namespace inferences { + class llamaCPP : public drogon::HttpController, public BaseModel, public BaseChatCompletion, @@ -64,14 +59,13 @@ class llamaCPP : public drogon::HttpController, // PATH_ADD("/llama/chat_completion", Post); METHOD_LIST_END void ChatCompletion( - const HttpRequestPtr& req, + inferences::ChatCompletionRequest &&completion, std::function&& callback) override; void Embedding( const HttpRequestPtr& req, std::function&& callback) override; - void LoadModel( - const HttpRequestPtr& req, - std::function&& callback) override; + void LoadModel(const HttpRequestPtr& req, + std::function&& callback) override; void UnloadModel( const HttpRequestPtr& req, std::function&& callback) override; @@ -101,7 +95,7 @@ class llamaCPP : public drogon::HttpController, trantor::ConcurrentTaskQueue* queue; bool LoadModelImpl(std::shared_ptr jsonBody); - void InferenceImpl(std::shared_ptr jsonBody, + void InferenceImpl(inferences::ChatCompletionRequest&& completion, std::function& callback); void EmbeddingImpl(std::shared_ptr jsonBody, std::function& callback); diff --git a/models/chat_completion_request.h b/models/chat_completion_request.h new file mode 100644 index 000000000..bd802d67e --- /dev/null +++ b/models/chat_completion_request.h @@ -0,0 +1,36 @@ +#pragma once +#include + +namespace inferences { +struct ChatCompletionRequest { + bool stream = false; + int max_tokens = 500; + float top_p = 0.95; + float temperature = 0.8; + float frequency_penalty = 0; + float presence_penalty = 0; + Json::Value stop = Json::Value(Json::arrayValue); + Json::Value messages = Json::Value(Json::arrayValue); +}; +} // namespace inferences + +namespace drogon { +template <> +inline inferences::ChatCompletionRequest fromRequest(const HttpRequest& req) { + auto jsonBody = req.getJsonObject(); + inferences::ChatCompletionRequest completion; + if (jsonBody) { + completion.stream = (*jsonBody).get("stream", false).asBool(); + completion.max_tokens = (*jsonBody).get("max_tokens", 500).asInt(); + completion.top_p = (*jsonBody).get("top_p", 0.95).asFloat(); + completion.temperature = (*jsonBody).get("temperature", 0.8).asFloat(); + completion.frequency_penalty = + (*jsonBody).get("frequency_penalty", 0).asFloat(); + completion.presence_penalty = + (*jsonBody).get("presence_penalty", 0).asFloat(); + completion.messages = (*jsonBody)["messages"]; + completion.stop = (*jsonBody)["stop"]; + } + return completion; +} +} // namespace inferences