diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 4df773e0b..f3ea1623b 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -10,11 +10,7 @@ using json = nlohmann::json; /** * The state of the inference task */ -enum InferenceStatus { - PENDING, - RUNNING, - FINISHED -}; +enum InferenceStatus { PENDING, RUNNING, FINISHED }; /** * There is a need to save state of current ongoing inference status of a @@ -141,7 +137,9 @@ std::string create_return_json(const std::string &id, const std::string &model, return Json::writeString(writer, root); } -llamaCPP::llamaCPP(): queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP")) { +llamaCPP::llamaCPP() + : queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, + "llamaCPP")) { // Some default values for now below log_disable(); // Disable the log to file feature, reduce bloat for // target @@ -172,7 +170,7 @@ void llamaCPP::inference( const auto &jsonBody = req->getJsonObject(); // Check if model is loaded - if(checkModelLoaded(callback)) { + if (checkModelLoaded(callback)) { // Model is loaded // Do Inference inferenceImpl(jsonBody, callback); @@ -329,8 +327,7 @@ void llamaCPP::inferenceImpl( auto state = create_inference_state(this); auto chunked_content_provider = [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t { - - if(state->inferenceStatus == PENDING) { + if (state->inferenceStatus == PENDING) { state->inferenceStatus = RUNNING; } else if (state->inferenceStatus == FINISHED) { return 0; @@ -341,7 +338,7 @@ void llamaCPP::inferenceImpl( state->inferenceStatus = FINISHED; return 0; } - + task_result result = state->instance->llama.next_result(state->task_id); if (!result.error) { const std::string to_send = result.result_json["content"]; @@ -367,10 +364,10 @@ void llamaCPP::inferenceImpl( LOG_INFO << "reached result stop"; state->inferenceStatus = FINISHED; } - + // Make sure nBufferSize is not zero // Otherwise it stop streaming - if(!nRead) { + if (!nRead) { state->inferenceStatus = FINISHED; } @@ -380,31 +377,33 @@ void llamaCPP::inferenceImpl( return 0; }; // Queued task - state->instance->queue->runTaskInQueue([callback, state, data, - chunked_content_provider]() { - state->task_id = - state->instance->llama.request_completion(data, false, false, -1); - - // Start streaming response - auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, - "chat_completions.txt"); - callback(resp); - - int retries = 0; - - // Since this is an async task, we will wait for the task to be completed - while (state->inferenceStatus != FINISHED && retries < 10) { - // Should wait chunked_content_provider lambda to be called within 3s - if(state->inferenceStatus == PENDING) { - retries += 1; - } - if(state->inferenceStatus != RUNNING) - LOG_INFO << "Wait for task to be released:" << state->task_id; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - // Request completed, release it - state->instance->llama.request_cancel(state->task_id); - }); + state->instance->queue->runTaskInQueue( + [callback, state, data, chunked_content_provider]() { + state->task_id = + state->instance->llama.request_completion(data, false, false, -1); + + // Start streaming response + auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, + "chat_completions.txt"); + callback(resp); + + int retries = 0; + + // Since this is an async task, we will wait for the task to be + // completed + while (state->inferenceStatus != FINISHED && retries < 10) { + // Should wait chunked_content_provider lambda to be called within + // 3s + if (state->inferenceStatus == PENDING) { + retries += 1; + } + if (state->inferenceStatus != RUNNING) + LOG_INFO << "Wait for task to be released:" << state->task_id; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + // Request completed, release it + state->instance->llama.request_cancel(state->task_id); + }); } else { Json::Value respData; auto resp = nitro_utils::nitroHttpResponse(); @@ -434,7 +433,7 @@ void llamaCPP::embedding( const HttpRequestPtr &req, std::function &&callback) { // Check if model is loaded - if(checkModelLoaded(callback)) { + if (checkModelLoaded(callback)) { // Model is loaded const auto &jsonBody = req->getJsonObject(); // Run embedding diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index ad1889be0..f2adea352 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -2526,10 +2526,11 @@ class llamaCPP : public drogon::HttpController, public ChatProvider { // Openai compatible path ADD_METHOD_TO(llamaCPP::inference, "/v1/chat/completions", Post); - // ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/chat/completions", Options); NOTE: prelight will be added back when browser support is properly planned + // ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/chat/completions", Options); + // NOTE: prelight will be added back when browser support is properly planned ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post); - //ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/embeddings", Options); + // ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/embeddings", Options); // PATH_ADD("/llama/chat_completion", Post); METHOD_LIST_END diff --git a/controllers/prelight.cc b/controllers/prelight.cc new file mode 100644 index 000000000..9c4c63095 --- /dev/null +++ b/controllers/prelight.cc @@ -0,0 +1,13 @@ +#include "prelight.h" + +void prelight::handlePrelight( + const HttpRequestPtr &req, + std::function &&callback) { + auto resp = drogon::HttpResponse::newHttpResponse(); + resp->setStatusCode(drogon::HttpStatusCode::k200OK); + resp->addHeader("Access-Control-Allow-Origin", "*"); + resp->addHeader("Access-Control-Allow-Methods", "POST, OPTIONS"); + resp->addHeader("Access-Control-Allow-Headers", "*"); + callback(resp); +} + diff --git a/controllers/prelight.h b/controllers/prelight.h new file mode 100644 index 000000000..387f5a51b --- /dev/null +++ b/controllers/prelight.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +using namespace drogon; + +class prelight : public drogon::HttpController { +public: + METHOD_LIST_BEGIN + ADD_METHOD_TO(prelight::handlePrelight, "/v1/chat/completions", Options); + ADD_METHOD_TO(prelight::handlePrelight, "/v1/embeddings", Options); + ADD_METHOD_TO(prelight::handlePrelight, "/v1/audio/transcriptions", Options); + ADD_METHOD_TO(prelight::handlePrelight, "/v1/audio/translations", Options); + METHOD_LIST_END + + void handlePrelight(const HttpRequestPtr &req, + std::function &&callback); +};