Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions context/llama_server_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ enum stop_type {
STOP_PARTIAL,
};

enum class ModelType { LLM = 0, EMBEDDING };

static bool ends_with(const std::string& str, const std::string& suffix) {
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
Expand Down Expand Up @@ -502,6 +504,7 @@ struct llama_server_context {
std::condition_variable condition_tasks;
std::mutex mutex_results;
std::condition_variable condition_results;
ModelType model_type = ModelType::LLM;

~llama_server_context() {
if (ctx) {
Expand Down
18 changes: 16 additions & 2 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ void llamaCPP::ChatCompletion(
void llamaCPP::InferenceImpl(
inferences::ChatCompletionRequest&& completion,
std::function<void(const HttpResponsePtr&)>&& callback) {
if (llama.model_type == ModelType::EMBEDDING) {
LOG_WARN << "Not support completion for embedding model";
Json::Value jsonResp;
jsonResp["message"] =
"Not support completion for embedding model";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
return;
}
std::string formatted_output = pre_prompt;
int request_id = ++no_of_requests;
LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request";
Expand Down Expand Up @@ -653,6 +663,11 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
params.n_ctx = jsonBody->get("ctx_len", 2048).asInt();
params.embedding = jsonBody->get("embedding", true).asBool();
model_type = jsonBody->get("model_type", "llm").asString();
if (model_type == "llm") {
llama.model_type = ModelType::LLM;
} else {
llama.model_type = ModelType::EMBEDDING;
}
// Check if n_parallel exists in jsonBody, if not, set to drogon_thread
params.n_batch = jsonBody->get("n_batch", 512).asInt();
params.n_parallel = jsonBody->get("n_parallel", 1).asInt();
Expand Down Expand Up @@ -712,8 +727,7 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {

// For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model.
// So we use this variable to differentiate with other models
// TODO: in case embedded model only, we should reject completion request from user?
if (model_type == "llm") {
if (llama.model_type == ModelType::LLM) {
WarmupModel();
}
return true;
Expand Down
10 changes: 6 additions & 4 deletions controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
// PATH_ADD("/llama/chat_completion", Post);
METHOD_LIST_END
void ChatCompletion(
inferences::ChatCompletionRequest &&completion,
inferences::ChatCompletionRequest&& completion,
std::function<void(const HttpResponsePtr&)>&& callback) override;
void Embedding(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) override;
void LoadModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) override;
void LoadModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) override;
void UnloadModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) override;
Expand Down Expand Up @@ -100,7 +101,8 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
std::function<void(const HttpResponsePtr&)>&& callback);
void EmbeddingImpl(std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr&)>&& callback);
bool CheckModelLoaded(const std::function<void(const HttpResponsePtr&)>& callback);
bool CheckModelLoaded(
const std::function<void(const HttpResponsePtr&)>& callback);
void WarmupModel();
void BackgroundTask();
void StopBackgroundTask();
Expand Down