From b33c4474f8b56f157ddb056b2b3cd8b43f507e6b Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Tue, 11 Feb 2025 21:46:25 -0800 Subject: [PATCH 01/51] updating code to match to match llamacpp tag b4689 --- CMakeLists.txt | 3 +- src/main/cpp/jllama.cpp | 303 +- src/main/cpp/jllama.h | 23 +- src/main/cpp/server.hpp | 4653 ++++++++++------- src/main/cpp/utils.hpp | 1138 ++-- .../java/de/kherud/llama/CliParameters.java | 40 + .../de/kherud/llama/InferenceParameters.java | 6 - src/main/java/de/kherud/llama/LlamaModel.java | 20 +- .../java/de/kherud/llama/ModelParameters.java | 1495 ++++-- .../java/de/kherud/llama/args/CacheType.java | 15 + .../de/kherud/llama/args/NumaStrategy.java | 4 +- .../de/kherud/llama/args/PoolingType.java | 19 +- .../de/kherud/llama/args/RopeScalingType.java | 19 +- .../java/de/kherud/llama/args/Sampler.java | 16 +- .../java/de/kherud/llama/LlamaModelTest.java | 18 +- src/test/java/examples/GrammarExample.java | 2 +- src/test/java/examples/InfillExample.java | 4 +- src/test/java/examples/MainExample.java | 4 +- 18 files changed, 4621 insertions(+), 3161 deletions(-) create mode 100644 src/main/java/de/kherud/llama/CliParameters.java create mode 100644 src/main/java/de/kherud/llama/args/CacheType.java diff --git a/CMakeLists.txt b/CMakeLists.txt index 847465e6..1b5f08f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,10 +20,11 @@ FetchContent_MakeAvailable(json) #################### llama.cpp #################### +set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3534 + GIT_TAG b4689 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index d59f3b77..29568727 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -1,10 +1,13 @@ #include "jllama.h" +#include "arg.h" #include "llama.h" +#include "log.h" #include "nlohmann/json.hpp" #include "server.hpp" #include +#include #include // We store some references to Java classes and their fields/methods here to speed up things for later and to fail @@ -93,6 +96,38 @@ std::string parse_jstring(JNIEnv *env, jstring java_string) return string; } +char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const jsize length) +{ + auto *const result = static_cast(malloc(length * sizeof(char *))); + + if (result == nullptr) + { + return nullptr; + } + + for (jsize i = 0; i < length; i++) + { + auto *const javaString = static_cast(env->GetObjectArrayElement(string_array, i)); + const char *cString = env->GetStringUTFChars(javaString, nullptr); + result[i] = strdup(cString); + env->ReleaseStringUTFChars(javaString, cString); + } + + return result; +} + +void free_string_array(char **array, jsize length) +{ + if (array != nullptr) + { + for (jsize i = 0; i < length; i++) + { + free(array[i]); + } + free(array); + } +} + /** * Since Java expects utf16 but std::strings are utf8, we can't directly use `env->NewString` or `env-NewString`, * but we directly send the bytes and do the conversion in Java. Unfortunately, there isn't a nice/standardized way to @@ -138,6 +173,9 @@ JNIEnv *get_jni_env() return env; } +bool log_json; +std::function log_callback; + /** * Invoke the log callback if there is any. */ @@ -150,9 +188,6 @@ void log_callback_trampoline(ggml_log_level level, const char *text, void *user_ } } // namespace -bool log_json; -std::function log_callback; - /** * The VM calls JNI_OnLoad when the native library is loaded (for example, through `System.loadLibrary`). * `JNI_OnLoad` must return the JNI version needed by the native library. @@ -352,55 +387,52 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) llama_backend_free(); } -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jstring jparams) +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jobjectArray jparams) { - gpt_params params; - - auto *ctx_server = new server_context(); + common_params params; - std::string c_params = parse_jstring(env, jparams); - json json_params = json::parse(c_params); - server_params_parse(json_params, params); - - if (json_value(json_params, "disable_log", false)) + const jsize argc = env->GetArrayLength(jparams); + char **argv = parse_string_array(env, jparams, argc); + if (argv == nullptr) { - log_disable(); - } - else - { - log_enable(); + return; } - if (!params.system_prompt.empty()) + const auto parsed_params = common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER); + free_string_array(argv, argc); + if (!parsed_params) { - ctx_server->system_prompt_set(params.system_prompt); + return; } + + SRV_INF("loading model '%s'\n", params.model.c_str()); - if (params.model_alias == "unknown") - { - params.model_alias = params.model; - } + common_init(); - llama_numa_init(params.numa); + // struct that contains llama context and inference + auto *ctx_server = new server_context(); - LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}}); + llama_backend_init(); + llama_numa_init(params.numa); - LOG_INFO("system info", { - {"n_threads", params.n_threads}, - {"n_threads_batch", params.n_threads_batch}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, + params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); std::atomic state{SERVER_STATE_LOADING_MODEL}; // Necessary similarity of prompt for slot selection ctx_server->slot_prompt_similarity = params.slot_prompt_similarity; + LOG_INF("%s: loading model\n", __func__); + // load the model if (!ctx_server->load_model(params)) { - state.store(SERVER_STATE_ERROR); + llama_backend_free(); + ; env->ThrowNew(c_llama_error, "could not load model from given file path"); return; } @@ -408,51 +440,30 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo ctx_server->init(); state.store(SERVER_STATE_READY); - LOG_INFO("model loaded", {}); + LOG_INF("%s: model loaded\n", __func__); const auto model_meta = ctx_server->model_meta(); // if a custom chat template is not supplied, we will use the one that comes with the model (if any) if (params.chat_template.empty()) { - if (!ctx_server->validate_model_chat_template()) - { - LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This " - "may cause the model to output suboptimal responses", - {}); - params.chat_template = "chatml"; - } - } - - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) - { - if (!ctx_server->validate_model_chat_template()) + if (!ctx_server->validate_builtin_chat_template(params.use_jinja)) { - LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This " - "may cause the model to output suboptimal responses", - {}); + LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. " + "This may cause the model to output suboptimal responses\n", + __func__); params.chat_template = "chatml"; } } // print sample chat example to make it clear which template is used - { - LOG_INFO("chat template", - { - {"chat_example", llama_chat_format_example(ctx_server->model, params.chat_template)}, - {"built_in", params.chat_template.empty()}, - }); - } + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(), + common_chat_format_example(*ctx_server->chat_templates.template_default, ctx_server->params_base.use_jinja) .c_str()); ctx_server->queue_tasks.on_new_task( std::bind(&server_context::process_single_task, ctx_server, std::placeholders::_1)); - ctx_server->queue_tasks.on_finish_multitask( - std::bind(&server_context::on_finish_multitask, ctx_server, std::placeholders::_1)); ctx_server->queue_tasks.on_update_slots(std::bind(&server_context::update_slots, ctx_server)); - ctx_server->queue_results.on_multitask_update(std::bind(&server_queue::update_multitask, &ctx_server->queue_tasks, - std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3)); std::thread t([ctx_server]() { JNIEnv *env; @@ -478,22 +489,63 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) std::string c_params = parse_jstring(env, jparams); - json json_params = json::parse(c_params); - const bool infill = json_params.contains("input_prefix") || json_params.contains("input_suffix"); + json data = json::parse(c_params); + + server_task_type type = SERVER_TASK_TYPE_COMPLETION; - if (json_params.value("use_chat_template", false)) + if (data.contains("input_prefix") || data.contains("input_suffix")) { - json chat; - chat.push_back({{"role", "system"}, {"content", ctx_server->system_prompt}}); - chat.push_back({{"role", "user"}, {"content", json_params["prompt"]}}); - json_params["prompt"] = format_chat(ctx_server->model, ctx_server->params.chat_template, chat); + type = SERVER_TASK_TYPE_INFILL; } - const int id_task = ctx_server->queue_tasks.get_new_id(); - ctx_server->queue_results.add_waiting_task_id(id_task); - ctx_server->request_completion(id_task, -1, json_params, infill, false); + auto completion_id = gen_chatcmplid(); + std::vector tasks; + + try + { + const auto & prompt = data.at("prompt"); + + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true); + + tasks.reserve(tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) + { + server_task task = server_task(type); + + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(tokenized_prompts[i]); + task.params = server_task::params_from_json_cmpl(ctx_server->ctx, ctx_server->params_base, data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_NONE; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl - return id_task; + tasks.push_back(task); + } + } + catch (const std::exception &e) + { + const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST); + env->ThrowNew(c_llama_error, err.dump().c_str()); + return 0; + } + + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); + + const auto task_ids = server_task::get_list_id(tasks); + + if (task_ids.size() != 1) + { + env->ThrowNew(c_llama_error, "multitasking currently not supported"); + return 0; + } + + return *task_ids.begin(); } JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *env, jobject obj, jint id_task) @@ -501,26 +553,26 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - server_task_result result = ctx_server->queue_results.recv(id_task); + server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - if (result.error) + if (result->is_error()) { - std::string response = result.data["message"].get(); + std::string response = result->to_json()["message"].get(); ctx_server->queue_results.remove_waiting_task_id(id_task); env->ThrowNew(c_llama_error, response.c_str()); return nullptr; } - - std::string response = result.data["content"].get(); - if (result.stop) + const auto out_res = result->to_json(); + std::string response = out_res["content"].get(); + if (result->is_stop()) { ctx_server->queue_results.remove_waiting_task_id(id_task); } jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); - if (result.data.contains("completion_probabilities")) + if (out_res.contains("completion_probabilities")) { - auto completion_probabilities = result.data["completion_probabilities"]; + auto completion_probabilities = out_res["completion_probabilities"]; for (const auto &entry : completion_probabilities) { auto probs = entry["probs"]; @@ -537,8 +589,10 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE } } + ctx_server->queue_results.remove_waiting_task_id(id_task); + jbyteArray jbytes = parse_jbytes(env, response); - return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result.stop); + return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result->is_stop()); } JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) @@ -546,41 +600,88 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - if (!ctx_server->params.embedding) + if (!ctx_server->params_base.embedding) { env->ThrowNew(c_llama_error, "model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))"); return nullptr; } + + const std::string prompt = parse_jstring(env, jprompt); + + SRV_INF("Calling embedding '%s'\n", prompt.c_str()); + + const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true); + std::vector tasks; + + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = 0; + task.prompt_tokens = std::move(tokens); - const int id_task = ctx_server->queue_tasks.get_new_id(); - ctx_server->queue_results.add_waiting_task_id(id_task); - ctx_server->request_completion(id_task, -1, {{"prompt", prompt}}, false, true); + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_NONE; - server_task_result result = ctx_server->queue_results.recv(id_task); + tasks.push_back(task); + + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); + + std::unordered_set task_ids = server_task::get_list_id(tasks); + const auto id_task = *task_ids.begin(); + json responses = json::array(); + + json error = nullptr; + + server_task_result_ptr result = ctx_server->queue_results.recv(id_task); ctx_server->queue_results.remove_waiting_task_id(id_task); - if (result.error) + + json response_str = result->to_json(); + if (result->is_error()) { - std::string response = result.data["message"].get(); + std::string response = result->to_json()["message"].get(); + ctx_server->queue_results.remove_waiting_task_id(id_task); env->ThrowNew(c_llama_error, response.c_str()); return nullptr; } - std::vector embedding = result.data["embedding"].get>(); - jsize embedding_size = embedding.size(); // NOLINT(*-narrowing-conversions) + const auto out_res = result->to_json(); - jfloatArray j_embedding = env->NewFloatArray(embedding_size); - if (j_embedding == nullptr) - { - env->ThrowNew(c_error_oom, "could not allocate embedding"); - return nullptr; - } + // Extract "embedding" as a vector of vectors (2D array) + std::vector> embedding = out_res["embedding"].get>>(); + + // Get total number of rows in the embedding + jsize embedding_rows = embedding.size(); + + // Get total number of columns in the first row (assuming all rows are of equal length) + jsize embedding_cols = embedding_rows > 0 ? embedding[0].size() : 0; - env->SetFloatArrayRegion(j_embedding, 0, embedding_size, reinterpret_cast(embedding.data())); + SRV_INF("Embedding has %d rows and %d columns\n", embedding_rows, embedding_cols); - return j_embedding; + // Ensure embedding is not empty + if (embedding.empty() || embedding[0].empty()) { + env->ThrowNew(c_error_oom, "embedding array is empty"); + return nullptr; + } + + // Extract only the first row + const std::vector& first_row = embedding[0]; // Reference to avoid copying + + + // Create a new float array in JNI + jfloatArray j_embedding = env->NewFloatArray(embedding_cols); + if (j_embedding == nullptr) { + env->ThrowNew(c_error_oom, "could not allocate embedding"); + return nullptr; + } + + // Copy the first row into the JNI float array + env->SetFloatArrayRegion(j_embedding, 0, embedding_cols, reinterpret_cast(first_row.data())); + + return j_embedding; } JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) @@ -589,7 +690,8 @@ JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) const std::string c_prompt = parse_jstring(env, jprompt); - std::vector tokens = ctx_server->tokenize(c_prompt, false); + + llama_tokens tokens = tokenize_mixed(ctx_server->vocab, c_prompt, false, true); jsize token_size = tokens.size(); // NOLINT(*-narrowing-conversions) jintArray java_tokens = env->NewIntArray(token_size); @@ -632,7 +734,8 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv * { jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - ctx_server->request_cancel(id_task); + std::unordered_set id_tasks = {id_task}; + ctx_server->cancel_tasks(id_tasks); ctx_server->queue_results.remove_waiting_task_id(id_task); } diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 2fd0529e..0ab39ea4 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -7,6 +7,25 @@ #ifdef __cplusplus extern "C" { #endif + +/* + * Class: de_kherud_llama_LlamaModel + * Method: requestEmbedding + * Signature: (Ljava/lang/String;)[F + */ +JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestEmbedding + (JNIEnv *, jobject, jstring); + + +/* + * Class: de_kherud_llama_LlamaModel + * Method: receiveEmbedding + * Signature: (Ljava/lang/Int;)[F + */ +JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_receiveEmbedding + (JNIEnv *, jobject, jint); + + /* * Class: de_kherud_llama_LlamaModel * Method: embed @@ -66,10 +85,10 @@ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes /* * Class: de_kherud_llama_LlamaModel * Method: loadModel - * Signature: (Ljava/lang/String;)V + * Signature: ([Ljava/lang/String;)V */ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel - (JNIEnv *, jobject, jstring); + (JNIEnv *, jobject, jobjectArray); /* * Class: de_kherud_llama_LlamaModel diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 0601dac4..70e7236d 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -1,8 +1,11 @@ #include "utils.hpp" #include "common.h" -#include "grammar-parser.h" +#include "json-schema-to-grammar.h" #include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" #include "nlohmann/json.hpp" @@ -10,161 +13,1257 @@ #include #include #include +#include +#include #include #include -#include #include #include +#include +#include using json = nlohmann::ordered_json; -enum stop_type -{ - STOP_TYPE_FULL, - STOP_TYPE_PARTIAL, -}; +constexpr int HTTP_POLLING_SECONDS = 1; -enum slot_state -{ - SLOT_STATE_IDLE, - SLOT_STATE_PROCESSING, +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, }; -enum slot_command -{ - SLOT_COMMAND_NONE, - SLOT_COMMAND_LOAD_PROMPT, - SLOT_COMMAND_RELEASE, +// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, }; -enum server_state -{ - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded }; -enum server_task_type -{ +enum server_task_type { SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_METRICS, SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, }; -struct server_task -{ - int id = -1; // to be filled by server_queue - int id_multi = -1; - int id_target = -1; +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + +struct slot_params { + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + bool ignore_eos = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + json to_json() const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + std::vector grammar_trigger_words; + for (const auto & trigger : sampling.grammar_trigger_words) { + grammar_trigger_words.push_back(trigger.word); + } + + return json { + {"n_predict", n_predict}, // Server configured n_predict + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"grammar_trigger_words", grammar_trigger_words}, + {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, + {"preserved_tokens", sampling.preserved_tokens}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) server_task_type type; - json data; - bool infill = false; - bool embedding = false; + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + llama_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task(server_task_type type) : type(type) {} + + static slot_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 2); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str()); + } else { + params.oaicompat_chat_format = defaults.oaicompat_chat_format; + } + } + + { + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + common_grammar_trigger trigger; + trigger.word = t.at("word"); + trigger.at_start = t.at("at_start"); + + auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); + params.sampling.grammar_trigger_tokens.push_back(ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + continue; + } + SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); + params.sampling.grammar_trigger_words.push_back(trigger); + } + } + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto & t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get().c_str()); + } + } + } + if (params.sampling.grammar_lazy) { + GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0); + } + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()){ + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } }; -struct server_task_result -{ - int id = -1; - int id_multi = -1; +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + json to_json() const { + return { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + } +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return false; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); + } + return probs_for_token; + } + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; + } - json data; + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } - bool stop; - bool error; + static std::vector str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } }; -struct server_task_multi -{ - int id = -1; +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + slot_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + SRV_DBG("Parsing chat message: %s\n", content.c_str()); + msg = common_chat_parse(content, oaicompat_chat_format); + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } else { + msg.content = content; + } + + json tool_calls; + if (!msg.tool_calls.empty()) { + tool_calls = json::array(); + for (const auto & tc : msg.tool_calls) { + tool_calls.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id}, + }); + } + } + + json message { + {"content", msg.content}, + {"tool_calls", tool_calls}, + {"role", "assistant"}, + }; + if (!msg.tool_plan.empty()) { + message["tool_plan"] = msg.tool_plan; + } + + json choice { + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", message}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + + json choice = json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()} + }; + + json ret = json { + {"choices", json::array({choice})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }; - std::set subtasks_remaining; - std::vector results; + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return ret; + } }; -struct slot_params -{ - bool stream = true; - bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = - 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict + std::string content; + llama_tokens tokens; - std::vector antiprompt; + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + completion_token_output prob_output; + result_timings timings; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; - json input_prefix; - json input_suffix; + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + bool first = n_decoded == 0; + std::time_t t = std::time(0); + json choices; + + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json { + {"content", content}}} + }})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } + } else { + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json { + {"content", content}, + }}, + }}); + } + + GGML_ASSERT(choices.size() >= 1); + + if (prob_output.probs.size() > 0) { + choices[0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"} + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return std::vector({ret}); + } }; -struct server_slot -{ +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } + + json to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override { + return format_error_response(err_msg, err_type); + } +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + int32_t kv_cache_tokens_count; + int32_t kv_cache_used_cells; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "kv_cache_tokens_count", kv_cache_tokens_count }, + { "kv_cache_used_cells", kv_cache_used_cells }, + + { "slots", slots_data }, + }; + } +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } else { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; + } + } +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; + } +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } +}; + +struct server_slot { int id; int id_task = -1; - int id_multi = -1; + + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + + llama_batch batch_spec = {}; + + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + common_speculative * spec = nullptr; + + std::vector lora; + + // the index relative to completion multi-task request + size_t index = 0; struct slot_params params; slot_state state = SLOT_STATE_IDLE; - slot_command command = SLOT_COMMAND_NONE; // used to determine the slot that has been used the longest int64_t t_last_used = -1; // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_past = 0; - int32_t n_decoded = 0; + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; + int32_t n_decoded = 0; int32_t n_remaining = -1; - int32_t i_batch = -1; - int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + int32_t i_batch = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - int32_t n_prompt_tokens = 0; + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated + int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; - json prompt; + // input prompt tokens + llama_tokens prompt_tokens; - // when a task is submitted, we first tokenize the prompt and store it here - std::vector prompt_tokens; + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + + llama_tokens cache_tokens; - std::string generated_text; - std::vector cache_tokens; std::vector generated_token_probs; - bool infill = false; - bool embedding = false; bool has_next_token = true; - bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; - - bool oaicompat = false; + bool has_new_line = false; + bool truncated = false; + stop_type stop; - std::string oaicompat_model; std::string stopping_word; // sampling - llama_token sampled; - struct llama_sampling_params sparams; - llama_sampling_context *ctx_sampling = nullptr; json json_schema; - int32_t ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor - int32_t ga_w = 512; // group-attention width + struct common_sampler * smpl = nullptr; + + llama_token sampled; - int32_t n_past_se = 0; // self-extend + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; // stats - size_t n_sent_text = 0; // number of sent text character - size_t n_sent_token_probs = 0; + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; @@ -172,115 +1271,113 @@ struct server_slot double t_prompt_processing; // ms double t_token_generation; // ms - void reset() - { - n_prompt_tokens = 0; - generated_text = ""; - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - n_sent_token_probs = 0; - infill = false; - ga_i = 0; - n_past_se = 0; + std::function callback_on_release; + + void reset() { + SLT_DBG(*this, "%s", "\n"); + + n_prompt_tokens = 0; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + task_type = SERVER_TASK_TYPE_COMPLETION; + generated_tokens.clear(); generated_token_probs.clear(); } - bool has_budget(gpt_params &global_params) - { - if (params.n_predict == -1 && global_params.n_predict == -1) - { + bool is_non_causal() const { + return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + } + + bool can_batch_with(server_slot & other_slot) { + return is_non_causal() == other_slot.is_non_causal() + && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { + if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - if (params.n_predict != -1) - { + if (params.n_predict != -1) { n_remaining = params.n_predict - n_decoded; - } - else if (global_params.n_predict != -1) - { + } else if (global_params.n_predict != -1) { n_remaining = global_params.n_predict - n_decoded; } return n_remaining > 0; // no budget } - bool available() const - { - return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE; + bool is_processing() const { + return state != SLOT_STATE_IDLE; } - bool is_processing() const - { - return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING; + bool can_speculate() const { + return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; } - void add_token_string(const completion_token_output &token) - { - if (command == SLOT_COMMAND_RELEASE) - { + void add_token(const completion_token_output & token) { + if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); return; } generated_token_probs.push_back(token); } - void release() - { - if (state == SLOT_STATE_PROCESSING) - { + void release() { + if (is_processing()) { + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); + + t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - command = SLOT_COMMAND_RELEASE; + state = SLOT_STATE_IDLE; + callback_on_release(id); } } - json get_formated_timings() const - { - return json{ - {"prompt_n", n_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed}, - - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, - {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, - }; + result_timings get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + return timings; } - size_t find_stopping_strings(const std::string &text, const size_t last_token_size, const stop_type type) - { + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t stop_pos = std::string::npos; - for (const std::string &word : params.antiprompt) - { + for (const std::string & word : params.antiprompt) { size_t pos; - if (type == STOP_TYPE_FULL) - { - const size_t tmp = word.size() + last_token_size; + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; pos = text.find(word, from_pos); - } - else - { + } else { + // otherwise, partial stop pos = find_partial_stop_string(word, text); } - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) - { - if (type == STOP_TYPE_FULL) - { - stopped_word = true; - stopping_word = word; + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; has_next_token = false; } stop_pos = pos; @@ -290,181 +1387,191 @@ struct server_slot return stop_pos; } - void print_timings() const - { - char buffer[512]; - - double t_token = t_prompt_processing / n_prompt_tokens_processed; - double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - snprintf(buffer, 512, - "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", - t_prompt_processing, n_prompt_tokens_processed, t_token, n_tokens_second); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_prompt_processing", t_prompt_processing}, - {"n_prompt_tokens_processed", n_prompt_tokens_processed}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - t_token = t_token_generation / n_decoded; - n_tokens_second = 1e3 / t_token_generation * n_decoded; - - snprintf(buffer, 512, - "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", - t_token_generation, n_decoded, t_token, n_tokens_second); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_token_generation", t_token_generation}, - {"n_decoded", n_decoded}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_prompt_processing", t_prompt_processing}, - {"t_token_generation", t_token_generation}, - {"t_total", t_prompt_processing + t_token_generation}, - }); + void print_timings() const { + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; + + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + } + + json to_json() const { + return json { + {"id", id}, + {"id_task", id_task}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + {"non_causal", is_non_causal()}, + {"params", params.to_json()}, + {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"next_token", + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + {"stopping_word", stopping_word}, + } + }, + }; } }; -struct server_metrics -{ + +struct server_metrics { int64_t t_start = 0; uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; + uint64_t t_prompt_processing = 0; - uint64_t n_tokens_predicted = 0; + uint64_t n_tokens_predicted = 0; uint64_t t_tokens_generation = 0; - void init() - { + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { t_start = ggml_time_us(); } - void on_prompt_eval(const server_slot &slot) - { + void on_prompt_eval(const server_slot & slot) { n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; } - void on_prediction(const server_slot &slot) - { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; } - void reset_bucket() - { + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + } + } + + void reset_bucket() { n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; } }; -struct server_queue -{ +struct server_queue { int id = 0; bool running; // queues - std::vector queue_tasks; - std::vector queue_tasks_deferred; - - std::vector queue_multitasks; + std::deque queue_tasks; + std::deque queue_tasks_deferred; std::mutex mutex_tasks; std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; - std::function callback_finish_multitask; - std::function callback_update_slots; + std::function callback_new_task; + std::function callback_update_slots; // Add a new task to the end of the queue - int post(server_task task) - { + int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); - if (task.id == -1) - { - task.id = id++; - LOG_VERBOSE("new task id", {{"new_id", task.id}}); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d, front = %d\n", task.id, front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); } - queue_tasks.push_back(std::move(task)); condition_tasks.notify_one(); return task.id; } + // multi-task version of post() + int post(std::vector & tasks, bool front = false) { + std::unique_lock lock(mutex_tasks); + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; + } + // Add a new task, but defer until one slot is available - void defer(server_task task) - { + void defer(server_task task) { std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); } - // Get the next id for creating anew task - int get_new_id() - { + // Get the next id for creating a new task + int get_new_id() { std::unique_lock lock(mutex_tasks); int new_id = id++; - LOG_VERBOSE("new task id", {{"new_id", new_id}}); return new_id; } // Register function to process a new task - void on_new_task(std::function callback) - { + void on_new_task(std::function callback) { callback_new_task = std::move(callback); } - // Register function to process a multitask when it is finished - void on_finish_multitask(std::function callback) - { - callback_finish_multitask = std::move(callback); - } - // Register the function to be called when all slots data is ready to be processed - void on_update_slots(std::function callback) - { + void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); } - // Call when the state of one slot is changed - void notify_slot_changed() - { - // move deferred tasks back to main loop + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task() { std::unique_lock lock(mutex_tasks); - for (auto &task : queue_tasks_deferred) - { - queue_tasks.push_back(std::move(task)); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_back(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); } - queue_tasks_deferred.clear(); + condition_tasks.notify_one(); } // end the start_loop routine - void terminate() - { + void terminate() { std::unique_lock lock(mutex_tasks); running = false; condition_tasks.notify_all(); @@ -477,146 +1584,127 @@ struct server_queue * - Check if multitask is finished * - Update all slots */ - void start_loop() - { + void start_loop() { running = true; - while (true) - { - LOG_VERBOSE("new task may arrive", {}); + while (true) { + QUE_DBG("%s", "processing new tasks\n"); - while (true) - { + while (true) { std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) - { + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { lock.unlock(); break; } server_task task = queue_tasks.front(); - queue_tasks.erase(queue_tasks.begin()); + queue_tasks.pop_front(); lock.unlock(); - LOG_VERBOSE("callback_new_task", {{"id_task", task.id}}); - callback_new_task(task); - } - - LOG_VERBOSE("update_multitasks", {}); - // check if we have any finished multitasks - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) - { - if (queue_iterator->subtasks_remaining.empty()) - { - // all subtasks done == multitask is done - server_task_multi current_multitask = *queue_iterator; - callback_finish_multitask(current_multitask); - // remove this multitask - queue_iterator = queue_multitasks.erase(queue_iterator); - } - else - { - ++queue_iterator; - } + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); } // all tasks in the current loop is processed, slots data is now ready - LOG_VERBOSE("callback_update_slots", {}); + QUE_DBG("%s", "update slots\n"); callback_update_slots(); - LOG_VERBOSE("wait for new task", {}); + QUE_DBG("%s", "waiting for new tasks\n"); { std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) - { - if (!running) - { - LOG_VERBOSE("ending start_loop", {}); - return; - } - condition_tasks.wait(lock, [&] { return (!queue_tasks.empty() || !running); }); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); } } } } - // - // functions to manage multitasks - // - - // add a multitask by specifying the id of all subtask (subtask is a server_task) - void add_multitask(int id_multi, std::vector &sub_ids) - { - std::lock_guard lock(mutex_tasks); - server_task_multi multi; - multi.id = id_multi; - std::copy(sub_ids.begin(), sub_ids.end(), - std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); - } - - // updatethe remaining subtasks, while appending results to multitask - void update_multitask(int id_multi, int id_sub, server_task_result &result) - { - std::lock_guard lock(mutex_tasks); - for (auto &multitask : queue_multitasks) - { - if (multitask.id == id_multi) - { - multitask.subtasks_remaining.erase(id_sub); - multitask.results.push_back(result); - } - } +private: + void cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id_target == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); } }; -struct server_response -{ - typedef std::function callback_multitask_t; - callback_multitask_t callback_update_multitask; - +struct server_response { // for keeping track of all tasks waiting for the result - std::set waiting_task_ids; + std::unordered_set waiting_task_ids; - // the main result queue - std::vector queue_results; + // the main result queue (using ptr for polymorphism) + std::vector queue_results; std::mutex mutex_results; std::condition_variable condition_results; // add the id_task to the list of tasks waiting for response - void add_waiting_task_id(int id_task) - { - LOG_VERBOSE("waiting for task id", {{"id_task", id_task}}); + void add_waiting_task_id(int id_task) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); std::unique_lock lock(mutex_results); waiting_task_ids.insert(id_task); } + void add_waiting_tasks(const std::vector & tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); + } + } + // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int id_task) - { - LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}}); + void remove_waiting_task_id(int id_task) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); std::unique_lock lock(mutex_results); waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); } - // This function blocks the thread until there is a response for this id_task - server_task_result recv(int id_task) - { - while (true) - { + void remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } + } + + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result_ptr recv(const std::unordered_set & id_tasks) { + while (true) { std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&] { return !queue_results.empty(); }); + condition_results.wait(lock, [&]{ + return !queue_results.empty(); + }); - for (int i = 0; i < (int)queue_results.size(); i++) - { - if (queue_results[i].id == id_task) - { - assert(queue_results[i].id_multi == -1); - server_task_result res = queue_results[i]; + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); return res; } @@ -626,33 +1714,45 @@ struct server_response // should never reach here } - // Register the function to update multitask - void on_multitask_update(callback_multitask_t callback) - { - callback_update_multitask = std::move(callback); + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here + } + + // single-task version of recv() + server_task_result_ptr recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); } // Send a new result to a waiting id_task - void send(server_task_result result) - { - LOG_VERBOSE("send new result", {{"id_task", result.id}}); + void send(server_task_result_ptr && result) { + SRV_DBG("sending result for task id = %d\n", result->id); std::unique_lock lock(mutex_results); - for (const auto &id_task : waiting_task_ids) - { - // LOG_TEE("waiting task id %i \n", id_task); - // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result - if (result.id_multi == id_task) - { - LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}}); - callback_update_multitask(id_task, result.id, result); - continue; - } + for (const auto & id_task : waiting_task_ids) { + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); - if (result.id == id_task) - { - LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}}); - queue_results.push_back(result); + queue_results.emplace_back(std::move(result)); condition_results.notify_all(); return; } @@ -660,31 +1760,35 @@ struct server_response } }; -struct server_context -{ - llama_model *model = nullptr; - llama_context *ctx = nullptr; +struct server_context { + common_params params_base; - gpt_params params; + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; - llama_batch batch; + llama_model * model = nullptr; + llama_context * ctx = nullptr; - bool clean_kv_cache = true; - bool add_bos_token = true; + const llama_vocab * vocab = nullptr; - int32_t n_ctx; // total context for all clients / slots + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch = {}; - // system prompt - bool system_need_update = false; + bool clean_kv_cache = true; + bool add_bos_token = true; + bool has_eos_token = false; - std::string system_prompt; - std::vector system_tokens; + int32_t n_ctx; // total context for all clients / slots // slots / clients std::vector slots; json default_generation_settings_for_props; - server_queue queue_tasks; + server_queue queue_tasks; server_response queue_results; server_metrics metrics; @@ -692,1392 +1796,1006 @@ struct server_context // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - ~server_context() - { - if (ctx) - { - llama_free(ctx); - ctx = nullptr; - } - - if (model) - { - llama_free_model(model); - model = nullptr; - } + common_chat_templates chat_templates; + ~server_context() { // Clear any sampling context - for (server_slot &slot : slots) - { - if (slot.ctx_sampling != nullptr) - { - llama_sampling_free(slot.ctx_sampling); - } - } + for (server_slot & slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; - llama_batch_free(batch); - } - - bool load_model(const gpt_params ¶ms_) - { - params = params_; + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; - // dedicate one sequence to the system prompt - params.n_parallel += 1; + common_speculative_free(slot.spec); + slot.spec = nullptr; - llama_init_result llama_init = llama_init_from_gpt_params(params); - - model = llama_init.model; - ctx = llama_init.context; - params.n_parallel -= 1; // but be sneaky about it - if (model == nullptr) - { - LOG_ERROR("unable to load model", {{"model", params.model}}); - return false; + llama_batch_free(slot.batch_spec); } - n_ctx = llama_n_ctx(ctx); - - add_bos_token = llama_should_add_bos_token(model); - GGML_ASSERT(llama_add_eos_token(model) != 1); - - return true; - } - - bool validate_model_chat_template() const - { - llama_chat_message chat[] = {{"user", "test"}}; - - const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); - - return res > 0; + llama_batch_free(batch); } - void init() - { - const int32_t n_ctx_slot = n_ctx / params.n_parallel; - - LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}}); - - for (int i = 0; i < params.n_parallel; i++) - { - server_slot slot; - - slot.id = i; - slot.n_ctx = n_ctx_slot; - slot.n_predict = params.n_predict; - - LOG_INFO("new slot", {{"id_slot", slot.id}, {"n_ctx_slot", slot.n_ctx}}); - - const int ga_n = params.grp_attn_n; - const int ga_w = params.grp_attn_w; - - if (ga_n != 1) - { - GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT - GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT - // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT - // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT - - LOG_INFO("slot self-extend", {{"id_slot", slot.id}, {"ga_n", ga_n}, {"ga_w", ga_w}}); - } + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.c_str()); - slot.ga_i = 0; - slot.ga_n = ga_n; - slot.ga_w = ga_w; + params_base = params; - slot.sparams = params.sparams; + llama_init = common_init_from_params(params_base); - slot.reset(); + model = llama_init.model.get(); + ctx = llama_init.context.get(); - slots.push_back(slot); + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); + return false; } - default_generation_settings_for_props = get_formated_generation(slots.front()); - default_generation_settings_for_props["seed"] = -1; - - // the update_slots() logic will always submit a maximum of n_batch tokens - // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not - // used) - { - const int32_t n_batch = llama_n_batch(ctx); + vocab = llama_model_get_vocab(model); - // only a single seq_id per token is needed - batch = llama_batch_init(n_batch, 0, 1); - } + n_ctx = llama_n_ctx(ctx); - metrics.init(); - } + add_bos_token = llama_vocab_get_add_bos(vocab); + has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; - std::vector tokenize(const json &json_prompt, bool add_special) const - { - // TODO: currently, we tokenize using special tokens by default - // this is not always correct (see - // https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) but it's better compared to - // completely ignoring ChatML and other chat templates - const bool TMP_FORCE_SPECIAL = true; + if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - std::vector prompt_tokens; + auto params_dft = params_base; - if (json_prompt.is_array()) - { - bool first = true; - for (const auto &p : json_prompt) - { - if (p.is_string()) - { - auto s = p.template get(); + params_dft.devices = params_base.speculative.devices; + params_dft.hf_file = params_base.speculative.hf_file; + params_dft.hf_repo = params_base.speculative.hf_repo; + params_dft.model = params_base.speculative.model; + params_dft.model_url = params_base.speculative.model_url; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; - std::vector p; - if (first) - { - p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); - first = false; - } - else - { - p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); - } + llama_init_dft = common_init_from_params(params_dft); - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } - else - { - if (first) - { - first = false; - } + model_dft = llama_init_dft.model.get(); - prompt_tokens.push_back(p.template get()); - } + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); + return false; } - } - else - { - auto s = json_prompt.template get(); - prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); - } - return prompt_tokens; - } + if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str()); - server_slot *get_slot_by_id(int id) - { - for (server_slot &slot : slots) - { - if (slot.id == id) - { - return &slot; + return false; } - } - - return nullptr; - } - - server_slot *get_available_slot(const std::string &prompt) - { - server_slot *ret = nullptr; - - // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) - { - int max_lcp_len = 0; - float similarity = 0; - - for (server_slot &slot : slots) - { - // skip the slot if it is not available - if (!slot.available()) - { - continue; - } - // skip the slot if it does not contains prompt - if (!slot.prompt.is_string()) - { - continue; - } - - // current slot's prompt - std::string slot_prompt = slot.prompt.get(); + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); - // length of the current slot's prompt - int slot_prompt_len = slot_prompt.size(); + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - int lcp_len = common_part(slot_prompt, prompt); - - // fraction of the common substring length compared to the current slot's prompt length - similarity = static_cast(lcp_len) / slot_prompt_len; - - // select the current slot if the criteria match - if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) - { - max_lcp_len = lcp_len; - ret = &slot; - } - } + // force F16 KV cache for the draft model for extra performance + cparams_dft.type_k = GGML_TYPE_F16; + cparams_dft.type_v = GGML_TYPE_F16; - if (ret != nullptr) - { - LOG_VERBOSE("selected slot by lcp similarity", { - {"id_slot", ret->id}, - {"max_lcp_len", max_lcp_len}, - {"similarity", similarity}, - }); - } + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); } - // find the slot that has been least recently used - if (ret == nullptr) - { - int64_t t_last = ggml_time_us(); - for (server_slot &slot : slots) - { - // skip the slot if it is not available - if (!slot.available()) - { - continue; - } - - // select the current slot if the criteria match - if (slot.t_last_used < t_last) - { - t_last = slot.t_last_used; - ret = &slot; - } - } - - if (ret != nullptr) - { - LOG_VERBOSE("selected slot by lru", { - {"id_slot", ret->id}, - {"t_last", t_last}, - }); - } + if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) { + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_from_model(model, "chatml"); + } else { + chat_templates = common_chat_templates_from_model(model, params_base.chat_template); } + GGML_ASSERT(chat_templates.template_default.get() != nullptr); - return ret; + return true; } - bool launch_slot_with_task(server_slot &slot, const server_task &task) - { - slot_params default_params; - // Sampling parameter defaults are loaded from the global server context (but individual requests can still - // override them) - llama_sampling_params default_sparams = params.sparams; - auto &data = task.data; - - slot.oaicompat = false; - slot.oaicompat_model = ""; - - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", false); - slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); - slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); - slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.sparams.seed = json_value(data, "seed", default_sparams.seed); - slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); - slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - - if (slot.params.cache_prompt && slot.ga_n != 1) - { - LOG_WARNING("cache_prompt is not supported with group-attention", {}); - slot.params.cache_prompt = false; - } - - if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) - { - // Might be better to reject the request with a 400 ? - LOG_WARNING("Max tokens to predict exceeds server configuration", - { - {"params.n_predict", slot.params.n_predict}, - {"slot.n_predict", slot.n_predict}, - }); - slot.params.n_predict = slot.n_predict; - } - - // infill - slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix); - slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); - - // get prompt - if (!task.infill) - { - const auto &prompt = data.find("prompt"); - if (prompt == data.end()) - { - send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - if ((prompt->is_string()) || (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) || - (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) - { - slot.prompt = *prompt; - } - else - { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - - // penalize user-provided tokens - { - slot.sparams.penalty_prompt_tokens.clear(); - slot.sparams.use_penalty_prompt_tokens = false; - - const auto &penalty_prompt = data.find("penalty_prompt"); - - if (penalty_prompt != data.end()) - { - if (penalty_prompt->is_string()) - { - const auto penalty_prompt_string = penalty_prompt->get(); - slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false); - - if (slot.params.n_predict > 0) - { - slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + - slot.params.n_predict); - } - slot.sparams.use_penalty_prompt_tokens = true; - - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); - } - else if (penalty_prompt->is_array()) - { - const auto n_tokens = penalty_prompt->size(); - slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); - - const int n_vocab = llama_n_vocab(model); - for (const auto &penalty_token : *penalty_prompt) - { - if (penalty_token.is_number_integer()) - { - const auto tok = penalty_token.get(); - if (tok >= 0 && tok < n_vocab) - { - slot.sparams.penalty_prompt_tokens.push_back(tok); - } - } - } - slot.sparams.use_penalty_prompt_tokens = true; + bool validate_builtin_chat_template(bool use_jinja) const { + llama_chat_message chat[] = {{"user", "test"}}; - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); + if (use_jinja) { + auto templates = common_chat_templates_from_model(model, ""); + common_chat_inputs inputs; + inputs.messages = json::array({{ + {"role", "user"}, + {"content", "test"}, + }}); + GGML_ASSERT(templates.template_default); + try { + common_chat_params_init(*templates.template_default, inputs); + if (templates.template_tool_use) { + common_chat_params_init(*templates.template_tool_use, inputs); } + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + return false; } + } else { + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); + const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); + return chat_res > 0; } + } - { - slot.sparams.logit_bias.clear(); + void init() { + const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; - if (json_value(data, "ignore_eos", false)) - { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; - } + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - const auto &logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) - { - const int n_vocab = llama_n_vocab(model); - for (const auto &el : *logit_bias) - { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) - { - float bias; - if (el[1].is_number()) - { - bias = el[1].get(); - } - else if (el[1].is_boolean() && !el[1].get()) - { - bias = -INFINITY; - } - else - { - continue; - } + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot slot; - if (el[0].is_number_integer()) - { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) - { - slot.sparams.logit_bias[tok] = bias; - } - } - else if (el[0].is_string()) - { - auto toks = llama_tokenize(model, el[0].get(), false); - for (auto tok : toks) - { - slot.sparams.logit_bias[tok] = bias; - } - } - } - } - } - } + slot.id = i; + slot.ctx = ctx; + slot.n_ctx = n_ctx_slot; + slot.n_predict = params_base.n_predict; - { - slot.params.antiprompt.clear(); + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - const auto &stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) - { - for (const auto &word : *stop) - { - if (!word.empty()) - { - slot.params.antiprompt.push_back(word); - } + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; } - } - } - { - const auto &samplers_sequence = data.find("samplers"); - if (samplers_sequence != data.end() && samplers_sequence->is_array()) - { - std::vector sampler_names; - for (const auto &sampler_name : *samplers_sequence) - { - if (sampler_name.is_string()) - { - sampler_names.emplace_back(sampler_name); - } + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; } - slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false); - } - else - { - slot.sparams.samplers_sequence = default_sparams.samplers_sequence; } + + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); + + slot.params.sampling = params_base.sampling; + + slot.callback_on_release = [this](int) { + queue_tasks.pop_deferred_task(); + }; + + slot.reset(); + + slots.push_back(slot); } + default_generation_settings_for_props = slots[0].to_json(); + + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { - if (slot.ctx_sampling != nullptr) - { - llama_sampling_free(slot.ctx_sampling); - } - slot.ctx_sampling = llama_sampling_init(slot.sparams); - if (slot.ctx_sampling == nullptr) - { - // for now, the only error that may happen here is invalid grammar - send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); - return false; - } + const int32_t n_batch = llama_n_batch(ctx); + + // only a single seq_id per token is needed + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } - slot.command = SLOT_COMMAND_LOAD_PROMPT; - slot.prompt_tokens.clear(); + metrics.init(); + } - LOG_INFO("slot is processing task", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - }); + server_slot * get_slot_by_id(int id) { + for (server_slot & slot : slots) { + if (slot.id == id) { + return &slot; + } + } - return true; + return nullptr; } - void kv_cache_clear() - { - LOG_VERBOSE("clearing KV cache", {}); + server_slot * get_available_slot(const server_task & task) { + server_slot * ret = nullptr; - // clear the entire KV cache - llama_kv_cache_clear(ctx); - clean_kv_cache = false; - } + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + int lcs_len = 0; + float similarity = 0; - void system_prompt_update() - { - LOG_VERBOSE("system prompt update", { - {"system_prompt", system_prompt}, - }); + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } - kv_cache_clear(); - system_tokens.clear(); + // skip the slot if it does not contains cached tokens + if (slot.cache_tokens.empty()) { + continue; + } - if (!system_prompt.empty()) - { - system_tokens = ::llama_tokenize(ctx, system_prompt, true); + // length of the Longest Common Subsequence between the current slot's prompt and the input prompt + int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); - llama_batch_clear(batch); + // fraction of the common subsequence length compared to the current slot's prompt length + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); - for (int i = 0; i < (int)system_tokens.size(); ++i) - { - llama_batch_add(batch, system_tokens[i], i, {0}, false); + // select the current slot if the criteria match + if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { + lcs_len = cur_lcs_len; + similarity = cur_similarity; + ret = &slot; + } } - const int32_t n_batch = llama_n_batch(ctx); + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity); + } + } - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) - { - const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, - 0, - 0, // unused - }; - - if (llama_decode(ctx, batch_view) != 0) - { - LOG_ERROR("llama_decode() failed", {}); - return; + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = ggml_time_us(); + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // select the current slot if the criteria match + if (slot.t_last_used < t_last) { + t_last = slot.t_last_used; + ret = &slot; } } - // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i <= params.n_parallel; ++i) - { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last); } } - system_need_update = false; + return ret; } - bool system_prompt_set(const std::string &sys_prompt) - { - system_prompt = sys_prompt; + bool launch_slot_with_task(server_slot & slot, const server_task & task) { + slot.reset(); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); + + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = task.params.lora; + } - LOG_VERBOSE("system prompt process", { - {"system_prompt", system_prompt}, - }); + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + + if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { + // Might be better to reject the request with a 400 ? + slot.params.n_predict = slot.n_predict; + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); + } + + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); + } - // release all slots - for (server_slot &slot : slots) { - slot.release(); + if (slot.smpl != nullptr) { + common_sampler_free(slot.smpl); + } + + slot.smpl = common_sampler_init(model, slot.params.sampling); + if (slot.smpl == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } } - system_need_update = true; + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + } + + slot.state = SLOT_STATE_STARTED; + + SLT_INF(slot, "%s", "processing task\n"); + return true; } - bool process_token(completion_token_output &result, server_slot &slot) - { + void kv_cache_clear() { + SRV_DBG("%s", "clearing KV cache\n"); + + // clear the entire KV cache + llama_kv_cache_clear(ctx); + clean_kv_cache = false; + } + + bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); + const std::string token_str = result.text_to_send; slot.sampled = result.tok; - // search stop word and delete it slot.generated_text += token_str; - slot.has_next_token = true; - - if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) - { - // we can change penalty_prompt_tokens because it is always created from scratch each request - slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); + if (slot.params.return_tokens) { + slot.generated_tokens.push_back(result.tok); } + slot.has_next_token = true; // check if there is incomplete UTF-8 character at the end - bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) - { - unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) - { - // continuation byte: 10xxxxxx - continue; - } - if ((c & 0xE0) == 0xC0) - { - // 2-byte character: 110xxxxx ... - incomplete = i < 2; - } - else if ((c & 0xF0) == 0xE0) - { - // 3-byte character: 1110xxxx ... - incomplete = i < 3; - } - else if ((c & 0xF8) == 0xF0) - { - // 4-byte character: 11110xxx ... - incomplete = i < 4; - } - // else 1-byte character or invalid byte - break; - } + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - if (!incomplete) - { + // search stop word and delete it + if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); - bool is_stop_full = false; + bool send_text = true; - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL); - if (stop_pos != std::string::npos) - { - is_stop_full = true; - slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } - else - { - is_stop_full = false; - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL); + } else if (slot.has_next_token) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; } // check if there is any token to predict - if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) - { + if (send_text) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.n_sent_text += result.text_to_send.size(); // add the token to slot queue and cache + } else { + result.text_to_send = ""; } - slot.add_token_string(result); - if (slot.params.stream) - { + slot.add_token(result); + if (slot.params.stream) { send_partial_response(slot, result); } } - if (incomplete) - { + if (incomplete) { slot.has_next_token = true; } // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) - { - slot.stopped_limit = true; + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - LOG_VERBOSE("stopped by limit", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_decoded", slot.n_decoded}, - {"n_predict", slot.params.n_predict}, - }); + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); } - if (llama_token_is_eog(model, result.tok)) - { - slot.stopped_eos = true; + if (slot.has_new_line) { + // if we have already seen a new line, we stop after a certain time limit + if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + } + + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + } + + // if context shift is disabled, we stop when it reaches the context limit + if (slot.n_past >= slot.n_ctx) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - LOG_VERBOSE("eos token found", {}); + SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); } - auto n_ctx_train = llama_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && - slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) - { - LOG_WARNING("n_predict is not set and self-context extend is disabled." - " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", - { - {"id_slot", slot.id}, - {"params.n_predict", slot.params.n_predict}, - {"slot.n_prompt_tokens", slot.n_prompt_tokens}, - {"slot.n_decoded", slot.n_decoded}, - {"slot.n_predict", slot.n_predict}, - {"n_slots", params.n_parallel}, - {"slot.n_ctx", slot.n_ctx}, - {"n_ctx", n_ctx}, - {"n_ctx_train", n_ctx_train}, - {"ga_n", slot.ga_n}, - }); - slot.truncated = true; - slot.stopped_limit = true; + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + const auto n_ctx_train = llama_model_n_ctx_train(model); + + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction + + SLT_WRN(slot, + "n_predict (%d) is set for infinite generation. " + "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", + slot.params.n_predict, n_ctx_train); } - LOG_VERBOSE("next token", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"token", result.tok}, - {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }); + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); return slot.has_next_token; // continue } - json get_formated_generation(const server_slot &slot) const - { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = - eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { + size_t n_probs = slot.params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } - std::vector samplers_sequence; - samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); - for (const auto &sampler_type : slot.sparams.samplers_sequence) - { - samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type)); - } - - return json{{"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, - {"model", params.model_alias}, - {"seed", slot.sparams.seed}, - {"temperature", slot.sparams.temp}, - {"dynatemp_range", slot.sparams.dynatemp_range}, - {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, - {"frequency_penalty", slot.sparams.penalty_freq}, - {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, - {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict - {"n_keep", slot.params.n_keep}, - {"n_discard", slot.params.n_discard}, - {"ignore_eos", ignore_eos}, - {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"min_keep", slot.sparams.min_keep}, - {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence}}; - } - - void send_error(const server_task &task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) - { - send_error(task.id, task.id_multi, error, type); + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_token_to_piece(ctx, cur_p->data[i].id, special), + cur_p->data[i].p + }); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_token_to_piece(ctx, cur[i].id, special), + cur[i].p + }); + } + } } - void send_error(const server_slot &slot, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) - { - send_error(slot.id_task, slot.id_multi, error, type); + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); } - void send_error(const int id_task, const int id_multi, const std::string &error, - const enum error_type type = ERROR_TYPE_SERVER) - { - LOG_ERROR("task error", { - {"id_multi", id_multi}, - {"id_task", id_task}, - {"error", error}, - }); + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.id_task, error, type); + } - server_task_result res; - res.id = id_task; - res.id_multi = id_multi; - res.stop = false; - res.error = true; - res.data = format_error_response(error, type); + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - queue_results.send(res); + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; + + queue_results.send(std::move(res)); } - void send_partial_response(server_slot &slot, completion_token_output tkn) - { - server_task_result res; - res.id = slot.id_task; - res.id_multi = slot.id_multi; - res.error = false; - res.stop = false; - res.data = json{{"content", tkn.text_to_send}, {"stop", false}, {"id_slot", slot.id}, {"multimodal", false}}; - - if (slot.sparams.n_probs > 0) - { - const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); - const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); - const size_t probs_stop_pos = - std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); + void send_partial_response(server_slot & slot, const completion_token_output & tkn) { + auto res = std::make_unique(); - std::vector probs_output; - if (probs_pos < probs_stop_pos) - { - probs_output = - std::vector(slot.generated_token_probs.begin() + probs_pos, - slot.generated_token_probs.begin() + probs_stop_pos); - } - slot.n_sent_token_probs = probs_stop_pos; + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; + + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs } - if (slot.oaicompat) - { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + res->timings = slot.get_timings(); } - queue_results.send(res); + queue_results.send(std::move(res)); } - void send_final_response(const server_slot &slot) - { - server_task_result res; - res.id = slot.id_task; - res.id_multi = slot.id_multi; - res.error = false; - res.stop = true; - res.data = json{{"content", !slot.params.stream ? slot.generated_text : ""}, - {"id_slot", slot.id}, - {"stop", true}, - {"model", params.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.n_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()}}; - - if (slot.sparams.n_probs > 0) - { - std::vector probs; - if (!slot.params.stream && slot.stopped_word) - { - const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; + + res->index = slot.index; + res->content = std::move(slot.generated_text); + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->response_fields = std::move(slot.params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_chat_format = slot.params.oaicompat_chat_format; + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - probs = std::vector(slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - safe_offset); - } - else - { - probs = std::vector(slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); } - - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs); } - if (slot.oaicompat) - { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; - } + res->generation_params = slot.params; // copy the parameters - queue_results.send(res); + queue_results.send(std::move(res)); } - void send_embedding(const server_slot &slot, const llama_batch &batch) - { - server_task_result res; - res.id = slot.id_task; - res.id_multi = slot.id_multi; - res.error = false; - res.stop = true; + void send_embedding(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + res->oaicompat = slot.params.oaicompat; - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embd_res(n_embd, 0.0f); - for (int i = 0; i < batch.n_tokens; ++i) - { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) - { + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { continue; } - const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) - { + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } - if (embd == NULL) - { - LOG_ERROR("failed to get embeddings", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}}); - - res.data = json{ - {"embedding", std::vector(n_embd, 0.0f)}, - }; + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; } - llama_embd_normalize(embd, embd_res.data(), n_embd); - - res.data = json{ - {"embedding", embd_res}, - }; + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({ embd, embd + n_embd }); + } } - queue_results.send(res); + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); } - void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) - { - server_task task; - task.id = id_task; - task.id_multi = id_multi; - task.id_target = 0; - task.data = std::move(data); - task.infill = infill; - task.embedding = embedding; - task.type = SERVER_TASK_TYPE_COMPLETION; - - // when a completion task's prompt array is not a singleton, we split it into multiple requests - // otherwise, it's a single-prompt task, we actually queue it - // if there's numbers in the prompt array it will be treated as an array of tokens - if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) - { - bool numbers = false; - for (const auto &e : task.data.at("prompt")) - { - if (e.is_number()) - { - numbers = true; - break; - } - } + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; - // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers, - // it will completely stall the server. I don't know where the bug for this is. - // - // if there are numbers, it needs to be treated like a single prompt, - // queue_tasks handles a mix of strings and numbers just fine. - if (numbers) - { - queue_tasks.post(task); + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; } - else - { - split_multiprompt_task(id_task, task); + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); } - } - else - { - queue_tasks.post(task); - } - } - void request_cancel(int id_task) - { - server_task task; - task.type = SERVER_TASK_TYPE_CANCEL; - task.id_target = id_task; + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - queue_tasks.post(task); - } + res->score = -1e6; + continue; + } - void split_multiprompt_task(int id_multi, const server_task &multiprompt_task) - { - const int prompt_count = multiprompt_task.data.at("prompt").size(); - if (prompt_count <= 1) - { - send_error(multiprompt_task, "error while handling multiple prompts"); - return; + res->score = embd[0]; } - // generate all the ID for subtask - std::vector subtask_ids(prompt_count); - for (int i = 0; i < prompt_count; i++) - { - subtask_ids[i] = queue_tasks.get_new_id(); - } + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); - // queue up the multitask so we can track its subtask progression - queue_tasks.add_multitask(id_multi, subtask_ids); + queue_results.send(std::move(res)); + } - // add subtasks - for (int i = 0; i < prompt_count; i++) - { - json subtask_data = multiprompt_task.data; - subtask_data["prompt"] = subtask_data.at("prompt")[i]; + // + // Functions to create new task(s) and receive result(s) + // + + void cancel_tasks(const std::unordered_set & id_tasks) { + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); - // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, - multiprompt_task.embedding); + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(task); } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(cancel_tasks, true); } - void process_single_task(const server_task &task) - { - switch (task.type) - { - case SERVER_TASK_TYPE_COMPLETION: { - const int id_slot = json_value(task.data, "id_slot", -1); - - server_slot *slot; - - if (id_slot != -1) - { - slot = get_slot_by_id(id_slot); + // receive the results from task(s) + void receive_multi_results( + const std::unordered_set & id_tasks, + const std::function&)> & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + std::vector results(id_tasks.size()); + for (int i = 0; i < (int)id_tasks.size(); i++) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; } - else - { - std::string prompt; - if (task.data.contains("prompt") && task.data.at("prompt").is_string()) - { - prompt = json_value(task.data, "prompt", std::string()); - } - slot = get_available_slot(prompt); + if (result == nullptr) { + i--; // retry + continue; } - if (slot == nullptr) - { - // if no slot is available, we defer this task for processing later - LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } - if (!slot->available()) - { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; } - if (task.data.contains("system_prompt")) - { - std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); - system_prompt_set(sys_prompt); + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + const size_t idx = result->get_index(); + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = std::move(result); + } + result_handler(results); + } - for (server_slot &slot : slots) - { - slot.n_past = 0; - slot.n_past_se = 0; - } + // receive the results from task(s), in stream mode + void receive_cmpl_results_stream( + const std::unordered_set & id_tasks, + const std::function & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + size_t n_finished = 0; + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; } - slot->reset(); + if (result == nullptr) { + continue; // retry + } - slot->id_task = task.id; - slot->id_multi = task.id_multi; - slot->infill = task.infill; - slot->embedding = task.embedding; + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } - if (!launch_slot_with_task(*slot, task)) - { - LOG_ERROR("error while launching slot", task.data); + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (!result_handler(result)) { + cancel_tasks(id_tasks); break; } - } - break; - case SERVER_TASK_TYPE_CANCEL: { - // release slot linked with the task id - for (auto &slot : slots) - { - if (slot.id_task == task.id_target) - { - slot.release(); + + if (result->is_stop()) { + if (++n_finished == id_tasks.size()) { break; } } } - break; - case SERVER_TASK_TYPE_NEXT_RESPONSE: { - // do nothing - } - break; - case SERVER_TASK_TYPE_METRICS: { - json slots_data = json::array(); + } - int n_idle_slots = 0; - int n_processing_slots = 0; + // + // Functions to process the task + // - for (server_slot &slot : slots) - { - json slot_data = get_formated_generation(slot); - slot_data["id"] = slot.id; - slot_data["id_task"] = slot.id_task; - slot_data["state"] = slot.state; - slot_data["prompt"] = slot.prompt; - slot_data["next_token"] = { - {"has_next_token", slot.has_next_token}, {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }; - - if (slot_data["state"] == SLOT_STATE_IDLE) - { - n_idle_slots++; - } - else + void process_single_task(server_task task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: { - n_processing_slots++; - } - - slots_data.push_back(slot_data); - } - LOG_INFO( - "slot data", - {{"id_task", task.id}, {"n_idle_slots", n_idle_slots}, {"n_processing_slots", n_processing_slots}}); - - LOG_VERBOSE("slot data", {{"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots}, - {"slots", slots_data}}); - - server_task_result res; - res.id = task.id; - res.id_multi = task.id_multi; - res.stop = true; - res.error = false; - res.data = { - {"idle", n_idle_slots}, - {"processing", n_processing_slots}, - {"deferred", queue_tasks.queue_tasks_deferred.size()}, - {"t_start", metrics.t_start}, - - {"n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total}, - {"t_tokens_generation_total", metrics.t_tokens_generation_total}, - {"n_tokens_predicted_total", metrics.n_tokens_predicted_total}, - {"t_prompt_processing_total", metrics.t_prompt_processing_total}, - - {"n_prompt_tokens_processed", metrics.n_prompt_tokens_processed}, - {"t_prompt_processing", metrics.t_prompt_processing}, - {"n_tokens_predicted", metrics.n_tokens_predicted}, - {"t_tokens_generation", metrics.t_tokens_generation}, - - {"kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)}, - {"kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, - - {"slots", slots_data}, - }; - - if (json_value(task.data, "reset_bucket", false)) - { - metrics.reset_bucket(); - } - queue_results.send(res); - } - break; - case SERVER_TASK_TYPE_SLOT_SAVE: { - int id_slot = task.data.at("id_slot"); - server_slot *slot = get_slot_by_id(id_slot); - if (slot == nullptr) - { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (!slot->available()) - { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } + const int id_slot = task.id_selected_slot; - const size_t token_count = slot->cache_tokens.size(); - const int64_t t_start = ggml_time_us(); + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } - const size_t nwrite = - llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + if (!launch_slot_with_task(*slot, task)) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); - const int64_t t_end = ggml_time_us(); - const double t_save_ms = (t_end - t_start) / 1000.0; + int n_idle_slots = 0; + int n_processing_slots = 0; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{"id_slot", id_slot}, - {"filename", filename}, - {"n_saved", token_count}, // tokens saved - {"n_written", nwrite}, // bytes written - {"timings", {{"save_ms", t_save_ms}}}}; - queue_results.send(result); - } - break; - case SERVER_TASK_TYPE_SLOT_RESTORE: { - int id_slot = task.data.at("id_slot"); - server_slot *slot = get_slot_by_id(id_slot); - if (slot == nullptr) - { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (!slot->available()) - { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } + for (server_slot & slot : slots) { + json slot_data = slot.to_json(); - const int64_t t_start = ggml_time_us(); + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; + + res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); + res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } - slot->cache_tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), - slot->cache_tokens.size(), &token_count); - if (nread == 0) - { - slot->cache_tokens.resize(0); - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", - ERROR_TYPE_INVALID_REQUEST); - break; - } - slot->cache_tokens.resize(token_count); - - const int64_t t_end = ggml_time_us(); - const double t_restore_ms = (t_end - t_start) / 1000.0; - - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{"id_slot", id_slot}, - {"filename", filename}, - {"n_restored", token_count}, // tokens restored - {"n_read", nread}, // bytes read - {"timings", {{"restore_ms", t_restore_ms}}}}; - queue_results.send(result); - } - break; - case SERVER_TASK_TYPE_SLOT_ERASE: { - int id_slot = task.data.at("id_slot"); - server_slot *slot = get_slot_by_id(id_slot); - if (slot == nullptr) - { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (!slot->available()) - { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); - // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); - slot->cache_tokens.clear(); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{"id_slot", id_slot}, {"n_erased", n_erased}}; - queue_results.send(result); - } - break; - } - } + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); - void on_finish_multitask(const server_task_multi &multitask) - { - // all subtasks done == multitask is done - server_task_result result; - result.id = multitask.id; - result.stop = true; - result.error = false; - - // collect json results into one json result - std::vector result_jsons; - for (const auto &subres : multitask.results) - { - result_jsons.push_back(subres.data); - result.error = result.error && subres.error; - } - result.data = json{{"results", result_jsons}}; + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; - queue_results.send(result); - } + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } - void update_slots() - { - if (system_need_update) - { - system_prompt_update(); - } + const int64_t t_start = ggml_time_us(); - // release slots - for (auto &slot : slots) - { - if (slot.command == SLOT_COMMAND_RELEASE) - { - slot.state = SLOT_STATE_IDLE; - slot.command = SLOT_COMMAND_NONE; - slot.t_last_used = ggml_time_us(); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; - LOG_INFO("slot released", {{"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated}}); + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.resize(0); + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + slot->cache_tokens.resize(token_count); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } - queue_tasks.notify_slot_changed(); - } + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); + slot->cache_tokens.clear(); + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SET_LORA: + { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; } + } + void update_slots() { // check if all slots are idle { bool all_idle = true; - for (auto &slot : slots) - { - if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) - { + for (auto & slot : slots) { + if (slot.is_processing()) { all_idle = false; break; } } - if (all_idle) - { - LOG_INFO("all slots are idle", {}); - if (system_prompt.empty() && clean_kv_cache) - { + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + if (clean_kv_cache) { kv_cache_clear(); } @@ -2086,494 +2804,358 @@ struct server_context } { - LOG_VERBOSE("posting NEXT_RESPONSE", {}); - - server_task task; - task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; - task.id_target = -1; + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); queue_tasks.post(task); } // apply context-shift if needed // TODO: simplify and improve - for (server_slot &slot : slots) - { - if (slot.ga_n == 1) - { - if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) - { - // Shift context - const int n_keep = slot.params.n_keep + add_bos_token; - const int n_left = (int)system_tokens.size() + slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - - LOG_INFO("slot context shift", {{"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_keep", n_keep}, - {"n_left", n_left}, - {"n_discard", n_discard}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}}); - - llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, - -n_discard); - - if (slot.params.cache_prompt) - { - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) - { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; - } + for (server_slot & slot : slots) { + if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + continue; + } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); - } + // Shift context + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_left = slot.n_past - n_keep; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - slot.n_past -= n_discard; + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - slot.truncated = true; + llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + + if (slot.params.cache_prompt) { + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } + + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); } + + slot.n_past -= n_discard; + + slot.truncated = true; } } // start populating the batch for this iteration - llama_batch_clear(batch); + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + }; // frist, add sampled tokens from any ongoing sequences - for (auto &slot : slots) - { - if (slot.state == SLOT_STATE_IDLE) - { + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING) { continue; } - slot.i_batch = batch.n_tokens; + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } - const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + slot.i_batch = batch.n_tokens; - // TODO: we always have to take into account the "system_tokens" - // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, {slot.id + 1}, true); + common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - if (slot.params.cache_prompt) - { + if (slot.params.cache_prompt) { slot.cache_tokens.push_back(slot.sampled); } - LOG_VERBOSE("slot decode token", {{"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated}}); + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); } // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); + int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // track if this is an embedding or non-embedding batch - // if we've added sampled tokens above, we are in non-embedding mode - // -1: none, 0: non-embedding, 1: embedding - int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; - // next, batch any pending prompts without exceeding n_batch - if (params.cont_batching || batch.n_tokens == 0) - { - for (auto &slot : slots) - { - // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) - { - auto &prompt_tokens = slot.prompt_tokens; + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + } - // we haven't tokenized the prompt yet - do it now: - if (prompt_tokens.empty()) - { - LOG_VERBOSE("tokenizing prompt", {{"id_slot", slot.id}, {"id_task", slot.id_task}}); + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + auto & prompt_tokens = slot.prompt_tokens; + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - if (slot.infill) - { - const bool add_bos = llama_should_add_bos_token(model); - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) - { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - - auto prefix_tokens = tokenize(slot.params.input_prefix, false); - auto suffix_tokens = tokenize(slot.params.input_suffix, false); - - const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) - { - suffix_tokens.erase(suffix_tokens.begin()); - } + slot.n_past = 0; + slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; - prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); - suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); - auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; - auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; - if (add_bos) - { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + // print prompt tokens (for debugging) + if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - - const llama_token middle_token = llama_token_middle(model); - if (middle_token >= 0) - { - embd_inp.push_back(middle_token); + } else { + // all + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - - prompt_tokens = embd_inp; } - else - { - prompt_tokens = - tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt - } - - slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); - - LOG_VERBOSE("prompt tokenized", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), - prompt_tokens.cend())}, - }); // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) - { - LOG_INFO("empty prompt - releasing slot", - {{"id_slot", slot.id}, {"id_task", slot.id_task}}); + if (prompt_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; slot.release(); slot.print_timings(); send_final_response(slot); continue; } - if (slot.embedding) - { - // this prompt is too large to process - discard it - if (slot.n_prompt_tokens > n_ubatch) - { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; + if (slot.is_non_causal()) { + if (slot.n_prompt_tokens > n_ubatch) { slot.release(); - send_error(slot, "input is too large to process. increase the physical batch size", - ERROR_TYPE_SERVER); + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; } - } - else - { - if (slot.params.n_keep < 0) - { + + if (slot.n_prompt_tokens > slot.n_ctx) { + slot.release(); + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); + continue; + } + } else { + if (!params_base.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + // TODO: there should be a separate parameter that control prompt truncation + // context shift should be applied only during the generation phase + if (slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + continue; + } + } + if (slot.params.n_keep < 0) { slot.params.n_keep = slot.n_prompt_tokens; } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - // if input prompt is too big, truncate it (if group attention self-extend is disabled) - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) - { + // if input prompt is too big, truncate it + if (slot.n_prompt_tokens >= slot.n_ctx) { const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; - const int erased_blocks = - (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - std::vector new_tokens(prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); + llama_tokens new_tokens( + prompt_tokens.begin(), + prompt_tokens.begin() + slot.params.n_keep); - new_tokens.insert(new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + - erased_blocks * n_block_size, - prompt_tokens.end()); + new_tokens.insert( + new_tokens.end(), + prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.end()); prompt_tokens = std::move(new_tokens); slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("input truncated", - { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", - tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, - }); + SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - llama_sampling_reset(slot.ctx_sampling); + if (slot.params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); - if (!slot.params.cache_prompt) - { - slot.n_past_se = 0; - slot.ga_i = 0; - } - else - { - GGML_ASSERT(slot.ga_n == 1); + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + size_t head_c = slot.n_past; // cache + size_t head_p = slot.n_past; // current prompt - // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); + + while (head_c < slot.cache_tokens.size() && + head_p < prompt_tokens.size()) { - // push the prompt into the sampling context (do not apply grammar) - for (int i = 0; i < slot.n_past; ++i) - { - llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + size_t n_match = 0; + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c); + llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); } } } - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) - { + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { // we have to evaluate at least 1 token to generate logits. - LOG_INFO("we have to evaluate at least 1 token to generate logits", - {{"id_slot", slot.id}, {"id_task", slot.id_task}}); + SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); slot.n_past--; - if (slot.ga_i > 0) - { - slot.n_past_se--; - } } slot.n_prompt_tokens_processed = 0; } - if (slot.embedding) - { + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) - { + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; } } - // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.embedding ? 1 : 0; - if (batch_type == -1) - { - batch_type = slot_type; - } - else if (batch_type != slot_type) - { - continue; - } - // keep only the common part - int p0 = (int)system_tokens.size() + slot.n_past; - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) - { + if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); - p0 = (int)system_tokens.size(); - if (p0 != 0) - { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); - } - - // there is no common part left (except for the system prompt) + // there is no common part left slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(slot.ctx_sampling); } + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + // remove the non-common part from the cache slot.cache_tokens.resize(slot.n_past); - LOG_INFO("kv cache rm [p0, end)", {{"id_slot", slot.id}, {"id_task", slot.id_task}, {"p0", p0}}); - - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - - int32_t ga_i = slot.ga_i; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - // add prompt tokens for processing in the current batch - // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow - for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) - { - if (slot.ga_n != 1) - { - while (slot_npast >= ga_i + ga_w) - { - const int bd = (ga_w / ga_n) * (ga_n - 1); - slot_npast -= bd; - ga_i += ga_w / ga_n; - } - } + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // without pooling, we want to output the embeddings for all the tokens in the batch + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, - {slot.id + 1}, false); + common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); - if (slot.params.cache_prompt) - { + if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); } slot.n_prompt_tokens_processed++; - slot_npast++; + slot.n_past++; } - LOG_VERBOSE("prompt processing progress", - { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - {"progress", (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens}, - }); - - // entire prompt has been processed - start decoding new tokens - if (slot.n_past == slot.n_prompt_tokens) - { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + + // entire prompt has been processed + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + common_sampler_accept(slot.smpl, prompt_tokens[i], false); + } + // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; - - LOG_VERBOSE("prompt done", { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - }); + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); } } - if (batch.n_tokens >= n_batch) - { + if (batch.n_tokens >= n_batch) { break; } } } - if (batch.n_tokens == 0) - { - LOG_VERBOSE("no tokens to decode", {}); + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); return; } - LOG_VERBOSE("decoding batch", { - {"n_tokens", batch.n_tokens}, - }); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, batch_type == 1); + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + } // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) - { + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - for (auto &slot : slots) - { - if (slot.ga_n != 1) - { - // context extension via Self-Extend - // TODO: simplify and/or abstract this - while (slot.n_past_se >= slot.ga_i + slot.ga_w) - { - const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; - const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); - const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; - - LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, - slot.ga_i + ib * bd, slot.n_past_se + ib * bd); - LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, - slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, - (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, - slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, - slot.n_past_se + ib * bd + dd); - - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, - slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, - slot.n_past_se + ib * bd, dd); - - slot.n_past_se -= bd; - - slot.ga_i += slot.ga_w / slot.ga_n; - - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, - slot.ga_i); - } - - slot.n_past_se += n_tokens; - } - } - llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, + batch.pos + i, batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, - 0, - 0, // unused + batch.seq_id + i, + batch.logits + i, }; const int ret = llama_decode(ctx, batch_view); + metrics.on_decoded(slots); - if (ret != 0) - { - if (n_batch == 1 || ret < 0) - { + if (ret != 0) { + if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size - LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", - { - {"i", i}, - {"n_batch", ret}, - {"ret", ret}, - }); - for (auto &slot : slots) - { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; + SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + for (auto & slot : slots) { slot.release(); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); } @@ -2584,127 +3166,245 @@ struct server_context n_batch /= 2; i -= n_batch; - LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try " - "increasing it via the context size or enable defragmentation", - { - {"i", i}, - {"n_batch", n_batch}, - {"ret", ret}, - }); + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); continue; // continue loop of n_batch } - for (auto &slot : slots) - { - if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) - { + for (auto & slot : slots) { + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; // continue loop of slots } - // prompt evaluated for embedding - if (slot.embedding) - { - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } - completion_token_output result; - const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; - llama_sampling_accept(slot.ctx_sampling, ctx, id, true); + common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; - if (slot.n_decoded == 1) - { - slot.t_start_generation = ggml_time_us(); + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); } - llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false}; - result.tok = id; - - const size_t n_probs = std::min(cur_p.size, (size_t)slot.sparams.n_probs); - if (n_probs > 0) - { - const size_t n_valid = slot.ctx_sampling->n_valid; + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; - // Make sure at least n_probs top tokens are at the front of the vector: - if (slot.sparams.temp == 0.0f && n_probs > n_valid) - { - llama_sample_top_k(ctx, &cur_p, n_probs, 0); - } + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - if (slot.sparams.temp == 0.0f) - { - // With greedy sampling the probabilities have possibly not been calculated. - for (size_t i = 0; i < n_probs; ++i) - { - result.probs.push_back({cur_p.data[i].id, i == 0 ? 1.0f : 0.0f}); - } - } - else - { - for (size_t i = 0; i < n_probs; ++i) - { - result.probs.push_back({ - cur_p.data[i].id, - i >= n_valid - ? 0.0f - : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. - }); - } - } + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); } - if (!process_token(result, slot)) - { + if (!process_token(result, slot)) { + // release slot because of stop condition slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + continue; } + } - slot.i_batch = -1; + // do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + + continue; + } + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); } } - LOG_VERBOSE("run slots completed", {}); + SRV_DBG("%s", "run slots completed\n"); } - json model_meta() const - { - return json{ - {"vocab_type", llama_vocab_type(model)}, {"n_vocab", llama_n_vocab(model)}, - {"n_ctx_train", llama_n_ctx_train(model)}, {"n_embd", llama_n_embd(model)}, - {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)}, + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, }; } }; +static void common_params_handle_model_default( + std::string & model, + const std::string & model_url, + std::string & hf_repo, + std::string & hf_file, + const std::string & hf_token) { + if (!hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (hf_file.empty()) { + if (model.empty()) { + auto auto_detected = common_get_hf_file(hf_repo, hf_token); + if (auto_detected.first.empty() || auto_detected.second.empty()) { + exit(1); // built without CURL, error message already printed + } + hf_repo = auto_detected.first; + hf_file = auto_detected.second; + } else { + hf_file = model; + } + } + // make sure model path is present (for caching purposes) + if (model.empty()) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filename = hf_repo + "_" + hf_file; + // to make sure we don't have any slashes in the filename + string_replace_all(filename, "/", "_"); + model = fs_get_cache_file(filename); + } + } else if (!model_url.empty()) { + if (model.empty()) { + auto f = string_split(model_url, '#').front(); + f = string_split(f, '?').front(); + model = fs_get_cache_file(string_split(f, '/').back()); + } + } else if (model.empty()) { + model = DEFAULT_MODEL_PATH; + } +} + // parse the given jparams (see de.kherud.llama.args.ModelParameters#toString()) from JSON to the required C++ struct. -static void server_params_parse(json jparams, gpt_params ¶ms) +static void server_params_parse(json jparams, common_params ¶ms) { - gpt_params default_params; + common_params default_params; - params.seed = json_value(jparams, "seed", default_params.seed); - params.n_threads = json_value(jparams, "n_threads", default_params.n_threads); - params.n_threads_draft = json_value(jparams, "n_threads_draft", default_params.n_threads_draft); - params.n_threads_batch = json_value(jparams, "n_threads_batch", default_params.n_threads_batch); - params.n_threads_batch_draft = json_value(jparams, "n_threads_batch_draft", default_params.n_threads_batch_draft); + params.sampling.seed = json_value(jparams, "seed", default_params.sampling.seed); + params.cpuparams.n_threads = json_value(jparams, "n_threads", default_params.cpuparams.n_threads); + params.speculative.cpuparams.n_threads = json_value(jparams, "n_threads_draft", default_params.speculative.cpuparams.n_threads); + params.cpuparams_batch.n_threads = json_value(jparams, "n_threads_batch", default_params.cpuparams_batch.n_threads); + params.speculative.cpuparams_batch.n_threads = json_value(jparams, "n_threads_batch_draft", default_params.speculative.cpuparams_batch.n_threads ); params.n_predict = json_value(jparams, "n_predict", default_params.n_predict); params.n_ctx = json_value(jparams, "n_ctx", default_params.n_ctx); params.n_batch = json_value(jparams, "n_batch", default_params.n_batch); params.n_ubatch = json_value(jparams, "n_ubatch", default_params.n_ubatch); params.n_keep = json_value(jparams, "n_keep", default_params.n_keep); - params.n_draft = json_value(jparams, "n_draft", default_params.n_draft); + + params.speculative.n_max = json_value(jparams, "n_draft", default_params.speculative.n_max); + params.speculative.n_min = json_value(jparams, "n_draft_min", default_params.speculative.n_min); + params.n_chunks = json_value(jparams, "n_chunks", default_params.n_chunks); params.n_parallel = json_value(jparams, "n_parallel", default_params.n_parallel); params.n_sequences = json_value(jparams, "n_sequences", default_params.n_sequences); - params.p_split = json_value(jparams, "p_split", default_params.p_split); + params.speculative.p_split = json_value(jparams, "p_split", default_params.speculative.p_split); params.grp_attn_n = json_value(jparams, "grp_attn_n", default_params.grp_attn_n); params.grp_attn_w = json_value(jparams, "grp_attn_w", default_params.grp_attn_w); params.n_print = json_value(jparams, "n_print", default_params.n_print); @@ -2720,7 +3420,7 @@ static void server_params_parse(json jparams, gpt_params ¶ms) params.rope_scaling_type = json_value(jparams, "rope_scaling_type", default_params.rope_scaling_type); params.pooling_type = json_value(jparams, "pooling_type", default_params.pooling_type); params.model = json_value(jparams, "model", default_params.model); - params.model_draft = json_value(jparams, "model_draft", default_params.model_draft); + params.speculative.model = json_value(jparams, "model_draft", default_params.speculative.model); params.model_alias = json_value(jparams, "model_alias", default_params.model_alias); params.model_url = json_value(jparams, "model_url", default_params.model_url); params.hf_repo = json_value(jparams, "hf_repo", default_params.hf_repo); @@ -2734,17 +3434,16 @@ static void server_params_parse(json jparams, gpt_params ¶ms) params.lookup_cache_static = json_value(jparams, "lookup_cache_static", default_params.lookup_cache_static); params.lookup_cache_dynamic = json_value(jparams, "lookup_cache_dynamic", default_params.lookup_cache_dynamic); params.logits_file = json_value(jparams, "logits_file", default_params.logits_file); - params.lora_adapter = json_value(jparams, "lora_adapter", default_params.lora_adapter); + // params.lora_adapters = json_value(jparams, "lora_adapter", default_params.lora_adapters); params.embedding = json_value(jparams, "embedding", default_params.embedding); params.escape = json_value(jparams, "escape", default_params.escape); params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching); params.flash_attn = json_value(jparams, "flash_attn", default_params.flash_attn); params.input_prefix_bos = json_value(jparams, "input_prefix_bos", default_params.input_prefix_bos); - params.ignore_eos = json_value(jparams, "ignore_eos", default_params.ignore_eos); + params.sampling.ignore_eos = json_value(jparams, "ignore_eos", default_params.sampling.ignore_eos); params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap); params.use_mlock = json_value(jparams, "use_mlock", default_params.use_mlock); params.no_kv_offload = json_value(jparams, "no_kv_offload", default_params.no_kv_offload); - params.system_prompt = json_value(jparams, "system_prompt", default_params.system_prompt); params.chat_template = json_value(jparams, "chat_template", default_params.chat_template); if (jparams.contains("n_gpu_layers")) @@ -2752,13 +3451,13 @@ static void server_params_parse(json jparams, gpt_params ¶ms) if (llama_supports_gpu_offload()) { params.n_gpu_layers = json_value(jparams, "n_gpu_layers", default_params.n_gpu_layers); - params.n_gpu_layers_draft = json_value(jparams, "n_gpu_layers_draft", default_params.n_gpu_layers_draft); + params.speculative.n_gpu_layers = json_value(jparams, "n_gpu_layers_draft", default_params.speculative.n_gpu_layers); } else { - LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " - "See main README.md for information on enabling GPU BLAS support", - {{"n_gpu_layers", params.n_gpu_layers}}); + SRV_WRN("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " + "See main README.md for information on enabling GPU BLAS support: %s = %d", + "n_gpu_layers", params.n_gpu_layers); } } @@ -2789,7 +3488,7 @@ static void server_params_parse(json jparams, gpt_params ¶ms) } } #else - LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n", {}); + SRV_WRN("%s","llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n"); #endif // GGML_USE_CUDA } @@ -2798,9 +3497,9 @@ static void server_params_parse(json jparams, gpt_params ¶ms) #if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL) params.main_gpu = json_value(jparams, "main_gpu", default_params.main_gpu); #else - LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a main GPU.", {}); + SRV_WRN("%s","llama.cpp was compiled without CUDA. It is not possible to set a main GPU."); #endif } - gpt_params_handle_model_default(params); + common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); } diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index 7de7eac4..5ff886da 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -1,202 +1,389 @@ #pragma once #include "common.h" +#include "log.h" #include "llama.h" +#include "base64.hpp" + +#ifndef NDEBUG +// crash the server in debug mode, otherwise send an http 500 error +#define CPPHTTPLIB_NO_EXCEPTIONS 1 +#endif +// increase max payload length to allow use of larger context size +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +//#include "httplib.h" + +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT #include "json.hpp" +#include "chat.hpp" +#include "chat-template.hpp" + #include #include #include #include +#include -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" using json = nlohmann::ordered_json; -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type -{ - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error -}; - -extern bool log_json; -extern std::function log_callback; - -#if SERVER_VERBOSE -#define LOG_VERBOSE(MSG, ...) \ - do \ - { \ - server_log(GGML_LOG_LEVEL_DEBUG, __func__, __LINE__, MSG, __VA_ARGS__); \ - } while (0) -#else -#define LOG_VERBOSE(MSG, ...) -#endif +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define LOG_ERROR(MSG, ...) server_log(GGML_LOG_LEVEL_ERROR, __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) server_log(GGML_LOG_LEVEL_WARN, __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO(MSG, ...) server_log(GGML_LOG_LEVEL_INFO, __func__, __LINE__, MSG, __VA_ARGS__) +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -static inline void server_log(ggml_log_level level, const char *function, int line, const char *message, - const json &extra); +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -template static T json_value(const json &body, const std::string &key, const T &default_value) -{ +template +static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) - { - try - { + if (body.contains(key) && !body.at(key).is_null()) { + try { return body.at(key); - } - catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) - { - std::stringstream ss; - ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() - << "', using default value."; - LOG_WARNING(ss.str().c_str(), body); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name()); return default_value; } - } - else - { + } else { return default_value; } } -static const char *log_level_to_string(ggml_log_level level) -{ - switch (level) - { - case GGML_LOG_LEVEL_ERROR: - return "ERROR"; - case GGML_LOG_LEVEL_WARN: - return "WARN"; - default: - case GGML_LOG_LEVEL_INFO: - return "INFO"; - case GGML_LOG_LEVEL_DEBUG: - return "DEBUG"; +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + +// +// tokenizer and input processing utils +// + +static bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; } + return false; } -static inline void server_log(ggml_log_level level, const char *function, int line, const char *message, - const json &extra) -{ - std::stringstream ss_tid; - ss_tid << std::this_thread::get_id(); - - if (log_json) - { - json log = json{ - {"msg", message}, -#if SERVER_VERBOSE - {"ts", time(nullptr)}, {"level", log_level_to_string(level)}, {"tid", ss_tid.str()}, {"function", function}, - {"line", line}, -#endif - }; - - if (!extra.empty()) - { - log.merge_patch(extra); +// is array having BOTH numbers & strings? +static bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } } + } + return false; +} - auto dump = log.dump(-1, ' ', false, json::error_handler_t::replace); - if (log_callback == nullptr) - { - printf("%s\n", dump.c_str()); +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); + + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } } - else - { - log_callback(level, dump.c_str(), nullptr); + if (valid_path) { + result[path] = current; } } - else - { - std::stringstream ss; - ss << message; - - if (!extra.empty()) - { - for (const auto &el : extra.items()) - { - const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); - ss << " " << el.key() << "=" << value; + return result; +} + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; + + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); } } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); + } -#if SERVER_VERBOSE - ss << " | ts " << time(nullptr) << " | tid " << ss_tid.str() << " | " << function << " line " << line; -#endif + return prompt_tokens; +} - const std::string str = ss.str(); - if (log_callback == nullptr) - { - printf("[%4s] %.*s\n", log_level_to_string(level), (int)str.size(), str.data()); +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] + */ +static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + result.push_back(json_prompt.get()); + } else if (json_prompt.is_array()) { + // array of prompts + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { + result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); + } else if (json_is_array_of_numbers(p)) { + // array of tokens + result.push_back(p.get()); + } else { + throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); + } } - else - { - log_callback(level, str.c_str(), nullptr); + } else { + throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; +} + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; } } - fflush(stdout); + + // If no cut-off multi-byte character is found, return full length + return len; } // -// chat template utils +// template utils // -// Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model *model, const std::string &tmpl, - const std::vector &messages) -{ - std::vector chat; +// format rerank task: [BOS]query[EOS][SEP]doc[EOS] +static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) { + llama_tokens result; + + result.reserve(doc.size() + query.size() + 4); + result.push_back(llama_vocab_bos(vocab)); + result.insert(result.end(), query.begin(), query.end()); + result.push_back(llama_vocab_eos(vocab)); + result.push_back(llama_vocab_sep(vocab)); + result.insert(result.end(), doc.begin(), doc.end()); + result.push_back(llama_vocab_eos(vocab)); + + return result; +} + +// format infill task +static llama_tokens format_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); + + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } else { + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + } + + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + } - for (size_t i = 0; i < messages.size(); ++i) - { - const auto &curr_msg = messages[i]; + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); + + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); + + return embd_inp; +} + +/// Format given chat. If tmpl is empty, we take the template from model metadata +inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { + std::vector chat; + + for (size_t i = 0; i < messages.size(); ++i) { + const auto & curr_msg = messages[i]; std::string role = json_value(curr_msg, "role", std::string("")); std::string content; - if (curr_msg.contains("content")) - { - if (curr_msg["content"].is_string()) - { + if (curr_msg.contains("content")) { + if (curr_msg["content"].is_string()) { content = curr_msg["content"].get(); - } - else if (curr_msg["content"].is_array()) - { - for (const auto &part : curr_msg["content"]) - { - if (part.contains("text")) - { + } else if (curr_msg["content"].is_array()) { + for (const auto & part : curr_msg["content"]) { + if (part.contains("text")) { content += "\n" + part["text"].get(); } } + } else { + throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - else - { - throw std::runtime_error( - "Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); - } - } - else - { + } else { throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - chat.push_back({role, content}); + chat.push_back({role, content, /* tool_calls= */ {}}); } - auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true); - LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); + const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); + LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); + return formatted_chat; } @@ -204,17 +391,16 @@ inline std::string format_chat(const struct llama_model *model, const std::strin // base64 utils (TODO: move to common in the future) // -static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; -static inline bool is_base64(uint8_t c) -{ +static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string &encoded_string) -{ +static inline std::vector base64_decode(const std::string & encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -226,23 +412,18 @@ static inline std::vector base64_decode(const std::string &encoded_stri std::vector ret; - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) - { - char_array_4[i++] = encoded_string[in_]; - in_++; - if (i == 4) - { - for (i = 0; i < 4; i++) - { + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { char_array_4[i] = base64_chars.find(char_array_4[i]); } - char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (i = 0; (i < 3); i++) - { + for (i = 0; (i < 3); i++) { ret.push_back(char_array_3[i]); } @@ -250,24 +431,20 @@ static inline std::vector base64_decode(const std::string &encoded_stri } } - if (i) - { - for (j = i; j < 4; j++) - { + if (i) { + for (j = i; j < 4; j++) { char_array_4[j] = 0; } - for (j = 0; j < 4; j++) - { + for (j = 0; j < 4; j++) { char_array_4[j] = base64_chars.find(char_array_4[j]); } - char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (j = 0; j < i - 1; j++) - { + for (j = 0; j < i - 1; j++) { ret.push_back(char_array_3[j]); } } @@ -279,8 +456,7 @@ static inline std::vector base64_decode(const std::string &encoded_stri // random string / id // -static std::string random_string() -{ +static std::string random_string() { static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); std::random_device rd; @@ -288,63 +464,32 @@ static std::string random_string() std::string result(32, ' '); - for (int i = 0; i < 32; ++i) - { + for (int i = 0; i < 32; ++i) { result[i] = str[generator() % str.size()]; } return result; } -static std::string gen_chatcmplid() -{ - std::stringstream chatcmplid; - chatcmplid << "chatcmpl-" << random_string(); - - return chatcmplid.str(); +static std::string gen_chatcmplid() { + return "chatcmpl-" + random_string(); } // // other common utils // -static size_t common_part(const std::vector &a, const std::vector &b) -{ - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) - { - } - - return i; -} - -static size_t common_part(const std::string &a, const std::string &b) -{ - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) - { - } - - return i; -} - -static bool ends_with(const std::string &str, const std::string &suffix) -{ +static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } -static size_t find_partial_stop_string(const std::string &stop, const std::string &text) -{ - if (!text.empty() && !stop.empty()) - { +static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { + if (!text.empty() && !stop.empty()) { const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) - { - if (stop[char_index] == text_last_char) - { + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) - { + if (ends_with(text, current_partial)) { return text.size() - char_index - 1; } } @@ -355,26 +500,23 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin } // TODO: reuse llama_detokenize -template static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) -{ +template +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { std::string ret; - for (; begin != end; ++begin) - { - ret += llama_token_to_piece(ctx, *begin); + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); } return ret; } // format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) -{ - std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); +static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); // if the size is 1 and first bit is 1, meaning it's a partial character // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) - { + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { std::stringstream ss; ss << std::hex << (out[0] & 0xff); std::string res(ss.str()); @@ -384,126 +526,160 @@ static std::string tokens_to_output_formatted_string(const llama_context *ctx, c return out; } -struct completion_token_output -{ - llama_token tok; - std::string text_to_send; +// +// OAI utils +// - struct token_prob - { - llama_token tok; - float prob; - }; +static json oaicompat_completion_params_parse(const json & body) { + json llama_params; - std::vector probs; -}; + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } -// convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context *ctx, const std::vector &probs) -{ - json out = json::array(); - - for (const auto &prob : probs) - { - json probs_for_token = json::array(); - - for (const auto &p : prob.probs) - { - const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); - probs_for_token.push_back(json{ - {"tok_str", tok_str}, - {"prob", p.prob}, - }); + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params { "best_of", "echo", "suffix" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); } + } - const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); - out.push_back(json{ - {"content", tok_str}, - {"probs", probs_for_token}, - }); + // Copy remaining properties to llama_params + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } } - return out; + return llama_params; } -// -// OAI utils -// - -static json oaicompat_completion_params_parse(const struct llama_model *model, - const json &body, /* openai api json semantics */ - const std::string &chat_template) +static json oaicompat_completion_params_parse( + const json & body, /* openai api json semantics */ + bool use_jinja, + const common_chat_templates & chat_templates) { json llama_params; + const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use + ? *chat_templates.template_tool_use + : *chat_templates.template_default; - llama_params["__oaicompat"] = true; + auto tools = json_value(body, "tools", json()); + auto stream = json_value(body, "stream", false); - // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + if (tools.is_array() && !tools.empty()) { + if (stream) { + throw std::runtime_error("Cannot use tools with stream"); + } + if (!use_jinja) { + throw std::runtime_error("tools param requires --jinja flag"); + } + } + if (!use_jinja) { + if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) { + throw std::runtime_error("Unsupported param: tool_choice"); + } + } // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) - { + if (body.contains("stop") && body.at("stop").is_string()) { llama_params["stop"] = json::array({body.at("stop").get()}); - } - else - { + } else { llama_params["stop"] = json_value(body, "stop", json::array()); } // Handle "response_format" field - if (body.contains("response_format")) - { - json response_format = json_value(body, "response_format", json::object()); + if (body.contains("response_format")) { + json response_format = json_value(body, "response_format", json::object()); std::string response_type = json_value(response_format, "type", std::string()); - if (response_type == "json_object") - { + if (response_type == "json_object") { llama_params["json_schema"] = json_value(response_format, "schema", json::object()); + } else if (response_type == "json_schema") { + json json_schema = json_value(response_format, "json_schema", json::object()); + llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); + } else if (!response_type.empty() && response_type != "text") { + throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); + } + } + + // Apply chat template to the list of messages + if (use_jinja) { + auto tool_choice = json_value(body, "tool_choice", std::string("auto")); + if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { + throw std::runtime_error("Invalid tool_choice: " + tool_choice); + } + if (tool_choice != "none" && llama_params.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + common_chat_inputs inputs; + inputs.messages = body.at("messages"); + inputs.tools = tools; + inputs.tool_choice = tool_choice; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { + LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); + inputs.parallel_tool_calls = false; } - else if (!response_type.empty() && response_type != "text") - { - throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + - response_type); + inputs.stream = stream; + // TODO: support mixing schema w/ tools beyond generic format. + inputs.json_schema = json_value(llama_params, "json_schema", json()); + auto chat_params = common_chat_params_init(tmpl, inputs); + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back({ + {"word", trigger.word}, + {"at_start", trigger.at_start}, + }); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); } + } else { + llama_params["prompt"] = format_chat(tmpl, body.at("messages")); } // Handle "n" field int n_choices = json_value(body, "n", 1); - if (n_choices != 1) - { + if (n_choices != 1) { throw std::runtime_error("Only one completion choice is allowed"); } // Handle "logprobs" field - // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may - // need to fix it in the future - if (body.contains("logprobs")) - { + // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future + if (json_value(body, "logprobs", false)) { llama_params["n_probs"] = json_value(body, "top_logprobs", 20); - } - else if (body.contains("top_logprobs")) - { + } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { throw std::runtime_error("top_logprobs requires logprobs to be set to true"); } - // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params{"tools", "tool_choice"}; - for (auto ¶m : unsupported_params) - { - if (body.contains(param)) - { - throw std::runtime_error("Unsupported param: " + param); - } - } - // Copy remaining properties to llama_params - // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint. + // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp - for (const auto &item : body.items()) - { + for (const auto & item : body.items()) { // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" - if (!llama_params.contains(item.key()) || item.key() == "n_predict") - { + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { llama_params[item.key()] = item.value(); } } @@ -511,219 +687,205 @@ static json oaicompat_completion_params_parse(const struct llama_model *model, return llama_params; } -static json format_final_response_oaicompat(const json &request, json result, const std::string &completion_id, - bool streaming = false) -{ - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason = "length"; - if (stopped_word || stopped_eos) - { - finish_reason = "stop"; - } - - json choices = streaming - ? json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, {"role", "assistant"}}}}}); - - std::time_t t = std::time(0); - - json res = json{{"choices", choices}, - {"created", t}, - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", completion_id}}; - -#if SERVER_VERBOSE - res["__verbose"] = result; -#endif +static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto & elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); - if (result.contains("completion_probabilities")) - { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); + n_tokens += json_value(elem, "tokens_evaluated", 0); } + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"data", data} + }; + return res; } -// return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(json result, const std::string &completion_id) -{ - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) - { - return std::vector({result}); +static json format_response_rerank(const json & request, const json & ranks) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto & rank : ranks) { + data.push_back(json{ + {"index", i++}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); } - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", data} + }; - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); + return res; +} - std::string finish_reason; - if (stopped_word || stopped_eos) - { - finish_reason = "stop"; - } - if (stopped_limit) - { - finish_reason = "length"; +static bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; + } } - std::time_t t = std::time(0); - - json choices; + return true; +} - if (!finish_reason.empty()) - { - choices = json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}}); - } - else - { - if (first) - { - if (content.empty()) - { - choices = json::array( - {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}}); - } - else - { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = - json{{"choices", - json::array( - {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"content", content}}}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } - } - else - { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) - { - return std::vector({json::object()}); - } +static json format_tokenizer_response(const json & tokens) { + return json { + {"tokens", tokens} + }; +} - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); - } - } +static json format_detokenized_response(const std::string & content) { + return json { + {"content", content} + }; +} - json ret = json{{"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - if (!finish_reason.empty()) - { - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - ret.push_back({"usage", json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}); +static json format_logit_bias(const std::vector & logit_bias) { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); } + return data; +} - return std::vector({ret}); +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); } -static json format_embeddings_response_oaicompat(const json &request, const json &embeddings) -{ - json data = json::array(); - int i = 0; - for (auto &elem : embeddings) - { - data.push_back( - json{{"embedding", json_value(elem, "embedding", json::array())}, {"index", i++}, {"object", "embedding"}}); +static std::vector get_token_probabilities(llama_context * ctx, int idx) { + std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - json res = json{{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json{{"prompt_tokens", 0}, {"total_tokens", 0}}}, - {"data", data}}; + // sort tokens by logits + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); - return res; -} + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } -static json format_tokenizer_response(const std::vector &tokens) -{ - return json{{"tokens", tokens}}; + return cur; } -static json format_detokenized_response(const std::string &content) -{ - return json{{"content", content}}; +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; } -static json format_error_response(const std::string &message, const enum error_type type) -{ - std::string type_str; - int code = 500; - switch (type) - { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - } - return json{ - {"code", code}, - {"message", message}, - {"type", type_str}, - }; +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; } diff --git a/src/main/java/de/kherud/llama/CliParameters.java b/src/main/java/de/kherud/llama/CliParameters.java new file mode 100644 index 00000000..4142628e --- /dev/null +++ b/src/main/java/de/kherud/llama/CliParameters.java @@ -0,0 +1,40 @@ +package de.kherud.llama; + +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +abstract class CliParameters { + + final Map parameters = new HashMap<>(); + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + for (String key : parameters.keySet()) { + String value = parameters.get(key); + builder.append(key).append(" "); + if (value != null) { + builder.append(value).append(" "); + } + } + return builder.toString(); + } + + public String[] toArray() { + List result = new ArrayList<>(); + result.add(""); // c args contain the program name as the first argument, so we add an empty entry + for (String key : parameters.keySet()) { + result.add(key); + String value = parameters.get(key); + if (value != null) { + result.add(value); + } + } + return result.toArray(new String[0]); + } + +} diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index d2698753..2c494c8c 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -459,12 +459,6 @@ public InferenceParameters setSamplers(Sampler... samplers) { case TOP_K: builder.append("\"top_k\""); break; - case TFS_Z: - builder.append("\"tfs_z\""); - break; - case TYPICAL_P: - builder.append("\"typical_p\""); - break; case TOP_P: builder.append("\"top_p\""); break; diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index b78e056e..1e8878c0 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -16,7 +16,7 @@ *
    *
  • Streaming answers (and probabilities) via {@link #generate(InferenceParameters)}
  • *
  • Creating whole responses to prompts via {@link #complete(InferenceParameters)}
  • - *
  • Creating embeddings via {@link #embed(String)} (make sure to configure {@link ModelParameters#setEmbedding(boolean)}
  • + *
  • Creating embeddings via {@link #embed(String)} (make sure to configure {@link ModelParameters#enableEmbedding()}
  • *
  • Accessing the tokenizer via {@link #encode(String)} and {@link #decode(int[])}
  • *
*/ @@ -32,16 +32,16 @@ public class LlamaModel implements AutoCloseable { /** * Load with the given {@link ModelParameters}. Make sure to either set *
    - *
  • {@link ModelParameters#setModelFilePath(String)}
  • + *
  • {@link ModelParameters#setModel(String)}
  • *
  • {@link ModelParameters#setModelUrl(String)}
  • - *
  • {@link ModelParameters#setHuggingFaceRepository(String)}}, {@link ModelParameters#setHuggingFaceFile(String)}
  • + *
  • {@link ModelParameters#setHfRepo(String)}, {@link ModelParameters#setHfFile(String)}
  • *
* * @param parameters the set of options * @throws LlamaException if no model could be loaded from the given file path */ public LlamaModel(ModelParameters parameters) { - loadModel(parameters.toString()); + loadModel(parameters.toArray()); } /** @@ -66,17 +66,19 @@ public String complete(InferenceParameters parameters) { public LlamaIterable generate(InferenceParameters parameters) { return () -> new LlamaIterator(this, parameters); } - + + + /** * Get the embedding of a string. Note, that the prompt isn't preprocessed in any way, nothing like * "User: ", "###Instruction", etc. is added. * * @param prompt the string to embed * @return an embedding float array - * @throws IllegalStateException if embedding mode was not activated (see - * {@link ModelParameters#setEmbedding(boolean)}) + * @throws IllegalStateException if embedding mode was not activated (see {@link ModelParameters#enableEmbedding()}) */ - public native float[] embed(String prompt); + public native float[] embed(String prompt); + /** * Tokenize a prompt given the native tokenizer @@ -124,7 +126,7 @@ public void close() { native byte[] decodeBytes(int[] tokens); - private native void loadModel(String parameters) throws LlamaException; + private native void loadModel(String... parameters) throws LlamaException; private native void delete(); diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index 3b34d3f3..91587001 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -1,557 +1,954 @@ package de.kherud.llama; -import java.util.Map; - -import de.kherud.llama.args.GpuSplitMode; -import de.kherud.llama.args.NumaStrategy; -import de.kherud.llama.args.PoolingType; -import de.kherud.llama.args.RopeScalingType; +import de.kherud.llama.args.*; /*** * Parameters used for initializing a {@link LlamaModel}. */ -public final class ModelParameters extends JsonParameters { - - private static final String PARAM_SEED = "seed"; - private static final String PARAM_N_THREADS = "n_threads"; - private static final String PARAM_N_THREADS_DRAFT = "n_threads_draft"; - private static final String PARAM_N_THREADS_BATCH = "n_threads_batch"; - private static final String PARAM_N_THREADS_BATCH_DRAFT = "n_threads_batch_draft"; - private static final String PARAM_N_PREDICT = "n_predict"; - private static final String PARAM_N_CTX = "n_ctx"; - private static final String PARAM_N_BATCH = "n_batch"; - private static final String PARAM_N_UBATCH = "n_ubatch"; - private static final String PARAM_N_KEEP = "n_keep"; - private static final String PARAM_N_DRAFT = "n_draft"; - private static final String PARAM_N_CHUNKS = "n_chunks"; - private static final String PARAM_N_PARALLEL = "n_parallel"; - private static final String PARAM_N_SEQUENCES = "n_sequences"; - private static final String PARAM_P_SPLIT = "p_split"; - private static final String PARAM_N_GPU_LAYERS = "n_gpu_layers"; - private static final String PARAM_N_GPU_LAYERS_DRAFT = "n_gpu_layers_draft"; - private static final String PARAM_SPLIT_MODE = "split_mode"; - private static final String PARAM_MAIN_GPU = "main_gpu"; - private static final String PARAM_TENSOR_SPLIT = "tensor_split"; - private static final String PARAM_GRP_ATTN_N = "grp_attn_n"; - private static final String PARAM_GRP_ATTN_W = "grp_attn_w"; - private static final String PARAM_ROPE_FREQ_BASE = "rope_freq_base"; - private static final String PARAM_ROPE_FREQ_SCALE = "rope_freq_scale"; - private static final String PARAM_YARN_EXT_FACTOR = "yarn_ext_factor"; - private static final String PARAM_YARN_ATTN_FACTOR = "yarn_attn_factor"; - private static final String PARAM_YARN_BETA_FAST = "yarn_beta_fast"; - private static final String PARAM_YARN_BETA_SLOW = "yarn_beta_slow"; - private static final String PARAM_YARN_ORIG_CTX = "yarn_orig_ctx"; - private static final String PARAM_DEFRAG_THOLD = "defrag_thold"; - private static final String PARAM_NUMA = "numa"; - private static final String PARAM_ROPE_SCALING_TYPE = "rope_scaling_type"; - private static final String PARAM_POOLING_TYPE = "pooling_type"; - private static final String PARAM_MODEL = "model"; - private static final String PARAM_MODEL_DRAFT = "model_draft"; - private static final String PARAM_MODEL_ALIAS = "model_alias"; - private static final String PARAM_MODEL_URL = "model_url"; - private static final String PARAM_HF_REPO = "hf_repo"; - private static final String PARAM_HF_FILE = "hf_file"; - private static final String PARAM_LOOKUP_CACHE_STATIC = "lookup_cache_static"; - private static final String PARAM_LOOKUP_CACHE_DYNAMIC = "lookup_cache_dynamic"; - private static final String PARAM_LORA_ADAPTER = "lora_adapter"; - private static final String PARAM_EMBEDDING = "embedding"; - private static final String PARAM_CONT_BATCHING = "cont_batching"; - private static final String PARAM_FLASH_ATTENTION = "flash_attn"; - private static final String PARAM_INPUT_PREFIX_BOS = "input_prefix_bos"; - private static final String PARAM_IGNORE_EOS = "ignore_eos"; - private static final String PARAM_USE_MMAP = "use_mmap"; - private static final String PARAM_USE_MLOCK = "use_mlock"; - private static final String PARAM_NO_KV_OFFLOAD = "no_kv_offload"; - private static final String PARAM_SYSTEM_PROMPT = "system_prompt"; - private static final String PARAM_CHAT_TEMPLATE = "chat_template"; - - /** - * Set the RNG seed - */ - public ModelParameters setSeed(int seed) { - parameters.put(PARAM_SEED, String.valueOf(seed)); - return this; - } - - /** - * Set the number of threads to use during generation (default: 8) - */ - public ModelParameters setNThreads(int nThreads) { - parameters.put(PARAM_N_THREADS, String.valueOf(nThreads)); - return this; - } - - /** - * Set the number of threads to use during draft generation (default: same as {@link #setNThreads(int)}) - */ - public ModelParameters setNThreadsDraft(int nThreadsDraft) { - parameters.put(PARAM_N_THREADS_DRAFT, String.valueOf(nThreadsDraft)); - return this; - } - - /** - * Set the number of threads to use during batch and prompt processing (default: same as {@link #setNThreads(int)}) - */ - public ModelParameters setNThreadsBatch(int nThreadsBatch) { - parameters.put(PARAM_N_THREADS_BATCH, String.valueOf(nThreadsBatch)); - return this; - } - - /** - * Set the number of threads to use during batch and prompt processing (default: same as - * {@link #setNThreadsDraft(int)}) - */ - public ModelParameters setNThreadsBatchDraft(int nThreadsBatchDraft) { - parameters.put(PARAM_N_THREADS_BATCH_DRAFT, String.valueOf(nThreadsBatchDraft)); - return this; - } - - /** - * Set the number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled) - */ - public ModelParameters setNPredict(int nPredict) { - parameters.put(PARAM_N_PREDICT, String.valueOf(nPredict)); - return this; - } - - /** - * Set the size of the prompt context (default: 512, 0 = loaded from model) - */ - public ModelParameters setNCtx(int nCtx) { - parameters.put(PARAM_N_CTX, String.valueOf(nCtx)); - return this; - } - - /** - * Set the logical batch size for prompt processing (must be >=32 to use BLAS) - */ - public ModelParameters setNBatch(int nBatch) { - parameters.put(PARAM_N_BATCH, String.valueOf(nBatch)); - return this; - } - - /** - * Set the physical batch size for prompt processing (must be >=32 to use BLAS) - */ - public ModelParameters setNUbatch(int nUbatch) { - parameters.put(PARAM_N_UBATCH, String.valueOf(nUbatch)); - return this; - } - - /** - * Set the number of tokens to keep from the initial prompt (default: 0, -1 = all) - */ - public ModelParameters setNKeep(int nKeep) { - parameters.put(PARAM_N_KEEP, String.valueOf(nKeep)); - return this; - } - - /** - * Set the number of tokens to draft for speculative decoding (default: 5) - */ - public ModelParameters setNDraft(int nDraft) { - parameters.put(PARAM_N_DRAFT, String.valueOf(nDraft)); - return this; - } - - /** - * Set the maximal number of chunks to process (default: -1, -1 = all) - */ - public ModelParameters setNChunks(int nChunks) { - parameters.put(PARAM_N_CHUNKS, String.valueOf(nChunks)); - return this; - } - - /** - * Set the number of parallel sequences to decode (default: 1) - */ - public ModelParameters setNParallel(int nParallel) { - parameters.put(PARAM_N_PARALLEL, String.valueOf(nParallel)); - return this; - } - - /** - * Set the number of sequences to decode (default: 1) - */ - public ModelParameters setNSequences(int nSequences) { - parameters.put(PARAM_N_SEQUENCES, String.valueOf(nSequences)); - return this; - } - - /** - * Set the speculative decoding split probability (default: 0.1) - */ - public ModelParameters setPSplit(float pSplit) { - parameters.put(PARAM_P_SPLIT, String.valueOf(pSplit)); - return this; - } - - /** - * Set the number of layers to store in VRAM (-1 - use default) - */ - public ModelParameters setNGpuLayers(int nGpuLayers) { - parameters.put(PARAM_N_GPU_LAYERS, String.valueOf(nGpuLayers)); - return this; - } - - /** - * Set the number of layers to store in VRAM for the draft model (-1 - use default) - */ - public ModelParameters setNGpuLayersDraft(int nGpuLayersDraft) { - parameters.put(PARAM_N_GPU_LAYERS_DRAFT, String.valueOf(nGpuLayersDraft)); - return this; - } - - /** - * Set how to split the model across GPUs - */ - public ModelParameters setSplitMode(GpuSplitMode splitMode) { -// switch (splitMode) { -// case NONE: parameters.put(PARAM_SPLIT_MODE, "\"none\""); break; -// case ROW: parameters.put(PARAM_SPLIT_MODE, "\"row\""); break; -// case LAYER: parameters.put(PARAM_SPLIT_MODE, "\"layer\""); break; -// } - parameters.put(PARAM_SPLIT_MODE, String.valueOf(splitMode.ordinal())); - return this; - } - - /** - * Set the GPU that is used for scratch and small tensors - */ - public ModelParameters setMainGpu(int mainGpu) { - parameters.put(PARAM_MAIN_GPU, String.valueOf(mainGpu)); - return this; - } - - /** - * Set how split tensors should be distributed across GPUs - */ - public ModelParameters setTensorSplit(float[] tensorSplit) { - if (tensorSplit.length > 0) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - for (int i = 0; i < tensorSplit.length; i++) { - builder.append(tensorSplit[i]); - if (i < tensorSplit.length - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_TENSOR_SPLIT, builder.toString()); - } - return this; - } - - /** - * Set the group-attention factor (default: 1) - */ - public ModelParameters setGrpAttnN(int grpAttnN) { - parameters.put(PARAM_GRP_ATTN_N, String.valueOf(grpAttnN)); - return this; - } - - /** - * Set the group-attention width (default: 512.0) - */ - public ModelParameters setGrpAttnW(int grpAttnW) { - parameters.put(PARAM_GRP_ATTN_W, String.valueOf(grpAttnW)); - return this; - } - - /** - * Set the RoPE base frequency, used by NTK-aware scaling (default: loaded from model) - */ - public ModelParameters setRopeFreqBase(float ropeFreqBase) { - parameters.put(PARAM_ROPE_FREQ_BASE, String.valueOf(ropeFreqBase)); - return this; - } - - /** - * Set the RoPE frequency scaling factor, expands context by a factor of 1/N - */ - public ModelParameters setRopeFreqScale(float ropeFreqScale) { - parameters.put(PARAM_ROPE_FREQ_SCALE, String.valueOf(ropeFreqScale)); - return this; - } - - /** - * Set the YaRN extrapolation mix factor (default: 1.0, 0.0 = full interpolation) - */ - public ModelParameters setYarnExtFactor(float yarnExtFactor) { - parameters.put(PARAM_YARN_EXT_FACTOR, String.valueOf(yarnExtFactor)); - return this; - } - - /** - * Set the YaRN scale sqrt(t) or attention magnitude (default: 1.0) - */ - public ModelParameters setYarnAttnFactor(float yarnAttnFactor) { - parameters.put(PARAM_YARN_ATTN_FACTOR, String.valueOf(yarnAttnFactor)); - return this; - } - - /** - * Set the YaRN low correction dim or beta (default: 32.0) - */ - public ModelParameters setYarnBetaFast(float yarnBetaFast) { - parameters.put(PARAM_YARN_BETA_FAST, String.valueOf(yarnBetaFast)); - return this; - } - - /** - * Set the YaRN high correction dim or alpha (default: 1.0) - */ - public ModelParameters setYarnBetaSlow(float yarnBetaSlow) { - parameters.put(PARAM_YARN_BETA_SLOW, String.valueOf(yarnBetaSlow)); - return this; - } - - /** - * Set the YaRN original context size of model (default: 0 = model training context size) - */ - public ModelParameters setYarnOrigCtx(int yarnOrigCtx) { - parameters.put(PARAM_YARN_ORIG_CTX, String.valueOf(yarnOrigCtx)); - return this; - } - - /** - * Set the KV cache defragmentation threshold (default: -1.0, < 0 - disabled) - */ - public ModelParameters setDefragmentationThreshold(float defragThold) { - parameters.put(PARAM_DEFRAG_THOLD, String.valueOf(defragThold)); - return this; - } - - /** - * Set optimization strategies that help on some NUMA systems (if available) - *
    - *
  • distribute: spread execution evenly over all nodes
  • - *
  • isolate: only spawn threads on CPUs on the node that execution started on
  • - *
  • numactl: use the CPU map provided by numactl
  • - *
- * If run without this previously, it is recommended to drop the system page cache before using this - * (see #1437). - */ - public ModelParameters setNuma(NumaStrategy numa) { -// switch (numa) { -// case DISTRIBUTE: -// parameters.put(PARAM_NUMA, "\"distribute\""); -// break; -// case ISOLATE: -// parameters.put(PARAM_NUMA, "\"isolate\""); -// break; -// case NUMA_CTL: -// parameters.put(PARAM_NUMA, "\"numactl\""); -// break; -// case MIRROR: -// parameters.put(PARAM_NUMA, "\"mirror\""); -// break; -// } - parameters.put(PARAM_NUMA, String.valueOf(numa.ordinal())); - return this; - } - - /** - * Set the RoPE frequency scaling method, defaults to linear unless specified by the model - */ - public ModelParameters setRopeScalingType(RopeScalingType ropeScalingType) { -// switch (ropeScalingType) { -// case LINEAR: -// parameters.put(PARAM_ROPE_SCALING_TYPE, "\"linear\""); -// break; -// case YARN: -// parameters.put(PARAM_ROPE_SCALING_TYPE, "\"yarn\""); -// break; -// } - parameters.put(PARAM_ROPE_SCALING_TYPE, String.valueOf(ropeScalingType.ordinal())); - return this; - } - - /** - * Set the pooling type for embeddings, use model default if unspecified - */ - public ModelParameters setPoolingType(PoolingType poolingType) { -// switch (poolingType) { -// case MEAN: -// parameters.put(PARAM_POOLING_TYPE, "\"mean\""); -// break; -// case CLS: -// parameters.put(PARAM_POOLING_TYPE, "\"cls\""); -// break; -// } - parameters.put(PARAM_POOLING_TYPE, String.valueOf(poolingType.ordinal())); - return this; - } - - /** - * Set the model file path to load (default: models/7B/ggml-model-f16.gguf) - */ - public ModelParameters setModelFilePath(String model) { - parameters.put(PARAM_MODEL, toJsonString(model)); - return this; - } - - /** - * Set the draft model for speculative decoding (default: unused) - */ - public ModelParameters setModelDraft(String modelDraft) { - parameters.put(PARAM_MODEL_DRAFT, toJsonString(modelDraft)); - return this; - } - - /** - * Set a model alias - */ - public ModelParameters setModelAlias(String modelAlias) { - parameters.put(PARAM_MODEL_ALIAS, toJsonString(modelAlias)); - return this; - } - - /** - * Set a URL to download a model from (default: unused). - * Note, that this requires the library to be built with CURL (-DLLAMA_CURL=ON). - */ - public ModelParameters setModelUrl(String modelUrl) { - parameters.put(PARAM_MODEL_URL, toJsonString(modelUrl)); - return this; - } - - /** - * Set a Hugging Face model repository to use a model from (default: unused, see - * {@link #setHuggingFaceFile(String)}) - */ - public ModelParameters setHuggingFaceRepository(String hfRepo) { - parameters.put(PARAM_HF_REPO, toJsonString(hfRepo)); - return this; - } - - /** - * Set a Hugging Face model file to use (default: unused, see {@link #setHuggingFaceRepository(String)}) - */ - public ModelParameters setHuggingFaceFile(String hfFile) { - parameters.put(PARAM_HF_FILE, toJsonString(hfFile)); - return this; - } - - /** - * Set path to static lookup cache to use for lookup decoding (not updated by generation) - */ - public ModelParameters setLookupCacheStaticFilePath(String lookupCacheStatic) { - parameters.put(PARAM_LOOKUP_CACHE_STATIC, toJsonString(lookupCacheStatic)); - return this; - } - - /** - * Set path to dynamic lookup cache to use for lookup decoding (updated by generation) - */ - public ModelParameters setLookupCacheDynamicFilePath(String lookupCacheDynamic) { - parameters.put(PARAM_LOOKUP_CACHE_DYNAMIC, toJsonString(lookupCacheDynamic)); - return this; - } - - /** - * Set LoRA adapters to use (implies --no-mmap). - * The key is expected to be a file path, the values are expected to be scales. - */ - public ModelParameters setLoraAdapters(Map loraAdapters) { - if (!loraAdapters.isEmpty()) { - StringBuilder builder = new StringBuilder(); - builder.append("{"); - int i = 0; - for (Map.Entry entry : loraAdapters.entrySet()) { - String key = entry.getKey(); - Float value = entry.getValue(); - builder.append(toJsonString(key)) - .append(": ") - .append(value); - if (i++ < loraAdapters.size() - 1) { - builder.append(", "); - } - } - builder.append("}"); - parameters.put(PARAM_LORA_ADAPTER, builder.toString()); - } - return this; - } - - /** - * Whether to load model with embedding support - */ - public ModelParameters setEmbedding(boolean embedding) { - parameters.put(PARAM_EMBEDDING, String.valueOf(embedding)); - return this; - } - - /** - * Whether to enable continuous batching (also called "dynamic batching") (default: disabled) - */ - public ModelParameters setContinuousBatching(boolean contBatching) { - parameters.put(PARAM_CONT_BATCHING, String.valueOf(contBatching)); - return this; - } - - /** - * Whether to enable Flash Attention (default: disabled) - */ - public ModelParameters setFlashAttention(boolean flashAttention) { - parameters.put(PARAM_FLASH_ATTENTION, String.valueOf(flashAttention)); - return this; - } - - /** - * Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string - */ - public ModelParameters setInputPrefixBos(boolean inputPrefixBos) { - parameters.put(PARAM_INPUT_PREFIX_BOS, String.valueOf(inputPrefixBos)); - return this; - } - - /** - * Whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf) - */ - public ModelParameters setIgnoreEos(boolean ignoreEos) { - parameters.put(PARAM_IGNORE_EOS, String.valueOf(ignoreEos)); - return this; - } - - /** - * Whether to use memory-map model (faster load but may increase pageouts if not using mlock) - */ - public ModelParameters setUseMmap(boolean useMmap) { - parameters.put(PARAM_USE_MMAP, String.valueOf(useMmap)); - return this; - } - - /** - * Whether to force the system to keep model in RAM rather than swapping or compressing - */ - public ModelParameters setUseMlock(boolean useMlock) { - parameters.put(PARAM_USE_MLOCK, String.valueOf(useMlock)); - return this; - } - - /** - * Whether to disable KV offload - */ - public ModelParameters setNoKvOffload(boolean noKvOffload) { - parameters.put(PARAM_NO_KV_OFFLOAD, String.valueOf(noKvOffload)); - return this; - } - - /** - * Set a system prompt to use - */ - public ModelParameters setSystemPrompt(String systemPrompt) { - parameters.put(PARAM_SYSTEM_PROMPT, toJsonString(systemPrompt)); - return this; - } - - /** - * The chat template to use (default: empty) - */ - public ModelParameters setChatTemplate(String chatTemplate) { - parameters.put(PARAM_CHAT_TEMPLATE, toJsonString(chatTemplate)); - return this; - } +@SuppressWarnings("unused") +public final class ModelParameters extends CliParameters { + + /** + * Set the number of threads to use during generation (default: -1). + */ + public ModelParameters setThreads(int nThreads) { + parameters.put("--threads", String.valueOf(nThreads)); + return this; + } + + /** + * Set the number of threads to use during batch and prompt processing (default: same as --threads). + */ + public ModelParameters setThreadsBatch(int nThreads) { + parameters.put("--threads-batch", String.valueOf(nThreads)); + return this; + } + + /** + * Set the CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: ""). + */ + public ModelParameters setCpuMask(String mask) { + parameters.put("--cpu-mask", mask); + return this; + } + + /** + * Set the range of CPUs for affinity. Complements --cpu-mask. + */ + public ModelParameters setCpuRange(String range) { + parameters.put("--cpu-range", range); + return this; + } + + /** + * Use strict CPU placement (default: 0). + */ + public ModelParameters setCpuStrict(int strictCpu) { + parameters.put("--cpu-strict", String.valueOf(strictCpu)); + return this; + } + + /** + * Set process/thread priority: 0-normal, 1-medium, 2-high, 3-realtime (default: 0). + */ + public ModelParameters setPriority(int priority) { + if (priority < 0 || priority > 3) { + throw new IllegalArgumentException("Invalid value for priority"); + } + parameters.put("--prio", String.valueOf(priority)); + return this; + } + + /** + * Set the polling level to wait for work (0 - no polling, default: 0). + */ + public ModelParameters setPoll(int poll) { + parameters.put("--poll", String.valueOf(poll)); + return this; + } + + /** + * Set the CPU affinity mask for batch processing: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask). + */ + public ModelParameters setCpuMaskBatch(String mask) { + parameters.put("--cpu-mask-batch", mask); + return this; + } + + /** + * Set the ranges of CPUs for batch affinity. Complements --cpu-mask-batch. + */ + public ModelParameters setCpuRangeBatch(String range) { + parameters.put("--cpu-range-batch", range); + return this; + } + + /** + * Use strict CPU placement for batch processing (default: same as --cpu-strict). + */ + public ModelParameters setCpuStrictBatch(int strictCpuBatch) { + parameters.put("--cpu-strict-batch", String.valueOf(strictCpuBatch)); + return this; + } + + /** + * Set process/thread priority for batch processing: 0-normal, 1-medium, 2-high, 3-realtime (default: 0). + */ + public ModelParameters setPriorityBatch(int priorityBatch) { + if (priorityBatch < 0 || priorityBatch > 3) { + throw new IllegalArgumentException("Invalid value for priority batch"); + } + parameters.put("--prio-batch", String.valueOf(priorityBatch)); + return this; + } + + /** + * Set the polling level for batch processing (default: same as --poll). + */ + public ModelParameters setPollBatch(int pollBatch) { + parameters.put("--poll-batch", String.valueOf(pollBatch)); + return this; + } + + /** + * Set the size of the prompt context (default: 0, 0 = loaded from model). + */ + public ModelParameters setCtxSize(int ctxSize) { + parameters.put("--ctx-size", String.valueOf(ctxSize)); + return this; + } + + /** + * Set the number of tokens to predict (default: -1 = infinity, -2 = until context filled). + */ + public ModelParameters setPredict(int nPredict) { + parameters.put("--predict", String.valueOf(nPredict)); + return this; + } + + /** + * Set the logical maximum batch size (default: 0). + */ + public ModelParameters setBatchSize(int batchSize) { + parameters.put("--batch-size", String.valueOf(batchSize)); + return this; + } + + /** + * Set the physical maximum batch size (default: 0). + */ + public ModelParameters setUbatchSize(int ubatchSize) { + parameters.put("--ubatch-size", String.valueOf(ubatchSize)); + return this; + } + + /** + * Set the number of tokens to keep from the initial prompt (default: -1 = all). + */ + public ModelParameters setKeep(int keep) { + parameters.put("--keep", String.valueOf(keep)); + return this; + } + + /** + * Disable context shift on infinite text generation (default: enabled). + */ + public ModelParameters disableContextShift() { + parameters.put("--no-context-shift", null); + return this; + } + + /** + * Enable Flash Attention (default: disabled). + */ + public ModelParameters enableFlashAttn() { + parameters.put("--flash-attn", null); + return this; + } + + /** + * Disable internal libllama performance timings (default: false). + */ + public ModelParameters disablePerf() { + parameters.put("--no-perf", null); + return this; + } + + /** + * Process escape sequences (default: true). + */ + public ModelParameters enableEscape() { + parameters.put("--escape", null); + return this; + } + + /** + * Do not process escape sequences (default: false). + */ + public ModelParameters disableEscape() { + parameters.put("--no-escape", null); + return this; + } + + /** + * Enable special tokens output (default: true). + */ + public ModelParameters enableSpecial() { + parameters.put("--special", null); + return this; + } + + /** + * Skip warming up the model with an empty run (default: false). + */ + public ModelParameters skipWarmup() { + parameters.put("--no-warmup", null); + return this; + } + + /** + * Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. + * (default: disabled) + */ + public ModelParameters setSpmInfill() { + parameters.put("--spm-infill", null); + return this; + } + + /** + * Set samplers that will be used for generation in the order, separated by ';' (default: all). + */ + public ModelParameters setSamplers(Sampler... samplers) { + if (samplers.length > 0) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < samplers.length; i++) { + Sampler sampler = samplers[i]; + builder.append(sampler.name().toLowerCase()); + if (i < samplers.length - 1) { + builder.append(";"); + } + } + parameters.put("--samplers", builder.toString()); + } + return this; + } + + /** + * Set RNG seed (default: -1, use random seed). + */ + public ModelParameters setSeed(long seed) { + parameters.put("--seed", String.valueOf(seed)); + return this; + } + + /** + * Ignore end of stream token and continue generating (implies --logit-bias EOS-inf). + */ + public ModelParameters ignoreEos() { + parameters.put("--ignore-eos", null); + return this; + } + + /** + * Set temperature for sampling (default: 0.8). + */ + public ModelParameters setTemp(float temp) { + parameters.put("--temp", String.valueOf(temp)); + return this; + } + + /** + * Set top-k sampling (default: 40, 0 = disabled). + */ + public ModelParameters setTopK(int topK) { + parameters.put("--top-k", String.valueOf(topK)); + return this; + } + + /** + * Set top-p sampling (default: 0.95, 1.0 = disabled). + */ + public ModelParameters setTopP(float topP) { + parameters.put("--top-p", String.valueOf(topP)); + return this; + } + + /** + * Set min-p sampling (default: 0.05, 0.0 = disabled). + */ + public ModelParameters setMinP(float minP) { + parameters.put("--min-p", String.valueOf(minP)); + return this; + } + + /** + * Set xtc probability (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setXtcProbability(float xtcProbability) { + parameters.put("--xtc-probability", String.valueOf(xtcProbability)); + return this; + } + + /** + * Set xtc threshold (default: 0.1, 1.0 = disabled). + */ + public ModelParameters setXtcThreshold(float xtcThreshold) { + parameters.put("--xtc-threshold", String.valueOf(xtcThreshold)); + return this; + } + + /** + * Set locally typical sampling parameter p (default: 1.0, 1.0 = disabled). + */ + public ModelParameters setTypical(float typP) { + parameters.put("--typical", String.valueOf(typP)); + return this; + } + + /** + * Set last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size). + */ + public ModelParameters setRepeatLastN(int repeatLastN) { + if (repeatLastN < -1) { + throw new RuntimeException("Invalid repeat-last-n value"); + } + parameters.put("--repeat-last-n", String.valueOf(repeatLastN)); + return this; + } + + /** + * Set penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled). + */ + public ModelParameters setRepeatPenalty(float repeatPenalty) { + parameters.put("--repeat-penalty", String.valueOf(repeatPenalty)); + return this; + } + + /** + * Set repeat alpha presence penalty (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setPresencePenalty(float presencePenalty) { + parameters.put("--presence-penalty", String.valueOf(presencePenalty)); + return this; + } + + /** + * Set repeat alpha frequency penalty (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setFrequencyPenalty(float frequencyPenalty) { + parameters.put("--frequency-penalty", String.valueOf(frequencyPenalty)); + return this; + } + + /** + * Set DRY sampling multiplier (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setDryMultiplier(float dryMultiplier) { + parameters.put("--dry-multiplier", String.valueOf(dryMultiplier)); + return this; + } + + /** + * Set DRY sampling base value (default: 1.75). + */ + public ModelParameters setDryBase(float dryBase) { + parameters.put("--dry-base", String.valueOf(dryBase)); + return this; + } + + /** + * Set allowed length for DRY sampling (default: 2). + */ + public ModelParameters setDryAllowedLength(int dryAllowedLength) { + parameters.put("--dry-allowed-length", String.valueOf(dryAllowedLength)); + return this; + } + + /** + * Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size). + */ + public ModelParameters setDryPenaltyLastN(int dryPenaltyLastN) { + if (dryPenaltyLastN < -1) { + throw new RuntimeException("Invalid dry-penalty-last-n value"); + } + parameters.put("--dry-penalty-last-n", String.valueOf(dryPenaltyLastN)); + return this; + } + + /** + * Add sequence breaker for DRY sampling, clearing out default breakers (default: none). + */ + public ModelParameters setDrySequenceBreaker(String drySequenceBreaker) { + parameters.put("--dry-sequence-breaker", drySequenceBreaker); + return this; + } + + /** + * Set dynamic temperature range (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setDynatempRange(float dynatempRange) { + parameters.put("--dynatemp-range", String.valueOf(dynatempRange)); + return this; + } + + /** + * Set dynamic temperature exponent (default: 1.0). + */ + public ModelParameters setDynatempExponent(float dynatempExponent) { + parameters.put("--dynatemp-exp", String.valueOf(dynatempExponent)); + return this; + } + + /** + * Use Mirostat sampling (default: PLACEHOLDER, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0). + */ + public ModelParameters setMirostat(MiroStat mirostat) { + parameters.put("--mirostat", String.valueOf(mirostat.ordinal())); + return this; + } + + /** + * Set Mirostat learning rate, parameter eta (default: 0.1). + */ + public ModelParameters setMirostatLR(float mirostatLR) { + parameters.put("--mirostat-lr", String.valueOf(mirostatLR)); + return this; + } + + /** + * Set Mirostat target entropy, parameter tau (default: 5.0). + */ + public ModelParameters setMirostatEnt(float mirostatEnt) { + parameters.put("--mirostat-ent", String.valueOf(mirostatEnt)); + return this; + } + + /** + * Modify the likelihood of token appearing in the completion. + */ + public ModelParameters setLogitBias(String tokenIdAndBias) { + parameters.put("--logit-bias", tokenIdAndBias); + return this; + } + + /** + * Set BNF-like grammar to constrain generations (default: empty). + */ + public ModelParameters setGrammar(String grammar) { + parameters.put("--grammar", grammar); + return this; + } + + /** + * Specify the file to read grammar from. + */ + public ModelParameters setGrammarFile(String fileName) { + parameters.put("--grammar-file", fileName); + return this; + } + + /** + * Specify the JSON schema to constrain generations (default: empty). + */ + public ModelParameters setJsonSchema(String schema) { + parameters.put("--json-schema", schema); + return this; + } + + /** + * Set pooling type for embeddings (default: model default if unspecified). + */ + public ModelParameters setPoolingType(PoolingType type) { + parameters.put("--pooling", String.valueOf(type.getId())); + return this; + } + + /** + * Set RoPE frequency scaling method (default: linear unless specified by the model). + */ + public ModelParameters setRopeScaling(RopeScalingType type) { + parameters.put("--rope-scaling", String.valueOf(type.getId())); + return this; + } + + /** + * Set RoPE context scaling factor, expands context by a factor of N. + */ + public ModelParameters setRopeScale(float ropeScale) { + parameters.put("--rope-scale", String.valueOf(ropeScale)); + return this; + } + + /** + * Set RoPE base frequency, used by NTK-aware scaling (default: loaded from model). + */ + public ModelParameters setRopeFreqBase(float ropeFreqBase) { + parameters.put("--rope-freq-base", String.valueOf(ropeFreqBase)); + return this; + } + + /** + * Set RoPE frequency scaling factor, expands context by a factor of 1/N. + */ + public ModelParameters setRopeFreqScale(float ropeFreqScale) { + parameters.put("--rope-freq-scale", String.valueOf(ropeFreqScale)); + return this; + } + + /** + * Set YaRN: original context size of model (default: model training context size). + */ + public ModelParameters setYarnOrigCtx(int yarnOrigCtx) { + parameters.put("--yarn-orig-ctx", String.valueOf(yarnOrigCtx)); + return this; + } + + /** + * Set YaRN: extrapolation mix factor (default: 0.0 = full interpolation). + */ + public ModelParameters setYarnExtFactor(float yarnExtFactor) { + parameters.put("--yarn-ext-factor", String.valueOf(yarnExtFactor)); + return this; + } + + /** + * Set YaRN: scale sqrt(t) or attention magnitude (default: 1.0). + */ + public ModelParameters setYarnAttnFactor(float yarnAttnFactor) { + parameters.put("--yarn-attn-factor", String.valueOf(yarnAttnFactor)); + return this; + } + + /** + * Set YaRN: high correction dim or alpha (default: 1.0). + */ + public ModelParameters setYarnBetaSlow(float yarnBetaSlow) { + parameters.put("--yarn-beta-slow", String.valueOf(yarnBetaSlow)); + return this; + } + + /** + * Set YaRN: low correction dim or beta (default: 32.0). + */ + public ModelParameters setYarnBetaFast(float yarnBetaFast) { + parameters.put("--yarn-beta-fast", String.valueOf(yarnBetaFast)); + return this; + } + + /** + * Set group-attention factor (default: 1). + */ + public ModelParameters setGrpAttnN(int grpAttnN) { + parameters.put("--grp-attn-n", String.valueOf(grpAttnN)); + return this; + } + + /** + * Set group-attention width (default: 512). + */ + public ModelParameters setGrpAttnW(int grpAttnW) { + parameters.put("--grp-attn-w", String.valueOf(grpAttnW)); + return this; + } + + /** + * Enable verbose printing of the KV cache. + */ + public ModelParameters enableDumpKvCache() { + parameters.put("--dump-kv-cache", null); + return this; + } + + /** + * Disable KV offload. + */ + public ModelParameters disableKvOffload() { + parameters.put("--no-kv-offload", null); + return this; + } + + /** + * Set KV cache data type for K (allowed values: F16). + */ + public ModelParameters setCacheTypeK(CacheType type) { + parameters.put("--cache-type-k", type.name().toLowerCase()); + return this; + } + + /** + * Set KV cache data type for V (allowed values: F16). + */ + public ModelParameters setCacheTypeV(CacheType type) { + parameters.put("--cache-type-v", type.name().toLowerCase()); + return this; + } + + /** + * Set KV cache defragmentation threshold (default: 0.1, < 0 - disabled). + */ + public ModelParameters setDefragThold(float defragThold) { + parameters.put("--defrag-thold", String.valueOf(defragThold)); + return this; + } + + /** + * Set the number of parallel sequences to decode (default: 1). + */ + public ModelParameters setParallel(int nParallel) { + parameters.put("--parallel", String.valueOf(nParallel)); + return this; + } + + /** + * Enable continuous batching (a.k.a dynamic batching) (default: disabled). + */ + public ModelParameters enableContBatching() { + parameters.put("--cont-batching", null); + return this; + } + + /** + * Disable continuous batching. + */ + public ModelParameters disableContBatching() { + parameters.put("--no-cont-batching", null); + return this; + } + + /** + * Force system to keep model in RAM rather than swapping or compressing. + */ + public ModelParameters enableMlock() { + parameters.put("--mlock", null); + return this; + } + + /** + * Do not memory-map model (slower load but may reduce pageouts if not using mlock). + */ + public ModelParameters disableMmap() { + parameters.put("--no-mmap", null); + return this; + } + + /** + * Set NUMA optimization type for system. + */ + public ModelParameters setNuma(NumaStrategy numaStrategy) { + parameters.put("--numa", numaStrategy.name().toLowerCase()); + return this; + } + + /** + * Set comma-separated list of devices to use for offloading (none = don't offload). + */ + public ModelParameters setDevices(String devices) { + parameters.put("--device", devices); + return this; + } + + /** + * Set the number of layers to store in VRAM. + */ + public ModelParameters setGpuLayers(int gpuLayers) { + parameters.put("--gpu-layers", String.valueOf(gpuLayers)); + return this; + } + + /** + * Set how to split the model across multiple GPUs (none, layer, row). + */ + public ModelParameters setSplitMode(GpuSplitMode splitMode) { + parameters.put("--split-mode", splitMode.name().toLowerCase()); + return this; + } + + /** + * Set fraction of the model to offload to each GPU, comma-separated list of proportions N0,N1,N2,.... + */ + public ModelParameters setTensorSplit(String tensorSplit) { + parameters.put("--tensor-split", tensorSplit); + return this; + } + + /** + * Set the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row). + */ + public ModelParameters setMainGpu(int mainGpu) { + parameters.put("--main-gpu", String.valueOf(mainGpu)); + return this; + } + + /** + * Enable checking model tensor data for invalid values. + */ + public ModelParameters enableCheckTensors() { + parameters.put("--check-tensors", null); + return this; + } + + /** + * Override model metadata by key. This option can be specified multiple times. + */ + public ModelParameters setOverrideKv(String keyValue) { + parameters.put("--override-kv", keyValue); + return this; + } + + /** + * Add a LoRA adapter (can be repeated to use multiple adapters). + */ + public ModelParameters addLoraAdapter(String fname) { + parameters.put("--lora", fname); + return this; + } + + /** + * Add a LoRA adapter with user-defined scaling (can be repeated to use multiple adapters). + */ + public ModelParameters addLoraScaledAdapter(String fname, float scale) { + parameters.put("--lora-scaled", fname + "," + scale); + return this; + } + + /** + * Add a control vector (this argument can be repeated to add multiple control vectors). + */ + public ModelParameters addControlVector(String fname) { + parameters.put("--control-vector", fname); + return this; + } + + /** + * Add a control vector with user-defined scaling (can be repeated to add multiple scaled control vectors). + */ + public ModelParameters addControlVectorScaled(String fname, float scale) { + parameters.put("--control-vector-scaled", fname + "," + scale); + return this; + } + + /** + * Set the layer range to apply the control vector(s) to (start and end inclusive). + */ + public ModelParameters setControlVectorLayerRange(int start, int end) { + parameters.put("--control-vector-layer-range", start + "," + end); + return this; + } + + /** + * Set the model path from which to load the base model. + */ + public ModelParameters setModel(String model) { + parameters.put("--model", model); + return this; + } + + /** + * Set the model download URL (default: unused). + */ + public ModelParameters setModelUrl(String modelUrl) { + parameters.put("--model-url", modelUrl); + return this; + } + + /** + * Set the Hugging Face model repository (default: unused). + */ + public ModelParameters setHfRepo(String hfRepo) { + parameters.put("--hf-repo", hfRepo); + return this; + } + + /** + * Set the Hugging Face model file (default: unused). + */ + public ModelParameters setHfFile(String hfFile) { + parameters.put("--hf-file", hfFile); + return this; + } + + /** + * Set the Hugging Face model repository for the vocoder model (default: unused). + */ + public ModelParameters setHfRepoV(String hfRepoV) { + parameters.put("--hf-repo-v", hfRepoV); + return this; + } + + /** + * Set the Hugging Face model file for the vocoder model (default: unused). + */ + public ModelParameters setHfFileV(String hfFileV) { + parameters.put("--hf-file-v", hfFileV); + return this; + } + + /** + * Set the Hugging Face access token (default: value from HF_TOKEN environment variable). + */ + public ModelParameters setHfToken(String hfToken) { + parameters.put("--hf-token", hfToken); + return this; + } + + /** + * Enable embedding use case; use only with dedicated embedding models. + */ + public ModelParameters enableEmbedding() { + parameters.put("--embedding", null); + return this; + } + + /** + * Enable reranking endpoint on server. + */ + public ModelParameters enableReranking() { + parameters.put("--reranking", null); + return this; + } + + /** + * Set minimum chunk size to attempt reusing from the cache via KV shifting. + */ + public ModelParameters setCacheReuse(int cacheReuse) { + parameters.put("--cache-reuse", String.valueOf(cacheReuse)); + return this; + } + + /** + * Set the path to save the slot kv cache. + */ + public ModelParameters setSlotSavePath(String slotSavePath) { + parameters.put("--slot-save-path", slotSavePath); + return this; + } + + /** + * Set custom jinja chat template. + */ + public ModelParameters setChatTemplate(String chatTemplate) { + parameters.put("--chat-template", chatTemplate); + return this; + } + + /** + * Set how much the prompt of a request must match the prompt of a slot in order to use that slot. + */ + public ModelParameters setSlotPromptSimilarity(float similarity) { + parameters.put("--slot-prompt-similarity", String.valueOf(similarity)); + return this; + } + + /** + * Load LoRA adapters without applying them (apply later via POST /lora-adapters). + */ + public ModelParameters setLoraInitWithoutApply() { + parameters.put("--lora-init-without-apply", null); + return this; + } + + /** + * Disable logging. + */ + public ModelParameters disableLog() { + parameters.put("--log-disable", null); + return this; + } + + /** + * Set the log file path. + */ + public ModelParameters setLogFile(String logFile) { + parameters.put("--log-file", logFile); + return this; + } + + /** + * Set verbosity level to infinity (log all messages, useful for debugging). + */ + public ModelParameters setVerbose() { + parameters.put("--verbose", null); + return this; + } + + /** + * Set the verbosity threshold (messages with a higher verbosity will be ignored). + */ + public ModelParameters setLogVerbosity(int verbosity) { + parameters.put("--log-verbosity", String.valueOf(verbosity)); + return this; + } + + /** + * Enable prefix in log messages. + */ + public ModelParameters enableLogPrefix() { + parameters.put("--log-prefix", null); + return this; + } + + /** + * Enable timestamps in log messages. + */ + public ModelParameters enableLogTimestamps() { + parameters.put("--log-timestamps", null); + return this; + } + + /** + * Set the number of tokens to draft for speculative decoding. + */ + public ModelParameters setDraftMax(int draftMax) { + parameters.put("--draft-max", String.valueOf(draftMax)); + return this; + } + + /** + * Set the minimum number of draft tokens to use for speculative decoding. + */ + public ModelParameters setDraftMin(int draftMin) { + parameters.put("--draft-min", String.valueOf(draftMin)); + return this; + } + + /** + * Set the minimum speculative decoding probability for greedy decoding. + */ + public ModelParameters setDraftPMin(float draftPMin) { + parameters.put("--draft-p-min", String.valueOf(draftPMin)); + return this; + } + + /** + * Set the size of the prompt context for the draft model. + */ + public ModelParameters setCtxSizeDraft(int ctxSizeDraft) { + parameters.put("--ctx-size-draft", String.valueOf(ctxSizeDraft)); + return this; + } + + /** + * Set the comma-separated list of devices to use for offloading the draft model. + */ + public ModelParameters setDeviceDraft(String deviceDraft) { + parameters.put("--device-draft", deviceDraft); + return this; + } + + /** + * Set the number of layers to store in VRAM for the draft model. + */ + public ModelParameters setGpuLayersDraft(int gpuLayersDraft) { + parameters.put("--gpu-layers-draft", String.valueOf(gpuLayersDraft)); + return this; + } + + /** + * Set the draft model for speculative decoding. + */ + public ModelParameters setModelDraft(String modelDraft) { + parameters.put("--model-draft", modelDraft); + return this; + } } diff --git a/src/main/java/de/kherud/llama/args/CacheType.java b/src/main/java/de/kherud/llama/args/CacheType.java new file mode 100644 index 00000000..8404ed75 --- /dev/null +++ b/src/main/java/de/kherud/llama/args/CacheType.java @@ -0,0 +1,15 @@ +package de.kherud.llama.args; + +public enum CacheType { + + F32, + F16, + BF16, + Q8_0, + Q4_0, + Q4_1, + IQ4_NL, + Q5_0, + Q5_1 + +} diff --git a/src/main/java/de/kherud/llama/args/NumaStrategy.java b/src/main/java/de/kherud/llama/args/NumaStrategy.java index 35b24e19..fa7a61b0 100644 --- a/src/main/java/de/kherud/llama/args/NumaStrategy.java +++ b/src/main/java/de/kherud/llama/args/NumaStrategy.java @@ -2,9 +2,7 @@ public enum NumaStrategy { - DISABLED, DISTRIBUTE, ISOLATE, - NUMA_CTL, - MIRROR + NUMACTL } diff --git a/src/main/java/de/kherud/llama/args/PoolingType.java b/src/main/java/de/kherud/llama/args/PoolingType.java index e9b441d4..a9c9dbae 100644 --- a/src/main/java/de/kherud/llama/args/PoolingType.java +++ b/src/main/java/de/kherud/llama/args/PoolingType.java @@ -2,7 +2,20 @@ public enum PoolingType { - UNSPECIFIED, - MEAN, - CLS + UNSPECIFIED(-1), + NONE(0), + MEAN(1), + CLS(2), + LAST(3), + RANK(4); + + private final int id; + + PoolingType(int value) { + this.id = value; + } + + public int getId() { + return id; + } } diff --git a/src/main/java/de/kherud/llama/args/RopeScalingType.java b/src/main/java/de/kherud/llama/args/RopeScalingType.java index a69596f5..eed939a1 100644 --- a/src/main/java/de/kherud/llama/args/RopeScalingType.java +++ b/src/main/java/de/kherud/llama/args/RopeScalingType.java @@ -2,7 +2,20 @@ public enum RopeScalingType { - UNSPECIFIED, - LINEAR, - YARN + UNSPECIFIED(-1), + NONE(0), + LINEAR(1), + YARN2(2), + LONGROPE(3), + MAX_VALUE(3); + + private final int id; + + RopeScalingType(int value) { + this.id = value; + } + + public int getId() { + return id; + } } diff --git a/src/main/java/de/kherud/llama/args/Sampler.java b/src/main/java/de/kherud/llama/args/Sampler.java index 0864e91b..564a2e6f 100644 --- a/src/main/java/de/kherud/llama/args/Sampler.java +++ b/src/main/java/de/kherud/llama/args/Sampler.java @@ -2,10 +2,14 @@ public enum Sampler { - TOP_K, - TFS_Z, - TYPICAL_P, - TOP_P, - MIN_P, - TEMPERATURE + DRY, + TOP_K, + TOP_P, + TYP_P, + MIN_P, + TEMPERATURE, + XTC, + INFILL, + PENALTIES + } diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index b5481cef..f4fbb0d6 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -15,7 +15,7 @@ public class LlamaModelTest { private static final String prefix = "def remove_non_ascii(s: str) -> str:\n \"\"\" "; private static final String suffix = "\n return result\n"; - private static final int nPredict = 10; + private static final int nPredict = 1024; private static LlamaModel model; @@ -24,11 +24,11 @@ public static void setup() { // LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> System.out.println(level + ": " + msg)); model = new LlamaModel( new ModelParameters() - .setNCtx(128) - .setModelFilePath("models/codellama-7b.Q2_K.gguf") +// .setCtxSize(128) + .setModel("/Users/vrao/Work/ml/llm_models/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf") // .setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf") - .setNGpuLayers(43) - .setEmbedding(true) + .setGpuLayers(43) + .enableEmbedding().enableLogTimestamps().enableLogPrefix() ); } @@ -155,7 +155,7 @@ public void testCancelGenerating() { @Test public void testEmbedding() { float[] embedding = model.embed(prefix); - Assert.assertEquals(4096, embedding.length); + Assert.assertEquals(1536, embedding.length); } @Test @@ -164,10 +164,10 @@ public void testTokenization() { int[] encoded = model.encode(prompt); String decoded = model.decode(encoded); // the llama tokenizer adds a space before the prompt - Assert.assertEquals(" " + prompt, decoded); + Assert.assertEquals(prompt, decoded); } - @Test + @Ignore public void testLogText() { List messages = new ArrayList<>(); LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> messages.add(new LogMessage(level, msg))); @@ -186,7 +186,7 @@ public void testLogText() { } } - @Test + @Ignore public void testLogJSON() { List messages = new ArrayList<>(); LlamaModel.setLogger(LogFormat.JSON, (level, msg) -> messages.add(new LogMessage(level, msg))); diff --git a/src/test/java/examples/GrammarExample.java b/src/test/java/examples/GrammarExample.java index a2fec2fb..d90de206 100644 --- a/src/test/java/examples/GrammarExample.java +++ b/src/test/java/examples/GrammarExample.java @@ -13,7 +13,7 @@ public static void main(String... args) { "expr ::= term ([-+*/] term)*\n" + "term ::= [0-9]"; ModelParameters modelParams = new ModelParameters() - .setModelFilePath("models/mistral-7b-instruct-v0.2.Q2_K.gguf"); + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf"); InferenceParameters inferParams = new InferenceParameters("") .setGrammar(grammar); try (LlamaModel model = new LlamaModel(modelParams)) { diff --git a/src/test/java/examples/InfillExample.java b/src/test/java/examples/InfillExample.java index b73eeb0f..e13ecb7c 100644 --- a/src/test/java/examples/InfillExample.java +++ b/src/test/java/examples/InfillExample.java @@ -9,8 +9,8 @@ public class InfillExample { public static void main(String... args) { ModelParameters modelParams = new ModelParameters() - .setModelFilePath("models/codellama-7b.Q2_K.gguf") - .setNGpuLayers(43); + .setModel("models/codellama-7b.Q2_K.gguf") + .setGpuLayers(43); String prefix = "def remove_non_ascii(s: str) -> str:\n \"\"\" "; String suffix = "\n return result\n"; diff --git a/src/test/java/examples/MainExample.java b/src/test/java/examples/MainExample.java index 92581144..2b5150a5 100644 --- a/src/test/java/examples/MainExample.java +++ b/src/test/java/examples/MainExample.java @@ -16,8 +16,8 @@ public class MainExample { public static void main(String... args) throws IOException { ModelParameters modelParams = new ModelParameters() - .setModelFilePath("models/mistral-7b-instruct-v0.2.Q2_K.gguf") - .setNGpuLayers(43); + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf") + .setGpuLayers(43); String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + "requests immediately and with precision.\n\n" + From a718e2e1c5613309ea51e8a225e10e0c7887136a Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 12 Feb 2025 11:08:13 -0800 Subject: [PATCH 02/51] replacing local model with modelWithUri --- src/main/cpp/jllama.cpp | 8 +++--- .../java/de/kherud/llama/LlamaModelTest.java | 25 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 29568727..c5dbfa17 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -554,7 +554,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - + if (result->is_error()) { std::string response = result->to_json()["message"].get(); @@ -563,6 +563,9 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE return nullptr; } const auto out_res = result->to_json(); + + + std::string response = out_res["content"].get(); if (result->is_stop()) { @@ -588,9 +591,6 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE } } } - - ctx_server->queue_results.remove_waiting_task_id(id_task); - jbyteArray jbytes = parse_jbytes(env, response); return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result->is_stop()); } diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index f4fbb0d6..ae8ada74 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -15,7 +15,7 @@ public class LlamaModelTest { private static final String prefix = "def remove_non_ascii(s: str) -> str:\n \"\"\" "; private static final String suffix = "\n return result\n"; - private static final int nPredict = 1024; + private static final int nPredict = 10; private static LlamaModel model; @@ -24,9 +24,8 @@ public static void setup() { // LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> System.out.println(level + ": " + msg)); model = new LlamaModel( new ModelParameters() -// .setCtxSize(128) - .setModel("/Users/vrao/Work/ml/llm_models/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf") -// .setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf") + .setCtxSize(128) + .setModelUrl("https://huggingface.co/bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q2_K.gguf") .setGpuLayers(43) .enableEmbedding().enableLogTimestamps().enableLogPrefix() ); @@ -43,7 +42,7 @@ public static void tearDown() { public void testGenerateAnswer() { Map logitBias = new HashMap<>(); logitBias.put(2, 2.0f); - InferenceParameters params = new InferenceParameters(prefix) + InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") .setTemperature(0.95f) .setStopStrings("\"\"\"") .setNPredict(nPredict) @@ -62,8 +61,8 @@ public void testGenerateInfill() { Map logitBias = new HashMap<>(); logitBias.put(2, 2.0f); InferenceParameters params = new InferenceParameters("") - .setInputPrefix(prefix) - .setInputSuffix(suffix) + .setInputPrefix("<|User|> " + prefix + " <|Assistant|> ") + .setInputSuffix(suffix ) .setTemperature(0.95f) .setStopStrings("\"\"\"") .setNPredict(nPredict) @@ -97,7 +96,7 @@ public void testGenerateGrammar() { public void testCompleteAnswer() { Map logitBias = new HashMap<>(); logitBias.put(2, 2.0f); - InferenceParameters params = new InferenceParameters(prefix) + InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") .setTemperature(0.95f) .setStopStrings("\"\"\"") .setNPredict(nPredict) @@ -113,7 +112,7 @@ public void testCompleteInfillCustom() { Map logitBias = new HashMap<>(); logitBias.put(2, 2.0f); InferenceParameters params = new InferenceParameters("") - .setInputPrefix(prefix) + .setInputPrefix("<|User|> " + prefix +" <|Assistant|> ") .setInputSuffix(suffix) .setTemperature(0.95f) .setStopStrings("\"\"\"") @@ -138,7 +137,7 @@ public void testCompleteGrammar() { @Test public void testCancelGenerating() { - InferenceParameters params = new InferenceParameters(prefix).setNPredict(nPredict); + InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ").setNPredict(nPredict); int generated = 0; LlamaIterator iterator = model.generate(params).iterator(); @@ -172,7 +171,7 @@ public void testLogText() { List messages = new ArrayList<>(); LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> messages.add(new LogMessage(level, msg))); - InferenceParameters params = new InferenceParameters(prefix) + InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") .setNPredict(nPredict) .setSeed(42); model.complete(params); @@ -191,7 +190,7 @@ public void testLogJSON() { List messages = new ArrayList<>(); LlamaModel.setLogger(LogFormat.JSON, (level, msg) -> messages.add(new LogMessage(level, msg))); - InferenceParameters params = new InferenceParameters(prefix) + InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") .setNPredict(nPredict) .setSeed(42); model.complete(params); @@ -208,7 +207,7 @@ public void testLogJSON() { @Test public void testLogStdout() { // Unfortunately, `printf` can't be easily re-directed to Java. This test only works manually, thus. - InferenceParameters params = new InferenceParameters(prefix) + InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") .setNPredict(nPredict) .setSeed(42); From 5745611ce90e63a159e7718895cec4e91d541cdd Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 12 Feb 2025 12:34:20 -0800 Subject: [PATCH 03/51] updating version and readme and parameter. --- README.md | 6 +++--- pom.xml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 718ec4be..341e740c 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Access this library via Maven: de.kherud llama - 3.4.1 + 3.4.2 ``` @@ -37,7 +37,7 @@ Bu default the default library artifact is built only with CPU inference support de.kherud llama - 3.4.1 + 3.4.2 cuda12-linux-x86-64 ``` @@ -78,7 +78,7 @@ cmake --build build --config Release ``` > [!TIP] -> Use `-DGGML_CURL=ON` to download models via Java code using `ModelParameters#setModelUrl(String)`. +> Use `-DLLAMA_CURL=ON` to download models via Java code using `ModelParameters#setModelUrl(String)`. All compiled libraries will be put in a resources directory matching your platform, which will appear in the cmake output. For example something like: diff --git a/pom.xml b/pom.xml index 68674de9..a086bef1 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ de.kherud llama - 3.4.1 + 3.4.2 jar ${project.groupId}:${project.artifactId} From 091337388595a007285b04a2f1433084e04aba06 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Thu, 13 Feb 2025 09:24:19 -0800 Subject: [PATCH 04/51] adding releaseTask and updated test to match workflow --- src/main/cpp/jllama.cpp | 10 ++++++++++ src/main/cpp/jllama.h | 9 +++++++++ src/main/java/de/kherud/llama/LlamaModel.java | 3 +++ src/test/java/de/kherud/llama/LlamaModelTest.java | 6 +++--- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index c5dbfa17..00eccbb7 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -548,6 +548,13 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv return *task_ids.begin(); } +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *env, jobject obj, jint id_task) +{ + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + ctx_server->queue_results.remove_waiting_task_id(id_task); +} + JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *env, jobject obj, jint id_task) { jlong server_handle = env->GetLongField(obj, f_model_pointer); @@ -722,6 +729,9 @@ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv return parse_jbytes(env, text); } + + + JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobject obj) { jlong server_handle = env->GetLongField(obj, f_model_pointer); diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 0ab39ea4..39048686 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -97,6 +97,15 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel */ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete (JNIEnv *, jobject); + + +/* + * Class: de_kherud_llama_LlamaModel + * Method: releaseTask + * Signature: ()V + */ +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask + (JNIEnv *, jobject, jint); #ifdef __cplusplus } diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index 1e8878c0..fc0e70fa 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -54,6 +54,7 @@ public String complete(InferenceParameters parameters) { parameters.setStream(false); int taskId = requestCompletion(parameters.toString()); LlamaOutput output = receiveCompletion(taskId); + releaseTask(taskId); return output.text; } @@ -129,5 +130,7 @@ public void close() { private native void loadModel(String... parameters) throws LlamaException; private native void delete(); + + private native void releaseTask(int taskId); } diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index ae8ada74..35f3b092 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -25,7 +25,7 @@ public static void setup() { model = new LlamaModel( new ModelParameters() .setCtxSize(128) - .setModelUrl("https://huggingface.co/bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q2_K.gguf") + .setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf") .setGpuLayers(43) .enableEmbedding().enableLogTimestamps().enableLogPrefix() ); @@ -154,7 +154,7 @@ public void testCancelGenerating() { @Test public void testEmbedding() { float[] embedding = model.embed(prefix); - Assert.assertEquals(1536, embedding.length); + Assert.assertEquals(4096, embedding.length); } @Test @@ -163,7 +163,7 @@ public void testTokenization() { int[] encoded = model.encode(prompt); String decoded = model.decode(encoded); // the llama tokenizer adds a space before the prompt - Assert.assertEquals(prompt, decoded); + Assert.assertEquals(" " +prompt, decoded); } @Ignore From 7c54bd386257a0e9200c2c5d1459195b3f38957b Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Thu, 13 Feb 2025 10:33:25 -0800 Subject: [PATCH 05/51] replacing the modelPath --- src/test/java/de/kherud/llama/LlamaModelTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index 35f3b092..c757d0c3 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -25,7 +25,8 @@ public static void setup() { model = new LlamaModel( new ModelParameters() .setCtxSize(128) - .setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf") + .setModel("models/codellama-7b.Q2_K.gguf") + //.setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf") .setGpuLayers(43) .enableEmbedding().enableLogTimestamps().enableLogPrefix() ); From d87a103a3c8d5288382ba373b093a7d09be25b66 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Thu, 13 Feb 2025 14:38:52 -0800 Subject: [PATCH 06/51] adding chat format and LLAMA_CURL=ON to build --- .github/workflows/ci.yml | 8 ++++---- src/main/cpp/jllama.h | 18 ------------------ src/main/cpp/server.hpp | 1 + .../de/kherud/llama/InferenceParameters.java | 7 ++++++- .../java/de/kherud/llama/ModelParameters.java | 8 ++++++++ 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1db8b696..a13f5b4a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ jobs: # cmake should figure out OS and ARCH automatically when running build.sh (but we need mvn compile for it) run: | mvn compile - .github/build.sh -DLLAMA_VERBOSE=ON + .github/build.sh -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON - name: Download model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} - name: Run tests @@ -43,11 +43,11 @@ jobs: target: - { runner: macos-13, - cmake: '-DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON' + cmake: '-DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON' } - { runner: macos-14, - cmake: '-DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON' + cmake: '-DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON' } steps: - uses: actions/checkout@v4 @@ -82,7 +82,7 @@ jobs: - name: Build libraries run: | mvn compile - .github\build.bat -DLLAMA_VERBOSE=ON + .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 39048686..fcc01486 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -8,24 +8,6 @@ extern "C" { #endif -/* - * Class: de_kherud_llama_LlamaModel - * Method: requestEmbedding - * Signature: (Ljava/lang/String;)[F - */ -JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestEmbedding - (JNIEnv *, jobject, jstring); - - -/* - * Class: de_kherud_llama_LlamaModel - * Method: receiveEmbedding - * Signature: (Ljava/lang/Int;)[F - */ -JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_receiveEmbedding - (JNIEnv *, jobject, jint); - - /* * Class: de_kherud_llama_LlamaModel * Method: embed diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 70e7236d..beed793d 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -164,6 +164,7 @@ struct slot_params { {"grammar_trigger_words", grammar_trigger_words}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_format)}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index 2c494c8c..0ac1b1dc 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -46,6 +46,7 @@ public final class InferenceParameters extends JsonParameters { private static final String PARAM_SAMPLERS = "samplers"; private static final String PARAM_STREAM = "stream"; private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template"; + private static final String PARAM_USE_JINJA = "use_jinja"; public InferenceParameters(String prompt) { // we always need a prompt @@ -488,8 +489,12 @@ InferenceParameters setStream(boolean stream) { * Set whether or not generate should apply a chat template (default: false) */ public InferenceParameters setUseChatTemplate(boolean useChatTemplate) { - parameters.put(PARAM_USE_CHAT_TEMPLATE, String.valueOf(useChatTemplate)); + parameters.put(PARAM_USE_JINJA, String.valueOf(useChatTemplate)); return this; } + + + + } diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index 91587001..8615bd50 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -950,5 +950,13 @@ public ModelParameters setModelDraft(String modelDraft) { parameters.put("--model-draft", modelDraft); return this; } + + /** + * Enable jinja for templating + */ + public ModelParameters enableJinja() { + parameters.put("--jinja", null); + return this; + } } From b7962aa0188e8e1b059e6c03170ebf3de9c35429 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Thu, 13 Feb 2025 23:47:35 -0800 Subject: [PATCH 07/51] updating version to latest. --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1b5f08f3..64d3d0dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4689 + GIT_TAG b4702 ) FetchContent_MakeAvailable(llama.cpp) From dcb14ff567619ddbb076d9a0a28ad18971db0ac4 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 14 Feb 2025 01:33:24 -0800 Subject: [PATCH 08/51] reverting to older version of llamacpp --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 64d3d0dc..1b5f08f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4702 + GIT_TAG b4689 ) FetchContent_MakeAvailable(llama.cpp) From e9b3d52e59ba5b15431539c675efc92e4b9f78b4 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 14 Feb 2025 22:32:15 -0800 Subject: [PATCH 09/51] adding tool support --- CMakeLists.txt | 2 +- src/main/cpp/server.hpp | 26 ++++++++++++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1b5f08f3..3cf89dc6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4689 + GIT_TAG b4719 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index beed793d..b435c3d4 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -705,7 +705,7 @@ struct server_task_result_cmpl_final : server_task_result { return res; } - json to_json_oaicompat_chat() { +json to_json_oaicompat_chat() { std::string finish_reason = "length"; common_chat_msg msg; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { @@ -716,9 +716,19 @@ struct server_task_result_cmpl_final : server_task_result { msg.content = content; } - json tool_calls; + json message { + {"role", "assistant"}, + }; + if (!msg.reasoning_content.empty()) { + message["reasoning_content"] = msg.reasoning_content; + } + if (msg.content.empty() && !msg.tool_calls.empty()) { + message["content"] = json(); + } else { + message["content"] = msg.content; + } if (!msg.tool_calls.empty()) { - tool_calls = json::array(); + auto tool_calls = json::array(); for (const auto & tc : msg.tool_calls) { tool_calls.push_back({ {"type", "function"}, @@ -729,15 +739,7 @@ struct server_task_result_cmpl_final : server_task_result { {"id", tc.id}, }); } - } - - json message { - {"content", msg.content}, - {"tool_calls", tool_calls}, - {"role", "assistant"}, - }; - if (!msg.tool_plan.empty()) { - message["tool_plan"] = msg.tool_plan; + message["tool_calls"] = tool_calls; } json choice { From ea1327a0a2548f75aa3266a4d2dfc05e55a27385 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Sat, 15 Feb 2025 13:54:51 -0800 Subject: [PATCH 10/51] adding condition for Grammar --- src/main/cpp/utils.hpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index 5ff886da..1c5e276a 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -642,16 +642,18 @@ static json oaicompat_completion_params_parse( llama_params["chat_format"] = static_cast(chat_params.format); llama_params["prompt"] = chat_params.prompt; - llama_params["grammar"] = chat_params.grammar; - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { - grammar_triggers.push_back({ - {"word", trigger.word}, - {"at_start", trigger.at_start}, - }); + if (inputs.json_schema == nullptr) { + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back({ + {"word", trigger.word}, + {"at_start", trigger.at_start}, + }); + } + llama_params["grammar_triggers"] = grammar_triggers; } - llama_params["grammar_triggers"] = grammar_triggers; llama_params["preserved_tokens"] = chat_params.preserved_tokens; for (const auto & stop : chat_params.additional_stops) { llama_params["stop"].push_back(stop); From 9fbebbab17c047b875eb3666ee5ada843ab4926a Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Sat, 15 Feb 2025 13:57:48 -0800 Subject: [PATCH 11/51] fixing code for apply template --- src/main/cpp/server.hpp | 46 ++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index b435c3d4..332c1edc 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -1893,31 +1893,43 @@ struct server_context { return true; } + + bool validate_jinja_templates() const { + auto templates = common_chat_templates_from_model(model, ""); + common_chat_inputs inputs; + inputs.messages = json::array({ + { + { "role", "user" }, + { "content", "test" }, + } + }); + GGML_ASSERT(templates.template_default); + try { + common_chat_params_init(*templates.template_default, inputs); + if (templates.template_tool_use) { + common_chat_params_init(*templates.template_tool_use, inputs); + } + + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + + return false; + } + } + bool validate_builtin_chat_template(bool use_jinja) const { llama_chat_message chat[] = {{"user", "test"}}; if (use_jinja) { - auto templates = common_chat_templates_from_model(model, ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - GGML_ASSERT(templates.template_default); - try { - common_chat_params_init(*templates.template_default, inputs); - if (templates.template_tool_use) { - common_chat_params_init(*templates.template_tool_use, inputs); - } - return true; - } catch (const std::exception & e) { - SRV_ERR("failed to apply template: %s\n", e.what()); - return false; - } + return validate_jinja_templates(); } else { const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); + if (chat_res < 0) { + return validate_jinja_templates(); + } return chat_res > 0; } } From 22cefc5c279683357866d4e3feebbcdedc3c2c56 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sun, 16 Feb 2025 13:27:52 +0100 Subject: [PATCH 12/51] install libcurl in github workflows --- .github/workflows/ci.yml | 6 +++++- .github/workflows/release.yaml | 6 +++++- .gitignore | 1 + 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a13f5b4a..d8db1a21 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,6 +18,8 @@ jobs: with: distribution: 'zulu' java-version: '11' + - name: Install libcurl + run: sudo apt-get install -y libcurl4-openssl-dev - name: Build libraries # cmake should figure out OS and ARCH automatically when running build.sh (but we need mvn compile for it) run: | @@ -79,10 +81,12 @@ jobs: with: distribution: 'zulu' java-version: '11' + - name: Install libcurl + run: vcpkg install curl - name: Build libraries run: | mvn compile - .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON + .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 85829ed9..2e60bffc 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -18,6 +18,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - name: Install libcurl + run: sudo apt-get install -y libcurl4-openssl-dev - name: Build libraries shell: bash run: | @@ -121,10 +123,12 @@ jobs: } steps: - uses: actions/checkout@v4 + - name: Install curl + run: vcpkg install curl - name: Build libraries shell: cmd run: | - .github\build.bat ${{ matrix.target.cmake }} -DOS_NAME=${{ matrix.target.os }} -DOS_ARCH=${{ matrix.target.arch }} + .github\build.bat ${{ matrix.target.cmake }} -DOS_NAME=${{ matrix.target.os }} -DOS_ARCH=${{ matrix.target.arch }} -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include - name: Upload artifacts uses: actions/upload-artifact@v4 with: diff --git a/.gitignore b/.gitignore index 8857fd04..274f8687 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .idea target build +cmake-build-* .DS_Store .directory .vscode From 2f8d2b0a0fb7671b399876109bdd8275a4ff130b Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 21 Feb 2025 10:09:16 -0800 Subject: [PATCH 13/51] updating test case to make codellama model --- .../java/de/kherud/llama/LlamaModelTest.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index c757d0c3..6fbe2e43 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -43,7 +43,7 @@ public static void tearDown() { public void testGenerateAnswer() { Map logitBias = new HashMap<>(); logitBias.put(2, 2.0f); - InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") + InferenceParameters params = new InferenceParameters(prefix) .setTemperature(0.95f) .setStopStrings("\"\"\"") .setNPredict(nPredict) @@ -62,7 +62,7 @@ public void testGenerateInfill() { Map logitBias = new HashMap<>(); logitBias.put(2, 2.0f); InferenceParameters params = new InferenceParameters("") - .setInputPrefix("<|User|> " + prefix + " <|Assistant|> ") + .setInputPrefix(prefix) .setInputSuffix(suffix ) .setTemperature(0.95f) .setStopStrings("\"\"\"") @@ -97,7 +97,7 @@ public void testGenerateGrammar() { public void testCompleteAnswer() { Map logitBias = new HashMap<>(); logitBias.put(2, 2.0f); - InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") + InferenceParameters params = new InferenceParameters(prefix) .setTemperature(0.95f) .setStopStrings("\"\"\"") .setNPredict(nPredict) @@ -113,7 +113,7 @@ public void testCompleteInfillCustom() { Map logitBias = new HashMap<>(); logitBias.put(2, 2.0f); InferenceParameters params = new InferenceParameters("") - .setInputPrefix("<|User|> " + prefix +" <|Assistant|> ") + .setInputPrefix(prefix) .setInputSuffix(suffix) .setTemperature(0.95f) .setStopStrings("\"\"\"") @@ -138,7 +138,7 @@ public void testCompleteGrammar() { @Test public void testCancelGenerating() { - InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ").setNPredict(nPredict); + InferenceParameters params = new InferenceParameters(prefix).setNPredict(nPredict); int generated = 0; LlamaIterator iterator = model.generate(params).iterator(); @@ -172,7 +172,7 @@ public void testLogText() { List messages = new ArrayList<>(); LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> messages.add(new LogMessage(level, msg))); - InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") + InferenceParameters params = new InferenceParameters(prefix) .setNPredict(nPredict) .setSeed(42); model.complete(params); @@ -191,7 +191,7 @@ public void testLogJSON() { List messages = new ArrayList<>(); LlamaModel.setLogger(LogFormat.JSON, (level, msg) -> messages.add(new LogMessage(level, msg))); - InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") + InferenceParameters params = new InferenceParameters(prefix) .setNPredict(nPredict) .setSeed(42); model.complete(params); @@ -208,7 +208,7 @@ public void testLogJSON() { @Test public void testLogStdout() { // Unfortunately, `printf` can't be easily re-directed to Java. This test only works manually, thus. - InferenceParameters params = new InferenceParameters("<|User|> " + prefix +" <|Assistant|> ") + InferenceParameters params = new InferenceParameters(prefix) .setNPredict(nPredict) .setSeed(42); From 54bf4bd58ed47010369c1ecbf2e17bc4456914ce Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 21 Feb 2025 11:09:10 -0800 Subject: [PATCH 14/51] updating to add speculative execution. --- CMakeLists.txt | 2 +- src/main/cpp/jllama.cpp | 66 +++++++++++--- src/main/cpp/server.hpp | 65 +++----------- src/main/cpp/utils.hpp | 189 ++++++++++++++++++---------------------- 4 files changed, 150 insertions(+), 172 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3cf89dc6..216faed6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4719 + GIT_TAG b4753 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 00eccbb7..b719a551 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -443,23 +443,63 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo LOG_INF("%s: model loaded\n", __func__); const auto model_meta = ctx_server->model_meta(); + + if (!params.speculative.model.empty() || !params.speculative.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str()); + auto params_dft = params; - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) - { - if (!ctx_server->validate_builtin_chat_template(params.use_jinja)) - { - LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. " - "This may cause the model to output suboptimal responses\n", - __func__); - params.chat_template = "chatml"; + params_dft.devices = params.speculative.devices; + params_dft.hf_file = params.speculative.hf_file; + params_dft.hf_repo = params.speculative.hf_repo; + params_dft.model = params.speculative.model; + params_dft.model_url = params.speculative.model_url; + params_dft.n_ctx = params.speculative.n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative.n_ctx; + params_dft.n_gpu_layers = params.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + + common_init_result llama_init_dft = common_init_from_params(params_dft); + + llama_model * model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str()); + } + + if (!common_speculative_are_compatible(ctx_server->ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params.speculative.model.c_str(), params.model.c_str()); + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + ctx_server->cparams_dft = common_context_params_to_llama(params_dft); + ctx_server->cparams_dft.n_batch = n_ctx_dft; + + // force F16 KV cache for the draft model for extra performance + ctx_server->cparams_dft.type_k = GGML_TYPE_F16; + ctx_server->cparams_dft.type_v = GGML_TYPE_F16; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); } - } - // print sample chat example to make it clear which template is used + ctx_server->chat_templates = common_chat_templates_init(ctx_server->model, params.chat_template); + try { + common_chat_format_example(ctx_server->chat_templates.get(), params.use_jinja); + } catch (const std::exception & e) { + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + ctx_server->chat_templates = common_chat_templates_init(ctx_server->model, "chatml"); + } + + // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(), - common_chat_format_example(*ctx_server->chat_templates.template_default, ctx_server->params_base.use_jinja) .c_str()); + common_chat_templates_source(ctx_server->chat_templates.get()), + common_chat_format_example(ctx_server->chat_templates.get(), ctx_server->params_base.use_jinja).c_str()); + + + // print sample chat example to make it clear which template is used +// LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + // common_chat_templates_source(ctx_server->chat_templates.get()), + // common_chat_format_example(*ctx_server->chat_templates.template_default, ctx_server->params_base.use_jinja) .c_str()); ctx_server->queue_tasks.on_new_task( std::bind(&server_context::process_single_task, ctx_server, std::placeholders::_1)); diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 332c1edc..40c65889 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -265,7 +265,7 @@ struct server_task { params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); - params.speculative.n_min = std::max(params.speculative.n_min, 2); + params.speculative.n_min = std::max(params.speculative.n_min, 0); params.speculative.n_max = std::max(params.speculative.n_max, 0); // Use OpenAI API logprobs only if n_probs wasn't provided @@ -320,9 +320,6 @@ struct server_task { } // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } if (data.contains("json_schema") && !data.contains("grammar")) { try { auto schema = json_value(data, "json_schema", json::object()); @@ -705,7 +702,7 @@ struct server_task_result_cmpl_final : server_task_result { return res; } -json to_json_oaicompat_chat() { + json to_json_oaicompat_chat() { std::string finish_reason = "length"; common_chat_msg msg; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { @@ -984,6 +981,7 @@ struct server_task_result_cmpl_partial : server_task_result { } }; + struct server_task_result_embd : server_task_result { int index = 0; std::vector> embedding; @@ -1430,7 +1428,6 @@ struct server_slot { } }; - struct server_metrics { int64_t t_start = 0; @@ -1483,6 +1480,7 @@ struct server_metrics { } }; + struct server_queue { int id = 0; bool running; @@ -1799,7 +1797,7 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - common_chat_templates chat_templates; + common_chat_templates_ptr chat_templates; ~server_context() { // Clear any sampling context @@ -1883,55 +1881,15 @@ struct server_context { llama_init_dft.context.reset(); } - if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) { - SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_from_model(model, "chatml"); - } else { - chat_templates = common_chat_templates_from_model(model, params_base.chat_template); - } - GGML_ASSERT(chat_templates.template_default.get() != nullptr); - - return true; - } - - bool validate_jinja_templates() const { - auto templates = common_chat_templates_from_model(model, ""); - common_chat_inputs inputs; - inputs.messages = json::array({ - { - { "role", "user" }, - { "content", "test" }, - } - }); - GGML_ASSERT(templates.template_default); + chat_templates = common_chat_templates_init(model, params_base.chat_template); try { - common_chat_params_init(*templates.template_default, inputs); - if (templates.template_tool_use) { - common_chat_params_init(*templates.template_tool_use, inputs); - } - - return true; + common_chat_format_example(chat_templates.get(), params.use_jinja); } catch (const std::exception & e) { - SRV_ERR("failed to apply template: %s\n", e.what()); - - return false; + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_init(model, "chatml"); } - } - - - bool validate_builtin_chat_template(bool use_jinja) const { - llama_chat_message chat[] = {{"user", "test"}}; - if (use_jinja) { - return validate_jinja_templates(); - } else { - const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); - const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); - if (chat_res < 0) { - return validate_jinja_templates(); - } - return chat_res > 0; - } + return true; } void init() { @@ -2080,8 +2038,8 @@ struct server_context { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict); slot.params.n_predict = slot.n_predict; - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); } if (slot.params.ignore_eos && has_eos_token) { @@ -3358,6 +3316,7 @@ struct server_context { } }; + static void common_params_handle_model_default( std::string & model, const std::string & model_url, diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index 1c5e276a..b454465f 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -18,8 +18,7 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -#include "chat.hpp" -#include "chat-template.hpp" +#include "chat.h" #include #include @@ -352,41 +351,6 @@ static llama_tokens format_infill( return embd_inp; } -/// Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { - std::vector chat; - - for (size_t i = 0; i < messages.size(); ++i) { - const auto & curr_msg = messages[i]; - - std::string role = json_value(curr_msg, "role", std::string("")); - - std::string content; - if (curr_msg.contains("content")) { - if (curr_msg["content"].is_string()) { - content = curr_msg["content"].get(); - } else if (curr_msg["content"].is_array()) { - for (const auto & part : curr_msg["content"]) { - if (part.contains("text")) { - content += "\n" + part["text"].get(); - } - } - } else { - throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); - } - } else { - throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); - } - - chat.push_back({role, content, /* tool_calls= */ {}}); - } - - const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); - LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); - - return formatted_chat; -} - // // base64 utils (TODO: move to common in the future) // @@ -572,12 +536,10 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, - const common_chat_templates & chat_templates) + common_reasoning_format reasoning_format, + const struct common_chat_templates * tmpls) { json llama_params; - const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use - ? *chat_templates.template_tool_use - : *chat_templates.template_default; auto tools = json_value(body, "tools", json()); auto stream = json_value(body, "stream", false); @@ -603,63 +565,58 @@ static json oaicompat_completion_params_parse( llama_params["stop"] = json_value(body, "stop", json::array()); } + auto json_schema = json_value(body, "json_schema", json()); + auto grammar = json_value(body, "grammar", std::string()); + if (!json_schema.is_null() && !grammar.empty()) { + throw std::runtime_error("Cannot use both json_schema and grammar"); + } + // Handle "response_format" field if (body.contains("response_format")) { json response_format = json_value(body, "response_format", json::object()); std::string response_type = json_value(response_format, "type", std::string()); if (response_type == "json_object") { - llama_params["json_schema"] = json_value(response_format, "schema", json::object()); + json_schema = json_value(response_format, "schema", json::object()); } else if (response_type == "json_schema") { json json_schema = json_value(response_format, "json_schema", json::object()); - llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); + json_schema = json_value(json_schema, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.add_generation_prompt = true; + inputs.use_jinja = use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + // Apply chat template to the list of messages - if (use_jinja) { - auto tool_choice = json_value(body, "tool_choice", std::string("auto")); - if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { - throw std::runtime_error("Invalid tool_choice: " + tool_choice); - } - if (tool_choice != "none" && llama_params.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); - } - common_chat_inputs inputs; - inputs.messages = body.at("messages"); - inputs.tools = tools; - inputs.tool_choice = tool_choice; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { - LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); - inputs.parallel_tool_calls = false; - } - inputs.stream = stream; - // TODO: support mixing schema w/ tools beyond generic format. - inputs.json_schema = json_value(llama_params, "json_schema", json()); - auto chat_params = common_chat_params_init(tmpl, inputs); - - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; - if (inputs.json_schema == nullptr) { - llama_params["grammar"] = chat_params.grammar; - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { - grammar_triggers.push_back({ - {"word", trigger.word}, - {"at_start", trigger.at_start}, - }); - } - llama_params["grammar_triggers"] = grammar_triggers; - } - llama_params["preserved_tokens"] = chat_params.preserved_tokens; - for (const auto & stop : chat_params.additional_stops) { - llama_params["stop"].push_back(stop); - } - } else { - llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + auto chat_params = common_chat_templates_apply(tmpls, inputs); + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back({ + {"word", trigger.word}, + {"at_start", trigger.at_start}, + }); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); } // Handle "n" field @@ -731,29 +688,51 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } -static json format_response_rerank(const json & request, const json & ranks) { - json data = json::array(); - int32_t n_tokens = 0; - int i = 0; - for (const auto & rank : ranks) { - data.push_back(json{ - {"index", i++}, - {"relevance_score", json_value(rank, "score", 0.0)}, - }); +static json format_response_rerank( + const json & request, + const json & ranks, + bool is_tei_format, + std::vector & texts) { + json res; + if (is_tei_format) { + // TEI response format + res = json::array(); + bool return_text = json_value(request, "return_text", false); + for (const auto & rank : ranks) { + int index = json_value(rank, "index", 0); + json elem = json{ + {"index", index}, + {"score", json_value(rank, "score", 0.0)}, + }; + if (return_text) { + elem["text"] = std::move(texts[index]); + } + res.push_back(elem); + } + } else { + // Jina response format + json results = json::array(); + int32_t n_tokens = 0; + for (const auto & rank : ranks) { + results.push_back(json{ + {"index", json_value(rank, "index", 0)}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); + } - n_tokens += json_value(rank, "tokens_evaluated", 0); + res = json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json{ + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", results} + }; } - json res = json { - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json { - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"results", data} - }; - return res; } From 15dbe6857767ab84939f0721b0d21e542e546ac5 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 5 Mar 2025 16:53:50 -0800 Subject: [PATCH 15/51] updating dependency to latest llamacpp version --- CMakeLists.txt | 2 +- src/main/cpp/server.hpp | 71 +++++++++++++++++++++++------------------ src/main/cpp/utils.hpp | 32 +++++++++++++------ 3 files changed, 64 insertions(+), 41 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 216faed6..6fe8778b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4753 + GIT_TAG b4831 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 40c65889..da2b410b 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -26,6 +26,7 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; + enum stop_type { STOP_TYPE_NONE, STOP_TYPE_EOS, @@ -33,7 +34,7 @@ enum stop_type { STOP_TYPE_LIMIT, }; -// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future @@ -122,9 +123,9 @@ struct slot_params { lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); } - std::vector grammar_trigger_words; - for (const auto & trigger : sampling.grammar_trigger_words) { - grammar_trigger_words.push_back(trigger.word); + auto grammar_triggers = json::array(); + for (const auto & trigger : sampling.grammar_triggers) { + grammar_triggers.push_back(trigger.to_json()); } return json { @@ -161,8 +162,8 @@ struct slot_params { {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, - {"grammar_trigger_words", grammar_trigger_words}, - {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, + {"grammar_lazy", sampling.grammar_lazy}, + {"grammar_triggers", grammar_triggers}, {"preserved_tokens", sampling.preserved_tokens}, {"chat_format", common_chat_format_name(oaicompat_chat_format)}, {"samplers", samplers}, @@ -347,24 +348,6 @@ struct server_task { } { - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto & t : *grammar_triggers) { - common_grammar_trigger trigger; - trigger.word = t.at("word"); - trigger.at_start = t.at("at_start"); - - auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); - params.sampling.grammar_trigger_tokens.push_back(ids[0]); - params.sampling.preserved_tokens.insert(ids[0]); - continue; - } - SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); - params.sampling.grammar_trigger_words.push_back(trigger); - } - } const auto preserved_tokens = data.find("preserved_tokens"); if (preserved_tokens != data.end()) { for (const auto & t : *preserved_tokens) { @@ -374,12 +357,38 @@ struct server_task { params.sampling.preserved_tokens.insert(ids[0]); } else { // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get().c_str()); + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); } } } - if (params.sampling.grammar_lazy) { - GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0); + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + auto ct = common_grammar_trigger::from_json(t); + if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = (llama_token) token; + params.sampling.grammar_triggers.push_back(trigger); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + params.sampling.grammar_triggers.push_back(ct); + } + } + } + if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); } } @@ -981,7 +990,6 @@ struct server_task_result_cmpl_partial : server_task_result { } }; - struct server_task_result_embd : server_task_result { int index = 0; std::vector> embedding; @@ -1480,7 +1488,6 @@ struct server_metrics { } }; - struct server_queue { int id = 0; bool running; @@ -2038,7 +2045,7 @@ struct server_context { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict); + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); slot.params.n_predict = slot.n_predict; } @@ -2996,7 +3003,7 @@ struct server_context { const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c); - llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift); + llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; @@ -3317,6 +3324,8 @@ struct server_context { }; + + static void common_params_handle_model_default( std::string & model, const std::string & model_url, diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index b454465f..cc384d96 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -490,6 +490,17 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } +//static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { +// const std::string str = +// std::string(event) + ": " + +// data.dump(-1, ' ', false, json::error_handler_t::replace) + +// "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). +// +// LOG_DBG("data stream, to_send: %s", str.c_str()); +// +// return sink.write(str.c_str(), str.size()); +//} + // // OAI utils // @@ -514,8 +525,13 @@ static json oaicompat_completion_params_parse(const json & body) { throw std::runtime_error("Only one completion choice is allowed"); } + // Handle "echo" field + if (json_value(body, "echo", false)) { + throw std::runtime_error("Only no echo is supported"); + } + // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "best_of", "echo", "suffix" }; + static const std::vector unsupported_params { "best_of", "suffix" }; for (const auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); @@ -578,8 +594,8 @@ static json oaicompat_completion_params_parse( if (response_type == "json_object") { json_schema = json_value(response_format, "schema", json::object()); } else if (response_type == "json_schema") { - json json_schema = json_value(response_format, "json_schema", json::object()); - json_schema = json_value(json_schema, "schema", json::object()); + auto schema_wrapper = json_value(response_format, "json_schema", json::object()); + json_schema = json_value(schema_wrapper, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } @@ -591,10 +607,11 @@ static json oaicompat_completion_params_parse( inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.grammar = grammar; - inputs.add_generation_prompt = true; + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); inputs.use_jinja = use_jinja; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { throw std::runtime_error("Cannot use custom grammar constraints with tools."); } @@ -608,10 +625,7 @@ static json oaicompat_completion_params_parse( llama_params["grammar_lazy"] = chat_params.grammar_lazy; auto grammar_triggers = json::array(); for (const auto & trigger : chat_params.grammar_triggers) { - grammar_triggers.push_back({ - {"word", trigger.word}, - {"at_start", trigger.at_start}, - }); + grammar_triggers.push_back(trigger.to_json()); } llama_params["grammar_triggers"] = grammar_triggers; llama_params["preserved_tokens"] = chat_params.preserved_tokens; @@ -869,4 +883,4 @@ static std::vector parse_lora_request( } return lora; -} +} \ No newline at end of file From c00de24bd632b0a7d804bcaf5ba1e306e2ad777c Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 5 Mar 2025 20:26:20 -0800 Subject: [PATCH 16/51] removed releaseTask --- src/main/java/de/kherud/llama/LlamaModel.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index fc0e70fa..43bf0772 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -54,7 +54,6 @@ public String complete(InferenceParameters parameters) { parameters.setStream(false); int taskId = requestCompletion(parameters.toString()); LlamaOutput output = receiveCompletion(taskId); - releaseTask(taskId); return output.text; } From 7a3f6726bf40cba45fe6370e42a7801577fe2583 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 5 Mar 2025 21:11:59 -0800 Subject: [PATCH 17/51] updated to remove unused and duplicate imports --- src/main/cpp/jllama.cpp | 2 +- src/main/cpp/server.hpp | 5 ----- src/main/cpp/utils.hpp | 2 +- src/test/java/de/kherud/llama/LlamaModelTest.java | 3 ++- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index b719a551..3a547bc8 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -777,7 +777,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobje jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) ctx_server->queue_tasks.terminate(); - delete ctx_server; + //delete ctx_server; } JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *env, jobject obj, jint id_task) diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index da2b410b..031c4a6b 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -1,14 +1,9 @@ #include "utils.hpp" -#include "common.h" #include "json-schema-to-grammar.h" -#include "llama.h" -#include "log.h" #include "sampling.h" #include "speculative.h" -#include "nlohmann/json.hpp" - #include #include #include diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index cc384d96..e9498014 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -16,7 +16,7 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT -#include "json.hpp" +#include "nlohmann/json.hpp" #include "chat.h" diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index 6fbe2e43..9e5b767b 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -133,7 +133,8 @@ public void testCompleteGrammar() { String output = model.complete(params); Assert.assertTrue(output + " doesn't match [ab]+", output.matches("[ab]+")); int generated = model.encode(output).length; - Assert.assertTrue(generated > 0 && generated <= nPredict + 1); + Assert.assertTrue("generated count is: " + generated, generated > 0 && generated <= nPredict + 1); + } @Test From cc8f1327b1d6c8c62281a0055a33861dc3b90d98 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Thu, 6 Mar 2025 23:13:24 -0800 Subject: [PATCH 18/51] adding x64 arch for windows --- src/main/java/de/kherud/llama/OSInfo.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/main/java/de/kherud/llama/OSInfo.java b/src/main/java/de/kherud/llama/OSInfo.java index a62861bf..772aeaef 100644 --- a/src/main/java/de/kherud/llama/OSInfo.java +++ b/src/main/java/de/kherud/llama/OSInfo.java @@ -32,6 +32,7 @@ @SuppressWarnings("UseOfSystemOutOrSystemErr") class OSInfo { public static final String X86 = "x86"; + public static final String X64 = "x64"; public static final String X86_64 = "x86_64"; public static final String IA64_32 = "ia64_32"; public static final String IA64 = "ia64"; @@ -78,6 +79,9 @@ class OSInfo { archMapping.put("power_rs64", PPC64); archMapping.put("ppc64el", PPC64); archMapping.put("ppc64le", PPC64); + + // TODO: Adding X64 support + archMapping.put(X64, X64); } public static void main(String[] args) { From 27dacab3438b3953b0b66262b2b875b22a6a3bf9 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 13:39:11 -0800 Subject: [PATCH 19/51] updating windows workflow to copy all the dlls --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d8db1a21..54a9435c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -87,6 +87,10 @@ jobs: run: | mvn compile .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include + - name: Copy DLL to Java resources + run: | + mkdir -Force "target/classes/Windows/x86_64" + Copy-Item ".\build\Release\*.dll" "target/classes/Windows/x86_64/" - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests From 036e020e6a9201ee6a6bdd1291afca8623932753 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 13:59:00 -0800 Subject: [PATCH 20/51] updating windows workflow. --- .github/workflows/ci.yml | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 54a9435c..0ebdd7bc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -84,17 +84,25 @@ jobs: - name: Install libcurl run: vcpkg install curl - name: Build libraries - run: | + run: | mvn compile .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include + - name: Copy DLL to Java resources - run: | + run: | mkdir -Force "target/classes/Windows/x86_64" - Copy-Item ".\build\Release\*.dll" "target/classes/Windows/x86_64/" + Copy-Item ".\build\Release\llama.dll" "target/classes/Windows/x86_64/" + + - name: Verify DLL placement (debug step) + run: dir target\classes\Windows\x86_64\ + - name: Download model - run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - - name: Run tests - run: mvn test + run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME + + - name: Run tests with explicit DLL path + run: | + mvn test -Djava.library.path="${{ github.workspace }}\target\classes\Windows\x86_64" + - if: failure() uses: actions/upload-artifact@v4 with: From aef5b69a9691294bfca2dc1931599c33170cdabf Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 14:13:45 -0800 Subject: [PATCH 21/51] validated yml file using lint --- .github/workflows/ci.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ebdd7bc..0d7be03f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -84,23 +84,23 @@ jobs: - name: Install libcurl run: vcpkg install curl - name: Build libraries - run: | + run: | mvn compile .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include - name: Copy DLL to Java resources - run: | + run: | mkdir -Force "target/classes/Windows/x86_64" Copy-Item ".\build\Release\llama.dll" "target/classes/Windows/x86_64/" - name: Verify DLL placement (debug step) - run: dir target\classes\Windows\x86_64\ + run: dir target\classes\Windows\x86_64\ - name: Download model - run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME + run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests with explicit DLL path - run: | + run: | mvn test -Djava.library.path="${{ github.workspace }}\target\classes\Windows\x86_64" - if: failure() From 6ea33c3a6b386fb40c1b641ffc1545c71cb86c79 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 14:31:57 -0800 Subject: [PATCH 22/51] trying few suggestion --- .github/workflows/ci.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0d7be03f..0a91e787 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,17 +91,18 @@ jobs: - name: Copy DLL to Java resources run: | mkdir -Force "target/classes/Windows/x86_64" - Copy-Item ".\build\Release\llama.dll" "target/classes/Windows/x86_64/" + Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - name: Verify DLL placement (debug step) - run: dir target\classes\Windows\x86_64\ + run: | + dir target\classes\de\kherud\llama\Windows\x86_64\ - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests with explicit DLL path run: | - mvn test -Djava.library.path="${{ github.workspace }}\target\classes\Windows\x86_64" + mvn test - if: failure() uses: actions/upload-artifact@v4 From 230b72f5ddc817e466816d3b9f51722ed1f16606 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 14:50:02 -0800 Subject: [PATCH 23/51] update the workflow path --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a91e787..820fa397 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -90,7 +90,7 @@ jobs: - name: Copy DLL to Java resources run: | - mkdir -Force "target/classes/Windows/x86_64" + mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - name: Verify DLL placement (debug step) From 746c31ab27fd9f3b471ce17dfe32bb2c934af693 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 14:58:52 -0800 Subject: [PATCH 24/51] trying to find which library we are missing --- src/main/java/de/kherud/llama/LlamaLoader.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/de/kherud/llama/LlamaLoader.java b/src/main/java/de/kherud/llama/LlamaLoader.java index a0239d20..6bb6ace2 100644 --- a/src/main/java/de/kherud/llama/LlamaLoader.java +++ b/src/main/java/de/kherud/llama/LlamaLoader.java @@ -152,7 +152,8 @@ private static void loadNativeLibrary(String name) { throw new UnsatisfiedLinkError( String.format( - "No native library found for os.name=%s, os.arch=%s, paths=[%s]", + "No native library found for name=%s os.name=%s, os.arch=%s, paths=[%s]", + name, OSInfo.getOSName(), OSInfo.getArchName(), String.join(File.pathSeparator, triedPaths) From 8b5de74948c358dcb4e8c300340eae3de12ebb16 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 17:15:23 -0800 Subject: [PATCH 25/51] update the workflow path --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 820fa397..6b694ca9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -102,7 +102,7 @@ jobs: - name: Run tests with explicit DLL path run: | - mvn test + mvn test "-Djava.library.path=${env:PATH};target/classes/de/kherud/llama/Windows/x86_64" - if: failure() uses: actions/upload-artifact@v4 From e0efe9f40b920beb1051d863f13aae92887cbad2 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 17:28:47 -0800 Subject: [PATCH 26/51] update the workflow path --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6b694ca9..ad6606a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -92,6 +92,7 @@ jobs: run: | mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" + Copy-Item "C:\vcpkg\installed\x64-windows\bin\curl.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - name: Verify DLL placement (debug step) run: | From 12220ea579fa0cb9d831e69c93eb45cd9715906d Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 17:42:38 -0800 Subject: [PATCH 27/51] update the workflow path --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad6606a2..96b78d6a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -92,7 +92,8 @@ jobs: run: | mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - Copy-Item "C:\vcpkg\installed\x64-windows\bin\curl.dll" "target/classes/de/kherud/llama/Windows/x86_64/" + Get-ChildItem "C:/vcpkg/packages/curl_x64-windows" -Filter *.dll -Recurse ` + | ForEach-Object { Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose } - name: Verify DLL placement (debug step) run: | From 859844f6d807c2162d306fb745c9eefacf5c1ca5 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 17:57:39 -0800 Subject: [PATCH 28/51] update the workflow path --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 96b78d6a..e99e510e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -104,6 +104,7 @@ jobs: - name: Run tests with explicit DLL path run: | + $env:PATH = "C:\vcpkg\installed\x64-windows\bin;${env:PATH}" mvn test "-Djava.library.path=${env:PATH};target/classes/de/kherud/llama/Windows/x86_64" - if: failure() From ed2421cc01a513422713afbf768f9da58c14daf8 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 20:56:59 -0800 Subject: [PATCH 29/51] update the workflow path --- .github/workflows/ci.yml | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e99e510e..3929ae63 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -88,12 +88,30 @@ jobs: mvn compile .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include - - name: Copy DLL to Java resources + - name: Prepare DLL directory run: | mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - Get-ChildItem "C:/vcpkg/packages/curl_x64-windows" -Filter *.dll -Recurse ` - | ForEach-Object { Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose } + + # Copy curl and all its dependencies to our directory + Get-ChildItem "C:/vcpkg/installed/x64-windows/bin" -Filter *.dll | ForEach-Object { + Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + } + + # Also copy from curl packages directory for completeness + Get-ChildItem "C:/vcpkg/packages/curl_x64-windows/bin" -Filter *.dll -Recurse | ForEach-Object { + Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + } + + # Copy OpenSSL DLLs if needed by curl + Get-ChildItem "C:/vcpkg/packages/openssl_x64-windows/bin" -Filter *.dll -Recurse | ForEach-Object { + Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + } + + # Copy zlib DLLs if needed + Get-ChildItem "C:/vcpkg/packages/zlib_x64-windows/bin" -Filter *.dll -Recurse | ForEach-Object { + Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + } - name: Verify DLL placement (debug step) run: | From 605c600a25d689975006a5e919f605f97e718b55 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 21:06:07 -0800 Subject: [PATCH 30/51] update the workflow path --- .github/workflows/ci.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3929ae63..3ec192e1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -81,8 +81,12 @@ jobs: with: distribution: 'zulu' java-version: '11' - - name: Install libcurl - run: vcpkg install curl + - name: Install libcurl and dependencies + run: | + vcpkg install curl:x64-windows + vcpkg install openssl:x64-windows + vcpkg install zlib:x64-windows + - name: Build libraries run: | mvn compile From d2677762e6116325d5e3dd9c3e54ead15246660e Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 21:30:40 -0800 Subject: [PATCH 31/51] update the workflow path --- .github/workflows/ci.yml | 82 ++++--- .../java/de/kherud/llama/LlamaLoader.java | 225 +++++++++++++----- 2 files changed, 221 insertions(+), 86 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3ec192e1..b3275047 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -86,40 +86,68 @@ jobs: vcpkg install curl:x64-windows vcpkg install openssl:x64-windows vcpkg install zlib:x64-windows + vcpkg install boost-filesystem:x64-windows # Often needed for C++ projects + vcpkg install boost-system:x64-windows # Often needed for C++ projects + + - name: Download Dependency Walker + run: | + Invoke-WebRequest -Uri "https://www.dependencywalker.com/depends22_x64.zip" -OutFile "depends.zip" + Expand-Archive -Path "depends.zip" -DestinationPath "depends" - name: Build libraries run: | mvn compile .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include - - name: Prepare DLL directory - run: | - mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" - Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - - # Copy curl and all its dependencies to our directory - Get-ChildItem "C:/vcpkg/installed/x64-windows/bin" -Filter *.dll | ForEach-Object { - Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose - } - - # Also copy from curl packages directory for completeness - Get-ChildItem "C:/vcpkg/packages/curl_x64-windows/bin" -Filter *.dll -Recurse | ForEach-Object { - Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose - } - - # Copy OpenSSL DLLs if needed by curl - Get-ChildItem "C:/vcpkg/packages/openssl_x64-windows/bin" -Filter *.dll -Recurse | ForEach-Object { - Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose - } - - # Copy zlib DLLs if needed - Get-ChildItem "C:/vcpkg/packages/zlib_x64-windows/bin" -Filter *.dll -Recurse | ForEach-Object { - Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + - name: Prepare DLL directory + run: | + mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" + Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" + + # Copy ALL DLLs from vcpkg directories to ensure we have everything + Get-ChildItem "C:/vcpkg/installed/x64-windows/bin" -Filter *.dll | ForEach-Object { + Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + } + + # Also from the packages directory + Get-ChildItem "C:/vcpkg/packages" -Recurse -Filter "*.dll" | Where-Object { $_.Directory -like "*bin*" } | ForEach-Object { + Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + } + + # Copy Visual C++ Redistributable DLLs + $vcredistPath = "C:\Windows\System32" + @( + "msvcp140.dll", + "vcruntime140.dll", + "vcruntime140_1.dll", + "msvcp140_1.dll", + "msvcp140_2.dll", + "concrt140.dll" + ) | ForEach-Object { + if (Test-Path "$vcredistPath\$_") { + Copy-Item "$vcredistPath\$_" -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose } + } + + - name: Analyze DLL dependencies + run: | + # Run dependency walker on ggml.dll to see what's missing + .\depends\depends.exe -c -oc:deps_ggml.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\ggml.dll" + # Also analyze jllama.dll and llama.dll + .\depends\depends.exe -c -oc:deps_jllama.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\jllama.dll" + .\depends\depends.exe -c -oc:deps_llama.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\llama.dll" + + # Display the results + Get-Content deps_ggml.txt + echo "--------------------" + Get-Content deps_jllama.txt + echo "--------------------" + Get-Content deps_llama.txt + + - name: Verify DLL placement + run: | + dir target\classes\de\kherud\llama\Windows\x86_64\ - - name: Verify DLL placement (debug step) - run: | - dir target\classes\de\kherud\llama\Windows\x86_64\ - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME @@ -127,7 +155,7 @@ jobs: - name: Run tests with explicit DLL path run: | $env:PATH = "C:\vcpkg\installed\x64-windows\bin;${env:PATH}" - mvn test "-Djava.library.path=${env:PATH};target/classes/de/kherud/llama/Windows/x86_64" + mvn test "-Djava.library.path=${env:PATH};target/classes/de/kherud/llama/Windows/x86_64" -Ddebug.native.loading=true - if: failure() uses: actions/upload-artifact@v4 diff --git a/src/main/java/de/kherud/llama/LlamaLoader.java b/src/main/java/de/kherud/llama/LlamaLoader.java index 6bb6ace2..2605d96e 100644 --- a/src/main/java/de/kherud/llama/LlamaLoader.java +++ b/src/main/java/de/kherud/llama/LlamaLoader.java @@ -26,6 +26,7 @@ import java.nio.file.StandardCopyOption; import java.util.LinkedList; import java.util.List; +import java.util.UUID; import java.util.stream.Stream; import org.jetbrains.annotations.Nullable; @@ -95,70 +96,176 @@ private static void cleanPath(Path path) { } private static void loadNativeLibrary(String name) { - List triedPaths = new LinkedList<>(); + List triedPaths = new LinkedList<>(); + boolean isDebug = System.getProperty("debug.native.loading", "false").equals("true"); + + if (isDebug) { + System.out.println("[DEBUG] Attempting to load native library: " + name); + System.out.println("[DEBUG] Current working directory: " + System.getProperty("user.dir")); + System.out.println("[DEBUG] java.library.path: " + System.getProperty("java.library.path", "")); + System.out.println("[DEBUG] PATH environment: " + System.getenv("PATH")); + } - String nativeLibName = System.mapLibraryName(name); - String nativeLibPath = System.getProperty("de.kherud.llama.lib.path"); - if (nativeLibPath != null) { - Path path = Paths.get(nativeLibPath, nativeLibName); - if (loadNativeLibrary(path)) { - return; - } - else { - triedPaths.add(nativeLibPath); - } - } + String nativeLibName = System.mapLibraryName(name); + if (isDebug) { + System.out.println("[DEBUG] Mapped library name: " + nativeLibName); + } + + String nativeLibPath = System.getProperty("de.kherud.llama.lib.path"); + if (nativeLibPath != null) { + Path path = Paths.get(nativeLibPath, nativeLibName); + if (isDebug) { + System.out.println("[DEBUG] Trying custom lib path: " + path); + } + if (loadNativeLibraryWithDebug(path, isDebug)) { + return; + } else { + triedPaths.add(nativeLibPath); + } + } - if (OSInfo.isAndroid()) { - try { - // loadLibrary can load directly from packed apk file automatically - // if java-llama.cpp is added as code source - System.loadLibrary(name); - return; - } - catch (UnsatisfiedLinkError e) { - triedPaths.add("Directly from .apk/lib"); - } - } + if (OSInfo.isAndroid()) { + try { + if (isDebug) { + System.out.println("[DEBUG] Android detected, trying System.loadLibrary directly"); + } + // loadLibrary can load directly from packed apk file automatically + // if java-llama.cpp is added as code source + System.loadLibrary(name); + return; + } catch (UnsatisfiedLinkError e) { + if (isDebug) { + System.out.println("[DEBUG] Failed to load from APK: " + e.getMessage()); + } + triedPaths.add("Directly from .apk/lib"); + } + } - // Try to load the library from java.library.path - String javaLibraryPath = System.getProperty("java.library.path", ""); - for (String ldPath : javaLibraryPath.split(File.pathSeparator)) { - if (ldPath.isEmpty()) { - continue; - } - Path path = Paths.get(ldPath, nativeLibName); - if (loadNativeLibrary(path)) { - return; - } - else { - triedPaths.add(ldPath); - } - } + // Try to load the library from java.library.path + String javaLibraryPath = System.getProperty("java.library.path", ""); + for (String ldPath : javaLibraryPath.split(File.pathSeparator)) { + if (ldPath.isEmpty()) { + continue; + } + Path path = Paths.get(ldPath, nativeLibName); + if (isDebug) { + System.out.println("[DEBUG] Trying java.library.path entry: " + path); + if (Files.exists(path)) { + System.out.println("[DEBUG] File exists at path: " + path); + } else { + System.out.println("[DEBUG] File does NOT exist at path: " + path); + } + } + if (loadNativeLibraryWithDebug(path, isDebug)) { + return; + } else { + triedPaths.add(ldPath); + } + } - // As a last resort try load the os-dependent library from the jar file - nativeLibPath = getNativeResourcePath(); - if (hasNativeLib(nativeLibPath, nativeLibName)) { - // temporary library folder - String tempFolder = getTempDir().getAbsolutePath(); - // Try extracting the library from jar - if (extractAndLoadLibraryFile(nativeLibPath, nativeLibName, tempFolder)) { - return; - } - else { - triedPaths.add(nativeLibPath); - } - } + // As a last resort try load the os-dependent library from the jar file + nativeLibPath = getNativeResourcePath(); + if (isDebug) { + System.out.println("[DEBUG] Trying to extract from JAR, native resource path: " + nativeLibPath); + } + + if (hasNativeLib(nativeLibPath, nativeLibName)) { + // temporary library folder + String tempFolder = getTempDir().getAbsolutePath(); + if (isDebug) { + System.out.println("[DEBUG] Extracting library to temp folder: " + tempFolder); + } + + // Try extracting the library from jar + if (extractAndLoadLibraryFileWithDebug(nativeLibPath, nativeLibName, tempFolder, isDebug)) { + return; + } else { + triedPaths.add(nativeLibPath); + } + } else if (isDebug) { + System.out.println("[DEBUG] Native library not found in JAR at path: " + nativeLibPath + "/" + nativeLibName); + } + + throw new UnsatisfiedLinkError( + String.format( + "No native library found for name=%s os.name=%s, os.arch=%s, paths=[%s]", + name, + OSInfo.getOSName(), + OSInfo.getArchName(), + String.join(File.pathSeparator, triedPaths) + ) + ); + } + + // Add these helper methods + + private static boolean loadNativeLibraryWithDebug(Path path, boolean isDebug) { + try { + if (isDebug) { + System.out.println("[DEBUG] Attempting to load: " + path.toAbsolutePath()); + } + + if (!Files.exists(path)) { + if (isDebug) System.out.println("[DEBUG] File doesn't exist: " + path); + return false; + } + + System.load(path.toAbsolutePath().toString()); + if (isDebug) System.out.println("[DEBUG] Successfully loaded: " + path); + return true; + } catch (UnsatisfiedLinkError e) { + if (isDebug) { + System.out.println("[DEBUG] Failed to load " + path + ": " + e.getMessage()); + e.printStackTrace(); + } + return false; + } + } - throw new UnsatisfiedLinkError( - String.format( - "No native library found for name=%s os.name=%s, os.arch=%s, paths=[%s]", - name, - OSInfo.getOSName(), - OSInfo.getArchName(), - String.join(File.pathSeparator, triedPaths) - ) - ); + private static boolean extractAndLoadLibraryFileWithDebug(String libFolderForCurrentOS, String libraryFileName, + String targetFolder, boolean isDebug) { + String nativeLibraryFilePath = libFolderForCurrentOS + "/" + libraryFileName; + + // Include architecture name in temporary filename to avoid naming conflicts + String uuid = UUID.randomUUID().toString(); + String extractedLibFileName = String.format("%s-%s-%s", libraryFileName, uuid, OSInfo.getArchName()); + File extractedLibFile = new File(targetFolder, extractedLibFileName); + + try (InputStream reader = LlamaLoader.class.getResourceAsStream(nativeLibraryFilePath)) { + if (isDebug) { + System.out.println("[DEBUG] Extracting native library from JAR: " + nativeLibraryFilePath); + } + + if (reader == null) { + if (isDebug) System.out.println("[DEBUG] Cannot find native library in JAR: " + nativeLibraryFilePath); + return false; + } + + Files.copy(reader, extractedLibFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + + if (isDebug) { + System.out.println("[DEBUG] Extracted to: " + extractedLibFile.getAbsolutePath()); + System.out.println("[DEBUG] Attempting to load extracted file"); + } + + try { + System.load(extractedLibFile.getAbsolutePath()); + if (isDebug) System.out.println("[DEBUG] Successfully loaded: " + extractedLibFile.getAbsolutePath()); + return true; + } catch (UnsatisfiedLinkError e) { + if (isDebug) { + System.out.println("[DEBUG] Failed to load extracted library: " + e.getMessage()); + e.printStackTrace(); + } + return false; + } + } catch (IOException e) { + if (isDebug) { + System.out.println("[DEBUG] Failed to extract library: " + e.getMessage()); + e.printStackTrace(); + } + return false; + } } /** From 2e8be8a40b4ebb89f26013195e0756e735de8eac Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 21:35:00 -0800 Subject: [PATCH 32/51] update the workflow path --- .github/workflows/ci.yml | 174 +++++++++++++++++++++------------------ 1 file changed, 94 insertions(+), 80 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b3275047..7d9fe776 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,14 +1,12 @@ -# This work flow runs all Java tests for continuous integration. -# Since it has to build llama.cpp first, for speed, it only runs / tests on the natively supported GitHub runners. - +--- name: Continuous Integration -on: [ "pull_request", "workflow_dispatch" ] +on: + - pull_request + - workflow_dispatch env: - MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf" - MODEL_NAME: "codellama-7b.Q2_K.gguf" + MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf + MODEL_NAME: codellama-7b.Q2_K.gguf jobs: - - # don't split build and test jobs to keep the workflow simple build-and-test-linux: name: ubuntu-latest runs-on: ubuntu-latest @@ -16,12 +14,11 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 with: - distribution: 'zulu' - java-version: '11' + distribution: zulu + java-version: "11" - name: Install libcurl run: sudo apt-get install -y libcurl4-openssl-dev - name: Build libraries - # cmake should figure out OS and ARCH automatically when running build.sh (but we need mvn compile for it) run: | mvn compile .github/build.sh -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON @@ -35,7 +32,6 @@ jobs: name: error-log-linux path: ${{ github.workspace }}/hs_err_pid*.log if-no-files-found: warn - build-and-test-macos: name: ${{ matrix.target.runner }} runs-on: ${{ matrix.target.runner }} @@ -43,20 +39,17 @@ jobs: fail-fast: false matrix: target: - - { - runner: macos-13, - cmake: '-DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON' - } - - { - runner: macos-14, - cmake: '-DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON' - } + - runner: macos-13 + cmake: -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON + - runner: macos-14 + cmake: -DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON + -DLLAMA_CURL=ON steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 with: - distribution: 'zulu' - java-version: '11' + distribution: zulu + java-version: "11" - name: Build libraries run: | mvn compile @@ -71,7 +64,6 @@ jobs: name: error-log-macos path: ${{ github.workspace }}/hs_err_pid*.log if-no-files-found: warn - build-and-test-windows: name: windows-latest runs-on: windows-latest @@ -79,87 +71,109 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 with: - distribution: 'zulu' - java-version: '11' + distribution: zulu + java-version: "11" - name: Install libcurl and dependencies - run: | + run: > vcpkg install curl:x64-windows + vcpkg install openssl:x64-windows + vcpkg install zlib:x64-windows + vcpkg install boost-filesystem:x64-windows # Often needed for C++ projects - vcpkg install boost-system:x64-windows # Often needed for C++ projects - - name: Download Dependency Walker - run: | - Invoke-WebRequest -Uri "https://www.dependencywalker.com/depends22_x64.zip" -OutFile "depends.zip" - Expand-Archive -Path "depends.zip" -DestinationPath "depends" + vcpkg install boost-system:x64-windows # Often needed for C++ projects + - name: Download Dependency Walker + run: > + Invoke-WebRequest -Uri "https://www.dependencywalker.com/depends22_x64.zip" + -OutFile "depends.zip" + Expand-Archive -Path "depends.zip" -DestinationPath "depends" - name: Build libraries - run: | + run: > mvn compile + .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include + - name: Prepare DLL directory + run: > + mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" + + Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" + + + #Copy ALL DLLs from vcpkg directories to ensure we have everything + + Get-ChildItem "C:/vcpkg/installed/x64-windows/bin" -Filter *.dll | ForEach-Object { - - name: Prepare DLL directory - run: | - mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" - Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - - # Copy ALL DLLs from vcpkg directories to ensure we have everything - Get-ChildItem "C:/vcpkg/installed/x64-windows/bin" -Filter *.dll | ForEach-Object { Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose - } - - # Also from the packages directory - Get-ChildItem "C:/vcpkg/packages" -Recurse -Filter "*.dll" | Where-Object { $_.Directory -like "*bin*" } | ForEach-Object { + + } + + + # Also from the packages directory + + Get-ChildItem "C:/vcpkg/packages" -Recurse -Filter "*.dll" | Where-Object { $_.Directory -like "*bin*" } | ForEach-Object { + Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose - } - - # Copy Visual C++ Redistributable DLLs - $vcredistPath = "C:\Windows\System32" - @( - "msvcp140.dll", - "vcruntime140.dll", - "vcruntime140_1.dll", - "msvcp140_1.dll", - "msvcp140_2.dll", - "concrt140.dll" - ) | ForEach-Object { - if (Test-Path "$vcredistPath\$_") { - Copy-Item "$vcredistPath\$_" -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + } - } - - - name: Analyze DLL dependencies - run: | - # Run dependency walker on ggml.dll to see what's missing - .\depends\depends.exe -c -oc:deps_ggml.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\ggml.dll" - # Also analyze jllama.dll and llama.dll - .\depends\depends.exe -c -oc:deps_jllama.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\jllama.dll" - .\depends\depends.exe -c -oc:deps_llama.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\llama.dll" - - # Display the results - Get-Content deps_ggml.txt - echo "--------------------" - Get-Content deps_jllama.txt - echo "--------------------" - Get-Content deps_llama.txt - - - name: Verify DLL placement - run: | - dir target\classes\de\kherud\llama\Windows\x86_64\ + # Copy Visual C++ Redistributable DLLs + + $vcredistPath = "C:\Windows\System32" + + @( + "msvcp140.dll", + "vcruntime140.dll", + "vcruntime140_1.dll", + "msvcp140_1.dll", + "msvcp140_2.dll", + "concrt140.dll" + ) | ForEach-Object { + if (Test-Path "$vcredistPath\$_") { + Copy-Item "$vcredistPath\$_" -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose + } + } + - name: Analyze DLL dependencies + run: > + # Run dependency walker on ggml.dll to see what's missing + + .\depends\depends.exe -c -oc:deps_ggml.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\ggml.dll" + + # Also analyze jllama.dll and llama.dll + + .\depends\depends.exe -c -oc:deps_jllama.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\jllama.dll" + + .\depends\depends.exe -c -oc:deps_llama.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\llama.dll" + + + # Display the results + + Get-Content deps_ggml.txt + + echo "--------------------" + + Get-Content deps_jllama.txt + + echo "--------------------" + + Get-Content deps_llama.txt + - name: Verify DLL placement + run: | + dir target\classes\de\kherud\llama\Windows\x86_64\ - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - - name: Run tests with explicit DLL path - run: | + run: > $env:PATH = "C:\vcpkg\installed\x64-windows\bin;${env:PATH}" - mvn test "-Djava.library.path=${env:PATH};target/classes/de/kherud/llama/Windows/x86_64" -Ddebug.native.loading=true + mvn test "-Djava.library.path=${env:PATH};target/classes/de/kherud/llama/Windows/x86_64" -Ddebug.native.loading=true - if: failure() uses: actions/upload-artifact@v4 with: name: error-log-windows path: ${{ github.workspace }}\hs_err_pid*.log if-no-files-found: warn + From f7bc392c4153946a003e2b4ee1db64e72bc8dcea Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 22:13:51 -0800 Subject: [PATCH 33/51] update the workflow path --- .github/workflows/ci.yml | 79 ++++++++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d9fe776..112d7216 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -90,6 +90,8 @@ jobs: -OutFile "depends.zip" Expand-Archive -Path "depends.zip" -DestinationPath "depends" + # Verify it was extracted correctly + dir depends - name: Build libraries run: > mvn compile @@ -137,29 +139,60 @@ jobs: } } - name: Analyze DLL dependencies - run: > - # Run dependency walker on ggml.dll to see what's missing - - .\depends\depends.exe -c -oc:deps_ggml.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\ggml.dll" - - # Also analyze jllama.dll and llama.dll - - .\depends\depends.exe -c -oc:deps_jllama.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\jllama.dll" - - .\depends\depends.exe -c -oc:deps_llama.txt "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\llama.dll" - - - # Display the results - - Get-Content deps_ggml.txt - - echo "--------------------" - - Get-Content deps_jllama.txt - - echo "--------------------" - - Get-Content deps_llama.txt + run: | + # Create directory for outputs + mkdir -Force "dependency_reports" + + # Get paths to DLLs for analysis + $ggmlPath = "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\ggml.dll" + $llamaPath = "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\llama.dll" + $jllamaPath = "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\jllama.dll" + + # Verify files exist before analysis + Write-Host "Verifying DLL files exist:" + if (Test-Path $ggmlPath) { Write-Host "ggml.dll exists at $ggmlPath" } else { Write-Host "ERROR: ggml.dll NOT FOUND at $ggmlPath" } + if (Test-Path $llamaPath) { Write-Host "llama.dll exists at $llamaPath" } else { Write-Host "ERROR: llama.dll NOT FOUND at $llamaPath" } + if (Test-Path $jllamaPath) { Write-Host "jllama.dll exists at $jllamaPath" } else { Write-Host "ERROR: jllama.dll NOT FOUND at $jllamaPath" } + + # Alternative approach using dumpbin (available on Windows) + Write-Host "Analyzing dependencies with dumpbin..." + + # Create a function to extract dependencies + function Get-Dependencies { + param([string]$dllPath, [string]$outputPath) + + if (Test-Path $dllPath) { + Write-Host "Running dumpbin on $dllPath" + & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\bin\Hostx64\x64\dumpbin.exe" /DEPENDENTS $dllPath > $outputPath + if ($LASTEXITCODE -eq 0) { + Write-Host "Successfully wrote dependencies to $outputPath" + Get-Content $outputPath | Select-String -Pattern "Image has the following dependencies" + Get-Content $outputPath | Select-String -Pattern "\.dll" + } else { + Write-Host "Error running dumpbin: $LASTEXITCODE" + } + } else { + Write-Host "ERROR: File not found: $dllPath" + } + } + + # Run dependency analysis + Get-Dependencies -dllPath $ggmlPath -outputPath "dependency_reports\deps_ggml.txt" + Get-Dependencies -dllPath $llamaPath -outputPath "dependency_reports\deps_llama.txt" + Get-Dependencies -dllPath $jllamaPath -outputPath "dependency_reports\deps_jllama.txt" + + # List files in the output directory + Write-Host "Files in dependency_reports directory:" + dir dependency_reports + + - name: Upload dependency reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: dependency-reports + path: dependency_reports\* + if-no-files-found: warn + - name: Verify DLL placement run: | dir target\classes\de\kherud\llama\Windows\x86_64\ From 932fac3fd992cf5102f7688cc23ec7c6e3324365 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 7 Mar 2025 23:58:58 -0800 Subject: [PATCH 34/51] removing curl support from windows --- .github/workflows/ci.yml | 136 ++------------------------------------- 1 file changed, 7 insertions(+), 129 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 112d7216..5d8e290a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,142 +71,20 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 with: - distribution: zulu - java-version: "11" - - name: Install libcurl and dependencies - run: > - vcpkg install curl:x64-windows - - vcpkg install openssl:x64-windows - - vcpkg install zlib:x64-windows - - vcpkg install boost-filesystem:x64-windows # Often needed for C++ projects - - vcpkg install boost-system:x64-windows # Often needed for C++ projects - - name: Download Dependency Walker - run: > - Invoke-WebRequest -Uri "https://www.dependencywalker.com/depends22_x64.zip" - -OutFile "depends.zip" - - Expand-Archive -Path "depends.zip" -DestinationPath "depends" - # Verify it was extracted correctly - dir depends + distribution: 'zulu' + java-version: '11' - name: Build libraries - run: > - mvn compile - - .github\build.bat -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON -DCURL_LIBRARY=C:/vcpkg/packages/curl_x64-windows/lib/libcurl.lib -DCURL_INCLUDE_DIR=C:/vcpkg/packages/curl_x64-windows/include - - name: Prepare DLL directory - run: > - mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" - - Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - - - #Copy ALL DLLs from vcpkg directories to ensure we have everything - - Get-ChildItem "C:/vcpkg/installed/x64-windows/bin" -Filter *.dll | ForEach-Object { - - Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose - - } - - - # Also from the packages directory - - Get-ChildItem "C:/vcpkg/packages" -Recurse -Filter "*.dll" | Where-Object { $_.Directory -like "*bin*" } | ForEach-Object { - - Copy-Item $_.FullName -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose - - } - - - # Copy Visual C++ Redistributable DLLs - - $vcredistPath = "C:\Windows\System32" - - @( - "msvcp140.dll", - "vcruntime140.dll", - "vcruntime140_1.dll", - "msvcp140_1.dll", - "msvcp140_2.dll", - "concrt140.dll" - ) | ForEach-Object { - if (Test-Path "$vcredistPath\$_") { - Copy-Item "$vcredistPath\$_" -Destination "target/classes/de/kherud/llama/Windows/x86_64/" -Verbose - } - } - - name: Analyze DLL dependencies run: | - # Create directory for outputs - mkdir -Force "dependency_reports" - - # Get paths to DLLs for analysis - $ggmlPath = "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\ggml.dll" - $llamaPath = "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\llama.dll" - $jllamaPath = "${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64\jllama.dll" - - # Verify files exist before analysis - Write-Host "Verifying DLL files exist:" - if (Test-Path $ggmlPath) { Write-Host "ggml.dll exists at $ggmlPath" } else { Write-Host "ERROR: ggml.dll NOT FOUND at $ggmlPath" } - if (Test-Path $llamaPath) { Write-Host "llama.dll exists at $llamaPath" } else { Write-Host "ERROR: llama.dll NOT FOUND at $llamaPath" } - if (Test-Path $jllamaPath) { Write-Host "jllama.dll exists at $jllamaPath" } else { Write-Host "ERROR: jllama.dll NOT FOUND at $jllamaPath" } - - # Alternative approach using dumpbin (available on Windows) - Write-Host "Analyzing dependencies with dumpbin..." - - # Create a function to extract dependencies - function Get-Dependencies { - param([string]$dllPath, [string]$outputPath) - - if (Test-Path $dllPath) { - Write-Host "Running dumpbin on $dllPath" - & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.29.30133\bin\Hostx64\x64\dumpbin.exe" /DEPENDENTS $dllPath > $outputPath - if ($LASTEXITCODE -eq 0) { - Write-Host "Successfully wrote dependencies to $outputPath" - Get-Content $outputPath | Select-String -Pattern "Image has the following dependencies" - Get-Content $outputPath | Select-String -Pattern "\.dll" - } else { - Write-Host "Error running dumpbin: $LASTEXITCODE" - } - } else { - Write-Host "ERROR: File not found: $dllPath" - } - } - - # Run dependency analysis - Get-Dependencies -dllPath $ggmlPath -outputPath "dependency_reports\deps_ggml.txt" - Get-Dependencies -dllPath $llamaPath -outputPath "dependency_reports\deps_llama.txt" - Get-Dependencies -dllPath $jllamaPath -outputPath "dependency_reports\deps_jllama.txt" - - # List files in the output directory - Write-Host "Files in dependency_reports directory:" - dir dependency_reports - - - name: Upload dependency reports - if: always() - uses: actions/upload-artifact@v4 - with: - name: dependency-reports - path: dependency_reports\* - if-no-files-found: warn - - - name: Verify DLL placement - run: | - dir target\classes\de\kherud\llama\Windows\x86_64\ + mvn compile + .github\build.bat -DLLAMA_VERBOSE=ON - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - - name: Run tests with explicit DLL path - run: > - $env:PATH = "C:\vcpkg\installed\x64-windows\bin;${env:PATH}" - - mvn test "-Djava.library.path=${env:PATH};target/classes/de/kherud/llama/Windows/x86_64" -Ddebug.native.loading=true + - name: Run tests + run: mvn test - if: failure() uses: actions/upload-artifact@v4 with: name: error-log-windows path: ${{ github.workspace }}\hs_err_pid*.log - if-no-files-found: warn + if-no-files-found: warn From 894262891928f9ea9e707586bcedf07adcd51fb2 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Sat, 8 Mar 2025 00:13:44 -0800 Subject: [PATCH 35/51] adding copy and verify step --- .github/workflows/ci.yml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5d8e290a..a27cb5c8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,10 +77,17 @@ jobs: run: | mvn compile .github\build.bat -DLLAMA_VERBOSE=ON + - name: Copy DLLs (including curl.dll) from vcpkg explicitly + run: | + mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" + Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" + - name: Verify DLL placement + run: | + dir target\classes\de\kherud\llama\Windows\x86_64\ - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests - run: mvn test + run: mvn test "-Djava.library.path=${env:PATH};${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64" -Ddebug.native.loading=true - if: failure() uses: actions/upload-artifact@v4 with: From 28c17b825e63b5bdaf549685198e199f9b4a470d Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Sat, 8 Mar 2025 00:25:17 -0800 Subject: [PATCH 36/51] adding copy and verify step --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a27cb5c8..5891f90b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -87,7 +87,7 @@ jobs: - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests - run: mvn test "-Djava.library.path=${env:PATH};${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64" -Ddebug.native.loading=true + run: mvn test "-Djava.library.path=${env:PATH};${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64 -Ddebug.native.loading=true" - if: failure() uses: actions/upload-artifact@v4 with: From 0b304b8e65d1d5b0b8937cb9fb630389e827ba51 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 20:54:16 +0100 Subject: [PATCH 37/51] statically link dependencies --- CMakeLists.txt | 3 ++- src/main/java/de/kherud/llama/LlamaLoader.java | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fe8778b..2851774b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,7 @@ include(FetchContent) set(BUILD_SHARED_LIBS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(BUILD_SHARED_LIBS OFF) option(LLAMA_VERBOSE "llama: verbose output" OFF) @@ -103,7 +104,7 @@ target_compile_definitions(jllama PRIVATE ) if(OS_NAME STREQUAL "Windows") - set_target_properties(jllama llama ggml PROPERTIES + set_target_properties(jllama llama ggml PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${JLLAMA_DIR} RUNTIME_OUTPUT_DIRECTORY_RELEASE ${JLLAMA_DIR} ) diff --git a/src/main/java/de/kherud/llama/LlamaLoader.java b/src/main/java/de/kherud/llama/LlamaLoader.java index 2605d96e..a083a1ec 100644 --- a/src/main/java/de/kherud/llama/LlamaLoader.java +++ b/src/main/java/de/kherud/llama/LlamaLoader.java @@ -63,8 +63,6 @@ static synchronized void initialize() throws UnsatisfiedLinkError { System.err.println("'ggml-metal.metal' not found"); } } - loadNativeLibrary("ggml"); - loadNativeLibrary("llama"); loadNativeLibrary("jllama"); extracted = true; } From a93a79e305284bfdc8bee662865b40a507bf45ce Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 20:54:44 +0100 Subject: [PATCH 38/51] ci workflow disable curl build --- .github/workflows/ci.yml | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5891f90b..f4e351c0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,6 +7,7 @@ env: MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf MODEL_NAME: codellama-7b.Q2_K.gguf jobs: + build-and-test-linux: name: ubuntu-latest runs-on: ubuntu-latest @@ -16,12 +17,10 @@ jobs: with: distribution: zulu java-version: "11" - - name: Install libcurl - run: sudo apt-get install -y libcurl4-openssl-dev - name: Build libraries run: | mvn compile - .github/build.sh -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON + .github/build.sh -DLLAMA_VERBOSE=ON - name: Download model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} - name: Run tests @@ -32,6 +31,7 @@ jobs: name: error-log-linux path: ${{ github.workspace }}/hs_err_pid*.log if-no-files-found: warn + build-and-test-macos: name: ${{ matrix.target.runner }} runs-on: ${{ matrix.target.runner }} @@ -40,10 +40,9 @@ jobs: matrix: target: - runner: macos-13 - cmake: -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON -DLLAMA_CURL=ON + cmake: -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON - runner: macos-14 cmake: -DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON - -DLLAMA_CURL=ON steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 @@ -64,6 +63,7 @@ jobs: name: error-log-macos path: ${{ github.workspace }}/hs_err_pid*.log if-no-files-found: warn + build-and-test-windows: name: windows-latest runs-on: windows-latest @@ -77,21 +77,13 @@ jobs: run: | mvn compile .github\build.bat -DLLAMA_VERBOSE=ON - - name: Copy DLLs (including curl.dll) from vcpkg explicitly - run: | - mkdir -Force "target/classes/de/kherud/llama/Windows/x86_64" - Copy-Item ".\src\main\resources\de\kherud\llama\Windows\x86_64\*.dll" "target/classes/de/kherud/llama/Windows/x86_64/" - - name: Verify DLL placement - run: | - dir target\classes\de\kherud\llama\Windows\x86_64\ - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests - run: mvn test "-Djava.library.path=${env:PATH};${{ github.workspace }}\target\classes\de\kherud\llama\Windows\x86_64 -Ddebug.native.loading=true" + run: mvn test - if: failure() uses: actions/upload-artifact@v4 with: name: error-log-windows path: ${{ github.workspace }}\hs_err_pid*.log - if-no-files-found: warn - + if-no-files-found: warn From 01c202b0e0d7c6eadc2fb8d4a1237aae29e149e7 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 21:26:10 +0100 Subject: [PATCH 39/51] ci workflow enable llama metal --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f4e351c0..2e1e743c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,7 @@ jobs: - runner: macos-13 cmake: -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON - runner: macos-14 - cmake: -DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON + cmake: -DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_VERBOSE=ON steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 From 6c70a31d79036e0d21c495eeeda5648529e9d6fa Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 21:26:22 +0100 Subject: [PATCH 40/51] ignore logging test --- src/test/java/de/kherud/llama/LlamaModelTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index 9e5b767b..39b4e0d7 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -206,6 +206,7 @@ public void testLogJSON() { } } + @Ignore @Test public void testLogStdout() { // Unfortunately, `printf` can't be easily re-directed to Java. This test only works manually, thus. From be6e34a693b798a4e4d9422a2afcb32c28b251a5 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 21:33:12 +0100 Subject: [PATCH 41/51] ci workflow disable native ggml windows build --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2e1e743c..906a58fd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,7 +76,7 @@ jobs: - name: Build libraries run: | mvn compile - .github\build.bat -DLLAMA_VERBOSE=ON + .github\build.bat -DGGML_NATIVE=OFF -DLLAMA_VERBOSE=ON - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests From e9df628fcf10096cbd6595bf5da4818e4c4ddd40 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 22:00:33 +0100 Subject: [PATCH 42/51] ci workflow upload windows libraries --- .github/workflows/ci.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 906a58fd..b0d63c8c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,7 +76,7 @@ jobs: - name: Build libraries run: | mvn compile - .github\build.bat -DGGML_NATIVE=OFF -DLLAMA_VERBOSE=ON + .github\build.bat -DLLAMA_VERBOSE=ON - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Run tests @@ -84,6 +84,8 @@ jobs: - if: failure() uses: actions/upload-artifact@v4 with: - name: error-log-windows - path: ${{ github.workspace }}\hs_err_pid*.log + name: windows-output + path: | + ${{ github.workspace }}\hs_err_pid*.log + ${{ github.workspace }}/src/main/resources/de/kherud/llama/**/* if-no-files-found: warn From 20a7df4b4f512814ae9d339a4a5bdf8ee1e99ed1 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 22:06:43 +0100 Subject: [PATCH 43/51] ci workflow build windows in release-debug mode --- .github/build.bat | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/build.bat b/.github/build.bat index a904405e..5cfa26c0 100755 --- a/.github/build.bat +++ b/.github/build.bat @@ -2,6 +2,6 @@ mkdir build cmake -Bbuild %* -cmake --build build --config Release +cmake --build build --config RelWithDebInfo if errorlevel 1 exit /b %ERRORLEVEL% \ No newline at end of file From b9bc6f3167a9c3c0c1669280cbd23281320c5da6 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 22:19:28 +0100 Subject: [PATCH 44/51] cmakelists add windows relwithdebinfo output path --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2851774b..2278d454 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -107,6 +107,7 @@ if(OS_NAME STREQUAL "Windows") set_target_properties(jllama llama ggml PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${JLLAMA_DIR} RUNTIME_OUTPUT_DIRECTORY_RELEASE ${JLLAMA_DIR} + RUNTIME_OUTPUT_DIRECTORY_RELWITHDEBINFO ${JLLAMA_DIR} ) else() set_target_properties(jllama llama ggml PROPERTIES From 3c5b489c53b14fa35ea1d24a19f22be4c952998b Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 22:31:17 +0100 Subject: [PATCH 45/51] ci workflow build windows in debug mode --- .github/build.bat | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/build.bat b/.github/build.bat index 5cfa26c0..2fefa247 100755 --- a/.github/build.bat +++ b/.github/build.bat @@ -2,6 +2,6 @@ mkdir build cmake -Bbuild %* -cmake --build build --config RelWithDebInfo +cmake --build build --config Debug if errorlevel 1 exit /b %ERRORLEVEL% \ No newline at end of file From 50129c9c316fe4546d1b56b2be079f513f368fc8 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 8 Mar 2025 22:36:43 +0100 Subject: [PATCH 46/51] add debug statements to jni load --- .github/build.bat | 2 +- src/main/cpp/jllama.cpp | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/build.bat b/.github/build.bat index 2fefa247..5cfa26c0 100755 --- a/.github/build.bat +++ b/.github/build.bat @@ -2,6 +2,6 @@ mkdir build cmake -Bbuild %* -cmake --build build --config Debug +cmake --build build --config RelWithDebInfo if errorlevel 1 exit /b %ERRORLEVEL% \ No newline at end of file diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 3a547bc8..cad3ca43 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -326,8 +326,12 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) goto error; } + printf("loaded JNI symbols\n"); fflush(stdout); + llama_backend_init(); + printf("loaded llama.cpp backend\n"); fflush(stdout); + goto success; error: @@ -391,6 +395,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo { common_params params; + printf("load model\n"); fflush(stdout); const jsize argc = env->GetArrayLength(jparams); char **argv = parse_string_array(env, jparams, argc); if (argv == nullptr) @@ -398,22 +403,25 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo return; } + printf("loaded jargs\n"); fflush(stdout); const auto parsed_params = common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER); free_string_array(argv, argc); if (!parsed_params) { return; } - + + printf("parsed params\n"); fflush(stdout); SRV_INF("loading model '%s'\n", params.model.c_str()); common_init(); + printf("initialized common\n"); fflush(stdout); // struct that contains llama context and inference auto *ctx_server = new server_context(); - llama_backend_init(); llama_numa_init(params.numa); + printf("created ctx\n"); fflush(stdout); LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); From 4481c1c71d24115023a0edd877c4e69bb72f550d Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sun, 9 Mar 2025 14:30:57 +0100 Subject: [PATCH 47/51] ci workflow windows use zulu 17 --- .github/build.bat | 2 +- .github/workflows/ci.yml | 2 +- src/main/cpp/jllama.cpp | 10 +--------- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/.github/build.bat b/.github/build.bat index 5cfa26c0..a904405e 100755 --- a/.github/build.bat +++ b/.github/build.bat @@ -2,6 +2,6 @@ mkdir build cmake -Bbuild %* -cmake --build build --config RelWithDebInfo +cmake --build build --config Release if errorlevel 1 exit /b %ERRORLEVEL% \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b0d63c8c..74151b9b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,7 +72,7 @@ jobs: - uses: actions/setup-java@v4 with: distribution: 'zulu' - java-version: '11' + java-version: '17' - name: Build libraries run: | mvn compile diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index cad3ca43..0e70e624 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -326,12 +326,8 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) goto error; } - printf("loaded JNI symbols\n"); fflush(stdout); - llama_backend_init(); - printf("loaded llama.cpp backend\n"); fflush(stdout); - goto success; error: @@ -395,7 +391,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo { common_params params; - printf("load model\n"); fflush(stdout); const jsize argc = env->GetArrayLength(jparams); char **argv = parse_string_array(env, jparams, argc); if (argv == nullptr) @@ -403,7 +398,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo return; } - printf("loaded jargs\n"); fflush(stdout); const auto parsed_params = common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER); free_string_array(argv, argc); if (!parsed_params) @@ -411,17 +405,15 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo return; } - printf("parsed params\n"); fflush(stdout); SRV_INF("loading model '%s'\n", params.model.c_str()); common_init(); - printf("initialized common\n"); fflush(stdout); // struct that contains llama context and inference auto *ctx_server = new server_context(); + llama_backend_init(); llama_numa_init(params.numa); - printf("created ctx\n"); fflush(stdout); LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); From d549764f6158b8d8a100e3d11f4a945b905037ec Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sun, 9 Mar 2025 14:41:57 +0100 Subject: [PATCH 48/51] defer llama backend initialization --- .github/workflows/ci.yml | 2 +- src/main/cpp/jllama.cpp | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 74151b9b..b0d63c8c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,7 +72,7 @@ jobs: - uses: actions/setup-java@v4 with: distribution: 'zulu' - java-version: '17' + java-version: '11' - name: Build libraries run: | mvn compile diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 0e70e624..5eb688ce 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -326,8 +326,6 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) goto error; } - llama_backend_init(); - goto success; error: From 66b31d9013aba18014cf17ecbb33dac6b10f8cec Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sun, 9 Mar 2025 15:06:54 +0100 Subject: [PATCH 49/51] statically link windows system libraries --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2278d454..83f5906a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,11 @@ set(BUILD_SHARED_LIBS OFF) option(LLAMA_VERBOSE "llama: verbose output" OFF) +if(MSVC) + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") + add_compile_options(/MT) +endif() + #################### json #################### FetchContent_Declare( From 5e6c5c9a4eb93992981230b37b5cf396bb0b7e4b Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sun, 9 Mar 2025 15:14:39 +0100 Subject: [PATCH 50/51] remove static linking and use older msvc in release workflow --- .github/workflows/ci.yml | 4 ++-- CMakeLists.txt | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b0d63c8c..631fc86d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,8 +65,8 @@ jobs: if-no-files-found: warn build-and-test-windows: - name: windows-latest - runs-on: windows-latest + name: windows-2019 + runs-on: windows-2019 steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 diff --git a/CMakeLists.txt b/CMakeLists.txt index 83f5906a..bfca2cc1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,10 +10,8 @@ set(BUILD_SHARED_LIBS OFF) option(LLAMA_VERBOSE "llama: verbose output" OFF) -if(MSVC) - set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") - add_compile_options(/MT) -endif() +message(STATUS "C++ Compiler: ${CMAKE_CXX_COMPILER}") +message(STATUS "C++ Compiler Version: ${CMAKE_CXX_COMPILER_VERSION}") #################### json #################### From f6ca909178a9af1837d6defd74a0261bcf8d3e0e Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sun, 9 Mar 2025 15:20:27 +0100 Subject: [PATCH 51/51] initialize llama backend on jni load and remove cmake debug statements --- CMakeLists.txt | 3 --- src/main/cpp/jllama.cpp | 3 ++- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bfca2cc1..2278d454 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,9 +10,6 @@ set(BUILD_SHARED_LIBS OFF) option(LLAMA_VERBOSE "llama: verbose output" OFF) -message(STATUS "C++ Compiler: ${CMAKE_CXX_COMPILER}") -message(STATUS "C++ Compiler Version: ${CMAKE_CXX_COMPILER_VERSION}") - #################### json #################### FetchContent_Declare( diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 5eb688ce..3e17e5dc 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -326,6 +326,8 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) goto error; } + llama_backend_init(); + goto success; error: @@ -410,7 +412,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo // struct that contains llama context and inference auto *ctx_server = new server_context(); - llama_backend_init(); llama_numa_init(params.numa); LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads,