diff --git a/common/arg.cpp b/common/arg.cpp index 5597de121c132..b20996396c05d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1110,6 +1110,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.prompt_cache_ro = true; } ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--dump-activations"}, "FNAME", + "file to dump activations to in GGUF format (default: none)", + [](common_params & params, const std::string & value) { + params.path_dump_activations = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--load-activations"}, "FNAME", + "file to load activations from in GGUF format (default: none)", + [](common_params & params, const std::string & value) { + params.path_load_activations = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( {"-r", "--reverse-prompt"}, "PROMPT", "halt generation at PROMPT, return control in interactive mode\n", @@ -1164,6 +1178,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.interactive_first = true; } ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--idle-action-interval"}, "N", + "auto-submit empty input after N minutes of idle time with no keystrokes (default: 0 = disabled)", + [](common_params & params, const std::string & value) { + params.idle_action_interval = std::stoi(value); + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( {"-mli", "--multiline-input"}, "allows you to write or paste multiple lines without ending each in '\\'", @@ -2453,6 +2474,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--kv-cache-auto-save"}, "BASE_NAME", + "automatically save all KV cache to BASE_NAME_/ directory on server shutdown (default: disabled)", + [](common_params & params, const std::string & value) { + params.kv_cache_auto_save_base = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--kv-cache-auto-load"}, "DIRNAME", + "automatically load KV cache from specified timestamped directory on server startup (default: disabled)", + [](common_params & params, const std::string & value) { + params.kv_cache_auto_load = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--jinja"}, "use jinja template for chat (default: disabled)", diff --git a/common/common.h b/common/common.h index 54b7849b17448..3161b0c4c2bb5 100644 --- a/common/common.h +++ b/common/common.h @@ -328,6 +328,8 @@ struct common_params { std::string system_prompt = ""; // NOLINT std::string prompt_file = ""; // store the external prompt file name // NOLINT std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT + std::string path_dump_activations = ""; // path to GGUF file for dumping activations // NOLINT + std::string path_load_activations = ""; // path to GGUF file for loading activations // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT std::string input_suffix = ""; // string to suffix user inputs with // NOLINT std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT @@ -370,6 +372,7 @@ struct common_params { bool special = false; // enable special token output bool interactive = false; // interactive mode bool interactive_first = false; // wait for user input immediately + int32_t idle_action_interval = 0; // auto-submit empty input after N minutes of idle (0 = disabled) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it @@ -457,6 +460,10 @@ struct common_params { std::string slot_save_path; + // Auto KV cache save/load for faster server restarts + std::string kv_cache_auto_save_base; // base name for auto-saving KV cache on shutdown (with timestamp) + std::string kv_cache_auto_load; // specific timestamped name to load on startup + float slot_prompt_similarity = 0.1f; // batched-bench params diff --git a/common/console.cpp b/common/console.cpp index 078a8d678d933..a435504c2adfa 100644 --- a/common/console.cpp +++ b/common/console.cpp @@ -501,4 +501,60 @@ namespace console { return readline_advanced(line, multiline_input); } + bool readline_with_timeout(std::string & line, bool multiline_input, int timeout_seconds, bool & timed_out) { + timed_out = false; + + if (timeout_seconds <= 0) { + // No timeout, use regular readline + return readline(line, multiline_input); + } + +#if defined(_WIN32) + // Windows: check if input is available with timeout + HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); + DWORD result = WaitForSingleObject(hStdin, timeout_seconds * 1000); + + if (result == WAIT_TIMEOUT) { + timed_out = true; + line.clear(); + return false; + } + + if (result != WAIT_OBJECT_0) { + // Error occurred + line.clear(); + return false; + } + + // Input is available, use regular readline + return readline(line, multiline_input); +#else + // Unix: use select() to check for input with timeout + fd_set readfds; + struct timeval tv; + + FD_ZERO(&readfds); + FD_SET(STDIN_FILENO, &readfds); + + tv.tv_sec = timeout_seconds; + tv.tv_usec = 0; + + int retval = select(STDIN_FILENO + 1, &readfds, NULL, NULL, &tv); + + if (retval == -1) { + // Error occurred + line.clear(); + return false; + } else if (retval == 0) { + // Timeout occurred + timed_out = true; + line.clear(); + return false; + } + + // Input is available, use regular readline + return readline(line, multiline_input); +#endif + } + } diff --git a/common/console.h b/common/console.h index ec175269b9d8a..e7a17618563e6 100644 --- a/common/console.h +++ b/common/console.h @@ -16,4 +16,5 @@ namespace console { void cleanup(); void set_display(display_t display); bool readline(std::string & line, bool multiline_input); + bool readline_with_timeout(std::string & line, bool multiline_input, int timeout_seconds, bool & timed_out); } diff --git a/common/sampling.cpp b/common/sampling.cpp index c69d525b5b358..31ea25de1c8d8 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -4,6 +4,7 @@ #include "log.h" #include +#include #include #include @@ -599,3 +600,96 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } + +// Get current temperature from the sampler +float common_sampler_get_temp(const struct common_sampler * gsmpl) { + if (!gsmpl) { + return 0.0f; + } + return gsmpl->params.temp; +} + +// Set temperature at runtime by replacing the temperature sampler in the chain +bool common_sampler_set_temp(struct common_sampler * gsmpl, float new_temp) { + if (!gsmpl || !gsmpl->chain) { + LOG_ERR("%s: invalid sampler or chain\n", __func__); + return false; + } + + // Find the temperature sampler in the chain + const int n_samplers = llama_sampler_chain_n(gsmpl->chain); + int temp_idx = -1; + + LOG_INF("%s: searching for temperature sampler in chain of %d samplers\n", __func__, n_samplers); + + for (int i = 0; i < n_samplers; i++) { + struct llama_sampler * s = llama_sampler_chain_get(gsmpl->chain, i); + const char * name = llama_sampler_name(s); + LOG_INF("%s: sampler[%d] = '%s'\n", __func__, i, name); + + // Look for "temp" or "temp-ext" sampler + if (strcmp(name, "temp") == 0 || strcmp(name, "temp-ext") == 0) { + temp_idx = i; + LOG_INF("%s: found temperature sampler '%s' at index %d\n", __func__, name, i); + break; + } + } + + if (temp_idx == -1) { + // No temperature sampler found - this might happen with mirostat + LOG_ERR("%s: no temperature sampler found in chain\n", __func__); + return false; + } + + LOG_INF("%s: removing old temperature sampler at index %d\n", __func__, temp_idx); + + // Remove the old temperature sampler + struct llama_sampler * old_temp = llama_sampler_chain_remove(gsmpl->chain, temp_idx); + if (old_temp) { + llama_sampler_free(old_temp); + LOG_INF("%s: freed old temperature sampler\n", __func__); + } + + // Collect all samplers that come after the temp position + std::vector samplers_after; + int n_after = llama_sampler_chain_n(gsmpl->chain) - temp_idx; + LOG_INF("%s: collecting %d samplers after temp position\n", __func__, n_after); + + for (int i = 0; i < n_after; i++) { + struct llama_sampler * s = llama_sampler_chain_remove(gsmpl->chain, temp_idx); + const char * name = llama_sampler_name(s); + LOG_INF("%s: removed sampler '%s'\n", __func__, name); + samplers_after.push_back(s); + } + + // Create and add new temperature sampler + struct llama_sampler * new_temp_sampler; + + // Use temp_ext if dynamic temperature was enabled, otherwise use simple temp + if (gsmpl->params.dynatemp_range > 0.0f) { + LOG_INF("%s: creating temp-ext sampler with temp=%.2f, range=%.2f, exp=%.2f\n", + __func__, new_temp, gsmpl->params.dynatemp_range, gsmpl->params.dynatemp_exponent); + new_temp_sampler = llama_sampler_init_temp_ext(new_temp, gsmpl->params.dynatemp_range, gsmpl->params.dynatemp_exponent); + } else { + LOG_INF("%s: creating temp sampler with temp=%.2f\n", __func__, new_temp); + new_temp_sampler = llama_sampler_init_temp(new_temp); + } + + llama_sampler_chain_add(gsmpl->chain, new_temp_sampler); + LOG_INF("%s: added new temperature sampler\n", __func__); + + // Add back the samplers that came after + for (auto * s : samplers_after) { + const char * name = llama_sampler_name(s); + llama_sampler_chain_add(gsmpl->chain, s); + LOG_INF("%s: re-added sampler '%s'\n", __func__, name); + } + + // Update the params to reflect the new temperature + gsmpl->params.temp = new_temp; + + LOG_INF("%s: final chain has %d samplers\n", __func__, llama_sampler_chain_n(gsmpl->chain)); + LOG_INF("%s: temperature update complete\n", __func__); + + return true; +} diff --git a/common/sampling.h b/common/sampling.h index e198eecda3810..b1cadbb133f58 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -99,6 +99,10 @@ std::string common_sampler_print(const struct common_sampler * gsmpl); // get a string representation of the last accepted tokens std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n); +// get/set temperature at runtime +float common_sampler_get_temp(const struct common_sampler * gsmpl); +bool common_sampler_set_temp(struct common_sampler * gsmpl, float temp); + char common_sampler_type_to_chr(enum common_sampler_type cnstr); std::string common_sampler_type_to_str(enum common_sampler_type cnstr); diff --git a/tools/main/main-state-save.cpp b/tools/main/main-state-save.cpp new file mode 100644 index 0000000000000..331570d517125 --- /dev/null +++ b/tools/main/main-state-save.cpp @@ -0,0 +1,125 @@ +// State save/load functions for main.cpp +// This file contains the simplified implementation using llama_state_get_data/set_data + +#include "llama.h" +#include "log.h" +#include "gguf.h" +#include +#include + +// Save complete LLM state to GGUF file +// This includes: KV cache, logits, embeddings, RNG state +static bool save_llm_state_to_gguf(llama_context * ctx, const std::string & filename) { + LOG("\nSaving LLM state to %s...\n", filename.c_str()); + + // Get the size of the state + const size_t state_size = llama_state_get_size(ctx); + LOG("State size: %zu bytes (%.2f MB)\n", state_size, state_size / (1024.0 * 1024.0)); + + // Allocate buffer and get state data + std::vector state_data(state_size); + const size_t written = llama_state_get_data(ctx, state_data.data(), state_size); + + if (written != state_size) { + LOG_ERR("Failed to get state data: got %zu bytes, expected %zu\n", written, state_size); + return false; + } + + // Create GGUF context + struct gguf_context * gguf_ctx = gguf_init_empty(); + + // Add metadata + gguf_set_val_u32(gguf_ctx, "llm_state.version", 1); + gguf_set_val_u64(gguf_ctx, "llm_state.size", state_size); + gguf_set_val_str(gguf_ctx, "llm_state.type", "kv_cache_rng_logits_embeddings"); + + // For GGUF, we need to add the state as a tensor + // Create a ggml context for the tensor + struct ggml_init_params params = { + /*.mem_size =*/ state_size + 1024*1024, // Extra space for tensor metadata + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // We already have the data + }; + + struct ggml_context * ggml_ctx = ggml_init(params); + + // Create a 1D tensor to hold the state data + int64_t ne[4] = {(int64_t)state_size, 1, 1, 1}; + struct ggml_tensor * state_tensor = ggml_new_tensor(ggml_ctx, GGML_TYPE_I8, 1, ne); + ggml_set_name(state_tensor, "llm_state_data"); + state_tensor->data = state_data.data(); + + // Add tensor to GGUF + gguf_add_tensor(gguf_ctx, state_tensor); + + // Write to file + gguf_write_to_file(gguf_ctx, filename.c_str(), false); + + LOG("Successfully saved LLM state (%zu bytes)\n", written); + + // Cleanup + ggml_free(ggml_ctx); + gguf_free(gguf_ctx); + + return true; +} + +// Load complete LLM state from GGUF file +static bool load_llm_state_from_gguf(llama_context * ctx, const std::string & filename) { + LOG("\nLoading LLM state from %s...\n", filename.c_str()); + + struct ggml_context * ggml_ctx = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ggml_ctx, + }; + + struct gguf_context * gguf_ctx = gguf_init_from_file(filename.c_str(), params); + + if (!gguf_ctx) { + LOG_ERR("Failed to load state file: %s\n", filename.c_str()); + return false; + } + + // Read metadata + const int n_kv = gguf_get_n_kv(gguf_ctx); + uint32_t version = 0; + uint64_t state_size = 0; + + for (int i = 0; i < n_kv; i++) { + const char * key = gguf_get_key(gguf_ctx, i); + const enum gguf_type type = gguf_get_kv_type(gguf_ctx, i); + + if (strcmp(key, "llm_state.version") == 0 && type == GGUF_TYPE_UINT32) { + version = gguf_get_val_u32(gguf_ctx, i); + } else if (strcmp(key, "llm_state.size") == 0 && type == GGUF_TYPE_UINT64) { + state_size = gguf_get_val_u64(gguf_ctx, i); + } + } + + LOG("State version: %u, size: %lu bytes (%.2f MB)\n", version, state_size, state_size / (1024.0 * 1024.0)); + + // Get the state tensor + struct ggml_tensor * state_tensor = ggml_get_tensor(ggml_ctx, "llm_state_data"); + if (!state_tensor) { + LOG_ERR("State tensor not found in file\n"); + gguf_free(gguf_ctx); + return false; + } + + // Set the state + const size_t loaded = llama_state_set_data(ctx, (const uint8_t*)state_tensor->data, ggml_nbytes(state_tensor)); + + if (loaded == 0) { + LOG_ERR("Failed to set state data\n"); + gguf_free(gguf_ctx); + return false; + } + + LOG("Successfully loaded LLM state (%zu bytes)\n", loaded); + + gguf_free(gguf_ctx); + + return true; +} diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 498e00e3a5e58..1beb619a76849 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -5,6 +5,7 @@ #include "sampling.h" #include "llama.h" #include "chat.h" +#include "gguf.h" #include #include @@ -14,10 +15,14 @@ #include #include #include +#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include #include +#include +#include +#include #elif defined (_WIN32) #define WIN32_LEAN_AND_MEAN #ifndef NOMINMAX @@ -41,6 +46,39 @@ static std::vector * g_output_tokens; static bool is_interacting = false; static bool need_insert_eot = false; +// State save/load flags for interactive commands +static bool g_save_state_next = false; +static std::string g_state_save_path = ""; + +// Tool execution tracking to prevent duplicate executions +static std::string g_last_executed_tool_signature = ""; + +// Idle timeout tracking +static time_t g_last_activity_time = 0; + +// Check if idle timeout has elapsed and we should auto-submit empty input +static bool should_auto_submit_on_idle(int idle_interval_minutes) { + if (idle_interval_minutes <= 0) { + return false; // Feature disabled + } + + time_t current_time = time(nullptr); + if (g_last_activity_time == 0) { + g_last_activity_time = current_time; + return false; + } + + int elapsed_seconds = (int)(current_time - g_last_activity_time); + int idle_threshold_seconds = idle_interval_minutes * 60; + + return elapsed_seconds >= idle_threshold_seconds; +} + +// Update activity timestamp +static void update_activity_time() { + g_last_activity_time = time(nullptr); +} + static void print_usage(int argc, char ** argv) { (void) argc; @@ -62,6 +100,317 @@ static bool file_is_empty(const std::string & path) { return f.tellg() == 0; } +// Tool calling support functions +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +static bool is_executable(const std::string & path) { + struct stat st; + if (stat(path.c_str(), &st) != 0) { + return false; + } + return (st.st_mode & S_IXUSR) != 0; +} + +static std::string execute_command(const std::string & command) { + std::string result; + FILE* pipe = popen(command.c_str(), "r"); + if (!pipe) { + return "Error: Failed to execute command\n"; + } + + char buffer[256]; + while (fgets(buffer, sizeof(buffer), pipe) != nullptr) { + result += buffer; + } + + int status = pclose(pipe); + if (WIFEXITED(status) && WEXITSTATUS(status) != 0) { + result += "\n[Tool exited with code " + std::to_string(WEXITSTATUS(status)) + "]\n"; + } + + return result; +} + +static std::vector get_tool_executables(const std::string & tools_dir) { + std::vector executables; + + DIR* dir = opendir(tools_dir.c_str()); + if (!dir) { + return executables; + } + + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + if (entry->d_name[0] == '.') { + continue; // Skip hidden files and . / .. + } + + std::string full_path = tools_dir + "/" + entry->d_name; + if (is_executable(full_path)) { + executables.push_back(entry->d_name); + } + } + + closedir(dir); + + // Sort alphabetically + std::sort(executables.begin(), executables.end()); + + return executables; +} + +static std::string collect_tools_help(const std::string & tools_dir) { + std::vector executables = get_tool_executables(tools_dir); + + if (executables.empty()) { + return "No executable tools found in the 'tools' directory.\n"; + } + + std::ostringstream help_text; + help_text << "Available tools:\n\n"; + + for (const auto & tool_name : executables) { + help_text << "=== " << tool_name << " ===\n"; + std::string command = tools_dir + "/" + tool_name + " help"; + std::string output = execute_command(command); + help_text << output; + if (!output.empty() && output.back() != '\n') { + help_text << "\n"; + } + help_text << "\nTo use this tool: " << tool_name << " [arguments]\n\n"; + } + + return help_text.str(); +} + +static std::string execute_tool(const std::string & tools_dir, const std::string & tool_name, const std::string & args) { + std::string full_path = tools_dir + "/" + tool_name; + + if (!is_executable(full_path)) { + return "Error: Tool '" + tool_name + "' not found or not executable\n"; + } + + std::string command = full_path; + if (!args.empty()) { + // Simple shell escaping - wrap in quotes if contains spaces + command += " " + args; + } + + LOG("\n[Executing tool: %s]\n", command.c_str()); + std::string output = execute_command(command); + LOG("[Tool output follows]\n"); + + return output; +} +#elif defined (_WIN32) +// Windows implementations (simplified - no tool support on Windows for now) +static bool is_executable(const std::string & path) { + return false; +} + +static std::string execute_command(const std::string & command) { + return "Error: Tool execution not supported on Windows\n"; +} + +static std::vector get_tool_executables(const std::string & tools_dir) { + return std::vector(); +} + +static std::string collect_tools_help(const std::string & tools_dir) { + return "Tool execution is not supported on Windows.\n"; +} + +static std::string execute_tool(const std::string & tools_dir, const std::string & tool_name, const std::string & args) { + return "Error: Tool execution not supported on Windows\n"; +} +#endif + +// Check if a position in text is inside ... tags +static bool is_inside_think_tags(const std::string & text, size_t pos) { + // Find the most recent before pos + size_t think_start = text.rfind("", pos); + if (think_start == std::string::npos) { + return false; // No tag before this position + } + + // Check if there's a between think_start and pos + size_t think_end = text.find("", think_start); + if (think_end == std::string::npos || think_end > pos) { + return true; // We're inside an unclosed or currently open think block + } + + return false; // The think block was closed before pos +} + +// Check if the recent output contains (outside of think tags) +static bool check_for_tools_help(const std::string & text) { + size_t pos = text.find(""); + if (pos == std::string::npos) { + return false; + } + + // Make sure it's not inside think tags + return !is_inside_think_tags(text, pos); +} + +// Check if the recent output contains ... and extract tool name and args +// Returns false if inside think tags or if already processed +static bool check_for_tool_launch(const std::string & text, std::string & tool_name, std::string & args, size_t search_from = 0) { + size_t start = text.find("", search_from); + if (start == std::string::npos) { + return false; + } + + // Check if this tag is inside think tags + if (is_inside_think_tags(text, start)) { + // Try to find the next one after this + return check_for_tool_launch(text, tool_name, args, start + 1); + } + + size_t end = text.find("", start); + if (end == std::string::npos) { + return false; + } + + // Extract the content between tags + start += 13; // length of "" + std::string content = text.substr(start, end - start); + + // Trim whitespace + content.erase(0, content.find_first_not_of(" \t\n\r")); + content.erase(content.find_last_not_of(" \t\n\r") + 1); + + // Split into tool name and args + size_t space_pos = content.find(' '); + if (space_pos == std::string::npos) { + tool_name = content; + args = ""; + } else { + tool_name = content.substr(0, space_pos); + args = content.substr(space_pos + 1); + // Trim args + args.erase(0, args.find_first_not_of(" \t\n\r")); + args.erase(args.find_last_not_of(" \t\n\r") + 1); + } + + return !tool_name.empty(); +} + +// Save complete LLM state (KV cache + RNG + logits + embeddings) to GGUF file +static bool save_llm_state_to_gguf(llama_context * ctx, const std::string & filename) { + LOG("\nSaving LLM state to %s...\n", filename.c_str()); + + // Get the size of the state + const size_t state_size = llama_state_get_size(ctx); + LOG("State size: %zu bytes (%.2f MB)\n", state_size, state_size / (1024.0 * 1024.0)); + + // Allocate buffer and get state data + std::vector state_data(state_size); + const size_t written = llama_state_get_data(ctx, state_data.data(), state_size); + + if (written != state_size) { + LOG_ERR("Failed to get state data: got %zu bytes, expected %zu\n", written, state_size); + return false; + } + + // Create GGUF context + struct gguf_context * gguf_ctx = gguf_init_empty(); + + // Add metadata + gguf_set_val_u32(gguf_ctx, "llm_state.version", 1); + gguf_set_val_u64(gguf_ctx, "llm_state.size", state_size); + gguf_set_val_str(gguf_ctx, "llm_state.type", "kv_cache_rng_logits_embeddings"); + + // Create a ggml context for the tensor + struct ggml_init_params params = { + /*.mem_size =*/ state_size + 1024*1024, // Extra space for tensor metadata + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // We already have the data + }; + + struct ggml_context * ggml_ctx = ggml_init(params); + + // Create a 1D tensor to hold the state data + int64_t ne[4] = {(int64_t)state_size, 1, 1, 1}; + struct ggml_tensor * state_tensor = ggml_new_tensor(ggml_ctx, GGML_TYPE_I8, 1, ne); + ggml_set_name(state_tensor, "llm_state_data"); + state_tensor->data = state_data.data(); + + // Add tensor to GGUF + gguf_add_tensor(gguf_ctx, state_tensor); + + // Write to file + gguf_write_to_file(gguf_ctx, filename.c_str(), false); + + LOG("Successfully saved LLM state (%zu bytes)\n", written); + + // Cleanup + ggml_free(ggml_ctx); + gguf_free(gguf_ctx); + + return true; +} + +// Load complete LLM state from GGUF file +static bool load_llm_state_from_gguf(llama_context * ctx, const std::string & filename) { + LOG("\nLoading LLM state from %s...\n", filename.c_str()); + + struct ggml_context * ggml_ctx = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ggml_ctx, + }; + + struct gguf_context * gguf_ctx = gguf_init_from_file(filename.c_str(), params); + + if (!gguf_ctx) { + LOG_ERR("Failed to load state file: %s\n", filename.c_str()); + return false; + } + + // Read metadata + const int n_kv = gguf_get_n_kv(gguf_ctx); + uint32_t version = 0; + uint64_t state_size = 0; + + for (int i = 0; i < n_kv; i++) { + const char * key = gguf_get_key(gguf_ctx, i); + const enum gguf_type type = gguf_get_kv_type(gguf_ctx, i); + + if (strcmp(key, "llm_state.version") == 0 && type == GGUF_TYPE_UINT32) { + version = gguf_get_val_u32(gguf_ctx, i); + } else if (strcmp(key, "llm_state.size") == 0 && type == GGUF_TYPE_UINT64) { + state_size = gguf_get_val_u64(gguf_ctx, i); + } + } + + LOG("State version: %u, size: %lu bytes (%.2f MB)\n", version, state_size, state_size / (1024.0 * 1024.0)); + + // Get the state tensor + struct ggml_tensor * state_tensor = ggml_get_tensor(ggml_ctx, "llm_state_data"); + if (!state_tensor) { + LOG_ERR("State tensor not found in file\n"); + gguf_free(gguf_ctx); + return false; + } + + // Set the state + const size_t loaded = llama_state_set_data(ctx, (const uint8_t*)state_tensor->data, ggml_nbytes(state_tensor)); + + if (loaded == 0) { + LOG_ERR("Failed to set state data\n"); + gguf_free(gguf_ctx); + return false; + } + + LOG("Successfully loaded LLM state (%zu bytes)\n", loaded); + LOG("LLM has been restored to the exact state when the save was made\n"); + + gguf_free(gguf_ctx); + + return true; +} + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) static void sigint_handler(int signo) { if (signo == SIGINT) { @@ -147,6 +496,14 @@ int main(int argc, char ** argv) { return 1; } + // Handle state loading + if (!params.path_load_activations.empty()) { + if (!load_llm_state_from_gguf(ctx, params.path_load_activations)) { + LOG_ERR("%s: failed to load LLM state\n", __func__); + return 1; + } + } + auto * mem = llama_get_memory(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -424,6 +781,17 @@ int main(int argc, char ** argv) { if (params.interactive) { LOG_INF("%s: interactive mode on.\n", __func__); + LOG_INF("Special commands:\n"); + LOG_INF(" /\\/save - Save complete LLM state (KV cache, etc.) to GGUF file\n"); + LOG_INF(" /\\/load - Load LLM state from GGUF file to restore exact conversation state\n"); + LOG_INF(" /\\/temp - Show current temperature setting\n"); + LOG_INF(" /\\/temp - Set temperature to a new value (e.g., /\\/temp 0.7)\n"); + LOG_INF(" /\\/timeout - Show or disable idle timeout (0 = disabled)\n"); + LOG_INF(" /\\/timeout - Set idle timeout to N minutes (e.g., /\\/timeout 5)\n"); + LOG_INF("\n"); + LOG_INF("Tool calling (when 'tools' directory exists):\n"); + LOG_INF(" Model can output to get list of available tools\n"); + LOG_INF(" Model can output tool-name args to execute a tool\n"); if (!params.antiprompt.empty()) { for (const auto & antiprompt : params.antiprompt) { @@ -764,10 +1132,73 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { + // Check for tool requests in recent output + const int n_prev = 128; // Look back further to catch full tags + const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev); + + // Check for request + // Note: Only one tool action per iteration to prevent help examples from being executed + if (check_for_tools_help(last_output)) { + LOG_DBG("Detected request\n"); + + // Check if tools directory exists + if (file_exists("tools")) { + std::string help_text = collect_tools_help("tools"); + + LOG("\n[Tools Help Requested]\n"); + LOG("%s", help_text.c_str()); + LOG("[End of Tools Help]\n\n"); + + // Inject the help text back into the conversation + auto help_tokens = common_tokenize(ctx, "\n\n" + help_text, false, true); + embd_inp.insert(embd_inp.end(), help_tokens.begin(), help_tokens.end()); + + // Continue generation after injecting help + is_interacting = false; + } else { + LOG("\n[Tools Help Requested but 'tools' directory not found]\n\n"); + auto msg_tokens = common_tokenize(ctx, "\n\nNo 'tools' directory found.\n\n", false, true); + embd_inp.insert(embd_inp.end(), msg_tokens.begin(), msg_tokens.end()); + } + } else { + // Check for ... request only if we didn't handle tools-help + std::string tool_name, tool_args; + if (check_for_tool_launch(last_output, tool_name, tool_args)) { + // Create signature to check for duplicate execution + std::string tool_signature = tool_name + "|" + tool_args; + + // Only execute if this is a new tool call (not the same as last execution) + if (tool_signature != g_last_executed_tool_signature) { + LOG_DBG("Detected request: tool=%s, args=%s\n", tool_name.c_str(), tool_args.c_str()); + + // Execute the tool + std::string tool_output = execute_tool("tools", tool_name, tool_args); + + LOG("%s", tool_output.c_str()); + LOG("[End of Tool Output]\n\n"); + + // Inject the tool output back into the conversation + auto output_tokens = common_tokenize(ctx, "\n\n" + tool_output + "\n\n", false, true); + + // For large outputs with flash attention and big contexts, inject in chunks + // to avoid KV cache allocation failures + LOG_DBG("Tool output: %zu tokens\n", output_tokens.size()); + embd_inp.insert(embd_inp.end(), output_tokens.begin(), output_tokens.end()); + + // Remember this execution to prevent duplicates + g_last_executed_tool_signature = tool_signature; + + // Continue generation after injecting output + is_interacting = false; + } else { + LOG_DBG("Skipping duplicate tool execution: tool=%s, args=%s\n", tool_name.c_str(), tool_args.c_str()); + } + } + } + // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { - const int n_prev = 32; - const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev); + const std::string last_output_for_antiprompt = common_sampler_prev_str(smpl, ctx, 32); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -775,11 +1206,11 @@ int main(int argc, char ** argv) { // so we'll compensate for that by widening the search window a bit. for (std::string & antiprompt : params.antiprompt) { size_t extra_padding = params.interactive ? 0 : 2; - size_t search_start_pos = last_output.length() > static_cast(antiprompt.length() + extra_padding) - ? last_output.length() - static_cast(antiprompt.length() + extra_padding) + size_t search_start_pos = last_output_for_antiprompt.length() > static_cast(antiprompt.length() + extra_padding) + ? last_output_for_antiprompt.length() - static_cast(antiprompt.length() + extra_padding) : 0; - if (last_output.find(antiprompt, search_start_pos) != std::string::npos) { + if (last_output_for_antiprompt.find(antiprompt, search_start_pos) != std::string::npos) { if (params.interactive) { is_interacting = true; } @@ -789,8 +1220,8 @@ int main(int argc, char ** argv) { } // check for reverse prompt using special tokens - // avoid calling common_sampler_last() if last_output is empty - if (!last_output.empty()) { + // avoid calling common_sampler_last() if last_output_for_antiprompt is empty + if (!last_output_for_antiprompt.empty()) { llama_token last_token = common_sampler_last(smpl); for (auto token : antiprompt_token) { if (token == last_token) { @@ -813,6 +1244,15 @@ int main(int argc, char ** argv) { LOG_DBG("found an EOG token\n"); if (params.interactive) { + // Save LLM state if requested + if (g_save_state_next && !g_state_save_path.empty()) { + if (!save_llm_state_to_gguf(ctx, g_state_save_path)) { + LOG_ERR("Failed to save LLM state to %s\n", g_state_save_path.c_str()); + } + g_save_state_next = false; + g_state_save_path = ""; + } + if (!params.antiprompt.empty()) { // tokenize and inject first reverse prompt const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true); @@ -838,6 +1278,10 @@ int main(int argc, char ** argv) { if ((n_past > 0 || waiting_for_first_input) && is_interacting) { LOG_DBG("waiting for user input\n"); + // Reset idle timer when we start waiting for user input + // This ensures we only count time spent waiting, not time spent generating + update_activity_time(); + if (params.conversation_mode) { LOG("\n> "); } @@ -857,23 +1301,59 @@ int main(int argc, char ** argv) { console::set_display(console::user_input); display = params.display_prompt; + // Calculate remaining timeout for readline + int timeout_seconds = 0; + if (params.idle_action_interval > 0) { + time_t current_time = time(nullptr); + int elapsed_seconds = (int)(current_time - g_last_activity_time); + int idle_threshold_seconds = params.idle_action_interval * 60; + int remaining_seconds = idle_threshold_seconds - elapsed_seconds; + + if (remaining_seconds > 0) { + timeout_seconds = remaining_seconds; + } else { + timeout_seconds = 1; // Will timeout immediately + } + } + + // Read input with timeout support std::string line; bool another_line = true; + bool timed_out = false; + do { - another_line = console::readline(line, params.multiline_input); + another_line = console::readline_with_timeout(line, params.multiline_input, timeout_seconds, timed_out); buffer += line; + + if (timed_out) { + // Idle timeout occurred + LOG_DBG("Idle timeout triggered during input wait\n"); + LOG("\n[Idle timeout - auto-submitting empty input]\n"); + update_activity_time(); // Reset timer for next iteration + + // Reset tool execution tracking to allow tools during idle thinking + g_last_executed_tool_signature = ""; + + another_line = false; // Stop reading more lines + break; + } + + // User provided input, update activity time and disable timeout for continuation lines + update_activity_time(); + timeout_seconds = 0; // No timeout for continuation lines } while (another_line); // done taking input, reset color console::set_display(console::reset); display = true; - if (buffer.empty()) { // Ctrl+D on empty line exits + if (buffer.empty() && !timed_out) { // Ctrl+D on empty line exits (but not timeout) LOG("EOF by user\n"); break; } - if (buffer.back() == '\n') { + // Process newline handling only if buffer is not empty + if (!buffer.empty() && buffer.back() == '\n') { // Implement #587: // If the user wants the text to end in a newline, // this should be accomplished by explicitly adding a newline by using \ followed by return, @@ -881,6 +1361,124 @@ int main(int argc, char ** argv) { buffer.pop_back(); } + // Handle special state save/load commands + if (buffer.rfind("/\\/save ", 0) == 0) { + // Extract filename + std::string filename = buffer.substr(8); // Skip "/\/save " + // Trim whitespace + filename.erase(0, filename.find_first_not_of(" \t\n\r\f\v")); + filename.erase(filename.find_last_not_of(" \t\n\r\f\v") + 1); + + if (!filename.empty()) { + LOG("\n"); + LOG("LLM state will be saved to: %s\n", filename.c_str()); + LOG("State will be saved after your next prompt and response.\n"); + + g_state_save_path = filename; + g_save_state_next = true; + } else { + LOG_ERR("Error: No filename specified for /\\/save command\n"); + } + // Keep is_interacting true and continue to wait for next input + is_interacting = true; + continue; + } else if (buffer.rfind("/\\/load ", 0) == 0) { + // Extract filename + std::string filename = buffer.substr(8); // Skip "/\/load " + // Trim whitespace + filename.erase(0, filename.find_first_not_of(" \t\n\r\f\v")); + filename.erase(filename.find_last_not_of(" \t\n\r\f\v") + 1); + + if (!filename.empty()) { + LOG("\n"); + if (!load_llm_state_from_gguf(ctx, filename)) { + LOG_ERR("Failed to load LLM state from: %s\n", filename.c_str()); + } + } else { + LOG_ERR("Error: No filename specified for /\\/load command\n"); + } + // Keep is_interacting true and continue to wait for next input + is_interacting = true; + continue; + } else if (buffer.rfind("/\\/temp", 0) == 0) { + // Handle temperature get/set command + std::string temp_arg = buffer.substr(7); // Skip "/\/temp" + // Trim whitespace + temp_arg.erase(0, temp_arg.find_first_not_of(" \t\n\r\f\v")); + temp_arg.erase(temp_arg.find_last_not_of(" \t\n\r\f\v") + 1); + + if (temp_arg.empty()) { + // Show current temperature + LOG("\n"); + LOG("Current temperature: %.2f\n", common_sampler_get_temp(smpl)); + } else { + // Set new temperature + try { + float new_temp = std::stof(temp_arg); + if (new_temp < 0.0f) { + LOG_ERR("Error: Temperature must be >= 0.0\n"); + } else { + LOG("\n"); + float old_temp = common_sampler_get_temp(smpl); + LOG("Changing temperature from %.2f to %.2f\n", old_temp, new_temp); + if (common_sampler_set_temp(smpl, new_temp)) { + LOG("Temperature successfully updated to %.2f\n", new_temp); + } else { + LOG_ERR("Failed to update temperature\n"); + } + } + } catch (const std::exception & e) { + LOG_ERR("Error: Invalid temperature value '%s'\n", temp_arg.c_str()); + } + } + // Keep is_interacting true and continue to wait for next input + is_interacting = true; + continue; + } else if (buffer.rfind("/\\/timeout", 0) == 0) { + // Handle idle timeout get/set command + std::string timeout_arg = buffer.substr(10); // Skip "/\/timeout" + // Trim whitespace + timeout_arg.erase(0, timeout_arg.find_first_not_of(" \t\n\r\f\v")); + timeout_arg.erase(timeout_arg.find_last_not_of(" \t\n\r\f\v") + 1); + + if (timeout_arg.empty()) { + // Show current timeout or disable it + LOG("\n"); + if (params.idle_action_interval > 0) { + LOG("Current idle timeout: %d minutes\n", params.idle_action_interval); + LOG("Disabling idle timeout\n"); + params.idle_action_interval = 0; + } else { + LOG("Idle timeout is currently disabled (0 minutes)\n"); + } + } else { + // Set new timeout + try { + int new_timeout = std::stoi(timeout_arg); + if (new_timeout < 0) { + LOG_ERR("Error: Timeout must be >= 0\n"); + } else { + LOG("\n"); + int old_timeout = params.idle_action_interval; + LOG("Changing idle timeout from %d to %d minutes\n", old_timeout, new_timeout); + params.idle_action_interval = new_timeout; + if (new_timeout == 0) { + LOG("Idle timeout disabled\n"); + } else { + LOG("Idle timeout set to %d minutes\n", new_timeout); + // Reset timer to start counting from now + update_activity_time(); + } + } + } catch (const std::exception & e) { + LOG_ERR("Error: Invalid timeout value '%s'\n", timeout_arg.c_str()); + } + } + // Keep is_interacting true and continue to wait for next input + is_interacting = true; + continue; + } + if (buffer.empty()) { // Enter key on empty line lets the user pass control back LOG_DBG("empty line, passing control back\n"); } else { // Add tokens to embd only if the input buffer is non-empty @@ -938,6 +1536,9 @@ int main(int argc, char ** argv) { // reset assistant message assistant_ss.str(""); + // Reset tool execution tracking on new user input + g_last_executed_tool_signature = ""; + n_remain -= line_inp.size(); LOG_DBG("n_remain: %d\n", n_remain); } @@ -979,6 +1580,14 @@ int main(int argc, char ** argv) { } LOG("\n\n"); + + // Save LLM state if dumping was enabled via CLI flag + if (!params.path_dump_activations.empty()) { + if (!save_llm_state_to_gguf(ctx, params.path_dump_activations)) { + LOG_ERR("%s: failed to save LLM state\n", __func__); + } + } + common_perf_print(ctx, smpl); common_sampler_free(smpl); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 164e8cf4e7084..6190d8640952c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -23,8 +23,11 @@ #include #include #include +#include #include #include +#include +#include #include #include #include @@ -2314,9 +2317,45 @@ struct server_response { } }; +// Activation capture entry for streaming to disk +struct activation_entry { + uint64_t timestamp_us; // microseconds since epoch + std::string label; // tensor name like "blk.0.attn_q" + enum ggml_type type; // tensor data type + int64_t ne[4]; // tensor dimensions + std::vector data; // tensor data (copied from GPU if needed) +}; + +// Activation capture system for streaming intermediate tensors +struct activation_capture { + std::atomic active{false}; + std::string output_file; + std::vector filters; // regex patterns for tensor names + int layer_start = -1; // -1 = all layers + int layer_end = -1; + size_t max_size_bytes = 0; // 0 = unlimited + + std::mutex queue_mutex; + std::condition_variable queue_cv; + std::queue entry_queue; + std::thread writer_thread; + std::atomic should_stop{false}; + std::atomic bytes_written{0}; + std::atomic entries_captured{0}; + + // Callback user data + llama_context * ctx = nullptr; +}; + +// Global pointer for activation capture (accessed by callback) +static activation_capture * g_activation_capture = nullptr; + struct server_context { common_params params_base; + // Activation capture system + std::unique_ptr act_capture; + // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result llama_init; common_init_result llama_init_dft; @@ -2384,6 +2423,10 @@ struct server_context { params_base = params; + // Set up activation capture callback (inactive until explicitly started) + params_base.cb_eval = activation_capture_callback; + params_base.cb_eval_user_data = nullptr; + llama_init = common_init_from_params(params_base); model = llama_init.model.get(); @@ -2596,6 +2639,353 @@ struct server_context { /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, }; + + // Auto-load KV cache if requested + if (!params_base.kv_cache_auto_load.empty()) { + auto_load_kv_cache(); + } + } + + // Save KV cache to specified directory (or generate timestamped name if empty) + // Returns the directory name used, or empty string on failure + std::string save_kv_cache_to_dir(const std::string & custom_dir = "") { + std::string dir_name; + + if (custom_dir.empty()) { + // Auto-generate timestamp directory name + if (params_base.kv_cache_auto_save_base.empty()) { + SRV_ERR("%s", "no directory specified and no auto-save base configured\n"); + return ""; + } + + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + std::tm tm_time; +#ifdef _WIN32 + localtime_s(&tm_time, &time_t); +#else + localtime_r(&time_t, &tm_time); +#endif + char timestamp[64]; + std::strftime(timestamp, sizeof(timestamp), "%Y%m%d_%H%M%S", &tm_time); + + dir_name = params_base.kv_cache_auto_save_base + "_" + timestamp; + } else { + dir_name = custom_dir; + } + + SRV_INF("auto-saving KV cache to directory: %s\n", dir_name.c_str()); + + // Create directory +#ifdef _WIN32 + _mkdir(dir_name.c_str()); +#else + mkdir(dir_name.c_str(), 0755); +#endif + + // Save each slot + int saved_count = 0; + for (const server_slot & slot : slots) { + if (slot.prompt.tokens.empty()) { + continue; // Skip empty slots + } + + std::string filepath = dir_name + DIRECTORY_SEPARATOR + "slot_" + std::to_string(slot.id) + ".bin"; + + const llama_tokens & tokens = slot.prompt.tokens.get_text_tokens(); + const size_t token_count = tokens.size(); + + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot.id, tokens.data(), token_count); + + if (nwrite > 0) { + SRV_INF("saved slot %d: %zu tokens, %zu bytes to %s\n", slot.id, token_count, nwrite, filepath.c_str()); + saved_count++; + } else { + SRV_WRN("failed to save slot %d to %s\n", slot.id, filepath.c_str()); + } + } + + SRV_INF("KV cache save complete: %d slots saved to %s\n", saved_count, dir_name.c_str()); + return dir_name; + } + + // Auto-save KV cache on shutdown with timestamp (convenience wrapper) + void auto_save_kv_cache() { + if (params_base.kv_cache_auto_save_base.empty()) { + return; + } + save_kv_cache_to_dir(); // Use default timestamped directory + } + + // Auto-load KV cache on startup from specified directory + void auto_load_kv_cache() { + if (params_base.kv_cache_auto_load.empty()) { + return; + } + + std::string dir_name = params_base.kv_cache_auto_load; + + SRV_INF("auto-loading KV cache from directory: %s\n", dir_name.c_str()); + + int loaded_count = 0; + + // Try to load each slot + for (server_slot & slot : slots) { + std::string filepath = dir_name + DIRECTORY_SEPARATOR + "slot_" + std::to_string(slot.id) + ".bin"; + + // Check if file exists + std::ifstream file(filepath); + if (!file.good()) { + SRV_DBG("slot %d file not found: %s - skipping\n", slot.id, filepath.c_str()); + continue; + } + file.close(); + + llama_tokens tokens; + tokens.resize(slot.n_ctx); + size_t token_count = 0; + + const size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot.id, tokens.data(), tokens.size(), &token_count); + + if (nread > 0 && token_count > 0) { + tokens.resize(token_count); + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(tokens); + + SRV_INF("loaded slot %d: %zu tokens, %zu bytes from %s\n", slot.id, token_count, nread, filepath.c_str()); + loaded_count++; + } else { + SRV_WRN("failed to load slot %d from %s\n", slot.id, filepath.c_str()); + } + } + + SRV_INF("KV cache auto-load complete: %d slots loaded from %s\n", loaded_count, dir_name.c_str()); + } + + // Activation capture: background writer thread + static void activation_writer_thread(activation_capture * capture) { + std::ofstream file(capture->output_file, std::ios::binary); + if (!file.is_open()) { + SRV_ERR("failed to open activation file: %s\n", capture->output_file.c_str()); + return; + } + + // Write magic header + const char magic[9] = "LLMACT01"; + file.write(magic, 8); + + while (true) { + activation_entry entry; + { + std::unique_lock lock(capture->queue_mutex); + capture->queue_cv.wait(lock, [capture] { + return !capture->entry_queue.empty() || capture->should_stop.load(); + }); + + if (capture->should_stop.load() && capture->entry_queue.empty()) { + break; + } + + if (capture->entry_queue.empty()) { + continue; + } + + entry = std::move(capture->entry_queue.front()); + capture->entry_queue.pop(); + } + + // Write entry to file + // Format: timestamp(8) + label_len(4) + label + type(1) + dims(4*8) + data_size(8) + data + file.write(reinterpret_cast(&entry.timestamp_us), sizeof(uint64_t)); + + uint32_t label_len = entry.label.size(); + file.write(reinterpret_cast(&label_len), sizeof(uint32_t)); + file.write(entry.label.data(), label_len); + + int8_t type_byte = static_cast(entry.type); + file.write(reinterpret_cast(&type_byte), sizeof(int8_t)); + + file.write(reinterpret_cast(entry.ne), sizeof(entry.ne)); + + uint64_t data_size = entry.data.size(); + file.write(reinterpret_cast(&data_size), sizeof(uint64_t)); + file.write(reinterpret_cast(entry.data.data()), data_size); + + capture->bytes_written.fetch_add(sizeof(uint64_t) + sizeof(uint32_t) + label_len + + sizeof(int8_t) + sizeof(entry.ne) + sizeof(uint64_t) + data_size); + + // Check size limit + if (capture->max_size_bytes > 0 && capture->bytes_written.load() >= capture->max_size_bytes) { + SRV_INF("activation capture reached size limit: %zu bytes\n", capture->bytes_written.load()); + capture->active.store(false); + break; + } + } + + file.close(); + SRV_INF("activation writer thread finished: %zu entries, %zu bytes written to %s\n", + capture->entries_captured.load(), capture->bytes_written.load(), capture->output_file.c_str()); + } + + // Activation capture: callback for tensor evaluation + static bool activation_capture_callback(struct ggml_tensor * t, bool ask, void * user_data) { + (void)user_data; // unused + if (!ask) return true; // We only care about the "ask" phase + + activation_capture * capture = g_activation_capture; + if (!capture || !capture->active.load()) { + return true; + } + + const char * name = ggml_get_name(t); + if (!name || strlen(name) == 0) { + return true; // Skip unnamed tensors + } + + std::string tensor_name(name); + + // Apply filters + if (!capture->filters.empty()) { + bool matches = false; + for (const auto & filter : capture->filters) { + if (std::regex_match(tensor_name, filter)) { + matches = true; + break; + } + } + if (!matches) { + return true; // Doesn't match any filter + } + } + + // Apply layer range filter (extract layer number from name like "blk.5.attn_q") + if (capture->layer_start >= 0) { + std::regex layer_regex(R"(blk\.(\d+)\.)"); + std::smatch match; + if (std::regex_search(tensor_name, match, layer_regex)) { + int layer_num = std::stoi(match[1]); + if (layer_num < capture->layer_start || layer_num > capture->layer_end) { + return true; // Outside layer range + } + } + } + + // Create entry + activation_entry entry; + entry.timestamp_us = ggml_time_us(); + entry.label = tensor_name; + entry.type = t->type; + for (int i = 0; i < 4; i++) { + entry.ne[i] = t->ne[i]; + } + + // Copy tensor data (handles GPU->CPU transfer automatically) + size_t nbytes = ggml_nbytes(t); + entry.data.resize(nbytes); + ggml_backend_tensor_get(t, entry.data.data(), 0, nbytes); + + // Queue entry for writing + { + std::lock_guard lock(capture->queue_mutex); + capture->entry_queue.push(std::move(entry)); + capture->entries_captured.fetch_add(1); + } + capture->queue_cv.notify_one(); + + return true; // Continue graph evaluation + } + + // Start activation capture + bool start_activation_capture(const std::string & output_file, + const std::vector & filter_patterns, + int layer_start = -1, + int layer_end = -1, + size_t max_size_mb = 0) { + if (act_capture && act_capture->active.load()) { + SRV_WRN("%s", "activation capture already active\n"); + return false; + } + + act_capture = std::make_unique(); + act_capture->output_file = output_file; + act_capture->layer_start = layer_start; + act_capture->layer_end = layer_end; + act_capture->max_size_bytes = max_size_mb * 1024 * 1024; + act_capture->ctx = ctx; + + // Compile regex filters + for (const auto & pattern : filter_patterns) { + try { + act_capture->filters.emplace_back(pattern); + } catch (const std::regex_error & e) { + SRV_ERR("invalid regex pattern '%s': %s\n", pattern.c_str(), e.what()); + return false; + } + } + + // Start writer thread + act_capture->should_stop.store(false); + act_capture->writer_thread = std::thread(activation_writer_thread, act_capture.get()); + + // Set global pointer for callback + g_activation_capture = act_capture.get(); + + act_capture->active.store(true); + + SRV_INF("activation capture started: file=%s, filters=%zu, layers=[%d,%d], max_size=%zu MB\n", + output_file.c_str(), filter_patterns.size(), layer_start, layer_end, max_size_mb); + + return true; + } + + // Stop activation capture + json stop_activation_capture() { + if (!act_capture || !act_capture->active.load()) { + return { + {"error", "no active capture"} + }; + } + + act_capture->active.store(false); + + // Clear global pointer + g_activation_capture = nullptr; + + // Stop writer thread + act_capture->should_stop.store(true); + act_capture->queue_cv.notify_one(); + if (act_capture->writer_thread.joinable()) { + act_capture->writer_thread.join(); + } + + json result = { + {"success", true}, + {"file", act_capture->output_file}, + {"entries_captured", act_capture->entries_captured.load()}, + {"bytes_written", act_capture->bytes_written.load()}, + {"message", "Activation capture stopped"} + }; + + act_capture.reset(); + + return result; + } + + // Get activation capture status + json get_activation_capture_status() const { + if (!act_capture) { + return { + {"active", false} + }; + } + + return { + {"active", act_capture->active.load()}, + {"file", act_capture->output_file}, + {"entries_captured", act_capture->entries_captured.load()}, + {"bytes_written", act_capture->bytes_written.load()}, + {"queue_size", act_capture->entry_queue.size()} + }; } server_slot * get_slot_by_id(int id) { @@ -4873,6 +5263,34 @@ int main(int argc, char ** argv) { res_ok(res, result->to_json()); }; + const auto handle_kv_cache_save = [&ctx_server, &res_ok, &res_error](const httplib::Request & req, httplib::Response & res) { + std::string dirname; + + // Parse request body if provided + if (!req.body.empty()) { + json request_data = json::parse(req.body); + if (request_data.contains("dirname") && request_data["dirname"].is_string()) { + dirname = request_data["dirname"]; + } + } + // If dirname is empty, save_kv_cache_to_dir will generate a timestamped name + + std::string saved_dir = ctx_server.save_kv_cache_to_dir(dirname); + + if (saved_dir.empty()) { + res_error(res, format_error_response("Failed to save KV cache - check server logs", ERROR_TYPE_SERVER)); + return; + } + + json response = { + {"success", true}, + {"directory", saved_dir}, + {"message", "KV cache saved successfully"} + }; + + res_ok(res, response); + }; + const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { if (params.slot_save_path.empty()) { res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); @@ -5658,6 +6076,49 @@ int main(int argc, char ** argv) { // Save & load slots svr->Get (params.api_prefix + "/slots", handle_slots); svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action); + // Save KV cache on demand + svr->Post(params.api_prefix + "/save-kv-cache", handle_kv_cache_save); + + // Activation capture endpoints + const auto handle_activations_start = [&ctx_server, &res_ok, &res_error](const httplib::Request & req, httplib::Response & res) { + json request_data = json::parse(req.body); + + std::string output_file = request_data.value("output_file", "activations.bin"); + std::vector filters = request_data.value("filters", std::vector()); + int layer_start = request_data.value("layer_start", -1); + int layer_end = request_data.value("layer_end", -1); + size_t max_size_mb = request_data.value("max_size_mb", 0); + + bool success = ctx_server.start_activation_capture(output_file, filters, layer_start, layer_end, max_size_mb); + + if (success) { + json response = { + {"success", true}, + {"message", "Activation capture started"}, + {"output_file", output_file}, + {"filters", filters}, + {"layer_range", {layer_start, layer_end}}, + {"max_size_mb", max_size_mb} + }; + res_ok(res, response); + } else { + res_error(res, format_error_response("Failed to start activation capture", ERROR_TYPE_SERVER)); + } + }; + + const auto handle_activations_stop = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + json result = ctx_server.stop_activation_capture(); + res_ok(res, result); + }; + + const auto handle_activations_status = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + json status = ctx_server.get_activation_capture_status(); + res_ok(res, status); + }; + + svr->Post(params.api_prefix + "/activations/start", handle_activations_start); + svr->Post(params.api_prefix + "/activations/stop", handle_activations_stop); + svr->Get (params.api_prefix + "/activations/status", handle_activations_status); // // Start the server @@ -5672,6 +6133,10 @@ int main(int argc, char ** argv) { // clean up function, to be called before exit auto clean_up = [&svr, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); + + // Auto-save KV cache if enabled + ctx_server.auto_save_kv_cache(); + svr->stop(); ctx_server.queue_results.terminate(); llama_backend_free();