Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ endif()
aux_source_directory(controllers CTL_SRC)
aux_source_directory(common COMMON_SRC)
aux_source_directory(context CONTEXT_SRC)
aux_source_directory(models MODEL_SRC)
# aux_source_directory(filters FILTER_SRC) aux_source_directory(plugins
# PLUGIN_SRC) aux_source_directory(models MODEL_SRC)
# PLUGIN_SRC)

# drogon_create_views(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/views
# ${CMAKE_CURRENT_BINARY_DIR}) use the following line to create views with
Expand Down
8 changes: 4 additions & 4 deletions common/base.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <drogon/HttpController.h>
#include <models/chat_completion_request.h>

using namespace drogon;

Expand All @@ -8,9 +9,8 @@ class BaseModel {
virtual ~BaseModel() {}

// Model management
virtual void LoadModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
virtual void LoadModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
virtual void UnloadModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
Expand All @@ -25,7 +25,7 @@ class BaseChatCompletion {

// General chat method
virtual void ChatCompletion(
const HttpRequestPtr& req,
inferences::ChatCompletionRequest &&completion,
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
};

Expand Down
241 changes: 119 additions & 122 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#include "llamaCPP.h"


#include <iostream>
#include <fstream>
#include <iostream>
#include "log.h"
#include "utils/nitro_utils.h"

// External
#include "common.h"
Expand Down Expand Up @@ -175,19 +173,18 @@ void llamaCPP::WarmupModel() {
}

void llamaCPP::ChatCompletion(
const HttpRequestPtr& req,
inferences::ChatCompletionRequest&& completion,
std::function<void(const HttpResponsePtr&)>&& callback) {
const auto& jsonBody = req->getJsonObject();
// Check if model is loaded
if (CheckModelLoaded(callback)) {
// Model is loaded
// Do Inference
InferenceImpl(jsonBody, callback);
InferenceImpl(std::move(completion), callback);
}
}

void llamaCPP::InferenceImpl(
std::shared_ptr<Json::Value> jsonBody,
inferences::ChatCompletionRequest&& completion,
std::function<void(const HttpResponsePtr&)>& callback) {
std::string formatted_output = pre_prompt;

Expand All @@ -196,131 +193,131 @@ void llamaCPP::InferenceImpl(
int no_images = 0;
// To set default value

if (jsonBody) {
// Increase number of chats received and clean the prompt
no_of_chats++;
if (no_of_chats % clean_cache_threshold == 0) {
LOG_INFO << "Clean cache threshold reached!";
llama.kv_cache_clear();
LOG_INFO << "Cache cleaned";
}

// Default values to enable auto caching
data["cache_prompt"] = caching_enabled;
data["n_keep"] = -1;

// Passing load value
data["repeat_last_n"] = this->repeat_last_n;

data["stream"] = (*jsonBody).get("stream", false).asBool();
data["n_predict"] = (*jsonBody).get("max_tokens", 500).asInt();
data["top_p"] = (*jsonBody).get("top_p", 0.95).asFloat();
data["temperature"] = (*jsonBody).get("temperature", 0.8).asFloat();
data["frequency_penalty"] =
(*jsonBody).get("frequency_penalty", 0).asFloat();
data["presence_penalty"] = (*jsonBody).get("presence_penalty", 0).asFloat();
const Json::Value& messages = (*jsonBody)["messages"];
// Increase number of chats received and clean the prompt
no_of_chats++;
if (no_of_chats % clean_cache_threshold == 0) {
LOG_INFO << "Clean cache threshold reached!";
llama.kv_cache_clear();
LOG_INFO << "Cache cleaned";
}

if (!grammar_file_content.empty()) {
data["grammar"] = grammar_file_content;
};
// Default values to enable auto caching
data["cache_prompt"] = caching_enabled;
data["n_keep"] = -1;

// Passing load value
data["repeat_last_n"] = this->repeat_last_n;

LOG_INFO << "Messages:" << completion.messages.toStyledString();
LOG_INFO << "Stop:" << completion.stop.toStyledString();

data["stream"] = completion.stream;
data["n_predict"] = completion.max_tokens;
data["top_p"] = completion.top_p;
data["temperature"] = completion.temperature;
data["frequency_penalty"] = completion.frequency_penalty;
data["presence_penalty"] = completion.presence_penalty;
const Json::Value& messages = completion.messages;

if (!grammar_file_content.empty()) {
data["grammar"] = grammar_file_content;
};

if (!llama.multimodal) {
for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
role = user_prompt;
std::string content = message["content"].asString();
formatted_output += role + content;
} else if (input_role == "assistant") {
role = ai_prompt;
std::string content = message["content"].asString();
formatted_output += role + content;
} else if (input_role == "system") {
role = system_prompt;
std::string content = message["content"].asString();
formatted_output = role + content + formatted_output;

if (!llama.multimodal) {
for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
role = user_prompt;
std::string content = message["content"].asString();
formatted_output += role + content;
} else if (input_role == "assistant") {
role = ai_prompt;
std::string content = message["content"].asString();
formatted_output += role + content;
} else if (input_role == "system") {
role = system_prompt;
std::string content = message["content"].asString();
formatted_output = role + content + formatted_output;

} else {
role = input_role;
std::string content = message["content"].asString();
formatted_output += role + content;
}
} else {
role = input_role;
std::string content = message["content"].asString();
formatted_output += role + content;
}
formatted_output += ai_prompt;
} else {
data["image_data"] = json::array();
for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
formatted_output += role;
for (auto content_piece : message["content"]) {
role = user_prompt;

json content_piece_image_data;
content_piece_image_data["data"] = "";

auto content_piece_type = content_piece["type"].asString();
if (content_piece_type == "text") {
auto text = content_piece["text"].asString();
formatted_output += text;
} else if (content_piece_type == "image_url") {
auto image_url = content_piece["image_url"]["url"].asString();
std::string base64_image_data;
if (image_url.find("http") != std::string::npos) {
LOG_INFO << "Remote image detected but not supported yet";
} else if (image_url.find("data:image") != std::string::npos) {
LOG_INFO << "Base64 image detected";
base64_image_data = nitro_utils::extractBase64(image_url);
LOG_INFO << base64_image_data;
} else {
LOG_INFO << "Local image detected";
nitro_utils::processLocalImage(
image_url, [&](const std::string& base64Image) {
base64_image_data = base64Image;
});
LOG_INFO << base64_image_data;
}
content_piece_image_data["data"] = base64_image_data;

formatted_output += "[img-" + std::to_string(no_images) + "]";
content_piece_image_data["id"] = no_images;
data["image_data"].push_back(content_piece_image_data);
no_images++;
}
formatted_output += ai_prompt;
} else {
data["image_data"] = json::array();
for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
formatted_output += role;
for (auto content_piece : message["content"]) {
role = user_prompt;

json content_piece_image_data;
content_piece_image_data["data"] = "";

auto content_piece_type = content_piece["type"].asString();
if (content_piece_type == "text") {
auto text = content_piece["text"].asString();
formatted_output += text;
} else if (content_piece_type == "image_url") {
auto image_url = content_piece["image_url"]["url"].asString();
std::string base64_image_data;
if (image_url.find("http") != std::string::npos) {
LOG_INFO << "Remote image detected but not supported yet";
} else if (image_url.find("data:image") != std::string::npos) {
LOG_INFO << "Base64 image detected";
base64_image_data = nitro_utils::extractBase64(image_url);
LOG_INFO << base64_image_data;
} else {
LOG_INFO << "Local image detected";
nitro_utils::processLocalImage(
image_url, [&](const std::string& base64Image) {
base64_image_data = base64Image;
});
LOG_INFO << base64_image_data;
}
}
content_piece_image_data["data"] = base64_image_data;

} else if (input_role == "assistant") {
role = ai_prompt;
std::string content = message["content"].asString();
formatted_output += role + content;
} else if (input_role == "system") {
role = system_prompt;
std::string content = message["content"].asString();
formatted_output = role + content + formatted_output;

} else {
role = input_role;
std::string content = message["content"].asString();
formatted_output += role + content;
formatted_output += "[img-" + std::to_string(no_images) + "]";
content_piece_image_data["id"] = no_images;
data["image_data"].push_back(content_piece_image_data);
no_images++;
}
}

} else if (input_role == "assistant") {
role = ai_prompt;
std::string content = message["content"].asString();
formatted_output += role + content;
} else if (input_role == "system") {
role = system_prompt;
std::string content = message["content"].asString();
formatted_output = role + content + formatted_output;

} else {
role = input_role;
std::string content = message["content"].asString();
formatted_output += role + content;
}
formatted_output += ai_prompt;
LOG_INFO << formatted_output;
}
formatted_output += ai_prompt;
LOG_INFO << formatted_output;
}

data["prompt"] = formatted_output;
for (const auto& stop_word : (*jsonBody)["stop"]) {
stopWords.push_back(stop_word.asString());
}
// specify default stop words
// Ensure success case for chatML
stopWords.push_back("<|im_end|>");
stopWords.push_back(nitro_utils::rtrim(user_prompt));
data["stop"] = stopWords;
data["prompt"] = formatted_output;
for (const auto& stop_word : completion.stop) {
stopWords.push_back(stop_word.asString());
}
// specify default stop words
// Ensure success case for chatML
stopWords.push_back("<|im_end|>");
stopWords.push_back(nitro_utils::rtrim(user_prompt));
data["stop"] = stopWords;

bool is_streamed = data["stream"];
// Enable full message debugging
Expand Down
30 changes: 12 additions & 18 deletions controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,37 @@
#if defined(_WIN32)
#define NOMINMAX
#endif

#pragma once
#define LOG_TARGET stdout

#include <drogon/HttpController.h>

#include "stb_image.h"
#include "context/llama_server_context.h"

#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
#endif

#include <trantor/utils/ConcurrentTaskQueue.h>
#include "common/base.h"
#include "utils/json.hpp"

// auto generated files (update with ./deps.sh)

#include <cstddef>
#include <string>
#include <thread>

#include <cstddef>
#include <thread>
#include "common/base.h"
#include "context/llama_server_context.h"
#include "stb_image.h"
#include "utils/json.hpp"

#include "models/chat_completion_request.h"

#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
#endif


using json = nlohmann::json;

using namespace drogon;

namespace inferences {

class llamaCPP : public drogon::HttpController<llamaCPP>,
public BaseModel,
public BaseChatCompletion,
Expand All @@ -64,14 +59,13 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
// PATH_ADD("/llama/chat_completion", Post);
METHOD_LIST_END
void ChatCompletion(
const HttpRequestPtr& req,
inferences::ChatCompletionRequest &&completion,
std::function<void(const HttpResponsePtr&)>&& callback) override;
void Embedding(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) override;
void LoadModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) override;
void LoadModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) override;
void UnloadModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) override;
Expand Down Expand Up @@ -101,7 +95,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
trantor::ConcurrentTaskQueue* queue;

bool LoadModelImpl(std::shared_ptr<Json::Value> jsonBody);
void InferenceImpl(std::shared_ptr<Json::Value> jsonBody,
void InferenceImpl(inferences::ChatCompletionRequest&& completion,
std::function<void(const HttpResponsePtr&)>& callback);
void EmbeddingImpl(std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr&)>& callback);
Expand Down
Loading