Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2488,12 +2488,29 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"path to save slot kv cache (default: disabled)",
[](common_params & params, const std::string & value) {
params.slot_save_path = value;
if (!fs_is_directory(params.slot_save_path)) {
throw std::invalid_argument("not a directory: " + value);
}
// if doesn't end with DIRECTORY_SEPARATOR, add it
if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
params.slot_save_path += DIRECTORY_SEPARATOR;
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--media-path"}, "PATH",
"directory for loading local media files; files can be accessed via file:// URLs using relative paths (default: disabled)",
[](common_params & params, const std::string & value) {
params.media_path = value;
if (!fs_is_directory(params.media_path)) {
throw std::invalid_argument("not a directory: " + value);
}
// if doesn't end with DIRECTORY_SEPARATOR, add it
if (!params.media_path.empty() && params.media_path[params.media_path.size() - 1] != DIRECTORY_SEPARATOR) {
params.media_path += DIRECTORY_SEPARATOR;
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--models-dir"}, "PATH",
"directory containing models for the router server (default: disabled)",
Expand Down
13 changes: 11 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over

// Validate if a filename is safe to use
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
bool fs_validate_filename(const std::string & filename) {
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
if (!filename.length()) {
// Empty filename invalid
return false;
Expand Down Expand Up @@ -754,10 +754,14 @@ bool fs_validate_filename(const std::string & filename) {
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|| c == ':' || c == '*' // Illegal characters
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
return false;
}
if (!allow_subdirs && (c == '/' || c == '\\')) {
// Subdirectories not allowed, reject path separators
return false;
}
}

// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
Expand Down Expand Up @@ -859,6 +863,11 @@ bool fs_create_directory_with_parents(const std::string & path) {
#endif // _WIN32
}

bool fs_is_directory(const std::string & path) {
std::filesystem::path dir(path);
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
}

std::string fs_get_cache_directory() {
std::string cache_directory = "";
auto ensure_trailing_slash = [](std::string p) {
Expand Down
4 changes: 3 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ struct common_params {
bool log_json = false;

std::string slot_save_path;
std::string media_path; // path to directory for loading media files

float slot_prompt_similarity = 0.1f;

Expand Down Expand Up @@ -635,8 +636,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
// Filesystem utils
//

bool fs_validate_filename(const std::string & filename);
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
bool fs_create_directory_with_parents(const std::string & path);
bool fs_is_directory(const std::string & path);

std::string fs_get_cache_directory();
std::string fs_get_cache_file(const std::string & filename);
Expand Down
99 changes: 64 additions & 35 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <random>
#include <sstream>
#include <fstream>

json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
Expand Down Expand Up @@ -774,6 +775,65 @@ json oaicompat_completion_params_parse(const json & body) {
return llama_params;
}

// media_path always end with '/', see arg.cpp
static void handle_media(
std::vector<raw_buffer> & out_files,
json & media_obj,
const std::string & media_path) {
std::string url = json_value(media_obj, "url", std::string());
if (string_starts_with(url, "http")) {
// download remote image
// TODO @ngxson : maybe make these params configurable
common_remote_params params;
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
params.max_size = 1024 * 1024 * 10; // 10MB
params.timeout = 10; // seconds
SRV_INF("downloading image from '%s'\n", url.c_str());
auto res = common_remote_get_content(url, params);
if (200 <= res.first && res.first < 300) {
SRV_INF("downloaded %ld bytes\n", res.second.size());
raw_buffer data;
data.insert(data.end(), res.second.begin(), res.second.end());
out_files.push_back(data);
} else {
throw std::runtime_error("Failed to download image");
}

} else if (string_starts_with(url, "file://")) {
if (media_path.empty()) {
throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified");
}
// load local image file
std::string file_path = url.substr(7); // remove "file://"
raw_buffer data;
if (!fs_validate_filename(file_path, true)) {
throw std::invalid_argument("file path is not allowed: " + file_path);
}
SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str());
std::ifstream file(media_path + file_path, std::ios::binary);
if (!file) {
throw std::invalid_argument("file does not exist or cannot be opened: " + file_path);
}
data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
out_files.push_back(data);

} else {
// try to decode base64 image
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
if (parts.size() != 2) {
throw std::runtime_error("Invalid url value");
} else if (!string_starts_with(parts[0], "data:image/")) {
throw std::runtime_error("Invalid url format: " + parts[0]);
} else if (!string_ends_with(parts[0], "base64")) {
throw std::runtime_error("url must be base64 encoded");
} else {
auto base64_data = parts[1];
auto decoded_data = base64_decode(base64_data);
out_files.push_back(decoded_data);
}
}
}

// used by /chat/completions endpoint
json oaicompat_chat_params_parse(
json & body, /* openai api json semantics */
Expand Down Expand Up @@ -860,41 +920,8 @@ json oaicompat_chat_params_parse(
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}

json image_url = json_value(p, "image_url", json::object());
std::string url = json_value(image_url, "url", std::string());
if (string_starts_with(url, "http")) {
// download remote image
// TODO @ngxson : maybe make these params configurable
common_remote_params params;
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
params.max_size = 1024 * 1024 * 10; // 10MB
params.timeout = 10; // seconds
SRV_INF("downloading image from '%s'\n", url.c_str());
auto res = common_remote_get_content(url, params);
if (200 <= res.first && res.first < 300) {
SRV_INF("downloaded %ld bytes\n", res.second.size());
raw_buffer data;
data.insert(data.end(), res.second.begin(), res.second.end());
out_files.push_back(data);
} else {
throw std::runtime_error("Failed to download image");
}

} else {
// try to decode base64 image
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
if (parts.size() != 2) {
throw std::invalid_argument("Invalid image_url.url value");
} else if (!string_starts_with(parts[0], "data:image/")) {
throw std::invalid_argument("Invalid image_url.url format: " + parts[0]);
} else if (!string_ends_with(parts[0], "base64")) {
throw std::invalid_argument("image_url.url must be base64 encoded");
} else {
auto base64_data = parts[1];
auto decoded_data = base64_decode(base64_data);
out_files.push_back(decoded_data);
}
}
json image_url = json_value(p, "image_url", json::object());
handle_media(out_files, image_url, opt.media_path);

// replace this chunk with a marker
p["type"] = "text";
Expand All @@ -916,6 +943,8 @@ json oaicompat_chat_params_parse(
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);

// TODO: add audio_url support by reusing handle_media()

// replace this chunk with a marker
p["type"] = "text";
p["text"] = mtmd_default_marker();
Expand Down
1 change: 1 addition & 0 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ struct oaicompat_parser_options {
bool allow_image;
bool allow_audio;
bool enable_thinking = true;
std::string media_path;
};

// used by /chat/completions endpoint
Expand Down
1 change: 1 addition & 0 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ struct server_context_impl {
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* media_path */ params_base.media_path,
};

// print sample chat example to make it clear which template is used
Expand Down
2 changes: 2 additions & 0 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
try {
return func(req);
} catch (const std::invalid_argument & e) {
// treat invalid_argument as invalid request (400)
error = ERROR_TYPE_INVALID_REQUEST;
message = e.what();
} catch (const std::exception & e) {
// treat other exceptions as server error (500)
error = ERROR_TYPE_SERVER;
message = e.what();
} catch (...) {
Expand Down
31 changes: 31 additions & 0 deletions tools/server/tests/unit/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,34 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
assert res.status_code == 200
assert cors_header in res.headers
assert res.headers[cors_header] == cors_header_value


@pytest.mark.parametrize(
"media_path, image_url, success",
[
(None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail
("../../../tools", "file://mtmd/test-1.jpeg", True),
("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above
("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file
("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal
]
)
def test_local_media_file(media_path, image_url, success,):
server = ServerPreset.tinygemma3()
server.media_path = media_path
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 1,
"messages": [
{"role": "user", "content": [
{"type": "text", "text": "test"},
{"type": "image_url", "image_url": {
"url": image_url,
}},
]},
],
})
if success:
assert res.status_code == 200
else:
assert res.status_code == 400
3 changes: 3 additions & 0 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class ServerProcess:
chat_template_file: str | None = None
server_path: str | None = None
mmproj_url: str | None = None
media_path: str | None = None

# session variables
process: subprocess.Popen | None = None
Expand Down Expand Up @@ -217,6 +218,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.mmproj_url:
server_args.extend(["--mmproj-url", self.mmproj_url])
if self.media_path:
server_args.extend(["--media-path", self.media_path])

args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")
Expand Down
Loading