Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/static/openapi/jan.json
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,15 @@
"value": "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf"
}
]
},
"id": {
"type": "string",
"description": "The id which will be used to register the model.",
"examples": [
{
"value": "my-custom-model-id"
}
]
}
}
},
Expand Down
7 changes: 4 additions & 3 deletions engine/cli/commands/engine_get_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "engine_get_cmd.h"
#include <json/reader.h>
#include <json/value.h>
#include <iostream>

#include "httplib.h"
#include "json/json.h"
#include "server_start_cmd.h"
#include "utils/logging_utils.h"

Expand All @@ -29,7 +30,6 @@ void EngineGetCmd::Exec(const std::string& host, int port,
auto res = cli.Get("/v1/engines/" + engine_name);
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
// CLI_LOG(res->body);
Json::Value v;
Json::Reader reader;
reader.parse(res->body, v);
Expand All @@ -39,7 +39,8 @@ void EngineGetCmd::Exec(const std::string& host, int port,
v["status"].asString()});

} else {
CLI_LOG_ERROR("Failed to get engine list with status code: " << res->status);
CLI_LOG_ERROR(
"Failed to get engine list with status code: " << res->status);
return;
}
} else {
Expand Down
5 changes: 3 additions & 2 deletions engine/cli/commands/model_list_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ using namespace tabulate;
using Row_t =
std::vector<variant<std::string, const char*, string_view, Table>>;

void ModelListCmd::Exec(const std::string& host, int port, std::string filter,
bool display_engine, bool display_version) {
void ModelListCmd::Exec(const std::string& host, int port,
const std::string& filter, bool display_engine,
bool display_version) {
// Start server if server is not started yet
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Starting server ...");
Expand Down
2 changes: 1 addition & 1 deletion engine/cli/commands/model_list_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace commands {

class ModelListCmd {
public:
void Exec(const std::string& host, int port, std::string filter,
void Exec(const std::string& host, int port, const std::string& filter,
bool display_engine = false, bool display_version = false);
};
} // namespace commands
6 changes: 6 additions & 0 deletions engine/common/download_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
#include <string>

enum class DownloadType { Model, Engine, Miscellaneous, CudaToolkit, Cortex };

using namespace nlohmann;

struct DownloadItem {

std::string id;

std::string downloadUrl;
Expand Down Expand Up @@ -54,8 +56,12 @@ inline std::string DownloadTypeToString(DownloadType type) {
}

struct DownloadTask {
enum class Status { Pending, InProgress, Completed, Cancelled, Error };

std::string id;

Status status;

DownloadType type;

std::vector<DownloadItem> items;
Expand Down
77 changes: 77 additions & 0 deletions engine/common/download_task_queue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include <condition_variable>
#include <deque>
#include <mutex>
#include <optional>
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include "common/download_task.h"

class DownloadTaskQueue {
private:
std::deque<DownloadTask> taskQueue;
std::unordered_map<std::string, typename std::deque<DownloadTask>::iterator>
taskMap;
mutable std::shared_mutex mutex;
std::condition_variable_any cv;

public:
void push(DownloadTask task) {
std::unique_lock lock(mutex);
taskQueue.push_back(std::move(task));
taskMap[taskQueue.back().id] = std::prev(taskQueue.end());
cv.notify_one();
}

std::optional<DownloadTask> pop() {
std::unique_lock lock(mutex);
if (taskQueue.empty()) {
return std::nullopt;
}
DownloadTask task = std::move(taskQueue.front());
taskQueue.pop_front();
taskMap.erase(task.id);
return task;
}

bool cancelTask(const std::string& taskId) {
std::unique_lock lock(mutex);
auto it = taskMap.find(taskId);
if (it != taskMap.end()) {
it->second->status = DownloadTask::Status::Cancelled;
taskQueue.erase(it->second);
taskMap.erase(it);
return true;
}
return false;
}

bool updateTaskStatus(const std::string& taskId,
DownloadTask::Status newStatus) {
std::unique_lock lock(mutex);
auto it = taskMap.find(taskId);
if (it != taskMap.end()) {
it->second->status = newStatus;
if (newStatus == DownloadTask::Status::Cancelled ||
newStatus == DownloadTask::Status::Error) {
taskQueue.erase(it->second);
taskMap.erase(it);
}
return true;
}
return false;
}

std::optional<DownloadTask> getNextPendingTask() {
std::shared_lock lock(mutex);
auto it = std::find_if(
taskQueue.begin(), taskQueue.end(), [](const DownloadTask& task) {
return task.status == DownloadTask::Status::Pending;
});

if (it != taskQueue.end()) {
return *it;
}
return std::nullopt;
}
};
108 changes: 38 additions & 70 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "database/models.h"
#include <drogon/HttpTypes.h>
#include <optional>
#include "config/gguf_parser.h"
#include "config/yaml_config.h"
#include "models.h"
Expand All @@ -26,15 +27,22 @@ void Models::PullModel(const HttpRequestPtr& req,
return;
}

std::optional<std::string> desired_model_id = std::nullopt;
auto id = (*(req->getJsonObject())).get("id", "").asString();
if (!id.empty()) {
desired_model_id = id;
}

auto handle_model_input =
[&, model_handle]() -> cpp::result<DownloadTask, std::string> {
CTL_INF("Handle model input, model handle: " + model_handle);
if (string_utils::StartsWith(model_handle, "https")) {
return model_service_->HandleDownloadUrlAsync(model_handle);
return model_service_->HandleDownloadUrlAsync(model_handle,
desired_model_id);
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
return model_service_->DownloadModelFromCortexsoAsync(
model_and_branch[0], model_and_branch[1]);
model_and_branch[0], model_and_branch[1], desired_model_id);
}

return cpp::fail("Invalid model handle or not supported!");
Expand Down Expand Up @@ -107,7 +115,6 @@ void Models::ListModel(
auto list_entry = modellist_handler.LoadModelList();
if (list_entry) {
for (const auto& model_entry : list_entry.value()) {
// auto model_entry = modellist_handler.GetModelInfo(model_handle);
try {
yaml_handler.ModelConfigFromFile(
fmu::ToAbsoluteCortexDataPath(
Expand All @@ -116,7 +123,6 @@ void Models::ListModel(
auto model_config = yaml_handler.GetModelConfig();
Json::Value obj = model_config.ToJson();
obj["id"] = model_entry.model;
obj["model_alias"] = model_entry.model_alias;
obj["model"] = model_entry.model;
data.append(std::move(obj));
yaml_handler.Reset();
Expand Down Expand Up @@ -156,7 +162,6 @@ void Models::GetModel(const HttpRequestPtr& req,
config::YamlHandler yaml_handler;
auto model_entry = modellist_handler.GetModelInfo(model_id);
if (model_entry.has_error()) {
// CLI_LOG("Error: " + model_entry.error());
ret["id"] = model_id;
ret["object"] = "model";
ret["result"] = "Fail to get model information";
Expand Down Expand Up @@ -333,71 +338,6 @@ void Models::ImportModel(
}
}

void Models::SetModelAlias(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
if (!http_util::HasFieldInReq(req, callback, "model") ||
!http_util::HasFieldInReq(req, callback, "modelAlias")) {
return;
}
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto model_alias = (*(req->getJsonObject())).get("modelAlias", "").asString();
LOG_DEBUG << "GetModel, Model handle: " << model_handle
<< ", Model alias: " << model_alias;

cortex::db::Models modellist_handler;
try {
auto result = modellist_handler.UpdateModelAlias(model_handle, model_alias);
if (result.has_error()) {
std::string message = result.error();
LOG_ERROR << message;
Json::Value ret;
ret["result"] = "Set alias failed!";
ret["modelHandle"] = model_handle;
ret["message"] = message;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
} else {
if (result.value()) {
std::string message = "Successfully set model alias '" + model_alias +
"' for modeID '" + model_handle + "'.";
LOG_INFO << message;
Json::Value ret;
ret["result"] = "OK";
ret["modelHandle"] = model_handle;
ret["message"] = message;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
} else {
std::string message = "Unable to set model alias for modelID '" +
model_handle + "': model alias '" + model_alias +
"' is not unique!";
LOG_ERROR << message;
Json::Value ret;
ret["result"] = "Set alias failed!";
ret["modelHandle"] = model_handle;
ret["message"] = message;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
}
}
} catch (const std::exception& e) {
std::string message = "Error when setting model alias ('" + model_alias +
"') for modelID '" + model_handle + "':" + e.what();
LOG_ERROR << message;
Json::Value ret;
ret["result"] = "Set alias failed!";
ret["modelHandle"] = model_handle;
ret["message"] = message;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
}
}

void Models::StartModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
Expand All @@ -407,6 +347,34 @@ void Models::StartModel(
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto custom_prompt_template =
(*(req->getJsonObject())).get("prompt_template", "").asString();
auto model_entry = model_service_->GetDownloadedModel(model_handle);
if (!model_entry.has_value()) {
Json::Value ret;
ret["message"] = "Cannot find model: " + model_handle;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
return;
}
auto engine_name = model_entry.value().engine;
auto engine_entry = engine_service_->GetEngineInfo(engine_name);
if (engine_entry.has_error()) {
Json::Value ret;
ret["message"] = "Cannot find engine: " + engine_name;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
return;
}
if (engine_entry->status != "Ready") {
Json::Value ret;
ret["message"] = "Engine is not ready! Please install first!";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
return;
}

auto result = model_service_->StartModel(
config.apiServerHost, std::stoi(config.apiServerPort), model_handle,
custom_prompt_template);
Expand Down
9 changes: 5 additions & 4 deletions engine/controllers/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <drogon/HttpController.h>
#include <trantor/utils/Logger.h>
#include "services/engine_service.h"
#include "services/model_service.h"

using namespace drogon;
Expand All @@ -16,7 +17,6 @@ class Models : public drogon::HttpController<Models, false> {
METHOD_ADD(Models::UpdateModel, "/{1}", Patch);
METHOD_ADD(Models::ImportModel, "/import", Post);
METHOD_ADD(Models::DeleteModel, "/{1}", Delete);
METHOD_ADD(Models::SetModelAlias, "/alias", Post);
METHOD_ADD(Models::StartModel, "/start", Post);
METHOD_ADD(Models::StopModel, "/stop", Post);
METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get);
Expand All @@ -28,14 +28,14 @@ class Models : public drogon::HttpController<Models, false> {
ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Patch);
ADD_METHOD_TO(Models::ImportModel, "/v1/models/import", Post);
ADD_METHOD_TO(Models::DeleteModel, "/v1/models/{1}", Delete);
ADD_METHOD_TO(Models::SetModelAlias, "/v1/models/alias", Post);
ADD_METHOD_TO(Models::StartModel, "/v1/models/start", Post);
ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Post);
ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get);
METHOD_LIST_END

explicit Models(std::shared_ptr<ModelService> model_service)
: model_service_{model_service} {}
explicit Models(std::shared_ptr<ModelService> model_service,
std::shared_ptr<EngineService> engine_service)
: model_service_{model_service}, engine_service_{engine_service} {}

void PullModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback);
Expand Down Expand Up @@ -71,4 +71,5 @@ class Models : public drogon::HttpController<Models, false> {

private:
std::shared_ptr<ModelService> model_service_;
std::shared_ptr<EngineService> engine_service_;
};
2 changes: 1 addition & 1 deletion engine/database/models.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include <SQLiteCpp/Database.h>
#include <trantor/utils/Logger.h>
#include <string>
#include <vector>
#include "SQLiteCpp/SQLiteCpp.h"
#include "utils/result.hpp"

namespace cortex::db {
Expand Down
1 change: 0 additions & 1 deletion engine/e2e-test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from test_api_model_start import TestApiModelStart
from test_api_model_stop import TestApiModelStop
from test_api_model_get import TestApiModelGet
from test_api_model_alias import TestApiModelAlias
from test_api_model_list import TestApiModelList
from test_api_model_update import TestApiModelUpdate
from test_api_model_delete import TestApiModelDelete
Expand Down
Loading
Loading