Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 59 additions & 75 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
#include "llamaCPP.h"

#include <trantor/utils/SerialTaskQueue.h>

#include "llama.h"
#include "log.h"
#include "utils/nitro_utils.h"

using namespace inferences;
using json = nlohmann::json;

/**
* Queue to handle the inference task, this is to ensure that the inference
* task is handled in a sequential manner
*/
static trantor::SerialTaskQueue queue("worker");

/**
* The state of the inference task
*/
Expand All @@ -32,7 +24,6 @@ enum InferenceStatus {
* associated with.
*/
struct inferenceState {
bool is_stopped = false;
int task_id;
InferenceStatus inferenceStatus = PENDING;
llamaCPP *instance;
Expand Down Expand Up @@ -150,7 +141,7 @@ std::string create_return_json(const std::string &id, const std::string &model,
return Json::writeString(writer, root);
}

llamaCPP::llamaCPP() {
llamaCPP::llamaCPP(): queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP")) {
// Some default values for now below
log_disable(); // Disable the log to file feature, reduce bloat for
// target
Expand Down Expand Up @@ -341,18 +332,17 @@ void llamaCPP::inferenceImpl(

if(state->inferenceStatus == PENDING) {
state->inferenceStatus = RUNNING;
} else if (state->inferenceStatus == FINISHED) {
return 0;
}

if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
state->instance->single_queue_is_busy = false;
return 0;
}
if (state->is_stopped) {
state->instance->single_queue_is_busy = false;
state->inferenceStatus = FINISHED;
return 0;
}


task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
Expand All @@ -377,31 +367,27 @@ void llamaCPP::inferenceImpl(
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->is_stopped = true;
state->instance->llama.request_cancel(state->task_id);
state->instance->single_queue_is_busy = false;
state->inferenceStatus = FINISHED;
}

// Make sure nBufferSize is not zero
// Otherwise it stop streaming
if(!nRead) {
state->instance->single_queue_is_busy = false;
state->inferenceStatus = FINISHED;
}

return nRead;
}
state->instance->single_queue_is_busy = false;
state->inferenceStatus = FINISHED;
return 0;
};

// Run task in serial queue
queue.runTaskInQueue([callback, state, data,
// Queued task
state->instance->queue->runTaskInQueue([callback, state, data,
chunked_content_provider]() {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);

state->instance->single_queue_is_busy = true;

// Start streaming response
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
Expand All @@ -410,16 +396,14 @@ void llamaCPP::inferenceImpl(
int retries = 0;

// Since this is an async task, we will wait for the task to be completed
while (state->instance->single_queue_is_busy && retries < 10) {
while (state->inferenceStatus != FINISHED && retries < 10) {
// Should wait chunked_content_provider lambda to be called within 3s
if(state->inferenceStatus == PENDING) {
retries += 1;
}
LOG_INFO << "Wait for task to be released:" << state->task_id;
std::this_thread::sleep_for(std::chrono::milliseconds(300));
}

state->inferenceStatus = FINISHED;
});
return;
} else {
Expand Down Expand Up @@ -466,59 +450,51 @@ void llamaCPP::embeddingImpl(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback) {

Json::Value responseData(Json::arrayValue);
// Queue embedding task
auto state = create_inference_state(this);
if (jsonBody->isMember("input")) {
// If single queue is busy, we will wait if not we will just go ahead and
// process and make it busy, and yet i'm aware not DRY, i have the same
// stuff on chatcompletion as well
if (state->instance->llama.params.n_parallel == 1) {
while (state->instance->single_queue_is_busy) {
LOG_INFO << "Waiting for task to be released status:"
<< state->instance->single_queue_is_busy;
std::this_thread::sleep_for(
std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step
}
}
const Json::Value &input = (*jsonBody)["input"];
if (input.isString()) {
// Process the single string input
state->task_id = llama.request_completion(
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
state->instance->single_queue_is_busy = true;
task_result result = llama.next_result(state->task_id);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));
} else if (input.isArray()) {
// Process each element in the array input
for (const auto &elem : input) {
if (elem.isString()) {
const int task_id = llama.request_completion(
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1);
task_result result = llama.next_result(task_id);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));

state->instance->queue->runTaskInQueue([this, state, jsonBody, callback]() {
Json::Value responseData(Json::arrayValue);

if (jsonBody->isMember("input")) {
const Json::Value &input = (*jsonBody)["input"];
if (input.isString()) {
// Process the single string input
state->task_id = llama.request_completion(
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
task_result result = llama.next_result(state->task_id);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));
} else if (input.isArray()) {
// Process each element in the array input
for (const auto &elem : input) {
if (elem.isString()) {
const int task_id = llama.request_completion(
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true,
-1);
task_result result = llama.next_result(task_id);
std::vector<float> embedding_result =
result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));
}
}
}
}
}

// We already got result of the embedding so no longer busy
state->instance->single_queue_is_busy = false;

auto resp = nitro_utils::nitroHttpResponse();
Json::Value root;
root["data"] = responseData;
root["model"] = "_";
root["object"] = "list";
Json::Value usage;
usage["prompt_tokens"] = 0;
usage["total_tokens"] = 0;
root["usage"] = usage;

resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
resp->setContentTypeString("application/json");
callback(resp);
auto resp = nitro_utils::nitroHttpResponse();
Json::Value root;
root["data"] = responseData;
root["model"] = "_";
root["object"] = "list";
Json::Value usage;
usage["prompt_tokens"] = 0;
usage["total_tokens"] = 0;
root["usage"] = usage;

resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
resp->setContentTypeString("application/json");
callback(resp);
});
}

void llamaCPP::unloadModel(
Expand All @@ -539,6 +515,7 @@ void llamaCPP::unloadModel(
callback(resp);
return;
}

void llamaCPP::modelStatus(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
Expand All @@ -555,6 +532,7 @@ void llamaCPP::modelStatus(
callback(resp);
return;
}

void llamaCPP::loadModel(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
Expand Down Expand Up @@ -674,6 +652,12 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
}
llama.initialize();

if (queue != nullptr) {
delete queue;
}

queue = new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP");

llama.model_loaded_external = true;

LOG_INFO << "Started background task here!";
Expand Down
8 changes: 6 additions & 2 deletions controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "common/base.h"
#include "utils/json.hpp"
#include <trantor/utils/ConcurrentTaskQueue.h>

// auto generated files (update with ./deps.sh)

Expand Down Expand Up @@ -2562,10 +2563,13 @@ class llamaCPP : public drogon::HttpController<llamaCPP>, public ChatProvider {
bool caching_enabled;
std::atomic<int> no_of_chats = 0;
int clean_cache_threshold;
std::atomic<bool> single_queue_is_busy; // This value only used under the
// condition n_parallel is 1
std::string grammar_file_content;

/**
* Queue to handle the inference tasks
*/
trantor::ConcurrentTaskQueue *queue;

bool loadModelImpl(std::shared_ptr<Json::Value> jsonBody);
void inferenceImpl(std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback);
Expand Down