Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 72 additions & 40 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
#include "llamaCPP.h"

#include <trantor/utils/SerialTaskQueue.h>

#include "llama.h"
#include "log.h"
#include "utils/nitro_utils.h"

using namespace inferences;
using json = nlohmann::json;

/**
* Queue to handle the inference task, this is to ensure that the inference
* task is handled in a sequential manner
*/
static trantor::SerialTaskQueue queue("worker");

/**
* The state of the inference task
*/
enum InferenceStatus {
PENDING,
RUNNING,
FINISHED
};

/**
* There is a need to save state of current ongoing inference status of a
* handler, this struct is to solve that issue
Expand All @@ -15,8 +33,8 @@ using json = nlohmann::json;
*/
struct inferenceState {
bool is_stopped = false;
bool is_streaming = false;
int task_id;
InferenceStatus inferenceStatus = PENDING;
llamaCPP *instance;

inferenceState(llamaCPP *inst) : instance(inst) {}
Expand All @@ -35,7 +53,7 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
* Check if model already loaded if not return message to user
* @param callback the function to return message to user
*/
void llamaCPP::checkModelLoaded(
bool llamaCPP::checkModelLoaded(
std::function<void(const HttpResponsePtr &)> &callback) {
if (!llama.model_loaded_external) {
Json::Value jsonResp;
Expand All @@ -44,8 +62,9 @@ void llamaCPP::checkModelLoaded(
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
resp->setStatusCode(drogon::k409Conflict);
callback(resp);
return;
return false;
}
return true;
}

Json::Value create_embedding_payload(const std::vector<float> &embedding,
Expand All @@ -70,7 +89,6 @@ std::string create_full_return_json(const std::string &id,
const std::string &system_fingerprint,
int prompt_tokens, int completion_tokens,
Json::Value finish_reason = Json::Value()) {

Json::Value root;

root["id"] = id;
Expand Down Expand Up @@ -163,9 +181,11 @@ void llamaCPP::inference(

const auto &jsonBody = req->getJsonObject();
// Check if model is loaded
checkModelLoaded(callback);

inferenceImpl(jsonBody, callback);
if(checkModelLoaded(callback)) {
// Model is loaded
// Do Inference
inferenceImpl(jsonBody, callback);
}
}

void llamaCPP::inferenceImpl(
Expand Down Expand Up @@ -318,28 +338,24 @@ void llamaCPP::inferenceImpl(
auto state = create_inference_state(this);
auto chunked_content_provider =
[state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
if (!state->is_streaming) {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);
state->instance->single_queue_is_busy = true;

if(state->inferenceStatus == PENDING) {
state->inferenceStatus = RUNNING;
}

if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
}
if (state->is_stopped) {
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
}

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
// Update streaming state to being streamed
state->is_streaming = true;
const std::string to_send = result.result_json["content"];
const std::string str =
"data: " +
Expand All @@ -363,35 +379,48 @@ void llamaCPP::inferenceImpl(
LOG_INFO << "reached result stop";
state->is_stopped = true;
state->instance->llama.request_cancel(state->task_id);
state->is_streaming = false;
state->instance->single_queue_is_busy = false;

return nRead;
}
return nRead;
} else {
if (state->instance->llama.params.n_parallel == 1) {
while (state->instance->single_queue_is_busy) {
LOG_INFO << "Waiting for task to be released status:"
<< state->instance->single_queue_is_busy;
std::this_thread::sleep_for(std::chrono::milliseconds(
500)); // Waiting in 500 miliseconds step
}

// Make sure nBufferSize is not zero
// Otherwise it stop streaming
if(!nRead) {
state->instance->single_queue_is_busy = false;
}
std::string str = "\n\n";
std::size_t nRead = str.size();
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "Failing retrying now";

return nRead;
}
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
};
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);

// Run task in serial queue
queue.runTaskInQueue([callback, state, data,
chunked_content_provider]() {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);

state->instance->single_queue_is_busy = true;

// Start streaming response
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);

int retries = 0;

// Since this is an async task, we will wait for the task to be completed
while (state->instance->single_queue_is_busy && retries < 10) {
// Should wait chunked_content_provider lambda to be called within 3s
if(state->inferenceStatus == PENDING) {
retries += 1;
}
LOG_INFO << "Wait for task to be released:" << state->task_id;
std::this_thread::sleep_for(std::chrono::milliseconds(300));
}

state->inferenceStatus = FINISHED;
});
return;
} else {
Json::Value respData;
Expand Down Expand Up @@ -423,11 +452,14 @@ void llamaCPP::inferenceImpl(
void llamaCPP::embedding(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
checkModelLoaded(callback);
const auto &jsonBody = req->getJsonObject();

embeddingImpl(jsonBody, callback);
return;
// Check if model is loaded
if(checkModelLoaded(callback)) {
// Model is loaded
const auto &jsonBody = req->getJsonObject();
// Run embedding
embeddingImpl(jsonBody, callback);
return;
}
}

void llamaCPP::embeddingImpl(
Expand Down
2 changes: 1 addition & 1 deletion controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -2571,7 +2571,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP>, public ChatProvider {
std::function<void(const HttpResponsePtr &)> &callback);
void embeddingImpl(std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback);
void checkModelLoaded(std::function<void(const HttpResponsePtr &)> &callback);
bool checkModelLoaded(std::function<void(const HttpResponsePtr &)> &callback);
void warmupModel();
void backgroundTask();
void stopBackgroundTask();
Expand Down