Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 32 additions & 29 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP* instance) {
* @param callback the function to return message to user
*/
bool llamaCPP::CheckModelLoaded(
std::function<void(const HttpResponsePtr&)>& callback) {
const std::function<void(const HttpResponsePtr&)>& callback) {
if (!llama.model_loaded_external) {
LOG_ERROR << "Model has not been loaded";
Json::Value jsonResp;
Expand Down Expand Up @@ -180,13 +180,13 @@ void llamaCPP::ChatCompletion(
if (CheckModelLoaded(callback)) {
// Model is loaded
// Do Inference
InferenceImpl(std::move(completion), callback);
InferenceImpl(std::move(completion), std::move(callback));
}
}

void llamaCPP::InferenceImpl(
inferences::ChatCompletionRequest&& completion,
std::function<void(const HttpResponsePtr&)>& callback) {
std::function<void(const HttpResponsePtr&)>&& callback) {
std::string formatted_output = pre_prompt;
int request_id = ++no_of_requests;
LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request";
Expand Down Expand Up @@ -404,14 +404,14 @@ void llamaCPP::InferenceImpl(
};
// Queued task
state->instance->queue->runTaskInQueue(
[callback, state, data, chunked_content_provider, request_id]() {
[cb = std::move(callback), state, data, chunked_content_provider, request_id]() {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);

// Start streaming response
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);
cb(resp);

int retries = 0;

Expand All @@ -434,28 +434,31 @@ void llamaCPP::InferenceImpl(
LOG_INFO_REQUEST(request_id) << "Inference completed";
});
} else {
Json::Value respData;
int task_id = llama.request_completion(data, false, false, -1);
LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone";
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
std::string to_send = result.result_json["content"];
nitro_utils::ltrim(to_send);
respData = create_full_return_json(
nitro_utils::generate_random_string(20), "_", to_send, "_",
prompt_tokens, predicted_tokens);
} else {
respData["message"] = "Internal error during inference";
LOG_ERROR_REQUEST(request_id) << "Error during inference";
}
auto resp = nitro_utils::nitroHttpJsonResponse(respData);
callback(resp);
LOG_INFO_REQUEST(request_id) << "Inference completed";
}
queue->runTaskInQueue(
[this, request_id, cb = std::move(callback), d = std::move(data)]() {
Json::Value respData;
int task_id = llama.request_completion(d, false, false, -1);
LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone";
if (!json_value(d, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
std::string to_send = result.result_json["content"];
nitro_utils::ltrim(to_send);
respData = create_full_return_json(
nitro_utils::generate_random_string(20), "_", to_send, "_",
prompt_tokens, predicted_tokens);
} else {
respData["message"] = "Internal error during inference";
LOG_ERROR_REQUEST(request_id) << "Error during inference";
}
auto resp = nitro_utils::nitroHttpJsonResponse(respData);
cb(resp);
LOG_INFO_REQUEST(request_id) << "Inference completed";
}
});
}
}

Expand All @@ -467,14 +470,14 @@ void llamaCPP::Embedding(
// Model is loaded
const auto& jsonBody = req->getJsonObject();
// Run embedding
EmbeddingImpl(jsonBody, callback);
EmbeddingImpl(jsonBody, std::move(callback));
return;
}
}

void llamaCPP::EmbeddingImpl(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr&)>& callback) {
std::function<void(const HttpResponsePtr&)>&& callback) {
int request_id = ++no_of_requests;
LOG_INFO_REQUEST(request_id) << "Generating reponse for embedding request";
// Queue embedding task
Expand Down
6 changes: 3 additions & 3 deletions controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,

bool LoadModelImpl(std::shared_ptr<Json::Value> jsonBody);
void InferenceImpl(inferences::ChatCompletionRequest&& completion,
std::function<void(const HttpResponsePtr&)>& callback);
std::function<void(const HttpResponsePtr&)>&& callback);
void EmbeddingImpl(std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr&)>& callback);
bool CheckModelLoaded(std::function<void(const HttpResponsePtr&)>& callback);
std::function<void(const HttpResponsePtr&)>&& callback);
bool CheckModelLoaded(const std::function<void(const HttpResponsePtr&)>& callback);
void WarmupModel();
void BackgroundTask();
void StopBackgroundTask();
Expand Down