Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using json = nlohmann::json;
/**
* The state of the inference task
*/
enum InferenceStatus { PENDING, RUNNING, FINISHED };
enum InferenceStatus { PENDING, RUNNING, EOS, FINISHED };

/**
* There is a need to save state of current ongoing inference status of a
Expand All @@ -21,7 +21,7 @@ enum InferenceStatus { PENDING, RUNNING, FINISHED };
*/
struct inferenceState {
int task_id;
InferenceStatus inferenceStatus = PENDING;
InferenceStatus inference_status = PENDING;
llamaCPP *instance;

inferenceState(llamaCPP *inst) : instance(inst) {}
Expand Down Expand Up @@ -111,7 +111,6 @@ std::string create_full_return_json(const std::string &id,
std::string create_return_json(const std::string &id, const std::string &model,
const std::string &content,
Json::Value finish_reason = Json::Value()) {

Json::Value root;

root["id"] = id;
Expand Down Expand Up @@ -167,7 +166,6 @@ void llamaCPP::warmupModel() {
void llamaCPP::inference(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {

const auto &jsonBody = req->getJsonObject();
// Check if model is loaded
if (checkModelLoaded(callback)) {
Expand All @@ -180,7 +178,6 @@ void llamaCPP::inference(
void llamaCPP::inferenceImpl(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback) {

std::string formatted_output = pre_prompt;

json data;
Expand Down Expand Up @@ -218,7 +215,6 @@ void llamaCPP::inferenceImpl(
};

if (!llama.multimodal) {

for (const auto &message : messages) {
std::string input_role = message["role"].asString();
std::string role;
Expand All @@ -243,7 +239,6 @@ void llamaCPP::inferenceImpl(
}
formatted_output += ai_prompt;
} else {

data["image_data"] = json::array();
for (const auto &message : messages) {
std::string input_role = message["role"].asString();
Expand Down Expand Up @@ -327,18 +322,33 @@ void llamaCPP::inferenceImpl(
auto state = create_inference_state(this);
auto chunked_content_provider =
[state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
if (state->inferenceStatus == PENDING) {
state->inferenceStatus = RUNNING;
} else if (state->inferenceStatus == FINISHED) {
if (state->inference_status == PENDING) {
state->inference_status = RUNNING;
} else if (state->inference_status == FINISHED) {
return 0;
}

if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->inferenceStatus = FINISHED;
state->inference_status = FINISHED;
return 0;
}

if (state->inference_status == EOS) {
LOG_INFO << "End of result";
const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_", "",
"stop") +
"\n\n" + "data: [DONE]" + "\n\n";

LOG_VERBOSE("data stream", {{"to_send", str}});
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
state->inference_status = FINISHED;
return nRead;
}

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
const std::string to_send = result.result_json["content"];
Expand All @@ -352,28 +362,22 @@ void llamaCPP::inferenceImpl(
memcpy(pBuffer, str.data(), nRead);

if (result.stop) {
const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_",
"", "stop") +
"\n\n" + "data: [DONE]" + "\n\n";

LOG_VERBOSE("data stream", {{"to_send", str}});
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->inferenceStatus = FINISHED;
state->inference_status = EOS;
return nRead;
}

// Make sure nBufferSize is not zero
// Otherwise it stop streaming
if (!nRead) {
state->inferenceStatus = FINISHED;
state->inference_status = FINISHED;
}

return nRead;
} else {
LOG_INFO << "Error during inference";
}
state->inferenceStatus = FINISHED;
state->inference_status = FINISHED;
return 0;
};
// Queued task
Expand All @@ -391,16 +395,17 @@ void llamaCPP::inferenceImpl(

// Since this is an async task, we will wait for the task to be
// completed
while (state->inferenceStatus != FINISHED && retries < 10) {
while (state->inference_status != FINISHED && retries < 10) {
// Should wait chunked_content_provider lambda to be called within
// 3s
if (state->inferenceStatus == PENDING) {
if (state->inference_status == PENDING) {
retries += 1;
}
if (state->inferenceStatus != RUNNING)
if (state->inference_status != RUNNING)
LOG_INFO << "Wait for task to be released:" << state->task_id;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
LOG_INFO << "Task completed, release it";
// Request completed, release it
state->instance->llama.request_cancel(state->task_id);
});
Expand Down Expand Up @@ -445,7 +450,6 @@ void llamaCPP::embedding(
void llamaCPP::embeddingImpl(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback) {

// Queue embedding task
auto state = create_inference_state(this);

Expand Down Expand Up @@ -532,7 +536,6 @@ void llamaCPP::modelStatus(
void llamaCPP::loadModel(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {

if (llama.model_loaded_external) {
LOG_INFO << "model loaded";
Json::Value jsonResp;
Expand Down Expand Up @@ -561,7 +564,6 @@ void llamaCPP::loadModel(
}

bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {

gpt_params params;
// By default will setting based on number of handlers
if (jsonBody) {
Expand All @@ -570,11 +572,9 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
params.mmproj = jsonBody->operator[]("mmproj").asString();
}
if (!jsonBody->operator[]("grp_attn_n").isNull()) {

params.grp_attn_n = jsonBody->operator[]("grp_attn_n").asInt();
}
if (!jsonBody->operator[]("grp_attn_w").isNull()) {

params.grp_attn_w = jsonBody->operator[]("grp_attn_w").asInt();
}
if (!jsonBody->operator[]("mlock").isNull()) {
Expand Down