Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions engine/controllers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"
#include "utils/function_calling/common.h"
#include "utils/http_util.h"

using namespace inferences;

Expand All @@ -27,6 +28,15 @@ void server::ChatCompletion(
LOG_DEBUG << "Start chat completion";
auto json_body = req->getJsonObject();
bool is_stream = (*json_body).get("stream", false).asBool();
auto model_id = (*json_body).get("model", "invalid_model").asString();
auto engine_type = [this, &json_body]() -> std::string {
if (!inference_svc_->HasFieldInReq(json_body, "engine")) {
return kLlamaRepo;
} else {
return (*(json_body)).get("engine", kLlamaRepo).asString();
}
}();

LOG_DEBUG << "request body: " << json_body->toStyledString();
auto q = std::make_shared<services::SyncQueue>();
auto ir = inference_svc_->HandleChatCompletion(q, json_body);
Expand All @@ -40,7 +50,7 @@ void server::ChatCompletion(
}
LOG_DEBUG << "Wait to chat completion responses";
if (is_stream) {
ProcessStreamRes(std::move(callback), q);
ProcessStreamRes(std::move(callback), q, engine_type, model_id);
} else {
ProcessNonStreamRes(std::move(callback), *q);
}
Expand Down Expand Up @@ -121,12 +131,16 @@ void server::LoadModel(const HttpRequestPtr& req,
}

void server::ProcessStreamRes(std::function<void(const HttpResponsePtr&)> cb,
std::shared_ptr<services::SyncQueue> q) {
std::shared_ptr<services::SyncQueue> q,
const std::string& engine_type,
const std::string& model_id) {
auto err_or_done = std::make_shared<std::atomic_bool>(false);
auto chunked_content_provider =
[q, err_or_done](char* buf, std::size_t buf_size) -> std::size_t {
auto chunked_content_provider = [this, q, err_or_done, engine_type, model_id](
char* buf,
std::size_t buf_size) -> std::size_t {
if (buf == nullptr) {
LOG_TRACE << "Buf is null";
inference_svc_->StopInferencing(engine_type, model_id);
return 0;
}

Expand Down
4 changes: 3 additions & 1 deletion engine/controllers/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ class server : public drogon::HttpController<server, false>,

private:
void ProcessStreamRes(std::function<void(const HttpResponsePtr&)> cb,
std::shared_ptr<services::SyncQueue> q);
std::shared_ptr<services::SyncQueue> q,
const std::string& engine_type,
const std::string& model_id);
void ProcessNonStreamRes(std::function<void(const HttpResponsePtr&)> cb,
services::SyncQueue& q);

Expand Down
3 changes: 2 additions & 1 deletion engine/cortex-common/EngineI.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,6 @@ class EngineI {
const std::string& log_path) = 0;
virtual void SetLogLevel(trantor::Logger::LogLevel logLevel) = 0;

virtual Json::Value GetRemoteModels() = 0;
// Stop inflight chat completion in stream mode
virtual void StopInferencing(const std::string& model_id) = 0;
};
67 changes: 34 additions & 33 deletions engine/e2e-test/test_api_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,38 +40,39 @@ async def test_models_on_cortexso_hub(self, model_url):
assert response.status_code == 200
models = [i["id"] for i in response.json()["data"]]
assert model_url in models, f"Model not found in list: {model_url}"

# TODO(sang) bypass for now. Re-enable when we publish new stable version for llama-cpp engine
# print("Start the model")
# # Start the model
# response = requests.post(
# "http://localhost:3928/v1/models/start", json=json_body
# )
# print(response.json())
# assert response.status_code == 200, f"status_code: {response.status_code}"

print("Start the model")
# Start the model
response = requests.post(
"http://localhost:3928/v1/models/start", json=json_body
)
print(response.json())
assert response.status_code == 200, f"status_code: {response.status_code}"

print("Send an inference request")
# Send an inference request
inference_json_body = {
"frequency_penalty": 0.2,
"max_tokens": 4096,
"messages": [{"content": "", "role": "user"}],
"model": model_url,
"presence_penalty": 0.6,
"stop": ["End"],
"stream": False,
"temperature": 0.8,
"top_p": 0.95,
}
response = requests.post(
"http://localhost:3928/v1/chat/completions",
json=inference_json_body,
headers={"Content-Type": "application/json"},
)
assert (
response.status_code == 200
), f"status_code: {response.status_code} response: {response.json()}"
# print("Send an inference request")
# # Send an inference request
# inference_json_body = {
# "frequency_penalty": 0.2,
# "max_tokens": 4096,
# "messages": [{"content": "", "role": "user"}],
# "model": model_url,
# "presence_penalty": 0.6,
# "stop": ["End"],
# "stream": False,
# "temperature": 0.8,
# "top_p": 0.95,
# }
# response = requests.post(
# "http://localhost:3928/v1/chat/completions",
# json=inference_json_body,
# headers={"Content-Type": "application/json"},
# )
# assert (
# response.status_code == 200
# ), f"status_code: {response.status_code} response: {response.json()}"

print("Stop the model")
# Stop the model
response = requests.post("http://localhost:3928/v1/models/stop", json=json_body)
assert response.status_code == 200, f"status_code: {response.status_code}"
# print("Stop the model")
# # Stop the model
# response = requests.post("http://localhost:3928/v1/models/stop", json=json_body)
# assert response.status_code == 200, f"status_code: {response.status_code}"
119 changes: 63 additions & 56 deletions engine/services/inference_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,18 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
return cpp::fail(std::make_pair(stt, res));
}

auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
}
q->push(std::make_pair(status, res));
};
if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->HandleChatCompletion(
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
}
q->push(std::make_pair(status, res));
});
->HandleChatCompletion(json_body, std::move(cb));
} else {
std::get<RemoteEngineI*>(engine_result.value())
->HandleChatCompletion(
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
}
q->push(std::make_pair(status, res));
});
->HandleChatCompletion(json_body, std::move(cb));
}

