Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 40edd99

Browse files
committed
feat: uplift pull and run cmd
1 parent d3e886b commit 40edd99

File tree

8 files changed

+154
-19
lines changed

8 files changed

+154
-19
lines changed

engine/commands/model_pull_cmd.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ void ModelPullCmd::Exec(const std::string& input) {
66
auto result = model_service_.DownloadModel(input);
77
if (result.has_error()) {
88
CLI_LOG(result.error());
9-
}
9+
}
1010
}
1111
}; // namespace commands

engine/database/models.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,29 @@ cpp::result<bool, std::string> Models::DeleteModelEntry(
273273
}
274274
}
275275

276+
cpp::result<std::vector<std::string>, std::string> Models::FindRelatedModel(
277+
const std::string& identifier) const {
278+
// TODO (namh): add check for alias as well
279+
try {
280+
std::vector<std::string> related_models;
281+
SQLite::Statement query(
282+
db_,
283+
"SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? "
284+
"OR model_id LIKE ? OR model_id LIKE ?");
285+
query.bind(1, identifier + ":%");
286+
query.bind(2, "%:" + identifier);
287+
query.bind(3, "%:" + identifier + ":%");
288+
query.bind(4, identifier);
289+
290+
while (query.executeStep()) {
291+
related_models.push_back(query.getColumn(0).getString());
292+
}
293+
return related_models;
294+
} catch (const std::exception& e) {
295+
return cpp::fail(e.what());
296+
}
297+
}
298+
276299
bool Models::HasModel(const std::string& identifier) const {
277300
try {
278301
SQLite::Statement query(

engine/database/models.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Models {
2424
const std::string& model_id,
2525
const std::string& model_alias) const;
2626

27-
cpp::result<std::vector<ModelEntry>, std::string> LoadModelListNoLock() const;
27+
cpp::result<std::vector<ModelEntry>, std::string> LoadModelListNoLock() const;
2828

2929
public:
3030
static const std::string kModelListPath;
@@ -35,15 +35,19 @@ class Models {
3535
std::string GenerateShortenedAlias(
3636
const std::string& model_id,
3737
const std::vector<ModelEntry>& entries) const;
38-
cpp::result<ModelEntry, std::string> GetModelInfo(const std::string& identifier) const;
38+
cpp::result<ModelEntry, std::string> GetModelInfo(
39+
const std::string& identifier) const;
3940
void PrintModelInfo(const ModelEntry& entry) const;
4041
cpp::result<bool, std::string> AddModelEntry(ModelEntry new_entry,
4142
bool use_short_alias = false);
42-
cpp::result<bool, std::string> UpdateModelEntry(const std::string& identifier,
43-
const ModelEntry& updated_entry);
44-
cpp::result<bool, std::string> DeleteModelEntry(const std::string& identifier);
45-
cpp::result<bool, std::string> UpdateModelAlias(const std::string& model_id,
46-
const std::string& model_alias);
43+
cpp::result<bool, std::string> UpdateModelEntry(
44+
const std::string& identifier, const ModelEntry& updated_entry);
45+
cpp::result<bool, std::string> DeleteModelEntry(
46+
const std::string& identifier);
47+
cpp::result<bool, std::string> UpdateModelAlias(
48+
const std::string& model_id, const std::string& model_alias);
49+
cpp::result<std::vector<std::string>, std::string> FindRelatedModel(
50+
const std::string& identifier) const;
4751
bool HasModel(const std::string& identifier) const;
4852
};
4953
} // namespace cortex::db

engine/services/model_service.cc

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,46 @@ cpp::result<std::string, std::string> ModelService::HandleCortexsoModel(
151151
return cpp::fail(branches.error());
152152
}
153153

154-
std::vector<std::string> options{};
154+
auto default_model_branch = huggingface_utils::GetDefaultBranch(modelName);
155+
156+
cortex::db::Models modellist_handler;
157+
auto downloaded_model_ids =
158+
modellist_handler.FindRelatedModel(modelName).value_or(
159+
std::vector<std::string>{});
160+
161+
std::vector<std::string> avai_download_opts{};
155162
for (const auto& branch : branches.value()) {
156-
if (branch.second.name != "main") {
157-
options.emplace_back(branch.second.name);
163+
if (branch.second.name == "main") { // main branch only have metadata. skip
164+
continue;
165+
}
166+
auto model_id = modelName + ":" + branch.second.name;
167+
if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(),
168+
model_id) !=
169+
downloaded_model_ids.end()) { // if downloaded, we skip it
170+
continue;
158171
}
172+
avai_download_opts.emplace_back(model_id);
159173
}
160-
if (options.empty()) {
174+
175+
if (avai_download_opts.empty()) {
176+
// TODO: only with pull, we return
161177
return cpp::fail("No variant available");
162178
}
163-
auto selection = cli_selection_utils::PrintSelection(options);
164-
return DownloadModelFromCortexso(modelName, selection.value());
179+
std::optional<std::string> normalized_def_branch = std::nullopt;
180+
if (default_model_branch.has_value()) {
181+
normalized_def_branch = modelName + ":" + default_model_branch.value();
182+
}
183+
string_utils::SortStrings(downloaded_model_ids);
184+
string_utils::SortStrings(avai_download_opts);
185+
auto selection = cli_selection_utils::PrintModelSelection(
186+
downloaded_model_ids, avai_download_opts, normalized_def_branch);
187+
if (!selection.has_value()) {
188+
return cpp::fail("Invalid selection");
189+
}
190+
191+
CLI_LOG("Selected: " << selection.value());
192+
auto branch_name = selection.value().substr(modelName.size() + 1);
193+
return DownloadModelFromCortexso(modelName, branch_name, false);
165194
}
166195

167196
std::optional<config::ModelConfig> ModelService::GetDownloadedModel(

engine/utils/cli_selection_utils.h

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,67 @@
22
#include <optional>
33
#include <string>
44
#include <vector>
5+
#include "utils/logging_utils.h"
56

67
namespace cli_selection_utils {
7-
inline void PrintMenu(const std::vector<std::string>& options) {
8-
auto index{1};
8+
const std::string indent = std::string(4, ' ');
9+
inline void PrintMenu(
10+
const std::vector<std::string>& options,
11+
const std::optional<std::string> default_option = std::nullopt,
12+
const int start_index = 1) {
13+
auto index{start_index};
914
for (const auto& option : options) {
10-
std::cout << index << ". " << option << "\n";
15+
bool is_default = false;
16+
if (default_option.has_value() && option == default_option.value()) {
17+
is_default = true;
18+
}
19+
std::string selection{std::to_string(index) + ". " + option +
20+
(is_default ? " (default)" : "") + "\n"};
21+
std::cout << indent << selection;
1122
index++;
1223
}
1324
std::endl(std::cout);
1425
}
1526

27+
inline std::optional<std::string> PrintModelSelection(
28+
const std::vector<std::string>& downloaded,
29+
const std::vector<std::string>& availables,
30+
const std::optional<std::string> default_selection = std::nullopt) {
31+
32+
std::string selection{""};
33+
if (!downloaded.empty()) {
34+
std::cout << "Downloaded models:\n";
35+
for (const auto& option : downloaded) {
36+
std::cout << indent << option << "\n";
37+
}
38+
std::endl(std::cout);
39+
}
40+
41+
if (!availables.empty()) {
42+
std::cout << "Available to download:\n";
43+
PrintMenu(availables, default_selection, 1);
44+
}
45+
46+
std::cout << "Select a model (" << 1 << "-" << availables.size() << "): ";
47+
std::getline(std::cin, selection);
48+
49+
// if selection is empty and default selection is inside availables, return default_selection
50+
if (selection.empty()) {
51+
if (default_selection.has_value() &&
52+
std::find(availables.begin(), availables.end(),
53+
default_selection.value()) != availables.end()) {
54+
return default_selection;
55+
}
56+
return std::nullopt;
57+
}
58+
59+
if (std::stoi(selection) > availables.size() || std::stoi(selection) < 1) {
60+
return std::nullopt;
61+
}
62+
63+
return availables[std::stoi(selection) - 1];
64+
}
65+
1666
inline std::optional<std::string> PrintSelection(
1767
const std::vector<std::string>& options,
1868
const std::string& title = "Select an option") {

engine/utils/curl_utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <curl/curl.h>
22
#include <nlohmann/json.hpp>
33
#include <string>
4-
#include "utils/logging_utils.h"
54
#include "utils/result.hpp"
65
#include "yaml-cpp/yaml.h"
76

@@ -74,4 +73,4 @@ inline cpp::result<nlohmann::json, std::string> SimpleGetJson(
7473
" parsing error: " + std::string(e.what()));
7574
}
7675
}
77-
} // namespace curl_utils
76+
} // namespace curl_utils

engine/utils/huggingface_utils.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ GetHuggingFaceModelRepoInfo(const std::string& author,
140140
return model_repo_info;
141141
}
142142

143+
inline std::string GetMetadataUrl(const std::string& model_id) {
144+
auto url_obj = url_parser::Url{
145+
.protocol = "https",
146+
.host = kHuggingfaceHost,
147+
.pathParams = {"cortexso", model_id, "resolve", "main", "metadata.yml"}};
148+
149+
return url_obj.ToFullPath();
150+
}
151+
143152
inline std::string GetDownloadableUrl(const std::string& author,
144153
const std::string& modelName,
145154
const std::string& fileName,
@@ -151,4 +160,21 @@ inline std::string GetDownloadableUrl(const std::string& author,
151160
};
152161
return url_parser::FromUrl(url_obj);
153162
}
163+
164+
inline std::optional<std::string> GetDefaultBranch(
165+
const std::string& model_name) {
166+
auto default_model_branch =
167+
curl_utils::ReadRemoteYaml(GetMetadataUrl(model_name));
168+
169+
if (default_model_branch.has_error()) {
170+
return std::nullopt;
171+
}
172+
173+
auto metadata = default_model_branch.value();
174+
auto default_branch = metadata["default"];
175+
if (default_branch.IsDefined()) {
176+
return default_branch.as<std::string>();
177+
}
178+
return std::nullopt;
179+
}
154180
} // namespace huggingface_utils

engine/utils/string_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ inline bool StartsWith(const std::string& str, const std::string& prefix) {
1313
return str.rfind(prefix, 0) == 0;
1414
}
1515

16+
inline void SortStrings(std::vector<std::string>& strings) {
17+
std::sort(strings.begin(), strings.end());
18+
}
19+
1620
inline bool EndsWith(const std::string& str, const std::string& suffix) {
1721
if (str.length() >= suffix.length()) {
1822
return (0 == str.compare(str.length() - suffix.length(), suffix.length(),

0 commit comments

Comments
 (0)