Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions docs/static/openapi/cortex.json
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,46 @@
"tags": ["Models"]
}
},
"/v1/models/import": {
"post": {
"operationId": "ModelsController_importModel",
"summary": "Import model",
"description": "Imports a model from a specified path.",
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ImportModelRequest"
},
"example": {
"model": "model-id",
"modelPath": "/path/to/gguf",
"name": "model display name"
}
}
}
},
"responses": {
"200": {
"description": "Model is imported successfully!",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ImportModelResponse"
},
"example": {
"message": "Model is imported successfully!",
"modelHandle": "model-id",
"result": "OK"
}
}
}
}
},
"tags": ["Models"]
}
},
"/v1/threads": {
"post": {
"operationId": "ThreadsController_create",
Expand Down Expand Up @@ -1660,6 +1700,15 @@
"value": "my-custom-model-id"
}
]
},
"name": {
"type": "string",
"description": "The name which will be used to overwrite the model name.",
"examples": [
{
"value": "my-custom-model-name"
}
]
}
}
},
Expand Down Expand Up @@ -1803,6 +1852,43 @@
}
}
},
"ImportModelRequest": {
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "The unique identifier of the model."
},
"modelPath": {
"type": "string",
"description": "The file path to the model."
},
"name": {
"type": "string",
"description": "The display name of the model."
}
},
"required": ["model", "modelPath"]
},
"ImportModelResponse": {
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Success message."
},
"modelHandle": {
"type": "string",
"description": "The unique identifier of the imported model."
},
"result": {
"type": "string",
"description": "Result status.",
"example": "OK"
}
},
"required": ["message", "modelHandle", "result"]
},
"CommonResponseDto": {
"type": "object",
"properties": {
Expand Down
13 changes: 11 additions & 2 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ void Models::PullModel(const HttpRequestPtr& req,
desired_model_id = id;
}

std::optional<std::string> desired_model_name = std::nullopt;
auto name_value = (*(req->getJsonObject())).get("name", "").asString();

if (!name_value.empty()) {
desired_model_name = name_value;
}

auto handle_model_input =
[&, model_handle]() -> cpp::result<DownloadTask, std::string> {
CTL_INF("Handle model input, model handle: " + model_handle);
if (string_utils::StartsWith(model_handle, "https")) {
return model_service_->HandleDownloadUrlAsync(model_handle,
desired_model_id);
return model_service_->HandleDownloadUrlAsync(
model_handle, desired_model_id, desired_model_name);
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
return model_service_->DownloadModelFromCortexsoAsync(
Expand Down Expand Up @@ -312,6 +319,7 @@ void Models::ImportModel(
}
auto modelHandle = (*(req->getJsonObject())).get("model", "").asString();
auto modelPath = (*(req->getJsonObject())).get("modelPath", "").asString();
auto modelName = (*(req->getJsonObject())).get("name", "").asString();
config::GGUFHandler gguf_handler;
config::YamlHandler yaml_handler;
cortex::db::Models modellist_utils_obj;
Expand All @@ -333,6 +341,7 @@ void Models::ImportModel(
config::ModelConfig model_config = gguf_handler.GetModelConfig();
model_config.files.push_back(modelPath);
model_config.model = modelHandle;
model_config.name = modelName.empty() ? model_config.name : modelName;
yaml_handler.UpdateModelConfig(model_config);

if (modellist_utils_obj.AddModelEntry(model_entry).value()) {
Expand Down
24 changes: 22 additions & 2 deletions engine/e2e-test/test_api_model_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,25 @@ def setup_and_teardown(self):
def test_model_import_should_be_success(self):
body_json = {'model': 'tinyllama:gguf',
'modelPath': '/path/to/local/gguf'}
response = requests.post("http://localhost:3928/models/import", json = body_json)
assert response.status_code == 200
response = requests.post("http://localhost:3928/models/import", json=body_json)
assert response.status_code == 200

@pytest.mark.skipif(True, reason="Expensive test. Only test when you have local gguf file.")
def test_model_import_with_name_should_be_success(self):
body_json = {'model': 'tinyllama:gguf',
'modelPath': '/path/to/local/gguf',
'name': 'test_model'}
response = requests.post("http://localhost:3928/models/import", json=body_json)
assert response.status_code == 200

def test_model_import_with_invalid_path_should_fail(self):
body_json = {'model': 'tinyllama:gguf',
'modelPath': '/invalid/path/to/gguf'}
response = requests.post("http://localhost:3928/models/import", json=body_json)
assert response.status_code == 400

def test_model_import_with_missing_model_should_fail(self):
body_json = {'modelPath': '/path/to/local/gguf'}
response = requests.post("http://localhost:3928/models/import", json=body_json)
print(response)
assert response.status_code == 409
29 changes: 24 additions & 5 deletions engine/e2e-test/test_api_model_pull_direct_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setup_and_teardown(self):
[
"models",
"delete",
"TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf",
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
],
)
yield
Expand All @@ -32,24 +32,43 @@ def setup_and_teardown(self):
[
"models",
"delete",
"TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf",
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
],
)
stop_server()

@pytest.mark.asyncio
async def test_model_pull_with_direct_url_should_be_success(self):
myobj = {
"model": "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
"model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf"
}
response = requests.post("http://localhost:3928/models/pull", json=myobj)
assert response.status_code == 200
await wait_for_websocket_download_success_event(timeout=None)
get_model_response = requests.get(
"http://127.0.0.1:3928/models/TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
"http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
)
assert get_model_response.status_code == 200
assert (
get_model_response.json()["model"]
== "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
== "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
)

@pytest.mark.asyncio
async def test_model_pull_with_direct_url_should_have_desired_name(self):
myobj = {
"model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf",
"name": "smol_llama_100m"
}
response = requests.post("http://localhost:3928/models/pull", json=myobj)
assert response.status_code == 200
await wait_for_websocket_download_success_event(timeout=None)
get_model_response = requests.get(
"http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
)
assert get_model_response.status_code == 200
print(get_model_response.json()["name"])
assert (
get_model_response.json()["name"]
== "smol_llama_100m"
)
22 changes: 14 additions & 8 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

namespace {
void ParseGguf(const DownloadItem& ggufDownloadItem,
std::optional<std::string> author) {
std::optional<std::string> author,
std::optional<std::string> name) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
config::GGUFHandler gguf_handler;
Expand All @@ -32,6 +33,8 @@ void ParseGguf(const DownloadItem& ggufDownloadItem,
fmu::ToRelativeCortexDataPath(fs::path(ggufDownloadItem.localPath));
model_config.files = {file_rel_path.string()};
model_config.model = ggufDownloadItem.id;
model_config.name =
name.has_value() ? name.value() : gguf_handler.GetModelConfig().name;
yaml_handler.UpdateModelConfig(model_config);

auto yaml_path{ggufDownloadItem.localPath};
Expand Down Expand Up @@ -223,7 +226,8 @@ std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
}

cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id) {
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name) {
auto url_obj = url_parser::FromUrlString(url);

if (url_obj.host == kHuggingFaceHost) {
Expand Down Expand Up @@ -279,9 +283,9 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
.localPath = local_path,
}}}};

auto on_finished = [author](const DownloadTask& finishedTask) {
auto on_finished = [author, temp_name](const DownloadTask& finishedTask) {
auto gguf_download_item = finishedTask.items[0];
ParseGguf(gguf_download_item, author);
ParseGguf(gguf_download_item, author, temp_name);
};

downloadTask.id = unique_model_id;
Expand Down Expand Up @@ -346,7 +350,7 @@ cpp::result<std::string, std::string> ModelService::HandleUrl(

auto on_finished = [author](const DownloadTask& finishedTask) {
auto gguf_download_item = finishedTask.items[0];
ParseGguf(gguf_download_item, author);
ParseGguf(gguf_download_item, author, std::nullopt);
};

auto result = download_service_->AddDownloadTask(downloadTask, on_finished);
Expand Down Expand Up @@ -770,7 +774,7 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
auto author{url_obj.pathParams[0]};
auto model_id{url_obj.pathParams[1]};
auto file_name{url_obj.pathParams.back()};
if (author == "cortexso") {
if (author == "cortexso") {
return ModelPullInfo{.id = model_id + ":" + url_obj.pathParams[3],
.downloaded_models = {},
.available_models = {},
Expand All @@ -787,8 +791,10 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
if (parsed.size() != 2) {
return cpp::fail("Invalid model handle: " + input);
}
return ModelPullInfo{
.id = input, .downloaded_models = {}, .available_models = {}, .download_url = input};
return ModelPullInfo{.id = input,
.downloaded_models = {},
.available_models = {},
.download_url = input};
}

if (input.find("/") != std::string::npos) {
Expand Down
5 changes: 3 additions & 2 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ModelService {
std::shared_ptr<DownloadService> download_service,
std::shared_ptr<services::InferenceService> inference_service)
: download_service_{download_service},
inference_svc_(inference_service) {};
inference_svc_(inference_service){};

/**
* Return model id if download successfully
Expand Down Expand Up @@ -81,7 +81,8 @@ class ModelService {
cpp::result<std::string, std::string> HandleUrl(const std::string& url);

cpp::result<DownloadTask, std::string> HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id);
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name);

private:
/**
Expand Down
Loading