return {};
Expand All @@ -66,16 +60,15 @@ cpp::result<void, InferResult> InferenceService::HandleEmbedding(
return cpp::fail(std::make_pair(stt, res));
}

auto cb = [q](Json::Value status, Json::Value res) {
q->push(std::make_pair(status, res));
};
if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) {
q->push(std::make_pair(status, res));
});
->HandleEmbedding(json_body, std::move(cb));
} else {
std::get<RemoteEngineI*>(engine_result.value())
->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) {
q->push(std::make_pair(status, res));
});
->HandleEmbedding(json_body, std::move(cb));
}
return {};
}
Expand Down Expand Up @@ -104,18 +97,16 @@ InferResult InferenceService::LoadModel(
// might need mutex here
auto engine_result = engine_service_->GetLoadedEngine(engine_type);

auto cb = [&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
};
if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
->LoadModel(json_body, std::move(cb));
} else {
std::get<RemoteEngineI*>(engine_result.value())
->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
->LoadModel(json_body, std::move(cb));
}
return std::make_pair(stt, r);
}
Expand All @@ -139,20 +130,16 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name,
json_body["model"] = model_id;

LOG_TRACE << "Start unload model";
auto cb = [&r, &stt](Json::Value status, Json::Value res) {
stt = status;
r = res;
};
if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->UnloadModel(std::make_shared<Json::Value>(json_body),
[&r, &stt](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
->UnloadModel(std::make_shared<Json::Value>(json_body), std::move(cb));
} else {
std::get<RemoteEngineI*>(engine_result.value())
->UnloadModel(std::make_shared<Json::Value>(json_body),
[&r, &stt](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
->UnloadModel(std::make_shared<Json::Value>(json_body), std::move(cb));
}

return std::make_pair(stt, r);
Expand Down Expand Up @@ -181,20 +168,16 @@ InferResult InferenceService::GetModelStatus(

LOG_TRACE << "Start to get model status";

auto cb = [&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
};
if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->GetModelStatus(json_body,
[&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
->GetModelStatus(json_body, std::move(cb));
} else {
std::get<RemoteEngineI*>(engine_result.value())
->GetModelStatus(json_body,
[&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
->GetModelStatus(json_body, std::move(cb));
}

return std::make_pair(stt, r);
Expand All @@ -214,15 +197,20 @@ InferResult InferenceService::GetModels(

LOG_TRACE << "Start to get models";
Json::Value resp_data(Json::arrayValue);
auto cb = [&resp_data](Json::Value status, Json::Value res) {
for (auto r : res["data"]) {
resp_data.append(r);
}
};
for (const auto& loaded_engine : loaded_engines) {
auto e = std::get<EngineI*>(loaded_engine);
if (e->IsSupported("GetModels")) {
e->GetModels(json_body,
[&resp_data](Json::Value status, Json::Value res) {
for (auto r : res["data"]) {
resp_data.append(r);
}
});
if (std::holds_alternative<EngineI*>(loaded_engine)) {
auto e = std::get<EngineI*>(loaded_engine);
if (e->IsSupported("GetModels")) {
e->GetModels(json_body, std::move(cb));
}
} else {
std::get<RemoteEngineI*>(loaded_engine)
->GetModels(json_body, std::move(cb));
}
}

Expand Down Expand Up @@ -283,6 +271,25 @@ InferResult InferenceService::FineTuning(
return std::make_pair(stt, r);
}

bool InferenceService::StopInferencing(const std::string& engine_name,
const std::string& model_id) {
CTL_DBG("Stop inferencing");
auto engine_result = engine_service_->GetLoadedEngine(engine_name);
if (engine_result.has_error()) {
LOG_WARN << "Engine is not loaded yet";
return false;
}

if (std::holds_alternative<EngineI*>(engine_result.value())) {
auto engine = std::get<EngineI*>(engine_result.value());
if (engine->IsSupported("StopInferencing")) {
engine->StopInferencing(model_id);
CTL_INF("Stopped inferencing");
}
}
return true;
}

bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
const std::string& field) {
if (!json_body || (*json_body)[field].isNull()) {
Expand Down
5 changes: 4 additions & 1 deletion engine/services/inference_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ class InferenceService {

InferResult FineTuning(std::shared_ptr<Json::Value> json_body);

private:
bool StopInferencing(const std::string& engine_name,
const std::string& model_id);

bool HasFieldInReq(std::shared_ptr<Json::Value> json_body,
const std::string& field);

private:
std::shared_ptr<EngineService> engine_service_;
};
} // namespace services