Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit f5fbad6

Browse files
authored
chore: add model name as a parameter support during import via API (#1600)
1 parent 166cdb5 commit f5fbad6

File tree

6 files changed

+160
-19
lines changed

6 files changed

+160
-19
lines changed

docs/static/openapi/cortex.json

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,46 @@
554554
"tags": ["Models"]
555555
}
556556
},
557+
"/v1/models/import": {
558+
"post": {
559+
"operationId": "ModelsController_importModel",
560+
"summary": "Import model",
561+
"description": "Imports a model from a specified path.",
562+
"requestBody": {
563+
"required": true,
564+
"content": {
565+
"application/json": {
566+
"schema": {
567+
"$ref": "#/components/schemas/ImportModelRequest"
568+
},
569+
"example": {
570+
"model": "model-id",
571+
"modelPath": "/path/to/gguf",
572+
"name": "model display name"
573+
}
574+
}
575+
}
576+
},
577+
"responses": {
578+
"200": {
579+
"description": "Model is imported successfully!",
580+
"content": {
581+
"application/json": {
582+
"schema": {
583+
"$ref": "#/components/schemas/ImportModelResponse"
584+
},
585+
"example": {
586+
"message": "Model is imported successfully!",
587+
"modelHandle": "model-id",
588+
"result": "OK"
589+
}
590+
}
591+
}
592+
}
593+
},
594+
"tags": ["Models"]
595+
}
596+
},
557597
"/v1/threads": {
558598
"post": {
559599
"operationId": "ThreadsController_create",
@@ -1660,6 +1700,15 @@
16601700
"value": "my-custom-model-id"
16611701
}
16621702
]
1703+
},
1704+
"name": {
1705+
"type": "string",
1706+
"description": "The name which will be used to overwrite the model name.",
1707+
"examples": [
1708+
{
1709+
"value": "my-custom-model-name"
1710+
}
1711+
]
16631712
}
16641713
}
16651714
},
@@ -1803,6 +1852,43 @@
18031852
}
18041853
}
18051854
},
1855+
"ImportModelRequest": {
1856+
"type": "object",
1857+
"properties": {
1858+
"model": {
1859+
"type": "string",
1860+
"description": "The unique identifier of the model."
1861+
},
1862+
"modelPath": {
1863+
"type": "string",
1864+
"description": "The file path to the model."
1865+
},
1866+
"name": {
1867+
"type": "string",
1868+
"description": "The display name of the model."
1869+
}
1870+
},
1871+
"required": ["model", "modelPath"]
1872+
},
1873+
"ImportModelResponse": {
1874+
"type": "object",
1875+
"properties": {
1876+
"message": {
1877+
"type": "string",
1878+
"description": "Success message."
1879+
},
1880+
"modelHandle": {
1881+
"type": "string",
1882+
"description": "The unique identifier of the imported model."
1883+
},
1884+
"result": {
1885+
"type": "string",
1886+
"description": "Result status.",
1887+
"example": "OK"
1888+
}
1889+
},
1890+
"required": ["message", "modelHandle", "result"]
1891+
},
18061892
"CommonResponseDto": {
18071893
"type": "object",
18081894
"properties": {

engine/controllers/models.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,19 @@ void Models::PullModel(const HttpRequestPtr& req,
3333
desired_model_id = id;
3434
}
3535

36+
std::optional<std::string> desired_model_name = std::nullopt;
37+
auto name_value = (*(req->getJsonObject())).get("name", "").asString();
38+
39+
if (!name_value.empty()) {
40+
desired_model_name = name_value;
41+
}
42+
3643
auto handle_model_input =
3744
[&, model_handle]() -> cpp::result<DownloadTask, std::string> {
3845
CTL_INF("Handle model input, model handle: " + model_handle);
3946
if (string_utils::StartsWith(model_handle, "https")) {
40-
return model_service_->HandleDownloadUrlAsync(model_handle,
41-
desired_model_id);
47+
return model_service_->HandleDownloadUrlAsync(
48+
model_handle, desired_model_id, desired_model_name);
4249
} else if (model_handle.find(":") != std::string::npos) {
4350
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
4451
return model_service_->DownloadModelFromCortexsoAsync(
@@ -312,6 +319,7 @@ void Models::ImportModel(
312319
}
313320
auto modelHandle = (*(req->getJsonObject())).get("model", "").asString();
314321
auto modelPath = (*(req->getJsonObject())).get("modelPath", "").asString();
322+
auto modelName = (*(req->getJsonObject())).get("name", "").asString();
315323
config::GGUFHandler gguf_handler;
316324
config::YamlHandler yaml_handler;
317325
cortex::db::Models modellist_utils_obj;
@@ -333,6 +341,7 @@ void Models::ImportModel(
333341
config::ModelConfig model_config = gguf_handler.GetModelConfig();
334342
model_config.files.push_back(modelPath);
335343
model_config.model = modelHandle;
344+
model_config.name = modelName.empty() ? model_config.name : modelName;
336345
yaml_handler.UpdateModelConfig(model_config);
337346

338347
if (modellist_utils_obj.AddModelEntry(model_entry).value()) {

engine/e2e-test/test_api_model_import.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,25 @@ def setup_and_teardown(self):
1818
def test_model_import_should_be_success(self):
1919
body_json = {'model': 'tinyllama:gguf',
2020
'modelPath': '/path/to/local/gguf'}
21-
response = requests.post("http://localhost:3928/models/import", json = body_json)
22-
assert response.status_code == 200
21+
response = requests.post("http://localhost:3928/models/import", json=body_json)
22+
assert response.status_code == 200
23+
24+
@pytest.mark.skipif(True, reason="Expensive test. Only test when you have local gguf file.")
25+
def test_model_import_with_name_should_be_success(self):
26+
body_json = {'model': 'tinyllama:gguf',
27+
'modelPath': '/path/to/local/gguf',
28+
'name': 'test_model'}
29+
response = requests.post("http://localhost:3928/models/import", json=body_json)
30+
assert response.status_code == 200
31+
32+
def test_model_import_with_invalid_path_should_fail(self):
33+
body_json = {'model': 'tinyllama:gguf',
34+
'modelPath': '/invalid/path/to/gguf'}
35+
response = requests.post("http://localhost:3928/models/import", json=body_json)
36+
assert response.status_code == 400
37+
38+
def test_model_import_with_missing_model_should_fail(self):
39+
body_json = {'modelPath': '/path/to/local/gguf'}
40+
response = requests.post("http://localhost:3928/models/import", json=body_json)
41+
print(response)
42+
assert response.status_code == 409

engine/e2e-test/test_api_model_pull_direct_url.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def setup_and_teardown(self):
2121
[
2222
"models",
2323
"delete",
24-
"TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf",
24+
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
2525
],
2626
)
2727
yield
@@ -32,24 +32,43 @@ def setup_and_teardown(self):
3232
[
3333
"models",
3434
"delete",
35-
"TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf",
35+
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
3636
],
3737
)
3838
stop_server()
3939

4040
@pytest.mark.asyncio
4141
async def test_model_pull_with_direct_url_should_be_success(self):
4242
myobj = {
43-
"model": "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
43+
"model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf"
4444
}
4545
response = requests.post("http://localhost:3928/models/pull", json=myobj)
4646
assert response.status_code == 200
4747
await wait_for_websocket_download_success_event(timeout=None)
4848
get_model_response = requests.get(
49-
"http://127.0.0.1:3928/models/TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
49+
"http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
5050
)
5151
assert get_model_response.status_code == 200
5252
assert (
5353
get_model_response.json()["model"]
54-
== "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
54+
== "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
55+
)
56+
57+
@pytest.mark.asyncio
58+
async def test_model_pull_with_direct_url_should_have_desired_name(self):
59+
myobj = {
60+
"model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf",
61+
"name": "smol_llama_100m"
62+
}
63+
response = requests.post("http://localhost:3928/models/pull", json=myobj)
64+
assert response.status_code == 200
65+
await wait_for_websocket_download_success_event(timeout=None)
66+
get_model_response = requests.get(
67+
"http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
68+
)
69+
assert get_model_response.status_code == 200
70+
print(get_model_response.json()["name"])
71+
assert (
72+
get_model_response.json()["name"]
73+
== "smol_llama_100m"
5574
)

engine/services/model_service.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
namespace {
1919
void ParseGguf(const DownloadItem& ggufDownloadItem,
20-
std::optional<std::string> author) {
20+
std::optional<std::string> author,
21+
std::optional<std::string> name) {
2122
namespace fs = std::filesystem;
2223
namespace fmu = file_manager_utils;
2324
config::GGUFHandler gguf_handler;
@@ -32,6 +33,8 @@ void ParseGguf(const DownloadItem& ggufDownloadItem,
3233
fmu::ToRelativeCortexDataPath(fs::path(ggufDownloadItem.localPath));
3334
model_config.files = {file_rel_path.string()};
3435
model_config.model = ggufDownloadItem.id;
36+
model_config.name =
37+
name.has_value() ? name.value() : gguf_handler.GetModelConfig().name;
3538
yaml_handler.UpdateModelConfig(model_config);
3639

3740
auto yaml_path{ggufDownloadItem.localPath};
@@ -223,7 +226,8 @@ std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
223226
}
224227

225228
cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
226-
const std::string& url, std::optional<std::string> temp_model_id) {
229+
const std::string& url, std::optional<std::string> temp_model_id,
230+
std::optional<std::string> temp_name) {
227231
auto url_obj = url_parser::FromUrlString(url);
228232

229233
if (url_obj.host == kHuggingFaceHost) {
@@ -279,9 +283,9 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
279283
.localPath = local_path,
280284
}}}};
281285

282-
auto on_finished = [author](const DownloadTask& finishedTask) {
286+
auto on_finished = [author, temp_name](const DownloadTask& finishedTask) {
283287
auto gguf_download_item = finishedTask.items[0];
284-
ParseGguf(gguf_download_item, author);
288+
ParseGguf(gguf_download_item, author, temp_name);
285289
};
286290

287291
downloadTask.id = unique_model_id;
@@ -346,7 +350,7 @@ cpp::result<std::string, std::string> ModelService::HandleUrl(
346350

347351
auto on_finished = [author](const DownloadTask& finishedTask) {
348352
auto gguf_download_item = finishedTask.items[0];
349-
ParseGguf(gguf_download_item, author);
353+
ParseGguf(gguf_download_item, author, std::nullopt);
350354
};
351355

352356
auto result = download_service_->AddDownloadTask(downloadTask, on_finished);
@@ -770,7 +774,7 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
770774
auto author{url_obj.pathParams[0]};
771775
auto model_id{url_obj.pathParams[1]};
772776
auto file_name{url_obj.pathParams.back()};
773-
if (author == "cortexso") {
777+
if (author == "cortexso") {
774778
return ModelPullInfo{.id = model_id + ":" + url_obj.pathParams[3],
775779
.downloaded_models = {},
776780
.available_models = {},
@@ -787,8 +791,10 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
787791
if (parsed.size() != 2) {
788792
return cpp::fail("Invalid model handle: " + input);
789793
}
790-
return ModelPullInfo{
791-
.id = input, .downloaded_models = {}, .available_models = {}, .download_url = input};
794+
return ModelPullInfo{.id = input,
795+
.downloaded_models = {},
796+
.available_models = {},
797+
.download_url = input};
792798
}
793799

794800
if (input.find("/") != std::string::npos) {

engine/services/model_service.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ModelService {
3939
std::shared_ptr<DownloadService> download_service,
4040
std::shared_ptr<services::InferenceService> inference_service)
4141
: download_service_{download_service},
42-
inference_svc_(inference_service) {};
42+
inference_svc_(inference_service){};
4343

4444
/**
4545
* Return model id if download successfully
@@ -81,7 +81,8 @@ class ModelService {
8181
cpp::result<std::string, std::string> HandleUrl(const std::string& url);
8282

8383
cpp::result<DownloadTask, std::string> HandleDownloadUrlAsync(
84-
const std::string& url, std::optional<std::string> temp_model_id);
84+
const std::string& url, std::optional<std::string> temp_model_id,
85+
std::optional<std::string> temp_name);
8586

8687
private:
8788
/**

0 commit comments

Comments
 (0)