diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 57e456768..24b5c9718 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -77,7 +77,8 @@ void llamaCPP::chatCompletion( const auto &jsonBody = req->getJsonObject(); std::string formatted_output = - "Below is a conversation between an AI system named ASSISTANT and USER\n"; + "Below is a conversation between an AI system named " + ai_prompt + + " and " + user_prompt + "\n"; json data; json stopWords; @@ -94,9 +95,19 @@ void llamaCPP::chatCompletion( const Json::Value &messages = (*jsonBody)["messages"]; for (const auto &message : messages) { - std::string role = message["role"].asString(); + std::string input_role = message["role"].asString(); + std::string role; + if (input_role == "user") { + role = user_prompt; + } else if (input_role == "assistant") { + role = ai_prompt; + } else if (input_role == "system") { + role = system_prompt; + } else { + role = input_role; + } std::string content = message["content"].asString(); - formatted_output += role + ": " + content + "\n"; + formatted_output += role + content + "\n"; } formatted_output += "assistant:"; @@ -105,8 +116,7 @@ void llamaCPP::chatCompletion( stopWords.push_back(stop_word.asString()); } // specify default stop words - stopWords.push_back("user:"); - stopWords.push_back("### USER:"); + stopWords.push_back(user_prompt); data["stop"] = stopWords; } @@ -202,19 +212,19 @@ void llamaCPP::loadModel( LOG_INFO << "Drogon thread is:" << drogon_thread; if (jsonBody) { params.model = (*jsonBody)["llama_model_path"].asString(); - params.n_gpu_layers = (*jsonBody)["ngl"].asInt(); - params.n_ctx = (*jsonBody)["ctx_len"].asInt(); - params.embedding = (*jsonBody)["embedding"].asBool(); + params.n_gpu_layers = (*jsonBody).get("ngl", 100).asInt(); + params.n_ctx = (*jsonBody).get("ctx_len", 2048).asInt(); + params.embedding = (*jsonBody).get("embedding", true).asBool(); // Check if n_parallel exists in jsonBody, if not, set to drogon_thread - if ((*jsonBody).isMember("n_parallel")) { - params.n_parallel = (*jsonBody)["n_parallel"].asInt(); - } else { - params.n_parallel = drogon_thread; - } + + params.n_parallel = (*jsonBody).get("n_parallel", drogon_thread).asInt(); params.cont_batching = (*jsonBody)["cont_batching"].asBool(); - // params.n_threads = (*jsonBody)["n_threads"].asInt(); - // params.n_threads_batch = params.n_threads; + + this->user_prompt = (*jsonBody).get("user_prompt", "USER: ").asString(); + this->ai_prompt = (*jsonBody).get("ai_prompt", "ASSISTANT: ").asString(); + this->system_prompt = + (*jsonBody).get("system_prompt", "ASSISTANT's RULE: ").asString(); } #ifdef GGML_USE_CUBLAS LOG_INFO << "Setting up GGML CUBLAS PARAMS"; diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 192ec358f..037cae926 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -2142,5 +2142,8 @@ class llamaCPP : public drogon::HttpController { size_t sent_count = 0; size_t sent_token_probs_index = 0; std::thread backgroundThread; + std::string user_prompt; + std::string ai_prompt; + std::string system_prompt; }; }; // namespace inferences