From cb8f25f8fa2ad7cd1c69f89c17a9151aae985187 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 7 Oct 2024 15:25:57 +0700 Subject: [PATCH 01/12] chore: change update to patch --- engine/controllers/models.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/controllers/models.h b/engine/controllers/models.h index c2c804170..167d4bb36 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -14,7 +14,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::PullModel, "/pull", Post); METHOD_ADD(Models::ListModel, "", Get); METHOD_ADD(Models::GetModel, "/{1}", Get); - METHOD_ADD(Models::UpdateModel, "/{1}", Post); + METHOD_ADD(Models::UpdateModel, "/{1}", Patch); METHOD_ADD(Models::ImportModel, "/import", Post); METHOD_ADD(Models::DeleteModel, "/{1}", Delete); METHOD_ADD(Models::SetModelAlias, "/alias", Post); @@ -24,7 +24,7 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Post); ADD_METHOD_TO(Models::ListModel, "/v1/models", Get); ADD_METHOD_TO(Models::GetModel, "/v1/models/{1}", Get); - ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Post); + ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Patch); ADD_METHOD_TO(Models::ImportModel, "/v1/models/import", Post); ADD_METHOD_TO(Models::DeleteModel, "/v1/models/{1}", Delete); ADD_METHOD_TO(Models::SetModelAlias, "/v1/models/alias", Post); From 7f0349ce36e2e42fb51c9a9631185a9ba3ee77f1 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 7 Oct 2024 15:28:25 +0700 Subject: [PATCH 02/12] fix: swagger --- engine/controllers/swagger.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/controllers/swagger.cc b/engine/controllers/swagger.cc index 89733b65c..27d9b7b15 100644 --- a/engine/controllers/swagger.cc +++ b/engine/controllers/swagger.cc @@ -211,7 +211,7 @@ Json::Value SwaggerController::generateOpenAPISpec() { ["message"]["type"] = "string"; // UpdateModel Endpoint - Json::Value& update = spec["paths"]["/v1/models/{model}"]["post"]; + Json::Value& update = spec["paths"]["/v1/models/{model}"]["patch"]; update["summary"] = "Update model details"; update["description"] = "Update various attributes of a model based on the ModelConfig " From 8d8846d1ed41600f5e20cfb990e9c6daa5d9a75f Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 7 Oct 2024 16:57:38 +0700 Subject: [PATCH 03/12] fix: pull api --- engine/services/model_service.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index d9c2aa48f..1b0946133 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -310,7 +310,7 @@ cpp::result ModelService::DownloadModelFromCortexso( } std::string model_id{name + ":" + branch}; - auto on_finished = [&](const DownloadTask& finishedTask) { + auto on_finished = [&, model_id](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; From d626c52e5eda95f94c3fb7a29540f47c81ab1194 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 14 Oct 2024 13:43:43 +0700 Subject: [PATCH 04/12] chore: refactor server controller --- engine/controllers/server.cc | 434 ++++----------------------- engine/controllers/server.h | 65 +--- engine/main.cc | 10 +- engine/services/inference_service.cc | 380 +++++++++++++++++++++++ engine/services/inference_service.h | 76 +++++ engine/utils/cortex_utils.h | 5 - engine/utils/engine_constants.h | 7 +- 7 files changed, 526 insertions(+), 451 deletions(-) create mode 100644 engine/services/inference_service.cc create mode 100644 engine/services/inference_service.h diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index ffbf2cef3..0c41b0f74 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -9,19 +9,7 @@ using namespace inferences; using json = nlohmann::json; namespace inferences { -namespace { -// Need to change this after we rename repositories -std::string NormalizeEngine(const std::string& engine) { - if (engine == kLlamaEngine) { - return kLlamaRepo; - } else if (engine == kOnnxEngine) { - return kOnnxRepo; - } else if (engine == kTrtLlmEngine) { - return kTrtLlmRepo; - } - return engine; -}; -} // namespace + server::server() { #if defined(_WIN32) SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS); @@ -33,404 +21,118 @@ server::~server() {} void server::ChatCompletion( const HttpRequestPtr& req, std::function&& callback) { - std::string engine_type; - if (!HasFieldInReq(req, "engine")) { - engine_type = kLlamaRepo; - } else { - engine_type = - (*(req->getJsonObject())).get("engine", kLlamaRepo).asString(); - } - - auto ne = NormalizeEngine(engine_type); - - if (!IsEngineLoaded(ne)) { - Json::Value res; - res["message"] = "Engine is not loaded yet"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k409Conflict); + LOG_DEBUG << "Start chat completion"; + auto json_body = req->getJsonObject(); + bool is_stream = (*json_body).get("stream", false).asBool(); + auto q = std::make_shared(); + auto ir = inference_svc_.HandleChatCompletion(q, json_body); + if (ir.has_error()) { + auto err = ir.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err)); + resp->setStatusCode( + static_cast(std::get<0>(err)["status_code"].asInt())); callback(resp); - LOG_WARN << "Engine is not loaded yet"; return; } - - LOG_TRACE << "Start chat completion"; - auto json_body = req->getJsonObject(); - bool is_stream = (*json_body).get("stream", false).asBool(); - auto q = std::make_shared(); - std::get(engines_[ne].engine) - ->HandleChatCompletion(json_body, - [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); - LOG_TRACE << "Wait to chat completion responses"; + LOG_DEBUG << "Wait to chat completion responses"; if (is_stream) { ProcessStreamRes(std::move(callback), q); } else { ProcessNonStreamRes(std::move(callback), *q); } - LOG_TRACE << "Done chat completion"; + LOG_DEBUG << "Done chat completion"; } void server::Embedding(const HttpRequestPtr& req, std::function&& callback) { - auto engine_type = - (*(req->getJsonObject())).get("engine", kLlamaRepo).asString(); - auto ne = NormalizeEngine(engine_type); - if (!IsEngineLoaded(ne)) { - Json::Value res; - res["message"] = "Engine is not loaded yet"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k409Conflict); + LOG_TRACE << "Start embedding"; + auto q = std::make_shared(); + auto ir = inference_svc_.HandleEmbedding(q, req->getJsonObject()); + if (ir.has_error()) { + auto err = ir.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err)); + resp->setStatusCode( + static_cast(std::get<0>(err)["status_code"].asInt())); callback(resp); - LOG_WARN << "Engine is not loaded yet"; return; } - - LOG_TRACE << "Start embedding"; - SyncQueue q; - std::get(engines_[ne].engine) - ->HandleEmbedding(req->getJsonObject(), - [&q](Json::Value status, Json::Value res) { - q.push(std::make_pair(status, res)); - }); LOG_TRACE << "Wait to embedding"; - ProcessNonStreamRes(std::move(callback), q); + ProcessNonStreamRes(std::move(callback), *q); LOG_TRACE << "Done embedding"; } void server::UnloadModel( const HttpRequestPtr& req, std::function&& callback) { - std::string engine_type; - if (!HasFieldInReq(req, "engine")) { - engine_type = kLlamaRepo; - } else { - engine_type = - (*(req->getJsonObject())).get("engine", kLlamaRepo).asString(); - } - auto ne = NormalizeEngine(engine_type); - - if (!IsEngineLoaded(ne)) { - Json::Value res; - res["message"] = "Engine is not loaded yet"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k409Conflict); - callback(resp); - LOG_WARN << "Engine is not loaded yet"; - return; - } - LOG_TRACE << "Start unload model"; - std::get(engines_[ne].engine) - ->UnloadModel( - req->getJsonObject(), - [cb = std::move(callback)](Json::Value status, Json::Value res) { - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(static_cast( - status["status_code"].asInt())); - cb(resp); - }); - LOG_TRACE << "Done unload model"; + auto ir = inference_svc_.UnloadModel(req->getJsonObject()); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); + resp->setStatusCode( + static_cast(std::get<0>(ir)["status_code"].asInt())); + callback(resp); } void server::ModelStatus( const HttpRequestPtr& req, std::function&& callback) { - std::string engine_type; - if (!HasFieldInReq(req, "engine")) { - engine_type = kLlamaRepo; - } else { - engine_type = - (*(req->getJsonObject())).get("engine", kLlamaRepo).asString(); - } - - auto ne = NormalizeEngine(engine_type); - - if (!IsEngineLoaded(ne)) { - Json::Value res; - res["message"] = "Engine is not loaded yet"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k409Conflict); - callback(resp); - LOG_WARN << "Engine is not loaded yet"; - return; - } - - LOG_TRACE << "Start to get model status"; - std::get(engines_[ne].engine) - ->GetModelStatus( - req->getJsonObject(), - [cb = std::move(callback)](Json::Value status, Json::Value res) { - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(static_cast( - status["status_code"].asInt())); - cb(resp); - }); - LOG_TRACE << "Done get model status"; + auto ir = inference_svc_.GetModelStatus(req->getJsonObject()); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); + resp->setStatusCode( + static_cast(std::get<0>(ir)["status_code"].asInt())); + callback(resp); } void server::GetModels(const HttpRequestPtr& req, std::function&& callback) { - if (engines_.empty()) { - Json::Value res; - res["message"] = "Engine is not loaded yet"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k409Conflict); - callback(resp); - LOG_WARN << "Engine is not loaded yet"; - return; - } - LOG_TRACE << "Start to get models"; - Json::Value resp_data(Json::arrayValue); - for (auto const& [k, v] : engines_) { - auto e = std::get(v.engine); - if (e->IsSupported("GetModels")) { - e->GetModels(req->getJsonObject(), - [&resp_data](Json::Value status, Json::Value res) { - for (auto r : res["data"]) { - resp_data.append(r); - } - }); - } - } - Json::Value root; - root["data"] = resp_data; - root["object"] = "list"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(root); - resp->setStatusCode(drogon::HttpStatusCode::k200OK); + auto ir = inference_svc_.GetModels(req->getJsonObject()); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); + resp->setStatusCode( + static_cast(std::get<0>(ir)["status_code"].asInt())); callback(resp); - LOG_TRACE << "Done get models"; } void server::GetEngines( const HttpRequestPtr& req, std::function&& callback) { - Json::Value res; - Json::Value engine_array(Json::arrayValue); - for (const auto& [s, _] : engines_) { - Json::Value val; - val["id"] = s; - val["object"] = "engine"; - engine_array.append(val); - } - - res["object"] = "list"; - res["data"] = engine_array; - - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + auto ir = inference_svc_.GetEngines(req->getJsonObject()); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ir); callback(resp); } void server::FineTuning( const HttpRequestPtr& req, std::function&& callback) { - auto engine_type = - (*(req->getJsonObject())).get("engine", kPythonRuntimeRepo).asString(); - - if (engines_.find(engine_type) == engines_.end()) { - try { - std::string abs_path = - (getenv("ENGINE_PATH") ? getenv("ENGINE_PATH") - : cortex_utils::GetCurrentPath()) + - cortex_utils::kPythonRuntimeLibPath; - engines_[engine_type].dl = - std::make_unique(abs_path, "engine"); - } catch (const cortex_cpp::dylib::load_error& e) { - - LOG_ERROR << "Could not load engine: " << e.what(); - engines_.erase(engine_type); - - Json::Value res; - res["message"] = "Could not load engine " + engine_type; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k500InternalServerError); - callback(resp); - return; - } - - auto func = engines_[engine_type].dl->get_function( - "get_engine"); - engines_[engine_type].engine = func(); - LOG_INFO << "Loaded engine: " << engine_type; - } - - LOG_TRACE << "Start to fine-tuning"; - auto& en = std::get(engines_[engine_type].engine); - if (en->IsSupported("HandlePythonFileExecutionRequest")) { - en->HandlePythonFileExecutionRequest( - req->getJsonObject(), - [cb = std::move(callback)](Json::Value status, Json::Value res) { - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(static_cast( - status["status_code"].asInt())); - cb(resp); - }); - } else { - Json::Value res; - res["message"] = "Method is not supported yet"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k500InternalServerError); - callback(resp); - LOG_WARN << "Method is not supported yet"; - } + auto ir = inference_svc_.FineTuning(req->getJsonObject()); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); + resp->setStatusCode( + static_cast(std::get<0>(ir)["status_code"].asInt())); + callback(resp); LOG_TRACE << "Done fine-tuning"; } void server::LoadModel(const HttpRequestPtr& req, std::function&& callback) { - auto engine_type = - (*(req->getJsonObject())).get("engine", kLlamaRepo).asString(); - - auto ne = NormalizeEngine(engine_type); - - // We have not loaded engine yet, should load it before using it - if (engines_.find(ne) == engines_.end()) { - auto get_engine_path = [](std::string_view e) { - if (e == kLlamaRepo) { - return cortex_utils::kLlamaLibPath; - } else if (e == kOnnxRepo) { - return cortex_utils::kOnnxLibPath; - } else if (e == kTrtLlmRepo) { - return cortex_utils::kTensorrtLlmPath; - } - return cortex_utils::kLlamaLibPath; - }; - - try { - if (ne == kLlamaRepo) { - cortex::cpuid::CpuInfo cpu_info; - LOG_INFO << "CPU instruction set: " << cpu_info.to_string(); - } - - std::string abs_path = - (getenv("ENGINE_PATH") - ? getenv("ENGINE_PATH") - : file_manager_utils::GetCortexDataPath().string()) + - get_engine_path(ne); -#if defined(_WIN32) - // TODO(?) If we only allow to load an engine at a time, the logic is simpler. - // We would like to support running multiple engines at the same time. Therefore, - // the adding/removing dll directory logic is quite complicated: - // 1. If llamacpp is loaded and new requested engine is tensorrt-llm: - // Unload the llamacpp dll directory then load the tensorrt-llm - // 2. If tensorrt-llm is loaded and new requested engine is llamacpp: - // Do nothing, llamacpp can re-use tensorrt-llm dependencies (need to be tested careful) - // 3. Add dll directory if met other conditions - - auto add_dll = [this](const std::string& e_type, const std::string& p) { - auto ws = std::wstring(p.begin(), p.end()); - if (auto cookie = AddDllDirectory(ws.c_str()); cookie != 0) { - LOG_INFO << "Added dll directory: " << p; - engines_[e_type].cookie = cookie; - } else { - LOG_WARN << "Could not add dll directory: " << p; - } - }; - - if (IsEngineLoaded(kLlamaRepo) && ne == kTrtLlmRepo) { - // Remove llamacpp dll directory - if (!RemoveDllDirectory(engines_[kLlamaRepo].cookie)) { - LOG_INFO << "Could not remove dll directory: " << kLlamaRepo; - } else { - LOG_WARN << "Removed dll directory: " << kLlamaRepo; - } - - add_dll(ne, abs_path); - } else if (IsEngineLoaded(kTrtLlmRepo) && ne == kLlamaRepo) { - // Do nothing - } else { - add_dll(ne, abs_path); - } -#endif - engines_[ne].dl = std::make_unique(abs_path, "engine"); - - } catch (const cortex_cpp::dylib::load_error& e) { - LOG_ERROR << "Could not load engine: " << e.what(); - engines_.erase(ne); - - Json::Value res; - res["message"] = "Could not load engine " + engine_type; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k500InternalServerError); - callback(resp); - return; - } - cur_engine_type_ = ne; - - auto func = engines_[ne].dl->get_function("get_engine"); - engines_[ne].engine = func(); - - auto& en = std::get(engines_[ne].engine); - if (ne == kLlamaRepo) { //fix for llamacpp engine first - auto config = file_manager_utils::GetCortexConfig(); - if (en->IsSupported("SetFileLogger")) { - en->SetFileLogger(config.maxLogLines, - (std::filesystem::path(config.logFolderPath) / - std::filesystem::path(config.logLlamaCppPath)) - .string()); - } else { - LOG_WARN << "Method SetFileLogger is not supported yet"; - } - } - LOG_INFO << "Loaded engine: " << engine_type; - } - - LOG_TRACE << "Load model"; - auto& en = std::get(engines_[ne].engine); - en->LoadModel(req->getJsonObject(), [cb = std::move(callback)]( - Json::Value status, Json::Value res) { - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode( - static_cast(status["status_code"].asInt())); - cb(resp); - }); + auto ir = inference_svc_.LoadModel(req->getJsonObject()); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); + resp->setStatusCode( + static_cast(std::get<0>(ir)["status_code"].asInt())); + callback(resp); LOG_TRACE << "Done load model"; } void server::UnloadEngine( const HttpRequestPtr& req, std::function&& callback) { - std::string engine_type; - if (!HasFieldInReq(req, "engine")) { - engine_type = kLlamaRepo; - } else { - engine_type = - (*(req->getJsonObject())).get("engine", kLlamaRepo).asString(); - } - - auto ne = NormalizeEngine(engine_type); - - if (!IsEngineLoaded(ne)) { - Json::Value res; - res["message"] = "Engine is not loaded yet"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k409Conflict); - callback(resp); - LOG_WARN << "Engine is not loaded yet"; - return; - } - - EngineI* e = std::get(engines_[ne].engine); - delete e; -#if defined(_WIN32) - if (!RemoveDllDirectory(engines_[ne].cookie)) { - LOG_WARN << "Could not remove dll directory: " << engine_type; - } else { - LOG_INFO << "Removed dll directory: " << engine_type; - } -#endif - engines_.erase(ne); - LOG_INFO << "Unloaded engine " + engine_type; - Json::Value res; - res["message"] = "Unloaded engine " + engine_type; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k200OK); + auto ir = inference_svc_.UnloadEngine(req->getJsonObject()); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); + resp->setStatusCode( + static_cast(std::get<0>(ir)["status_code"].asInt())); callback(resp); } void server::ProcessStreamRes(std::function cb, - std::shared_ptr q) { + std::shared_ptr q) { auto err_or_done = std::make_shared(false); auto chunked_content_provider = [q, err_or_done](char* buf, std::size_t buf_size) -> std::size_t { @@ -464,7 +166,7 @@ void server::ProcessStreamRes(std::function cb, } void server::ProcessNonStreamRes(std::function cb, - SyncQueue& q) { + services::SyncQueue& q) { auto [status, res] = q.wait_and_pop(); auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode( @@ -472,32 +174,4 @@ void server::ProcessNonStreamRes(std::function cb, cb(resp); } -bool server::IsEngineLoaded(const std::string& e) { - return engines_.find(e) != engines_.end(); -} - -bool server::HasFieldInReq( - const HttpRequestPtr& req, - std::function& callback, - const std::string& field) { - if (auto o = req->getJsonObject(); !o || (*o)[field].isNull()) { - Json::Value res; - res["message"] = "No " + field + " field in request body"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k409Conflict); - callback(resp); - LOG_WARN << "No " << field << " field in request body"; - return false; - } - return true; -} - -bool server::HasFieldInReq(const HttpRequestPtr& req, - const std::string& field) { - if (auto o = req->getJsonObject(); !o || (*o)[field].isNull()) { - return false; - } - return true; -} - } // namespace inferences diff --git a/engine/controllers/server.h b/engine/controllers/server.h index f1fe89bd5..1eb4203ca 100644 --- a/engine/controllers/server.h +++ b/engine/controllers/server.h @@ -17,9 +17,7 @@ #include #include "common/base.h" -#include "cortex-common/EngineI.h" -#include "cortex-common/cortexpythoni.h" -#include "utils/dylib.h" +#include "services/inference_service.h" #include "utils/json.hpp" #ifndef SERVER_VERBOSE @@ -98,66 +96,11 @@ class server : public drogon::HttpController, private: void ProcessStreamRes(std::function cb, - std::shared_ptr q); + std::shared_ptr q); void ProcessNonStreamRes(std::function cb, - SyncQueue& q); - bool IsEngineLoaded(const std::string& e); - - bool HasFieldInReq(const HttpRequestPtr& req, - std::function& callback, - const std::string& field); - - bool HasFieldInReq(const HttpRequestPtr& req, const std::string& field); - - private: - struct SyncQueue { - void push(std::pair&& p) { - std::unique_lock l(mtx); - q.push(p); - cond.notify_one(); - } - - std::pair wait_and_pop() { - std::unique_lock l(mtx); - cond.wait(l, [this] { return !q.empty(); }); - auto res = q.front(); - q.pop(); - return res; - } - - std::mutex mtx; - std::condition_variable cond; - // Status and result - std::queue> q; - }; - struct StreamStatus { - void Done() { - std::unique_lock l(m); - stream_done = true; - cv.notify_all(); - } - - void Wait() { - std::unique_lock l(m); - cv.wait(l, [this] { return stream_done; }); - } - - private: - std::mutex m; - std::condition_variable cv; - bool stream_done = false; - }; + services::SyncQueue& q); private: - using EngineV = std::variant; - struct EngineInfo { - std::unique_ptr dl; - EngineV engine; -#if defined(_WIN32) - DLL_DIRECTORY_COOKIE cookie; -#endif - }; - std::unordered_map engines_; - std::string cur_engine_type_; + services::InferenceService inference_svc_; }; }; // namespace inferences diff --git a/engine/main.cc b/engine/main.cc index 985042a5d..43eecd9ec 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -46,8 +46,10 @@ void RunServer() { std::filesystem::path(config.logFolderPath) / std::filesystem::path(cortex_utils::logs_folder)); trantor::FileLogger asyncFileLogger; - asyncFileLogger.setFileName((std::filesystem::path(config.logFolderPath) / - std::filesystem::path(cortex_utils::logs_base_name)).string()); + asyncFileLogger.setFileName( + (std::filesystem::path(config.logFolderPath) / + std::filesystem::path(cortex_utils::logs_base_name)) + .string()); asyncFileLogger.setMaxLines(config.maxLogLines); // Keep last 100000 lines asyncFileLogger.startLogging(); trantor::Logger::setOutputFunction( @@ -111,8 +113,8 @@ int main(int argc, char* argv[]) { std::string py_home_path = (argc > 3) ? argv[3] : ""; std::unique_ptr dl; try { - std::string abs_path = cortex_utils::GetCurrentPath() + - cortex_utils::kPythonRuntimeLibPath; + std::string abs_path = + cortex_utils::GetCurrentPath() + kPythonRuntimeLibPath; dl = std::make_unique(abs_path, "engine"); } catch (const cortex_cpp::dylib::load_error& e) { LOG_ERROR << "Could not load engine: " << e.what(); diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc new file mode 100644 index 000000000..51457b931 --- /dev/null +++ b/engine/services/inference_service.cc @@ -0,0 +1,380 @@ +#include "inference_service.h" +#include "utils/cpuid/cpu_info.h" +#include "utils/engine_constants.h" +#include "utils/file_manager_utils.h" + +namespace services { + +namespace { +// Need to change this after we rename repositories +std::string NormalizeEngine(const std::string& engine) { + if (engine == kLlamaEngine) { + return kLlamaRepo; + } else if (engine == kOnnxEngine) { + return kOnnxRepo; + } else if (engine == kTrtLlmEngine) { + return kTrtLlmRepo; + } + return engine; +}; +} // namespace + +cpp::result InferenceService::HandleChatCompletion( + std::shared_ptr q, std::shared_ptr json_body) { + std::string engine_type; + if (!HasFieldInReq(json_body, "engine")) { + engine_type = kLlamaRepo; + } else { + engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); + } + auto ne = NormalizeEngine(engine_type); + if (!IsEngineLoaded(ne)) { + Json::Value res; + res["message"] = "Engine is not loaded yet"; + Json::Value stt; + stt["status_code"] = 409; + LOG_WARN << "Engine is not loaded yet"; + return cpp::fail(std::make_pair(stt, res)); + } + std::get(engines_[ne].engine) + ->HandleChatCompletion(json_body, + [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + return {}; +} + +cpp::result InferenceService::HandleEmbedding( + std::shared_ptr q, std::shared_ptr json_body) { + std::string engine_type; + if (!HasFieldInReq(json_body, "engine")) { + engine_type = kLlamaRepo; + } else { + engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); + } + + auto ne = NormalizeEngine(engine_type); + if (!IsEngineLoaded(ne)) { + Json::Value res; + res["message"] = "Engine is not loaded yet"; + Json::Value stt; + stt["status_code"] = 409; + LOG_WARN << "Engine is not loaded yet"; + return cpp::fail(std::make_pair(stt, res)); + } + std::get(engines_["llama-cpp"].engine) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + return {}; +} + +InferResult InferenceService::LoadModel( + std::shared_ptr json_body) { + std::string engine_type; + if (!HasFieldInReq(json_body, "engine")) { + engine_type = kLlamaRepo; + } else { + engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); + } + + auto ne = NormalizeEngine(engine_type); + Json::Value r; + Json::Value stt; + // We have not loaded engine yet, should load it before using it + if (engines_.find(ne) == engines_.end()) { + auto get_engine_path = [](std::string_view e) { + if (e == kLlamaRepo) { + return kLlamaLibPath; + } else if (e == kOnnxRepo) { + return kOnnxLibPath; + } else if (e == kTrtLlmRepo) { + return kTensorrtLlmPath; + } + return kLlamaLibPath; + }; + try { + if (ne == kLlamaRepo) { + cortex::cpuid::CpuInfo cpu_info; + LOG_INFO << "CPU instruction set: " << cpu_info.to_string(); + } + + std::string abs_path = + (getenv("ENGINE_PATH") + ? getenv("ENGINE_PATH") + : file_manager_utils::GetCortexDataPath().string()) + + get_engine_path(ne); +#if defined(_WIN32) + // TODO(?) If we only allow to load an engine at a time, the logic is simpler. + // We would like to support running multiple engines at the same time. Therefore, + // the adding/removing dll directory logic is quite complicated: + // 1. If llamacpp is loaded and new requested engine is tensorrt-llm: + // Unload the llamacpp dll directory then load the tensorrt-llm + // 2. If tensorrt-llm is loaded and new requested engine is llamacpp: + // Do nothing, llamacpp can re-use tensorrt-llm dependencies (need to be tested careful) + // 3. Add dll directory if met other conditions + + auto add_dll = [this](const std::string& e_type, const std::string& p) { + auto ws = std::wstring(p.begin(), p.end()); + if (auto cookie = AddDllDirectory(ws.c_str()); cookie != 0) { + LOG_INFO << "Added dll directory: " << p; + engines_[e_type].cookie = cookie; + } else { + LOG_WARN << "Could not add dll directory: " << p; + } + }; + + if (IsEngineLoaded(kLlamaRepo) && ne == kTrtLlmRepo) { + // Remove llamacpp dll directory + if (!RemoveDllDirectory(engines_[kLlamaRepo].cookie)) { + LOG_INFO << "Could not remove dll directory: " << kLlamaRepo; + } else { + LOG_WARN << "Removed dll directory: " << kLlamaRepo; + } + + add_dll(ne, abs_path); + } else if (IsEngineLoaded(kTrtLlmRepo) && ne == kLlamaRepo) { + // Do nothing + } else { + add_dll(ne, abs_path); + } +#endif + engines_[ne].dl = std::make_unique(abs_path, "engine"); + + } catch (const cortex_cpp::dylib::load_error& e) { + LOG_ERROR << "Could not load engine: " << e.what(); + engines_.erase(ne); + + r["message"] = "Could not load engine " + ne; + stt["status_code"] = 500; + return std::make_pair(stt, r); + } + + auto func = engines_[ne].dl->get_function("get_engine"); + engines_[ne].engine = func(); + + auto& en = std::get(engines_[ne].engine); + if (ne == kLlamaRepo) { //fix for llamacpp engine first + auto config = file_manager_utils::GetCortexConfig(); + if (en->IsSupported("SetFileLogger")) { + en->SetFileLogger(config.maxLogLines, + (std::filesystem::path(config.logFolderPath) / + std::filesystem::path(config.logLlamaCppPath)) + .string()); + } else { + LOG_WARN << "Method SetFileLogger is not supported yet"; + } + } + LOG_INFO << "Loaded engine: " << ne; + } + + // LOG_TRACE << "Load model"; + auto& en = std::get(engines_[ne].engine); + en->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + return std::make_pair(stt, r); +} + +InferResult InferenceService::UnloadModel( + std::shared_ptr json_body) { + std::string engine_type; + if (!HasFieldInReq(json_body, "engine")) { + engine_type = kLlamaRepo; + } else { + engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); + } + + auto ne = NormalizeEngine(engine_type); + Json::Value r; + Json::Value stt; + if (!IsEngineLoaded(ne)) { + r["message"] = "Engine is not loaded yet"; + stt["status_code"] = 409; + LOG_WARN << "Engine is not loaded yet"; + return std::make_pair(stt, r); + } + LOG_TRACE << "Start unload model"; + std::get(engines_[ne].engine) + ->UnloadModel(json_body, [&r, &stt](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + return std::make_pair(stt, r); +} + +InferResult InferenceService::GetModelStatus( + std::shared_ptr json_body) { + std::string engine_type; + if (!HasFieldInReq(json_body, "engine")) { + engine_type = kLlamaRepo; + } else { + engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); + } + + auto ne = NormalizeEngine(engine_type); + Json::Value r; + Json::Value stt; + + if (!IsEngineLoaded(ne)) { + r["message"] = "Engine is not loaded yet"; + stt["status_code"] = 409; + LOG_WARN << "Engine is not loaded yet"; + return std::make_pair(stt, r); + } + + LOG_TRACE << "Start to get model status"; + std::get(engines_[ne].engine) + ->GetModelStatus(json_body, + [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + return std::make_pair(stt, r); +} + +InferResult InferenceService::GetModels( + std::shared_ptr json_body) { + Json::Value r; + Json::Value stt; + if (engines_.empty()) { + r["message"] = "Engine is not loaded yet"; + stt["status_code"] = 409; + return std::make_pair(stt, r); + } + + LOG_TRACE << "Start to get models"; + Json::Value resp_data(Json::arrayValue); + for (auto const& [k, v] : engines_) { + auto e = std::get(v.engine); + if (e->IsSupported("GetModels")) { + e->GetModels(json_body, + [&resp_data](Json::Value status, Json::Value res) { + for (auto r : res["data"]) { + resp_data.append(r); + } + }); + } + } + Json::Value root; + root["data"] = resp_data; + root["object"] = "list"; + stt["status_code"] = 200; + return std::make_pair(stt, root); + // LOG_TRACE << "Done get models"; +} + +Json::Value InferenceService::GetEngines( + std::shared_ptr json_body) { + Json::Value res; + Json::Value engine_array(Json::arrayValue); + for (const auto& [s, _] : engines_) { + Json::Value val; + val["id"] = s; + val["object"] = "engine"; + engine_array.append(val); + } + + res["object"] = "list"; + res["data"] = engine_array; + return res; +} + +InferResult InferenceService::FineTuning( + std::shared_ptr json_body) { + std::string ne = kPythonRuntimeRepo; + Json::Value r; + Json::Value stt; + + if (engines_.find(ne) == engines_.end()) { + try { + std::string abs_path = + (getenv("ENGINE_PATH") + ? getenv("ENGINE_PATH") + : file_manager_utils::GetCortexDataPath().string()) + + kPythonRuntimeLibPath; + engines_[ne].dl = std::make_unique(abs_path, "engine"); + } catch (const cortex_cpp::dylib::load_error& e) { + + LOG_ERROR << "Could not load engine: " << e.what(); + engines_.erase(ne); + + Json::Value res; + r["message"] = "Could not load engine " + ne; + stt["status_code"] = 500; + return std::make_pair(stt, r); + } + + auto func = + engines_[ne].dl->get_function("get_engine"); + engines_[ne].engine = func(); + LOG_INFO << "Loaded engine: " << ne; + } + + LOG_TRACE << "Start to fine-tuning"; + auto& en = std::get(engines_[ne].engine); + if (en->IsSupported("HandlePythonFileExecutionRequest")) { + en->HandlePythonFileExecutionRequest( + json_body, [&r, &stt](Json::Value status, Json::Value res) { + r = res; + stt = status; + }); + } else { + LOG_WARN << "Method is not supported yet"; + r["message"] = "Method is not supported yet"; + stt["status_code"] = 500; + return std::make_pair(stt, r); + } + LOG_TRACE << "Done fine-tuning"; + return std::make_pair(stt, r); +} + +InferResult InferenceService::UnloadEngine( + std::shared_ptr json_body) { + std::string engine_type; + if (!HasFieldInReq(json_body, "engine")) { + engine_type = kLlamaRepo; + } else { + engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); + } + + auto ne = NormalizeEngine(engine_type); + Json::Value r; + Json::Value stt; + + if (!IsEngineLoaded(ne)) { + r["message"] = "Engine is not loaded yet"; + stt["status_code"] = 409; + LOG_WARN << "Engine is not loaded yet"; + return std::make_pair(stt, r); + } + + EngineI* e = std::get(engines_[ne].engine); + delete e; +#if defined(_WIN32) + if (!RemoveDllDirectory(engines_[ne].cookie)) { + LOG_WARN << "Could not remove dll directory: " << ne; + } else { + LOG_INFO << "Removed dll directory: " << ne; + } +#endif + engines_.erase(ne); + LOG_INFO << "Unloaded engine " + ne; + r["message"] = "Unloaded engine " + ne; + stt["status_code"] = 200; + return std::make_pair(stt, r); +} + +bool InferenceService::IsEngineLoaded(const std::string& e) { + return engines_.find(e) != engines_.end(); +} + +bool InferenceService::HasFieldInReq(std::shared_ptr json_body, + const std::string& field) { + if (!json_body || (*json_body)[field].isNull()) { + return false; + } + return true; +} +} // namespace services \ No newline at end of file diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h new file mode 100644 index 000000000..16c8e3b99 --- /dev/null +++ b/engine/services/inference_service.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include "common/base.h" +#include "cortex-common/EngineI.h" +#include "cortex-common/cortexpythoni.h" +#include "utils/dylib.h" +#include "utils/json.hpp" +#include "utils/result.hpp" + +namespace services { +// Status and result +using InferResult = std::pair; + +struct SyncQueue { + void push(InferResult&& p) { + std::unique_lock l(mtx); + q.push(p); + cond.notify_one(); + } + + InferResult wait_and_pop() { + std::unique_lock l(mtx); + cond.wait(l, [this] { return !q.empty(); }); + auto res = q.front(); + q.pop(); + return res; + } + + std::mutex mtx; + std::condition_variable cond; + std::queue q; +}; + +class InferenceService { + public: + cpp::result HandleChatCompletion( + std::shared_ptr q, std::shared_ptr json_body); + + cpp::result HandleEmbedding( + std::shared_ptr q, std::shared_ptr json_body); + + InferResult LoadModel(std::shared_ptr json_body); + + InferResult UnloadModel(std::shared_ptr json_body); + + InferResult GetModelStatus(std::shared_ptr json_body); + + InferResult GetModels(std::shared_ptr json_body); + + Json::Value GetEngines(std::shared_ptr json_body); + + InferResult FineTuning(std::shared_ptr json_body); + + InferResult UnloadEngine(std::shared_ptr json_body); + + private: + bool IsEngineLoaded(const std::string& e); + + bool HasFieldInReq(std::shared_ptr json_body, + const std::string& field); + + private: + using EngineV = std::variant; + struct EngineInfo { + std::unique_ptr dl; + EngineV engine; +#if defined(_WIN32) + DLL_DIRECTORY_COOKIE cookie; +#endif + }; + // TODO(sang) move engines_ into engine service? + std::unordered_map engines_; +}; +} // namespace services \ No newline at end of file diff --git a/engine/utils/cortex_utils.h b/engine/utils/cortex_utils.h index 9673f0c1a..f0c2a5c1b 100644 --- a/engine/utils/cortex_utils.h +++ b/engine/utils/cortex_utils.h @@ -27,11 +27,6 @@ #endif namespace cortex_utils { -constexpr static auto kLlamaLibPath = "/engines/cortex.llamacpp"; -constexpr static auto kPythonRuntimeLibPath = "/engines/cortex.python"; -constexpr static auto kOnnxLibPath = "/engines/cortex.onnx"; -constexpr static auto kTensorrtLlmPath = "/engines/cortex.tensorrt-llm"; - inline std::string models_folder = "./models"; inline std::string logs_folder = "./logs"; inline std::string logs_base_name = "./logs/cortex.log"; diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index 63334b860..dbc1e223b 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -7,4 +7,9 @@ constexpr const auto kTrtLlmEngine = "tensorrt-llm"; constexpr const auto kOnnxRepo = "cortex.onnx"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; constexpr const auto kTrtLlmRepo = "cortex.tensorrt-llm"; -constexpr const auto kPythonRuntimeRepo = "cortex.python"; \ No newline at end of file +constexpr const auto kPythonRuntimeRepo = "cortex.python"; + +constexpr const auto kLlamaLibPath = "/engines/cortex.llamacpp"; +constexpr const auto kPythonRuntimeLibPath = "/engines/cortex.python"; +constexpr const auto kOnnxLibPath = "/engines/cortex.onnx"; +constexpr const auto kTensorrtLlmPath = "/engines/cortex.tensorrt-llm"; \ No newline at end of file From 472f7162b9687d4e056ddb5b6eb1dde2fe988e2f Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 14 Oct 2024 14:02:55 +0700 Subject: [PATCH 05/12] fix: update status --- engine/services/inference_service.cc | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 51457b931..08dbc33aa 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -17,6 +17,11 @@ std::string NormalizeEngine(const std::string& engine) { } return engine; }; + +constexpr const int k200OK = 200; +constexpr const int k400BadRequest = 400; +constexpr const int k409Conflict = 409; +constexpr const int k500InternalServerError = 500; } // namespace cpp::result InferenceService::HandleChatCompletion( @@ -32,7 +37,7 @@ cpp::result InferenceService::HandleChatCompletion( Json::Value res; res["message"] = "Engine is not loaded yet"; Json::Value stt; - stt["status_code"] = 409; + stt["status_code"] = k409Conflict; LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } @@ -58,7 +63,7 @@ cpp::result InferenceService::HandleEmbedding( Json::Value res; res["message"] = "Engine is not loaded yet"; Json::Value stt; - stt["status_code"] = 409; + stt["status_code"] = k409Conflict; LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } @@ -146,7 +151,7 @@ InferResult InferenceService::LoadModel( engines_.erase(ne); r["message"] = "Could not load engine " + ne; - stt["status_code"] = 500; + stt["status_code"] = k500InternalServerError; return std::make_pair(stt, r); } @@ -191,7 +196,7 @@ InferResult InferenceService::UnloadModel( Json::Value stt; if (!IsEngineLoaded(ne)) { r["message"] = "Engine is not loaded yet"; - stt["status_code"] = 409; + stt["status_code"] = k409Conflict; LOG_WARN << "Engine is not loaded yet"; return std::make_pair(stt, r); } @@ -219,7 +224,7 @@ InferResult InferenceService::GetModelStatus( if (!IsEngineLoaded(ne)) { r["message"] = "Engine is not loaded yet"; - stt["status_code"] = 409; + stt["status_code"] = k409Conflict; LOG_WARN << "Engine is not loaded yet"; return std::make_pair(stt, r); } @@ -240,7 +245,7 @@ InferResult InferenceService::GetModels( Json::Value stt; if (engines_.empty()) { r["message"] = "Engine is not loaded yet"; - stt["status_code"] = 409; + stt["status_code"] = k409Conflict; return std::make_pair(stt, r); } @@ -260,7 +265,7 @@ InferResult InferenceService::GetModels( Json::Value root; root["data"] = resp_data; root["object"] = "list"; - stt["status_code"] = 200; + stt["status_code"] = k200OK; return std::make_pair(stt, root); // LOG_TRACE << "Done get models"; } @@ -302,7 +307,7 @@ InferResult InferenceService::FineTuning( Json::Value res; r["message"] = "Could not load engine " + ne; - stt["status_code"] = 500; + stt["status_code"] = k500InternalServerError; return std::make_pair(stt, r); } @@ -323,7 +328,7 @@ InferResult InferenceService::FineTuning( } else { LOG_WARN << "Method is not supported yet"; r["message"] = "Method is not supported yet"; - stt["status_code"] = 500; + stt["status_code"] = k500InternalServerError; return std::make_pair(stt, r); } LOG_TRACE << "Done fine-tuning"; @@ -345,7 +350,7 @@ InferResult InferenceService::UnloadEngine( if (!IsEngineLoaded(ne)) { r["message"] = "Engine is not loaded yet"; - stt["status_code"] = 409; + stt["status_code"] = k409Conflict; LOG_WARN << "Engine is not loaded yet"; return std::make_pair(stt, r); } @@ -362,7 +367,7 @@ InferResult InferenceService::UnloadEngine( engines_.erase(ne); LOG_INFO << "Unloaded engine " + ne; r["message"] = "Unloaded engine " + ne; - stt["status_code"] = 200; + stt["status_code"] = k200OK; return std::make_pair(stt, r); } From 6fd4696c5354dce81af04d4187b7a4d856c59764 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Mon, 14 Oct 2024 22:59:59 +0700 Subject: [PATCH 06/12] feat: mimic openai function calling api with llama3.1 --- engine/controllers/server.cc | 4 +- engine/services/inference_service.cc | 4 +- engine/services/inference_service.h | 1 - engine/utils/function_calling/common.h | 232 +++++++++++++++++++++++ engine/utils/function_calling/llama3.1.h | 43 +++++ 5 files changed, 281 insertions(+), 3 deletions(-) create mode 100644 engine/utils/function_calling/common.h create mode 100644 engine/utils/function_calling/llama3.1.h diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 0c41b0f74..966ba1403 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -5,7 +5,7 @@ #include "utils/cpuid/cpu_info.h" #include "utils/engine_constants.h" #include "utils/file_manager_utils.h" - +#include "utils/function_calling/common.h" using namespace inferences; using json = nlohmann::json; namespace inferences { @@ -168,6 +168,8 @@ void server::ProcessStreamRes(std::function cb, void server::ProcessNonStreamRes(std::function cb, services::SyncQueue& q) { auto [status, res] = q.wait_and_pop(); + function_calling_utils::PostProcessResponse(res); + std::cout << Json::StyledWriter().write(res) << std::endl; auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode( static_cast(status["status_code"].asInt())); diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 08dbc33aa..d708ecede 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -2,7 +2,7 @@ #include "utils/cpuid/cpu_info.h" #include "utils/engine_constants.h" #include "utils/file_manager_utils.h" - +#include "utils/function_calling/common.h" namespace services { namespace { @@ -41,6 +41,8 @@ cpp::result InferenceService::HandleChatCompletion( LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } + function_calling_utils::PreprocessRequest(json_body); + std::cout << Json::StyledWriter().write(*json_body) << std::endl; std::get(engines_[ne].engine) ->HandleChatCompletion(json_body, [q](Json::Value status, Json::Value res) { diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 05147c84f..277e85862 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -7,7 +7,6 @@ #include "cortex-common/cortexpythoni.h" #include "utils/dylib.h" #include "utils/result.hpp" - namespace services { // Status and result using InferResult = std::pair; diff --git a/engine/utils/function_calling/common.h b/engine/utils/function_calling/common.h new file mode 100644 index 000000000..b1279d3e2 --- /dev/null +++ b/engine/utils/function_calling/common.h @@ -0,0 +1,232 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "llama3.1.h" + +namespace function_calling_utils { +constexpr auto custom_template_function = ""; + +constexpr auto gamma_json = R"( +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\\x7F\x00-\x1F] | + "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= | " " | "\n" [ \t]{0,20})"; + +inline std::string ReplaceCustomFunctions(const std::string& original, + const std::string& replacement) { + std::string result = original; + + size_t pos = result.find(custom_template_function); + if (pos != std::string::npos) { + result.replace(pos, std::string(custom_template_function).length(), + replacement); + } + + return result; +} + +inline bool HasTools(const std::shared_ptr& request) { + return request->isMember("tools") && (*request)["tools"].isArray(); +} + +inline std::string ProcessTools(const std::shared_ptr& request) { + if (!HasTools(request)) { + return ""; + } + + std::ostringstream result; + result << "\n"; + + const Json::Value& tools = (*request)["tools"]; + for (const auto& tool : tools) { + if (tool["type"] == "function") { + const Json::Value& function = tool["function"]; + result << "Use the function '" << function["name"].asString() + << "' to: " << function["description"].asString() << "\n"; + + Json::FastWriter writer; + std::string jsonString = writer.write(tool); + result << jsonString << "\n"; + } + } + + return result.str(); +} + +inline Json::Value ParseMultipleFunctionStrings(const std::string& input) { + Json::Value results(Json::arrayValue); + + // Regular expression to match the function name and arguments + std::regex functionRegex("]+)>(.+?)"); + + // Iterator for regex matches + auto words_begin = + std::sregex_iterator(input.begin(), input.end(), functionRegex); + auto words_end = std::sregex_iterator(); + + for (std::sregex_iterator i = words_begin; i != words_end; ++i) { + std::smatch match = *i; + if (match.size() == 3) { + Json::Value function; + function["type"] = "function"; + function["function"]["name"] = match[1].str(); + function["function"]["arguments"] = match[2].str(); + results.append(function); + } + } + + return results; +} + +inline std::string ConvertJsonToFunctionStrings(const Json::Value& jsonArray) { + if (!jsonArray.isArray()) { + return ""; // Return empty string if input is not an array + } + + std::ostringstream result; + + for (const auto& function : jsonArray) { + auto function_json = function.get("function", {}); + if (function_json.isMember("name") && function_json.isMember("arguments")) { + result << "" + << function_json["arguments"].asString() << ""; + } + } + return result.str(); +} + +// Helper function to parse a JSON string to Json +inline Json::Value ParseJsonString(const std::string& jsonString) { + Json::Value root; + Json::Reader reader; + reader.parse(jsonString, root); + return root; +} + +inline std::string CreateCustomFunctionsString( + std::shared_ptr request) { + std::string customFunctions = ProcessTools(request); + if (customFunctions.empty()) { + return ""; // No custom functions found + } + + return "```\n" + customFunctions + "```"; +} + +inline void UpdateMessages(const std::string& system_prompt, + std::shared_ptr request) { + bool original_stream_config = (*request).get("stream", false).asBool(); + // (*request)["grammar"] = function_calling_utils::gamma_json; + (*request)["stream"] = + false; //when using function calling, disable stream automatically because we need to parse the response to get function name and params + if (!request->isMember("messages") || !(*request)["messages"].isArray() || + (*request)["messages"].empty()) { + // If no messages, add the system prompt as the first message + Json::Value systemMessage; + systemMessage["role"] = "system"; + systemMessage["content"] = system_prompt; + (*request)["messages"].append(systemMessage); + } else { + Json::Value& firstMessage = (*request)["messages"][0]; + if (firstMessage["role"] == "system") { + bool addCustomPrompt = + request->get("add_custom_system_prompt", false).asBool(); + if (addCustomPrompt) { + firstMessage["content"] = + system_prompt + "\n" + firstMessage["content"].asString(); + } + } else { + // If the first message is not a system message, prepend the system prompt + Json::Value systemMessage; + systemMessage["role"] = "system"; + systemMessage["content"] = system_prompt; + (*request)["messages"].insert(0, systemMessage); + } + Json::Value& lastMessage = + (*request)["messages"][(*request)["messages"].size() - 1]; + if (lastMessage.get("role", "") == "tool") { + lastMessage["role"] = function_calling_llama3_1_utils::tool_role; + (*request)["stream"] = + original_stream_config; // if role is tool then should restore stream config to original value + } + } + for (Json::Value& message : (*request)["messages"]) { + if (message["role"] == "assistant" && message.isMember("tool_calls")) { + const Json::Value& tool_calls = message["tool_calls"]; + if (!tool_calls.isNull() && tool_calls.isArray() && + tool_calls.size() > 0) { + message["content"] = ConvertJsonToFunctionStrings(tool_calls); + message["tool_calls"] = {}; + } + } + } +} +inline void PreprocessRequest(std::shared_ptr request) { + if (!function_calling_utils::HasTools(request)) { + return; // Exit if no tools present + } + + std::string customFunctionsString = + function_calling_utils::CreateCustomFunctionsString(request); + std::string new_system_prompt = + function_calling_utils::ReplaceCustomFunctions( + function_calling_llama3_1_utils::system_prompt, + customFunctionsString); + UpdateMessages(new_system_prompt, request); +} + +inline void PostProcessResponse(Json::Value& response) { + if (!response.isMember("choices") || !response["choices"].isArray() || + response["choices"].empty()) { + // If there are no choices or the structure is incorrect, do nothing + return; + } + + // Get a reference to the first choice + Json::Value& firstChoice = response["choices"][0]; + + // Check if the choice has a message with content + if (firstChoice.isMember("message") && + firstChoice["message"].isMember("content")) { + std::string content = firstChoice["message"]["content"].asString(); + + // Create a new structure for tool_calls + Json::Value toolCall = ParseMultipleFunctionStrings(content); + if (toolCall.size() > 0) { + // Add tool_calls to the message + firstChoice["finish_reason"] = "tool_calls"; + firstChoice["message"]["tool_calls"] = toolCall; + + // Clear the content as it's now represented in tool_calls + firstChoice["message"]["content"] = ""; + } + } + + // Add any additional post-processing logic here +} +} // namespace function_calling_utils diff --git a/engine/utils/function_calling/llama3.1.h b/engine/utils/function_calling/llama3.1.h new file mode 100644 index 000000000..d925605f7 --- /dev/null +++ b/engine/utils/function_calling/llama3.1.h @@ -0,0 +1,43 @@ +#pragma once + +namespace function_calling_llama3_1_utils { +constexpr auto system_prompt = R"( +Environment: ipython +Tools: brave_search, wolfram_alpha +Cutting Knowledge Date: December 2023 +Today Date: 20 September 2024 + +# Tool Instructions +- Always execute python code in messages that you share. +- When looking for real time information use relevant functions if available else fallback to brave_search + +You have access to the following CUSTOM functions: + + + + +If a you choose to call a CUSTOM function ONLY reply in the following format: +<{start_tag}={function_name}>{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query +- If can not find correct parameters corresponding to function, ask user again to provide. +- No explanation are needed when calling a function. + +You are a helpful assistant. +)"; + +constexpr auto tool_role = "<|eot_id|>\n<|start_header_id|>ipython<|end_header_id|>\n"; +} // namespace function_calling_llama3_1_utils From 7e39c782042fa257b9e7d12bba6a26331b3ac813 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Mon, 14 Oct 2024 23:04:21 +0700 Subject: [PATCH 07/12] chore: remove unnecessary cout --- engine/controllers/server.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 966ba1403..87094528b 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -169,7 +169,6 @@ void server::ProcessNonStreamRes(std::function cb, services::SyncQueue& q) { auto [status, res] = q.wait_and_pop(); function_calling_utils::PostProcessResponse(res); - std::cout << Json::StyledWriter().write(res) << std::endl; auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode( static_cast(status["status_code"].asInt())); From f2655cbea98538a26aedd392780c2ddda6a8c4e3 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Tue, 15 Oct 2024 12:02:29 +0700 Subject: [PATCH 08/12] feat: add tool choice option to api --- engine/services/inference_service.cc | 13 +++++--- engine/utils/function_calling/common.h | 41 +++++++++++++++++++++--- engine/utils/function_calling/llama3.1.h | 2 +- 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 37c049fe8..eebc7e2ee 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -44,12 +44,15 @@ cpp::result InferenceService::HandleChatCompletion( } function_calling_utils::PreprocessRequest(json_body); - + Json::Value tool_choice = json_body->get("tool_choice", Json::Value::null); std::get(engines_[ne].engine) - ->HandleChatCompletion(json_body, - [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); return {}; } diff --git a/engine/utils/function_calling/common.h b/engine/utils/function_calling/common.h index b1279d3e2..b5f682fd1 100644 --- a/engine/utils/function_calling/common.h +++ b/engine/utils/function_calling/common.h @@ -6,7 +6,7 @@ #include #include #include "llama3.1.h" - +#include "utils/logging_utils.h" namespace function_calling_utils { constexpr auto custom_template_function = ""; @@ -137,9 +137,26 @@ inline std::string CreateCustomFunctionsString( return "```\n" + customFunctions + "```"; } - -inline void UpdateMessages(const std::string& system_prompt, +inline bool IsValidToolChoiceFormat(const Json::Value& root) { + return root.isObject() && root.isMember("type") && root["type"].isString() && + root["type"].asString() == "function" && root.isMember("function") && + root["function"].isObject() && root["function"].isMember("name") && + root["function"]["name"].isString(); +} +inline void UpdateMessages(std::string& system_prompt, std::shared_ptr request) { + Json::Value tool_choice = request->get("tool_choice", "auto"); + if (tool_choice.isString() && tool_choice.asString() == "required") { + system_prompt += + "\n\nYou must use a function to answer the user's question."; + } else if (!tool_choice.isString()) { + + system_prompt += + "\n\nNow this is your first priority: You must call the function '" + + tool_choice["function"]["name"].asString() + + "' to answer the user's question."; + } + bool original_stream_config = (*request).get("stream", false).asBool(); // (*request)["grammar"] = function_calling_utils::gamma_json; (*request)["stream"] = @@ -190,7 +207,12 @@ inline void PreprocessRequest(std::shared_ptr request) { if (!function_calling_utils::HasTools(request)) { return; // Exit if no tools present } - + if (request->get("tool_choice", "auto").isString()) { + std::string tool_choice = request->get("tool_choice", "auto").asString(); + if (tool_choice == "none") { + return; // Exit if tool_choice is none + } + } std::string customFunctionsString = function_calling_utils::CreateCustomFunctionsString(request); std::string new_system_prompt = @@ -219,7 +241,16 @@ inline void PostProcessResponse(Json::Value& response) { Json::Value toolCall = ParseMultipleFunctionStrings(content); if (toolCall.size() > 0) { // Add tool_calls to the message - firstChoice["finish_reason"] = "tool_calls"; + if (response.get("tool_choice", "auto").isString()) { + std::string tool_choice = + response.get("tool_choice", "auto").asString(); + if (tool_choice == "auto") { + firstChoice["finish_reason"] = "tool_calls"; + } else { + firstChoice["finish_reason"] = "stop"; + } + } + firstChoice["message"]["tool_calls"] = toolCall; // Clear the content as it's now represented in tool_calls diff --git a/engine/utils/function_calling/llama3.1.h b/engine/utils/function_calling/llama3.1.h index d925605f7..5c2e6ffdb 100644 --- a/engine/utils/function_calling/llama3.1.h +++ b/engine/utils/function_calling/llama3.1.h @@ -16,7 +16,7 @@ You have access to the following CUSTOM functions: -If a you choose to call a CUSTOM function ONLY reply in the following format: +If a you choose to call a function ONLY reply in the following format: <{start_tag}={function_name}>{parameters}{end_tag} where From 0eb1369b528d851d9ff9db390023a60592e68b01 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Tue, 15 Oct 2024 13:02:38 +0700 Subject: [PATCH 09/12] feat: add unitest --- .../test/components/test_function_calling.cc | 157 ++++++++++++++++++ engine/utils/function_calling/common.h | 3 +- 2 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 engine/test/components/test_function_calling.cc diff --git a/engine/test/components/test_function_calling.cc b/engine/test/components/test_function_calling.cc new file mode 100644 index 000000000..729775a06 --- /dev/null +++ b/engine/test/components/test_function_calling.cc @@ -0,0 +1,157 @@ +#include +#include "gtest/gtest.h" +#include "json/json.h" +#include "utils/function_calling/common.h" + +class FunctionCallingUtilsTest : public ::testing::Test { + protected: + std::shared_ptr createTestRequest() { + auto request = std::make_shared(); + (*request)["tools"] = Json::Value(Json::arrayValue); + return request; + } +}; + +TEST_F(FunctionCallingUtilsTest, ReplaceCustomFunctions) { + std::string original = "Test placeholder"; + std::string replacement = "Custom function"; + std::string result = + function_calling_utils::ReplaceCustomFunctions(original, replacement); + EXPECT_EQ(result, "Test Custom function placeholder"); +} + +TEST_F(FunctionCallingUtilsTest, HasTools) { + auto request = createTestRequest(); + EXPECT_FALSE(function_calling_utils::HasTools(request)); + + (*request)["tools"].append(Json::Value()); + EXPECT_TRUE(function_calling_utils::HasTools(request)); + + (*request)["tools"] = "random"; + EXPECT_FALSE(function_calling_utils::HasTools(request)); + + (*request)["tools"] = Json::Value::null; + EXPECT_FALSE(function_calling_utils::HasTools(request)); +} + +TEST_F(FunctionCallingUtilsTest, ProcessTools) { + auto request = createTestRequest(); + Json::Value tool; + tool["type"] = "function"; + tool["function"]["name"] = "test_function"; + tool["function"]["description"] = "Test description"; + (*request)["tools"].append(tool); + + std::string result = function_calling_utils::ProcessTools(request); + EXPECT_TRUE( + result.find("Use the function 'test_function' to: Test description") != + std::string::npos); +} + +TEST_F(FunctionCallingUtilsTest, ParseMultipleFunctionStrings) { + std::string input = + "{\"arg\":\"value1\"}{\"arg\":\"value2\"}"; + Json::Value result = + function_calling_utils::ParseMultipleFunctionStrings(input); + + ASSERT_EQ(result.size(), 2); + EXPECT_EQ(result[0]["function"]["name"].asString(), "func1"); + EXPECT_EQ(result[0]["function"]["arguments"].asString(), + "{\"arg\":\"value1\"}"); + EXPECT_EQ(result[1]["function"]["name"].asString(), "func2"); + EXPECT_EQ(result[1]["function"]["arguments"].asString(), + "{\"arg\":\"value2\"}"); +} + +TEST_F(FunctionCallingUtilsTest, ConvertJsonToFunctionStrings) { + Json::Value jsonArray(Json::arrayValue); + Json::Value function1, function2; + function1["function"]["name"] = "func1"; + function1["function"]["arguments"] = "{\"arg\":\"value1\"}"; + function2["function"]["name"] = "func2"; + function2["function"]["arguments"] = "{\"arg\":\"value2\"}"; + jsonArray.append(function1); + jsonArray.append(function2); + + std::string result = + function_calling_utils::ConvertJsonToFunctionStrings(jsonArray); + EXPECT_EQ(result, + "{\"arg\":\"value1\"}{\"arg\":\"value2\"}"); +} + +TEST_F(FunctionCallingUtilsTest, CreateCustomFunctionsString) { + auto request = createTestRequest(); + Json::Value tool; + tool["type"] = "function"; + tool["function"]["name"] = "test_function"; + tool["function"]["description"] = "Test description"; + (*request)["tools"].append(tool); + + std::string result = + function_calling_utils::CreateCustomFunctionsString(request); + EXPECT_TRUE(result.find("```") != std::string::npos); + EXPECT_TRUE( + result.find("Use the function 'test_function' to: Test description") != + std::string::npos); +} + +TEST_F(FunctionCallingUtilsTest, IsValidToolChoiceFormat) { + Json::Value validTool; + validTool["type"] = "function"; + validTool["function"]["name"] = "test_function"; + EXPECT_TRUE(function_calling_utils::IsValidToolChoiceFormat(validTool)); + + Json::Value invalidTool; + EXPECT_FALSE(function_calling_utils::IsValidToolChoiceFormat(invalidTool)); +} + +TEST_F(FunctionCallingUtilsTest, UpdateMessages) { + auto request = createTestRequest(); + std::string system_prompt = "Original prompt"; + (*request)["messages"] = Json::Value(Json::arrayValue); + + function_calling_utils::UpdateMessages(system_prompt, request); + + ASSERT_TRUE((*request)["messages"].isArray()); + EXPECT_EQ((*request)["messages"][0]["role"].asString(), "system"); + EXPECT_EQ((*request)["messages"][0]["content"].asString(), system_prompt); +} + +TEST_F(FunctionCallingUtilsTest, PreprocessRequest) { + auto request = createTestRequest(); + Json::Value tool; + tool["type"] = "function"; + tool["function"]["name"] = "test_function"; + tool["function"]["description"] = "Test description"; + (*request)["tools"].append(tool); + + function_calling_utils::PreprocessRequest(request); + + ASSERT_TRUE((*request)["messages"].isArray()); + EXPECT_TRUE((*request)["messages"][0]["content"].asString().find( + "Test description") != std::string::npos); +} + +TEST_F(FunctionCallingUtilsTest, PostProcessResponse) { + Json::Value response; + response["choices"] = Json::Value(Json::arrayValue); + Json::Value choice; + choice["message"]["content"] = + "{\"arg\":\"value\"}"; + response["choices"].append(choice); + + function_calling_utils::PostProcessResponse(response); + + EXPECT_EQ(response["choices"][0]["message"]["content"].asString(), ""); + EXPECT_TRUE(response["choices"][0]["message"]["tool_calls"].isArray()); + EXPECT_EQ( + response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] + .asString(), + "test_function"); + EXPECT_EQ(response["choices"][0]["message"]["tool_calls"][0]["function"] + ["arguments"] + .asString(), + "{\"arg\":\"value\"}"); +} \ No newline at end of file diff --git a/engine/utils/function_calling/common.h b/engine/utils/function_calling/common.h index b5f682fd1..d70e366c9 100644 --- a/engine/utils/function_calling/common.h +++ b/engine/utils/function_calling/common.h @@ -51,7 +51,8 @@ inline std::string ReplaceCustomFunctions(const std::string& original, } inline bool HasTools(const std::shared_ptr& request) { - return request->isMember("tools") && (*request)["tools"].isArray(); + return request->isMember("tools") && (*request)["tools"].isArray() && + (*request)["tools"].size() > 0; } inline std::string ProcessTools(const std::shared_ptr& request) { From 97b8e5d185ec452c8737ff6b820b4ef407b069e4 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Tue, 15 Oct 2024 17:01:03 +0700 Subject: [PATCH 10/12] chore: format code --- engine/test/components/test_function_calling.cc | 2 +- engine/utils/function_calling/common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/test/components/test_function_calling.cc b/engine/test/components/test_function_calling.cc index 729775a06..7a4810b29 100644 --- a/engine/test/components/test_function_calling.cc +++ b/engine/test/components/test_function_calling.cc @@ -29,7 +29,7 @@ TEST_F(FunctionCallingUtilsTest, HasTools) { (*request)["tools"] = "random"; EXPECT_FALSE(function_calling_utils::HasTools(request)); - + (*request)["tools"] = Json::Value::null; EXPECT_FALSE(function_calling_utils::HasTools(request)); } diff --git a/engine/utils/function_calling/common.h b/engine/utils/function_calling/common.h index d70e366c9..cd47ab529 100644 --- a/engine/utils/function_calling/common.h +++ b/engine/utils/function_calling/common.h @@ -173,7 +173,7 @@ inline void UpdateMessages(std::string& system_prompt, Json::Value& firstMessage = (*request)["messages"][0]; if (firstMessage["role"] == "system") { bool addCustomPrompt = - request->get("add_custom_system_prompt", false).asBool(); + request->get("add_custom_system_prompt", true).asBool(); if (addCustomPrompt) { firstMessage["content"] = system_prompt + "\n" + firstMessage["content"].asString(); From 7c7834473cbde2e9cfbced8226a9dcceb8b93a9b Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Wed, 16 Oct 2024 16:38:09 +0700 Subject: [PATCH 11/12] feat: function calling in user message --- engine/controllers/swagger.cc | 11 ++++++ engine/services/inference_service.cc | 1 + engine/utils/function_calling/common.h | 48 ++++++++++++++++++-------- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/engine/controllers/swagger.cc b/engine/controllers/swagger.cc index 527ae60b8..cef339761 100644 --- a/engine/controllers/swagger.cc +++ b/engine/controllers/swagger.cc @@ -630,10 +630,21 @@ Json::Value SwaggerController::generateOpenAPISpec() { "#/components/schemas/ChatMessage"; schemas["ChatCompletionRequest"]["properties"]["stream"]["type"] = "boolean"; schemas["ChatCompletionRequest"]["properties"]["engine"]["type"] = "string"; + schemas["ChatCompletionRequest"]["properties"]["tools"]["type"] = "array"; + schemas["ChatCompletionRequest"]["properties"]["tools"]["items"]["$ref"] = + "#/components/schemas/ToolsCall"; + schemas["ChatCompletionRequest"]["properties"]["tools_call_in_user_message"] + ["type"] = "boolean"; + schemas["ChatCompletionRequest"]["properties"]["tools_call_in_user_message"] + ["default"] = false; + schemas["ToolsCall"]["type"] = "object"; schemas["ChatMessage"]["type"] = "object"; schemas["ChatMessage"]["properties"]["role"]["type"] = "string"; schemas["ChatMessage"]["properties"]["content"]["type"] = "string"; + schemas["ChatMessage"]["properties"]["tools"]["type"] = "array"; + schemas["ChatMessage"]["properties"]["tools"]["items"]["$ref"] = + "#/components/schemas/ToolsCall"; schemas["ChatCompletionResponse"]["type"] = "object"; // Add properties based on your implementation diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index eebc7e2ee..8634640fb 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -44,6 +44,7 @@ cpp::result InferenceService::HandleChatCompletion( } function_calling_utils::PreprocessRequest(json_body); + std::cout<<"json_body: "<toStyledString()<get("tool_choice", Json::Value::null); std::get(engines_[ne].engine) ->HandleChatCompletion( diff --git a/engine/utils/function_calling/common.h b/engine/utils/function_calling/common.h index cd47ab529..d01f22423 100644 --- a/engine/utils/function_calling/common.h +++ b/engine/utils/function_calling/common.h @@ -51,8 +51,9 @@ inline std::string ReplaceCustomFunctions(const std::string& original, } inline bool HasTools(const std::shared_ptr& request) { - return request->isMember("tools") && (*request)["tools"].isArray() && - (*request)["tools"].size() > 0; + return (request->isMember("tools") && (*request)["tools"].isArray() && + (*request)["tools"].size() > 0) || + request->get("tools_call_in_user_message", false).asBool(); } inline std::string ProcessTools(const std::shared_ptr& request) { @@ -149,7 +150,7 @@ inline void UpdateMessages(std::string& system_prompt, Json::Value tool_choice = request->get("tool_choice", "auto"); if (tool_choice.isString() && tool_choice.asString() == "required") { system_prompt += - "\n\nYou must use a function to answer the user's question."; + "\n\nYou must call a function to answer the user's question."; } else if (!tool_choice.isString()) { system_prompt += @@ -158,10 +159,14 @@ inline void UpdateMessages(std::string& system_prompt, "' to answer the user's question."; } + bool tools_call_in_user_message = + request->get("tools_call_in_user_message", false).asBool(); + bool original_stream_config = (*request).get("stream", false).asBool(); // (*request)["grammar"] = function_calling_utils::gamma_json; (*request)["stream"] = false; //when using function calling, disable stream automatically because we need to parse the response to get function name and params + if (!request->isMember("messages") || !(*request)["messages"].isArray() || (*request)["messages"].empty()) { // If no messages, add the system prompt as the first message @@ -170,21 +175,34 @@ inline void UpdateMessages(std::string& system_prompt, systemMessage["content"] = system_prompt; (*request)["messages"].append(systemMessage); } else { - Json::Value& firstMessage = (*request)["messages"][0]; - if (firstMessage["role"] == "system") { - bool addCustomPrompt = - request->get("add_custom_system_prompt", true).asBool(); - if (addCustomPrompt) { - firstMessage["content"] = - system_prompt + "\n" + firstMessage["content"].asString(); + + if (tools_call_in_user_message) { + for (Json::Value& message : (*request)["messages"]) { + if (message["role"] == "user" && message.isMember("tools") && + message["tools"].isArray() && message["tools"].size() > 0) { + message["content"] = system_prompt + "\n User question: " + + message["content"].asString(); + } } } else { - // If the first message is not a system message, prepend the system prompt - Json::Value systemMessage; - systemMessage["role"] = "system"; - systemMessage["content"] = system_prompt; - (*request)["messages"].insert(0, systemMessage); + Json::Value& firstMessage = (*request)["messages"][0]; + if (firstMessage["role"] == "system") { + bool addCustomPrompt = + request->get("add_custom_system_prompt", true).asBool(); + if (addCustomPrompt) { + firstMessage["content"] = + system_prompt + "\n" + firstMessage["content"].asString(); + } + } else { + // If the first message is not a system message, prepend the system prompt + Json::Value systemMessage; + systemMessage["role"] = "system"; + systemMessage["content"] = system_prompt; + (*request)["messages"].insert(0, systemMessage); + } } + + // transform last message role to tool if it is a function call Json::Value& lastMessage = (*request)["messages"][(*request)["messages"].size() - 1]; if (lastMessage.get("role", "") == "tool") { From c11aad21e6c80889247fd77d3f6c5ebb7bf76114 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> Date: Wed, 16 Oct 2024 16:46:02 +0700 Subject: [PATCH 12/12] Update inference_service.cc --- engine/services/inference_service.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 8634640fb..a8d9a3166 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -44,7 +44,6 @@ cpp::result InferenceService::HandleChatCompletion( } function_calling_utils::PreprocessRequest(json_body); - std::cout<<"json_body: "<toStyledString()<get("tool_choice", Json::Value::null); std::get(engines_[ne].engine) ->HandleChatCompletion( @@ -390,4 +389,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr json_body, } return true; } -} // namespace services \ No newline at end of file +} // namespace services