Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3295,6 +3295,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.port = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT"));
add_opt(common_arg(
{"--allowed-local-media-path"}, "PATH",
string_format("path from which local media files are allowed to be read from (default: none)"),
[](common_params & params, const std::string & value) {
try {
params.allowed_local_media_path = std::filesystem::canonical(std::filesystem::path(value));
if (!std::filesystem::is_directory(params.allowed_local_media_path)) {
throw std::invalid_argument(string_format("allowed local media path must be a dir: %s", params.allowed_local_media_path.c_str()));
}
} catch (std::filesystem::filesystem_error &err) {
throw std::invalid_argument(string_format("invalid allowed local media path: %s", err.what()));
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALLOWED_LOCAL_MEDIA_PATH"));
add_opt(common_arg(
{"--local-media-max-size-mb"}, "N",
string_format("max size in mb for local media files (default: %lu)", params.local_media_max_size_mb),
[](common_params & params, int value) {
params.local_media_max_size_mb = static_cast<size_t>(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_LOCAL_MEDIA_MAX_SIZE_MB"));
add_opt(common_arg(
{"--path"}, "PATH",
string_format("path to serve static files from (default: %s)", params.public_path.c_str()),
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <map>
#include <sstream>
#include <cmath>
#include <filesystem>

#include "ggml-opt.h"
#include "llama-cpp.h"
Expand Down Expand Up @@ -429,9 +430,11 @@ struct common_params {
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
size_t local_media_max_size_mb = 15; // 0 = no limit, 15 = 1 MiB. Max size of loaded local media files

std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::filesystem::path allowed_local_media_path; // NOLINT
std::string api_prefix = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
Expand Down
4 changes: 4 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ The project is under active development, and we are [looking for feedback and co
| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
| `--allowed-local-media-path PATH` | path from which local media files are allowed to be read from (default: none)<br/>(env: LLAMA_ARG_ALLOWED_LOCAL_MEDIA_PATH) |
| `--local-media-max-size-mb N` | max size in mb for local media files (default: 15)<br/>(env: LLAMA_ARG_LOCAL_MEDIA_MAX_SIZE_MB) |
| `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )<br/>(env: LLAMA_ARG_API_PREFIX) |
| `--no-webui` | Disable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_NO_WEBUI) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
Expand Down Expand Up @@ -1213,6 +1215,8 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte

If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more.

We also support local files as input (e.g. `file://`) if enabled (see `--allowed-local-media-path` and `--local-media-max-size-mb` for details).

*Options:*

See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
Expand Down
18 changes: 10 additions & 8 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2586,14 +2586,16 @@ struct server_context {
SRV_INF("thinking = %d\n", enable_thinking);

oai_parser_opt = {
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* chat_template_kwargs */ params_base.default_template_kwargs,
/* common_chat_templates */ chat_templates.get(),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* chat_template_kwargs */ params_base.default_template_kwargs,
/* common_chat_templates */ chat_templates.get(),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* local_media_max_size_mb */ params_base.local_media_max_size_mb,
/* allowed_local_media_path */ params_base.allowed_local_media_path,
};
}

Expand Down
89 changes: 87 additions & 2 deletions tools/server/tests/unit/test_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
from utils import *
import base64
import requests
from pathlib import Path

server: ServerProcess

def get_img_url(id: str) -> str:

def get_img_url(id: str, tmp_path: str | None = None) -> str:
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
IMG_FILE_2 = "https://picsum.photos/id/237/5000"
if id == "IMG_URL_0":
return IMG_URL_0
elif id == "IMG_URL_1":
Expand All @@ -28,6 +31,46 @@ def get_img_url(id: str) -> str:
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
return base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_FILE_0":
if tmp_path is None:
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
image_name = IMG_URL_0.split('/')[-1]
file_name: Path = Path(tmp_path) / image_name
if file_name.exists():
return f"file://{file_name}"
else:
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
with open(file_name, 'wb') as f:
f.write(response.content)
return f"file://{file_name}"
elif id == "IMG_FILE_1":
if tmp_path is None:
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
image_name = IMG_URL_1.split('/')[-1]
file_name: Path = Path(tmp_path) / image_name
if file_name.exists():
return f"file://{file_name}"
else:
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
with open(file_name, 'wb') as f:
f.write(response.content)
return f"file://{file_name}"
elif id == "IMG_FILE_2":
if tmp_path is None:
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
image_name = "dog.jpg"
file_name: Path = Path(tmp_path) / image_name
if file_name.exists():
return f"file://{file_name}"
else:
response = requests.get(IMG_FILE_2)
response.raise_for_status() # Raise an exception for bad status codes
with open(file_name, 'wb') as f:
f.write(response.content)
return f"file://{file_name}"

else:
return id

Expand Down Expand Up @@ -70,6 +113,9 @@ def test_v1_models_supports_multimodal_capability():
("What is this:\n", "malformed", False, None),
("What is this:\n", "https://google.com/404", False, None), # non-existent image
("What is this:\n", "https://ggml.ai", False, None), # non-image data
("What is this:\n", "IMG_FILE_0", False, None),
("What is this:\n", "IMG_FILE_1", False, None),
("What is this:\n", "IMG_FILE_2", False, None),
# TODO @ngxson : test with multiple images, no images and with audio
]
)
Expand All @@ -83,7 +129,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": get_img_url(image_url),
"url": get_img_url(image_url, "./tmp"),
}},
]},
],
Expand All @@ -97,6 +143,45 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
assert res.status_code != 200


@pytest.mark.parametrize(
"allowed_mb_size, allowed_path, img_dir_path, prompt, image_url, success, re_content",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
(0, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_0", True, "(cat)+"),
(0, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_1", True, "(frog)+"),
(1, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_2", False, None),
(0, "./tmp/allowed", "./tmp", "What is this:\n", "IMG_FILE_0", False, None),
(0, "./tm", "./tmp", "What is this:\n", "IMG_FILE_0", False, None),
(0, "./tmp/allowed", "./tmp/allowed/..", "What is this:\n", "IMG_FILE_0", False, None),
(0, "./tmp/allowed", "./tmp/allowed/../.", "What is this:\n", "IMG_FILE_0", False, None),
]
)
def test_vision_chat_completion_local_files(allowed_mb_size, allowed_path, img_dir_path, prompt, image_url, success, re_content):
global server
server.local_media_max_size_mb = allowed_mb_size
server.allowed_local_media_path = allowed_path
Path(allowed_path).mkdir(exist_ok=True)
server.start()
res = server.make_request("POST", "/chat/completions", data={
"temperature": 0.0,
"top_k": 1,
"messages": [
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": get_img_url(image_url, img_dir_path),
}},
]},
],
})
if success:
assert res.status_code == 200
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200

@pytest.mark.parametrize(
"prompt, image_data, success, re_content",
[
Expand Down
6 changes: 6 additions & 0 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class ServerProcess:
chat_template_file: str | None = None
server_path: str | None = None
mmproj_url: str | None = None
local_media_max_size_mb: int | None = None
allowed_local_media_path: str | None = None

# session variables
process: subprocess.Popen | None = None
Expand Down Expand Up @@ -215,6 +217,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.mmproj_url:
server_args.extend(["--mmproj-url", self.mmproj_url])
if self.local_media_max_size_mb:
server_args.extend(["--local-media-max-size-mb", self.local_media_max_size_mb])
if self.allowed_local_media_path:
server_args.extend(["--allowed-local-media-path", self.allowed_local_media_path])

args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")
Expand Down
35 changes: 35 additions & 0 deletions tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <vector>
#include <memory>
#include <cinttypes>
#include <filesystem>
#include <fstream>

#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"

Expand Down Expand Up @@ -530,6 +532,8 @@ struct oaicompat_parser_options {
bool allow_image;
bool allow_audio;
bool enable_thinking = true;
size_t local_media_max_size_mb;
std::filesystem::path allowed_local_media_path;
};

// used by /chat/completions endpoint
Expand Down Expand Up @@ -638,6 +642,37 @@ static json oaicompat_chat_params_parse(
throw std::runtime_error("Failed to download image");
}

} else if (string_starts_with(url, "file://")) {
if (opt.allowed_local_media_path.empty()) {
throw std::runtime_error("Local media paths are not enabled");
}
// Strip off the leading "file://"
const std::string fname = url.substr(7);
const std::filesystem::path input_path = std::filesystem::canonical(std::filesystem::path(fname));
auto [allowed_end, nothing] = std::mismatch(opt.allowed_local_media_path.begin(), opt.allowed_local_media_path.end(), input_path.begin());
if (allowed_end != opt.allowed_local_media_path.end()) {
throw std::runtime_error("Local media file path not allowed: " + fname);
}
if (!std::filesystem::is_regular_file(input_path)) {
throw std::runtime_error("Local media file does not exist: " + fname);
}
const auto file_size = std::filesystem::file_size(input_path);
if (file_size > opt.local_media_max_size_mb * 1024 * 1024) {
throw std::runtime_error("Local media file exceeds maximum allowed size");
}
// load local file path
std::ifstream f(input_path, std::ios::binary);
if (!f) {
SRV_ERR("Unable to open file %s: %s\n", fname.c_str(), strerror(errno));
throw std::runtime_error("Unable to open local media file: " + fname);
}
raw_buffer buf((std::istreambuf_iterator(f)), std::istreambuf_iterator<char>());
if (buf.size() != file_size) {
SRV_ERR("Failed to read entire file %s", fname.c_str());
throw std::runtime_error("Failed to read entire image file");
}
out_files.push_back(buf);

} else {
// try to decode base64 image
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
Expand Down