Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 108 additions & 44 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <regex>
#include <string>
#include <thread>
#include <trantor/utils/Logger.h>

using namespace inferences;
using json = nlohmann::json;
Expand All @@ -28,6 +27,45 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {

// --------------------------------------------

std::string create_full_return_json(const std::string &id,
const std::string &model,
const std::string &content,
const std::string &system_fingerprint,
int prompt_tokens, int completion_tokens,
Json::Value finish_reason = Json::Value()) {

Json::Value root;

root["id"] = id;
root["model"] = model;
root["created"] = static_cast<int>(std::time(nullptr));
root["object"] = "chat.completion";
root["system_fingerprint"] = system_fingerprint;

Json::Value choicesArray(Json::arrayValue);
Json::Value choice;

choice["index"] = 0;
Json::Value message;
message["role"] = "assistant";
message["content"] = content;
choice["message"] = message;
choice["finish_reason"] = finish_reason;

choicesArray.append(choice);
root["choices"] = choicesArray;

Json::Value usage;
usage["prompt_tokens"] = prompt_tokens;
usage["completion_tokens"] = completion_tokens;
usage["total_tokens"] = prompt_tokens + completion_tokens;
root["usage"] = usage;

Json::StreamWriterBuilder writer;
writer["indentation"] = ""; // Compact output
return Json::writeString(writer, root);
}

std::string create_return_json(const std::string &id, const std::string &model,
const std::string &content,
Json::Value finish_reason = Json::Value()) {
Expand Down Expand Up @@ -82,9 +120,9 @@ void llamaCPP::chatCompletion(
json data;
json stopWords;
// To set default value
data["stream"] = true;

if (jsonBody) {
data["stream"] = (*jsonBody).get("stream", false).asBool();
data["n_predict"] = (*jsonBody).get("max_tokens", 500).asInt();
data["top_p"] = (*jsonBody).get("top_p", 0.95).asFloat();
data["temperature"] = (*jsonBody).get("temperature", 0.8).asFloat();
Expand Down Expand Up @@ -119,62 +157,87 @@ void llamaCPP::chatCompletion(
data["stop"] = stopWords;
}

bool is_streamed = data["stream"];

const int task_id = llama.request_completion(data, false, false);
LOG_INFO << "Resolved request for task_id:" << task_id;

auto state = createState(task_id, this);
if (is_streamed) {
auto state = createState(task_id, this);

auto chunked_content_provider =
[state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
return 0;
}
if (state->isStopped) {
return 0;
}

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
const std::string to_send = result.result_json["content"];
const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_",
to_send) +
"\n\n";

std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
auto chunked_content_provider =
[state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
return 0;
}
if (state->isStopped) {
return 0;
}

if (result.stop) {
task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
const std::string to_send = result.result_json["content"];
const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_", "",
"stop") +
"\n\n" + "data: [DONE]" + "\n\n";
create_return_json(nitro_utils::generate_random_string(20), "_",
to_send) +
"\n\n";

LOG_VERBOSE("data stream", {{"to_send", str}});
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->isStopped = true;
state->instance->llama.request_cancel(state->task_id);

if (result.stop) {
const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_",
"", "stop") +
"\n\n" + "data: [DONE]" + "\n\n";

LOG_VERBOSE("data stream", {{"to_send", str}});
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->isStopped = true;
state->instance->llama.request_cancel(state->task_id);
return nRead;
}
return nRead;
} else {
return 0;
}
return nRead;
} else {
return 0;
}
return 0;
};
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);
};
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);

return;
return;
} else {
Json::Value respData;
auto resp = nitro_utils::nitroHttpResponse();
respData["testing"] = "thunghiem value moi";
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
std::string full_return =
create_full_return_json(nitro_utils::generate_random_string(20),
"_", result.result_json["content"], "_",
prompt_tokens, predicted_tokens);
resp->setBody(full_return);
} else {
resp->setBody("internal error during inference");
return;
}
callback(resp);
return;
}
}
}

void llamaCPP::embedding(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
Expand Down Expand Up @@ -262,7 +325,8 @@ void llamaCPP::loadModel(
this->pre_prompt =
(*jsonBody)
.get("pre_prompt",
"A chat between a curious user and an artificial intelligence "
"A chat between a curious user and an artificial "
"intelligence "
"assistant. The assistant follows the given rules no matter "
"what.\\n")
.asString();
Expand Down